joshuar-go-hass-agent/internal/preferences/prefs_test.go

405 lines
9.7 KiB
Go

// Copyright (c) 2024 Joshua Rich <joshua.rich@gmail.com>
//
// This software is released under the MIT License.
// https://opensource.org/licenses/MIT
//nolint:paralleltest
//revive:disable:unused-receiver,comment-spacings
package preferences
import (
"context"
"os"
"path/filepath"
"reflect"
"sync"
"testing"
"github.com/adrg/xdg"
"github.com/stretchr/testify/require"
)
func deviceEqual(t *testing.T, got, want *Device) bool {
t.Helper()
switch {
case !reflect.DeepEqual(got.AppName, want.AppName):
t.Error("appName does not match")
return false
case !reflect.DeepEqual(got.AppVersion, want.AppVersion):
t.Error("appVersion does not match")
return false
case !reflect.DeepEqual(got.OsName, want.OsName):
t.Error("distro does not match")
return false
case !reflect.DeepEqual(got.OsVersion, want.OsVersion):
t.Errorf("distroVersion does not match: got %s want %s", got.OsVersion, want.OsVersion)
return false
case !reflect.DeepEqual(got.Name, want.Name):
t.Error("hostname does not match")
return false
case !reflect.DeepEqual(got.Model, want.Model):
t.Error("hwModel does not match")
return false
case !reflect.DeepEqual(got.Manufacturer, want.Manufacturer):
t.Error("hwVendor does not match")
return false
}
return true
}
func preferencesEqual(t *testing.T, got, want *Preferences) bool {
t.Helper()
switch {
case !deviceEqual(t, got.Device, want.Device):
t.Error("device does not match")
return false
case !reflect.DeepEqual(got.Hass, want.Hass):
t.Error("hass preferences do not match")
return false
case !reflect.DeepEqual(got.Registration, want.Registration):
t.Error("registration preferences do not match")
return false
case !reflect.DeepEqual(got.MQTT, want.MQTT):
t.Errorf("mqtt preferences do not match")
return false
case !reflect.DeepEqual(got.Version, want.Version):
t.Error("version does not match")
return false
case !reflect.DeepEqual(got.Registered, want.Registered):
t.Error("registered does not match")
return false
case !reflect.DeepEqual(got.file, want.file):
t.Error("file does not match")
return false
}
return true
}
func TestPreferences_Validate(t *testing.T) {
validPrefs := DefaultPreferences(filepath.Join(t.TempDir(), preferencesFile))
type fields struct {
mu *sync.Mutex
MQTT *MQTT
Registration *Registration
Hass *Hass
Device *Device
Version string
file string
Registered bool
}
tests := []struct {
name string
fields fields
wantErr bool
}{
{
name: "valid",
fields: fields{
MQTT: validPrefs.MQTT,
Registration: validPrefs.Registration,
Hass: validPrefs.Hass,
Device: validPrefs.Device,
Version: AppVersion,
},
},
{
name: "required field missing",
fields: fields{
MQTT: validPrefs.MQTT,
Registration: validPrefs.Registration,
Hass: validPrefs.Hass,
Device: validPrefs.Device,
},
wantErr: true,
},
{
name: "invalid field value",
fields: fields{
MQTT: validPrefs.MQTT,
Registration: &Registration{Server: "notaurl", Token: "somestring"},
Hass: validPrefs.Hass,
Device: validPrefs.Device,
Version: AppVersion,
},
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
p := &Preferences{
mu: tt.fields.mu,
MQTT: tt.fields.MQTT,
Registration: tt.fields.Registration,
Hass: tt.fields.Hass,
Device: tt.fields.Device,
Version: tt.fields.Version,
Registered: tt.fields.Registered,
file: tt.fields.file,
}
if err := p.Validate(); (err != nil) != tt.wantErr {
t.Errorf("Preferences.Validate() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}
func TestPreferences_Save(t *testing.T) {
validPrefs := DefaultPreferences(filepath.Join(t.TempDir(), preferencesFile))
type fields struct {
mu *sync.Mutex
MQTT *MQTT
Registration *Registration
Hass *Hass
Device *Device
Version string
file string
Registered bool
}
tests := []struct {
name string
fields fields
wantErr bool
}{
{
name: "valid preferences",
fields: fields{
MQTT: validPrefs.MQTT,
Registration: validPrefs.Registration,
Hass: validPrefs.Hass,
Device: validPrefs.Device,
Version: AppVersion,
file: validPrefs.file,
},
},
{
name: "invalid preferences",
wantErr: true,
},
{
name: "unwriteable preferences path",
fields: fields{
MQTT: validPrefs.MQTT,
Registration: validPrefs.Registration,
Hass: validPrefs.Hass,
Device: validPrefs.Device,
Version: AppVersion,
file: "/",
},
wantErr: true,
},
{
name: "missing preferences path",
fields: fields{
MQTT: validPrefs.MQTT,
Registration: validPrefs.Registration,
Hass: validPrefs.Hass,
Device: validPrefs.Device,
Version: AppVersion,
},
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
p := &Preferences{
mu: tt.fields.mu,
MQTT: tt.fields.MQTT,
Registration: tt.fields.Registration,
Hass: tt.fields.Hass,
Device: tt.fields.Device,
Version: tt.fields.Version,
Registered: tt.fields.Registered,
file: tt.fields.file,
}
if err := p.Save(); (err != nil) != tt.wantErr {
t.Errorf("Preferences.Save() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}
func TestPreferences_GetMQTTPreferences(t *testing.T) {
validPrefs := DefaultPreferences(filepath.Join(t.TempDir(), preferencesFile))
type fields struct {
mu *sync.Mutex
MQTT *MQTT
Registration *Registration
Hass *Hass
Device *Device
Version string
file string
Registered bool
}
tests := []struct {
want *MQTT
name string
fields fields
}{
{
name: "valid MQTT prefs",
fields: fields{
MQTT: validPrefs.MQTT,
},
want: validPrefs.MQTT,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
p := &Preferences{
mu: tt.fields.mu,
MQTT: tt.fields.MQTT,
Registration: tt.fields.Registration,
Hass: tt.fields.Hass,
Device: tt.fields.Device,
Version: tt.fields.Version,
Registered: tt.fields.Registered,
file: tt.fields.file,
}
if got := p.GetMQTTPreferences(); !reflect.DeepEqual(got, tt.want) {
t.Errorf("Preferences.GetMQTTPreferences() = %v, want %v", got, tt.want)
}
})
}
}
func TestLoad(t *testing.T) {
appID := "go-hass-agent-test"
ctx := AppIDToContext(context.TODO(), appID)
noExistingPrefsPath := t.TempDir()
noExistingPrefs := DefaultPreferences(filepath.Join(noExistingPrefsPath, appID, preferencesFile))
invalidFilePath := t.TempDir()
err := checkPath(filepath.Join(invalidFilePath, appID))
require.NoError(t, err)
err = os.WriteFile(filepath.Join(invalidFilePath, appID, preferencesFile), []byte(`invalid`), 0o600)
require.NoError(t, err)
existingFilePath := t.TempDir()
existingPrefs := DefaultPreferences(filepath.Join(existingFilePath, appID, preferencesFile))
err = existingPrefs.Save()
require.NoError(t, err)
type args struct {
path string
}
tests := []struct {
wantErrType error
want *Preferences
name string
args args
wantErr bool
}{
{
name: "new file",
args: args{path: noExistingPrefsPath},
want: noExistingPrefs,
wantErr: true,
wantErrType: ErrNoPreferences,
},
{
name: "invalid file",
args: args{path: invalidFilePath},
wantErr: true,
wantErrType: ErrFileContents,
},
{
name: "existing file",
args: args{path: existingFilePath},
want: existingPrefs,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
xdg.ConfigHome = tt.args.path
got, err := Load(ctx)
if (err != nil) != tt.wantErr {
t.Errorf("Load() error = %v, wantErr %v", err, tt.wantErr)
return
}
if tt.wantErr {
require.ErrorIs(t, err, tt.wantErrType)
return
}
if !preferencesEqual(t, got, tt.want) {
t.Errorf("Load() = %v, want %v", got, tt.want)
}
})
}
}
func TestReset(t *testing.T) {
appID := "go-hass-agent-test"
ctx := AppIDToContext(context.TODO(), appID)
existingFilePath := t.TempDir()
existingPrefs := DefaultPreferences(filepath.Join(existingFilePath, appID, preferencesFile))
err := existingPrefs.Save()
require.NoError(t, err)
type args struct {
path string
}
tests := []struct {
name string
args args
wantErr bool
}{
{
name: "valid path",
args: args{path: existingFilePath},
},
{
name: "invalid path",
args: args{path: filepath.Join(t.TempDir(), "nonexistent")},
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
xdg.ConfigHome = tt.args.path
if err := Reset(ctx); (err != nil) != tt.wantErr {
t.Errorf("Reset() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}
func Test_checkPath(t *testing.T) {
type args struct {
path string
}
tests := []struct {
name string
args args
wantErr bool
}{
{
name: "exists",
args: args{path: t.TempDir()},
},
{
name: "does not exist",
args: args{path: filepath.Join(t.TempDir(), "notexists")},
},
{
name: "unwriteable",
args: args{path: "/notexists"},
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if err := checkPath(tt.args.path); (err != nil) != tt.wantErr {
t.Errorf("checkPath() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}