mautrix-signal/pkg/signalmeow/store/recipient_store.go

340 lines
11 KiB
Go

// mautrix-signal - A Matrix-signal puppeting bridge.
// Copyright (C) 2023 Scott Weber
// Copyright (C) 2024 Tulir Asokan
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <https://www.gnu.org/licenses/>.
package store
import (
"context"
"database/sql"
"errors"
"fmt"
"time"
"github.com/google/uuid"
"go.mau.fi/util/dbutil"
"go.mau.fi/mautrix-signal/pkg/libsignalgo"
"go.mau.fi/mautrix-signal/pkg/signalmeow/types"
)
type RecipientStore interface {
LoadProfileKey(ctx context.Context, theirACI uuid.UUID) (*libsignalgo.ProfileKey, error)
StoreProfileKey(ctx context.Context, theirACI uuid.UUID, key libsignalgo.ProfileKey) error
MyProfileKey(ctx context.Context) (*libsignalgo.ProfileKey, error)
LoadAndUpdateRecipient(ctx context.Context, aci, pni uuid.UUID, updater RecipientUpdaterFunc) (*types.Recipient, error)
LoadRecipientByE164(ctx context.Context, e164 string) (*types.Recipient, error)
StoreRecipient(ctx context.Context, recipient *types.Recipient) error
UpdateRecipientE164(ctx context.Context, aci, pni uuid.UUID, e164 string) (*types.Recipient, error)
LoadAllContacts(ctx context.Context) ([]*types.Recipient, error)
}
var _ RecipientStore = (*sqlStore)(nil)
const (
getAllRecipientsQuery = `
SELECT
aci_uuid,
pni_uuid,
e164_number,
contact_name,
contact_avatar_hash,
profile_key,
profile_name,
profile_about,
profile_about_emoji,
profile_avatar_path,
profile_fetched_at,
needs_pni_signature
FROM signalmeow_recipients
WHERE account_id = $1
`
getAllRecipientsWithNameOrPhoneQuery = getAllRecipientsQuery + `AND (contact_name <> '' OR profile_name <> '' OR e164_number <> '')`
getRecipientByACIQuery = getAllRecipientsQuery + `AND aci_uuid = $2`
getRecipientByPNIQuery = getAllRecipientsQuery + `AND pni_uuid = $2`
getRecipientByACIOrPNIQuery = getAllRecipientsQuery + `AND (($2<>'00000000-0000-0000-0000-000000000000' AND aci_uuid = $2) OR ($3<>'00000000-0000-0000-0000-000000000000' AND pni_uuid = $3))`
getRecipientByPhoneQuery = getAllRecipientsQuery + `AND e164_number = $2`
deleteRecipientByPNIQuery = `DELETE FROM signalmeow_recipients WHERE account_id = $1 AND pni_uuid = $2`
upsertACIRecipientQuery = `
INSERT INTO signalmeow_recipients (
account_id,
aci_uuid,
pni_uuid,
e164_number,
contact_name,
contact_avatar_hash,
profile_key,
profile_name,
profile_about,
profile_about_emoji,
profile_avatar_path,
profile_fetched_at,
needs_pni_signature
)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13)
ON CONFLICT (account_id, aci_uuid) DO UPDATE SET
pni_uuid = excluded.pni_uuid,
e164_number = excluded.e164_number,
contact_name = excluded.contact_name,
contact_avatar_hash = excluded.contact_avatar_hash,
profile_key = excluded.profile_key,
profile_name = excluded.profile_name,
profile_about = excluded.profile_about,
profile_about_emoji = excluded.profile_about_emoji,
profile_avatar_path = excluded.profile_avatar_path,
profile_fetched_at = excluded.profile_fetched_at,
needs_pni_signature = excluded.needs_pni_signature
`
upsertPNIRecipientQuery = `
INSERT INTO signalmeow_recipients (
account_id,
pni_uuid,
e164_number,
contact_name,
contact_avatar_hash
)
VALUES ($1, $2, $3, $4, $5)
ON CONFLICT (account_id, pni_uuid) DO UPDATE SET
e164_number = excluded.e164_number,
contact_name = excluded.contact_name,
contact_avatar_hash = excluded.contact_avatar_hash
`
)
func scanRecipient(row dbutil.Scannable) (*types.Recipient, error) {
var recipient types.Recipient
var aci, pni uuid.NullUUID
var profileKey []byte
var profileFetchedAt sql.NullInt64
err := row.Scan(
&aci,
&pni,
&recipient.E164,
&recipient.ContactName,
&recipient.ContactAvatar.Hash,
&profileKey,
&recipient.Profile.Name,
&recipient.Profile.About,
&recipient.Profile.AboutEmoji,
&recipient.Profile.AvatarPath,
&profileFetchedAt,
&recipient.NeedsPNISignature,
)
if errors.Is(err, sql.ErrNoRows) {
return nil, nil
} else if err != nil {
return nil, err
}
recipient.ACI = aci.UUID
recipient.PNI = pni.UUID
if profileFetchedAt.Valid {
recipient.Profile.FetchedAt = time.UnixMilli(profileFetchedAt.Int64)
}
if len(profileKey) == libsignalgo.ProfileKeyLength {
recipient.Profile.Key = libsignalgo.ProfileKey(profileKey)
}
return &recipient, err
}
func (s *sqlStore) LoadRecipientByACI(ctx context.Context, theirUUID uuid.UUID) (*types.Recipient, error) {
return scanRecipient(s.db.QueryRow(ctx, getRecipientByACIQuery, s.AccountID, theirUUID))
}
func (s *sqlStore) LoadRecipientByPNI(ctx context.Context, theirUUID uuid.UUID) (*types.Recipient, error) {
return scanRecipient(s.db.QueryRow(ctx, getRecipientByPNIQuery, s.AccountID, theirUUID))
}
type RecipientUpdaterFunc func(recipient *types.Recipient) (changed bool, err error)
func (s *sqlStore) mergeRecipients(ctx context.Context, first, second *types.Recipient, updater RecipientUpdaterFunc) (*types.Recipient, error) {
if first.ACI == uuid.Nil {
first, second = second, first
}
first.PNI = second.PNI
if second.E164 != "" {
first.E164 = second.E164
}
if first.ContactName == "" {
first.ContactName = second.ContactName
}
if first.ContactAvatar.Hash == "" {
first.ContactAvatar = second.ContactAvatar
}
_, err := updater(first)
if err != nil {
return first, fmt.Errorf("failed to run updater function: %w", err)
}
err = s.DeleteRecipientByPNI(ctx, first.PNI)
if err != nil {
return first, fmt.Errorf("failed to delete duplicate PNI row: %w", err)
}
err = s.StoreRecipient(ctx, first)
if err != nil {
return first, fmt.Errorf("failed to store merged row: %w", err)
}
return first, nil
}
func (s *sqlStore) LoadAndUpdateRecipient(ctx context.Context, aci, pni uuid.UUID, updater RecipientUpdaterFunc) (outRecipient *types.Recipient, outErr error) {
if aci == uuid.Nil && pni == uuid.Nil {
return nil, fmt.Errorf("no ACI or PNI provided in LoadAndUpdateRecipient call")
}
if updater == nil {
updater = func(recipient *types.Recipient) (bool, error) {
return false, nil
}
}
s.contactLock.Lock()
defer s.contactLock.Unlock()
outErr = s.db.DoTxn(ctx, nil, func(ctx context.Context) error {
var entries []*types.Recipient
var err error
if aci != uuid.Nil && pni != uuid.Nil {
query := getRecipientByACIOrPNIQuery
if s.db.Dialect == dbutil.Postgres {
query += " FOR UPDATE"
}
entries, err = dbutil.ConvertRowFn[*types.Recipient](scanRecipient).
NewRowIter(s.db.Query(ctx, query, s.AccountID, aci, pni)).
AsList()
} else if aci != uuid.Nil {
var entry *types.Recipient
entry, err = s.LoadRecipientByACI(ctx, aci)
if entry != nil {
entries = []*types.Recipient{entry}
}
} else if pni != uuid.Nil {
var entry *types.Recipient
entry, err = s.LoadRecipientByPNI(ctx, pni)
if entry != nil {
entries = []*types.Recipient{entry}
}
} else {
panic("impossible case")
}
if err != nil {
return err
} else if len(entries) > 2 {
return fmt.Errorf("got more than two recipient rows for ACI %s and PNI %s", aci, pni)
} else if len(entries) < 2 {
if len(entries) == 0 {
outRecipient = &types.Recipient{
ACI: aci,
PNI: pni,
}
} else {
outRecipient = entries[0]
}
changed, err := updater(outRecipient)
if err != nil {
return fmt.Errorf("failed to run updater function: %w", err)
}
// SQL only supports one ON CONFLICT clause, which means StoreRecipient will key on the ACI if it's present.
// If we're adding an ACI to a PNI row, just delete the PNI row first to avoid conflicts on the PNI key.
if outRecipient.PNI != uuid.Nil && outRecipient.ACI == uuid.Nil && aci != uuid.Nil {
err = s.DeleteRecipientByPNI(ctx, outRecipient.PNI)
if err != nil {
return fmt.Errorf("failed to delete old PNI row: %w", err)
}
}
if outRecipient.PNI == uuid.Nil && pni != uuid.Nil {
outRecipient.PNI = pni
changed = true
}
if outRecipient.ACI == uuid.Nil && aci != uuid.Nil {
outRecipient.ACI = aci
changed = true
}
if changed || len(entries) == 0 {
err = s.StoreRecipient(ctx, outRecipient)
if err != nil {
return fmt.Errorf("failed to store updated recipient row: %w", err)
}
}
return nil
} else if outRecipient, err = s.mergeRecipients(ctx, entries[0], entries[1], updater); err != nil {
return fmt.Errorf("failed to merge recipient rows for ACI %s and PNI %s: %w", aci, pni, err)
} else {
return nil
}
})
return
}
func (s *sqlStore) UpdateRecipientE164(ctx context.Context, aci, pni uuid.UUID, e164 string) (*types.Recipient, error) {
return s.LoadAndUpdateRecipient(ctx, aci, pni, func(recipient *types.Recipient) (bool, error) {
if recipient.E164 != e164 {
recipient.E164 = e164
return true, nil
}
return false, nil
})
}
func (s *sqlStore) LoadRecipientByE164(ctx context.Context, e164 string) (*types.Recipient, error) {
return scanRecipient(s.db.QueryRow(ctx, getRecipientByPhoneQuery, s.AccountID, e164))
}
func (s *sqlStore) LoadAllContacts(ctx context.Context) ([]*types.Recipient, error) {
rows, err := s.db.Query(ctx, getAllRecipientsWithNameOrPhoneQuery, s.AccountID)
return dbutil.NewRowIterWithError(rows, scanRecipient, err).AsList()
}
func (s *sqlStore) DeleteRecipientByPNI(ctx context.Context, pni uuid.UUID) error {
_, err := s.db.Exec(ctx, deleteRecipientByPNIQuery, s.AccountID, pni)
return err
}
func nullableUUID(u uuid.UUID) uuid.NullUUID {
return uuid.NullUUID{UUID: u, Valid: u != uuid.Nil}
}
func (s *sqlStore) StoreRecipient(ctx context.Context, recipient *types.Recipient) (err error) {
if recipient.ACI != uuid.Nil {
_, err = s.db.Exec(
ctx,
upsertACIRecipientQuery,
s.AccountID,
recipient.ACI,
nullableUUID(recipient.PNI),
recipient.E164,
recipient.ContactName,
recipient.ContactAvatar.Hash,
recipient.Profile.Key.Slice(),
recipient.Profile.Name,
recipient.Profile.About,
recipient.Profile.AboutEmoji,
recipient.Profile.AvatarPath,
dbutil.UnixMilliPtr(recipient.Profile.FetchedAt),
recipient.NeedsPNISignature,
)
} else if recipient.PNI != uuid.Nil {
_, err = s.db.Exec(
ctx,
upsertPNIRecipientQuery,
s.AccountID,
recipient.PNI,
recipient.E164,
recipient.ContactName,
recipient.ContactAvatar.Hash,
)
} else {
err = fmt.Errorf("no ACI or PNI provided in StoreRecipient call")
}
return
}