mautrix-signal/pkg/signalmeow/store/container.go

213 lines
6.9 KiB
Go

package store
import (
"context"
"database/sql"
"errors"
"fmt"
"github.com/google/uuid"
"github.com/rs/zerolog"
"go.mau.fi/util/dbutil"
"google.golang.org/protobuf/proto"
"go.mau.fi/mautrix-signal/pkg/libsignalgo"
signalpb "go.mau.fi/mautrix-signal/pkg/signalmeow/protobuf"
"go.mau.fi/mautrix-signal/pkg/signalmeow/store/upgrades"
)
var _ DeviceStore = (*Container)(nil)
type DeviceStore interface {
PutDevice(ctx context.Context, dd *DeviceData) error
DeviceByACI(ctx context.Context, aci uuid.UUID) (*Device, error)
DeviceByPNI(ctx context.Context, pni uuid.UUID) (*Device, error)
}
// Container is a wrapper for a SQL database that can contain multiple signalmeow sessions.
type Container struct {
db *dbutil.Database
}
func NewStore(db *dbutil.Database, log dbutil.DatabaseLogger) *Container {
return &Container{db: db.Child("signalmeow_version", upgrades.Table, log)}
}
const getAllDevicesQuery = `
SELECT
aci_uuid, aci_identity_key_pair, registration_id,
pni_uuid, pni_identity_key_pair, pni_registration_id,
device_id, number, password, master_key, account_record
FROM signalmeow_device
`
const getDeviceQuery = getAllDevicesQuery + " WHERE aci_uuid=$1"
const deviceByPNIQuery = getAllDevicesQuery + "WHERE pni_uuid=$1"
func (c *Container) Upgrade(ctx context.Context) error {
return c.db.Upgrade(ctx)
}
func (c *Container) scanDevice(row dbutil.Scannable) (*Device, error) {
var device Device
var aciIdentityKeyPair, pniIdentityKeyPair, accountRecordBytes []byte
err := row.Scan(
&device.ACI, &aciIdentityKeyPair, &device.ACIRegistrationID,
&device.PNI, &pniIdentityKeyPair, &device.PNIRegistrationID,
&device.DeviceID, &device.Number, &device.Password, &device.MasterKey,
&accountRecordBytes,
)
if err != nil {
return nil, fmt.Errorf("failed to scan session: %w", err)
}
device.ACIIdentityKeyPair, err = libsignalgo.DeserializeIdentityKeyPair(aciIdentityKeyPair)
if err != nil {
return nil, fmt.Errorf("failed to deserialize ACI identity key pair: %w", err)
}
device.PNIIdentityKeyPair, err = libsignalgo.DeserializeIdentityKeyPair(pniIdentityKeyPair)
if err != nil {
return nil, fmt.Errorf("failed to deserialize PNI identity key pair: %w", err)
}
if len(device.MasterKey) == 0 {
device.MasterKey = nil
}
if len(accountRecordBytes) > 0 {
device.AccountRecord = &signalpb.AccountRecord{}
err = proto.Unmarshal(accountRecordBytes, device.AccountRecord)
if err != nil {
return nil, fmt.Errorf("failed to unmarshal account record: %w", err)
}
}
baseStore := &sqlStore{Container: c, AccountID: device.ACI}
aciStore := &scopedSQLStore{Container: c, AccountID: device.ACI, ServiceID: device.ACIServiceID()}
pniStore := &scopedSQLStore{Container: c, AccountID: device.ACI, ServiceID: device.PNIServiceID()}
device.ACIPreKeyStore = aciStore
device.PNIPreKeyStore = pniStore
device.ACISessionStore = aciStore
device.PNISessionStore = pniStore
device.ACIIdentityStore = &sqlIdentityStore{
sqlStore: baseStore,
OwnKeyPair: device.ACIIdentityKeyPair,
LocalRegistrationID: uint32(device.ACIRegistrationID),
}
device.PNIIdentityStore = &sqlIdentityStore{
sqlStore: baseStore,
OwnKeyPair: device.PNIIdentityKeyPair,
LocalRegistrationID: uint32(device.PNIRegistrationID),
}
device.IdentityKeyStore = baseStore
device.SenderKeyStore = baseStore
device.GroupStore = baseStore
device.RecipientStore = baseStore
device.DeviceStore = baseStore
return &device, nil
}
// GetAllDevices finds all the devices in the database.
func (c *Container) GetAllDevices(ctx context.Context) ([]*Device, error) {
rows, err := c.db.Query(ctx, getAllDevicesQuery)
if err != nil {
return nil, fmt.Errorf("failed to query sessions: %w", err)
}
defer rows.Close()
sessions := make([]*Device, 0)
for rows.Next() {
sess, scanErr := c.scanDevice(rows)
if scanErr != nil {
return sessions, scanErr
}
sessions = append(sessions, sess)
}
return sessions, nil
}
// GetDevice finds the device with the specified ACI UUID in the database.
// If the device is not found, nil is returned instead.
func (c *Container) DeviceByACI(ctx context.Context, aci uuid.UUID) (*Device, error) {
sess, err := c.scanDevice(c.db.QueryRow(ctx, getDeviceQuery, aci))
if errors.Is(err, sql.ErrNoRows) {
return nil, nil
}
return sess, err
}
func (c *Container) DeviceByPNI(ctx context.Context, pni uuid.UUID) (*Device, error) {
sess, err := c.scanDevice(c.db.QueryRow(ctx, deviceByPNIQuery, pni))
if errors.Is(err, sql.ErrNoRows) {
return nil, nil
}
return sess, err
}
const (
insertDeviceQuery = `
INSERT INTO signalmeow_device (
aci_uuid, aci_identity_key_pair, registration_id,
pni_uuid, pni_identity_key_pair, pni_registration_id,
device_id, number, password, master_key, account_record
)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11)
ON CONFLICT (aci_uuid) DO UPDATE SET
aci_identity_key_pair=excluded.aci_identity_key_pair,
registration_id=excluded.registration_id,
pni_uuid=excluded.pni_uuid,
pni_identity_key_pair=excluded.pni_identity_key_pair,
pni_registration_id=excluded.pni_registration_id,
device_id=excluded.device_id,
number=excluded.number,
password=excluded.password,
master_key=excluded.master_key,
account_record=excluded.account_record
`
deleteDeviceQuery = `DELETE FROM signalmeow_device WHERE aci_uuid=$1`
)
// ErrDeviceIDMustBeSet is the error returned by PutDevice if you try to save a device before knowing its ACI UUID.
var ErrDeviceIDMustBeSet = errors.New("device aci_uuid must be known before accessing database")
// PutDevice stores the given device in this database.
func (c *Container) PutDevice(ctx context.Context, device *DeviceData) error {
if device.ACI == uuid.Nil {
return ErrDeviceIDMustBeSet
}
aciIdentityKeyPair, err := device.ACIIdentityKeyPair.Serialize()
if err != nil {
zerolog.Ctx(ctx).Err(err).Msg("failed to serialize aci identity key pair")
return err
}
pniIdentityKeyPair, err := device.PNIIdentityKeyPair.Serialize()
if err != nil {
zerolog.Ctx(ctx).Err(err).Msg("failed to serialize pni identity key pair")
return err
}
var accountRecordBytes []byte
if device.AccountRecord != nil {
accountRecordBytes, err = proto.Marshal(device.AccountRecord)
if err != nil {
return fmt.Errorf("failed to marshal account record: %w", err)
}
}
_, err = c.db.Exec(ctx, insertDeviceQuery,
device.ACI, aciIdentityKeyPair, device.ACIRegistrationID,
device.PNI, pniIdentityKeyPair, device.PNIRegistrationID,
device.DeviceID, device.Number, device.Password, device.MasterKey,
accountRecordBytes,
)
if err != nil {
zerolog.Ctx(ctx).Err(err).Msg("failed to insert device")
}
return err
}
// DeleteDevice deletes the given device from this database
func (c *Container) DeleteDevice(ctx context.Context, device *DeviceData) error {
if device.ACI == uuid.Nil {
return ErrDeviceIDMustBeSet
}
_, err := c.db.Exec(ctx, deleteDeviceQuery, device.ACI)
return err
}