340 lines
11 KiB
Go
340 lines
11 KiB
Go
// mautrix-signal - A Matrix-signal puppeting bridge.
|
|
// Copyright (C) 2023 Scott Weber
|
|
// Copyright (C) 2024 Tulir Asokan
|
|
//
|
|
// This program is free software: you can redistribute it and/or modify
|
|
// it under the terms of the GNU Affero General Public License as published by
|
|
// the Free Software Foundation, either version 3 of the License, or
|
|
// (at your option) any later version.
|
|
//
|
|
// This program is distributed in the hope that it will be useful,
|
|
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
|
// GNU Affero General Public License for more details.
|
|
//
|
|
// You should have received a copy of the GNU Affero General Public License
|
|
// along with this program. If not, see <https://www.gnu.org/licenses/>.
|
|
|
|
package store
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"errors"
|
|
"fmt"
|
|
"time"
|
|
|
|
"github.com/google/uuid"
|
|
"go.mau.fi/util/dbutil"
|
|
|
|
"go.mau.fi/mautrix-signal/pkg/libsignalgo"
|
|
"go.mau.fi/mautrix-signal/pkg/signalmeow/types"
|
|
)
|
|
|
|
type RecipientStore interface {
|
|
LoadProfileKey(ctx context.Context, theirACI uuid.UUID) (*libsignalgo.ProfileKey, error)
|
|
StoreProfileKey(ctx context.Context, theirACI uuid.UUID, key libsignalgo.ProfileKey) error
|
|
MyProfileKey(ctx context.Context) (*libsignalgo.ProfileKey, error)
|
|
|
|
LoadAndUpdateRecipient(ctx context.Context, aci, pni uuid.UUID, updater RecipientUpdaterFunc) (*types.Recipient, error)
|
|
LoadRecipientByE164(ctx context.Context, e164 string) (*types.Recipient, error)
|
|
StoreRecipient(ctx context.Context, recipient *types.Recipient) error
|
|
UpdateRecipientE164(ctx context.Context, aci, pni uuid.UUID, e164 string) (*types.Recipient, error)
|
|
|
|
LoadAllContacts(ctx context.Context) ([]*types.Recipient, error)
|
|
}
|
|
|
|
var _ RecipientStore = (*sqlStore)(nil)
|
|
|
|
const (
|
|
getAllRecipientsQuery = `
|
|
SELECT
|
|
aci_uuid,
|
|
pni_uuid,
|
|
e164_number,
|
|
contact_name,
|
|
contact_avatar_hash,
|
|
profile_key,
|
|
profile_name,
|
|
profile_about,
|
|
profile_about_emoji,
|
|
profile_avatar_path,
|
|
profile_fetched_at,
|
|
needs_pni_signature
|
|
FROM signalmeow_recipients
|
|
WHERE account_id = $1
|
|
`
|
|
getAllRecipientsWithNameOrPhoneQuery = getAllRecipientsQuery + `AND (contact_name <> '' OR profile_name <> '' OR e164_number <> '')`
|
|
getRecipientByACIQuery = getAllRecipientsQuery + `AND aci_uuid = $2`
|
|
getRecipientByPNIQuery = getAllRecipientsQuery + `AND pni_uuid = $2`
|
|
getRecipientByACIOrPNIQuery = getAllRecipientsQuery + `AND (($2<>'00000000-0000-0000-0000-000000000000' AND aci_uuid = $2) OR ($3<>'00000000-0000-0000-0000-000000000000' AND pni_uuid = $3))`
|
|
getRecipientByPhoneQuery = getAllRecipientsQuery + `AND e164_number = $2`
|
|
deleteRecipientByPNIQuery = `DELETE FROM signalmeow_recipients WHERE account_id = $1 AND pni_uuid = $2`
|
|
upsertACIRecipientQuery = `
|
|
INSERT INTO signalmeow_recipients (
|
|
account_id,
|
|
aci_uuid,
|
|
pni_uuid,
|
|
e164_number,
|
|
contact_name,
|
|
contact_avatar_hash,
|
|
profile_key,
|
|
profile_name,
|
|
profile_about,
|
|
profile_about_emoji,
|
|
profile_avatar_path,
|
|
profile_fetched_at,
|
|
needs_pni_signature
|
|
)
|
|
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13)
|
|
ON CONFLICT (account_id, aci_uuid) DO UPDATE SET
|
|
pni_uuid = excluded.pni_uuid,
|
|
e164_number = excluded.e164_number,
|
|
contact_name = excluded.contact_name,
|
|
contact_avatar_hash = excluded.contact_avatar_hash,
|
|
profile_key = excluded.profile_key,
|
|
profile_name = excluded.profile_name,
|
|
profile_about = excluded.profile_about,
|
|
profile_about_emoji = excluded.profile_about_emoji,
|
|
profile_avatar_path = excluded.profile_avatar_path,
|
|
profile_fetched_at = excluded.profile_fetched_at,
|
|
needs_pni_signature = excluded.needs_pni_signature
|
|
`
|
|
upsertPNIRecipientQuery = `
|
|
INSERT INTO signalmeow_recipients (
|
|
account_id,
|
|
pni_uuid,
|
|
e164_number,
|
|
contact_name,
|
|
contact_avatar_hash
|
|
)
|
|
VALUES ($1, $2, $3, $4, $5)
|
|
ON CONFLICT (account_id, pni_uuid) DO UPDATE SET
|
|
e164_number = excluded.e164_number,
|
|
contact_name = excluded.contact_name,
|
|
contact_avatar_hash = excluded.contact_avatar_hash
|
|
`
|
|
)
|
|
|
|
func scanRecipient(row dbutil.Scannable) (*types.Recipient, error) {
|
|
var recipient types.Recipient
|
|
var aci, pni uuid.NullUUID
|
|
var profileKey []byte
|
|
var profileFetchedAt sql.NullInt64
|
|
err := row.Scan(
|
|
&aci,
|
|
&pni,
|
|
&recipient.E164,
|
|
&recipient.ContactName,
|
|
&recipient.ContactAvatar.Hash,
|
|
&profileKey,
|
|
&recipient.Profile.Name,
|
|
&recipient.Profile.About,
|
|
&recipient.Profile.AboutEmoji,
|
|
&recipient.Profile.AvatarPath,
|
|
&profileFetchedAt,
|
|
&recipient.NeedsPNISignature,
|
|
)
|
|
if errors.Is(err, sql.ErrNoRows) {
|
|
return nil, nil
|
|
} else if err != nil {
|
|
return nil, err
|
|
}
|
|
recipient.ACI = aci.UUID
|
|
recipient.PNI = pni.UUID
|
|
if profileFetchedAt.Valid {
|
|
recipient.Profile.FetchedAt = time.UnixMilli(profileFetchedAt.Int64)
|
|
}
|
|
if len(profileKey) == libsignalgo.ProfileKeyLength {
|
|
recipient.Profile.Key = libsignalgo.ProfileKey(profileKey)
|
|
}
|
|
return &recipient, err
|
|
}
|
|
|
|
func (s *sqlStore) LoadRecipientByACI(ctx context.Context, theirUUID uuid.UUID) (*types.Recipient, error) {
|
|
return scanRecipient(s.db.QueryRow(ctx, getRecipientByACIQuery, s.AccountID, theirUUID))
|
|
}
|
|
|
|
func (s *sqlStore) LoadRecipientByPNI(ctx context.Context, theirUUID uuid.UUID) (*types.Recipient, error) {
|
|
return scanRecipient(s.db.QueryRow(ctx, getRecipientByPNIQuery, s.AccountID, theirUUID))
|
|
}
|
|
|
|
type RecipientUpdaterFunc func(recipient *types.Recipient) (changed bool, err error)
|
|
|
|
func (s *sqlStore) mergeRecipients(ctx context.Context, first, second *types.Recipient, updater RecipientUpdaterFunc) (*types.Recipient, error) {
|
|
if first.ACI == uuid.Nil {
|
|
first, second = second, first
|
|
}
|
|
first.PNI = second.PNI
|
|
if second.E164 != "" {
|
|
first.E164 = second.E164
|
|
}
|
|
if first.ContactName == "" {
|
|
first.ContactName = second.ContactName
|
|
}
|
|
if first.ContactAvatar.Hash == "" {
|
|
first.ContactAvatar = second.ContactAvatar
|
|
}
|
|
_, err := updater(first)
|
|
if err != nil {
|
|
return first, fmt.Errorf("failed to run updater function: %w", err)
|
|
}
|
|
err = s.DeleteRecipientByPNI(ctx, first.PNI)
|
|
if err != nil {
|
|
return first, fmt.Errorf("failed to delete duplicate PNI row: %w", err)
|
|
}
|
|
err = s.StoreRecipient(ctx, first)
|
|
if err != nil {
|
|
return first, fmt.Errorf("failed to store merged row: %w", err)
|
|
}
|
|
return first, nil
|
|
}
|
|
|
|
func (s *sqlStore) LoadAndUpdateRecipient(ctx context.Context, aci, pni uuid.UUID, updater RecipientUpdaterFunc) (outRecipient *types.Recipient, outErr error) {
|
|
if aci == uuid.Nil && pni == uuid.Nil {
|
|
return nil, fmt.Errorf("no ACI or PNI provided in LoadAndUpdateRecipient call")
|
|
}
|
|
if updater == nil {
|
|
updater = func(recipient *types.Recipient) (bool, error) {
|
|
return false, nil
|
|
}
|
|
}
|
|
s.contactLock.Lock()
|
|
defer s.contactLock.Unlock()
|
|
outErr = s.db.DoTxn(ctx, nil, func(ctx context.Context) error {
|
|
var entries []*types.Recipient
|
|
var err error
|
|
if aci != uuid.Nil && pni != uuid.Nil {
|
|
query := getRecipientByACIOrPNIQuery
|
|
if s.db.Dialect == dbutil.Postgres {
|
|
query += " FOR UPDATE"
|
|
}
|
|
entries, err = dbutil.ConvertRowFn[*types.Recipient](scanRecipient).
|
|
NewRowIter(s.db.Query(ctx, query, s.AccountID, aci, pni)).
|
|
AsList()
|
|
} else if aci != uuid.Nil {
|
|
var entry *types.Recipient
|
|
entry, err = s.LoadRecipientByACI(ctx, aci)
|
|
if entry != nil {
|
|
entries = []*types.Recipient{entry}
|
|
}
|
|
} else if pni != uuid.Nil {
|
|
var entry *types.Recipient
|
|
entry, err = s.LoadRecipientByPNI(ctx, pni)
|
|
if entry != nil {
|
|
entries = []*types.Recipient{entry}
|
|
}
|
|
} else {
|
|
panic("impossible case")
|
|
}
|
|
if err != nil {
|
|
return err
|
|
} else if len(entries) > 2 {
|
|
return fmt.Errorf("got more than two recipient rows for ACI %s and PNI %s", aci, pni)
|
|
} else if len(entries) < 2 {
|
|
if len(entries) == 0 {
|
|
outRecipient = &types.Recipient{
|
|
ACI: aci,
|
|
PNI: pni,
|
|
}
|
|
} else {
|
|
outRecipient = entries[0]
|
|
}
|
|
changed, err := updater(outRecipient)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to run updater function: %w", err)
|
|
}
|
|
// SQL only supports one ON CONFLICT clause, which means StoreRecipient will key on the ACI if it's present.
|
|
// If we're adding an ACI to a PNI row, just delete the PNI row first to avoid conflicts on the PNI key.
|
|
if outRecipient.PNI != uuid.Nil && outRecipient.ACI == uuid.Nil && aci != uuid.Nil {
|
|
err = s.DeleteRecipientByPNI(ctx, outRecipient.PNI)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to delete old PNI row: %w", err)
|
|
}
|
|
}
|
|
if outRecipient.PNI == uuid.Nil && pni != uuid.Nil {
|
|
outRecipient.PNI = pni
|
|
changed = true
|
|
}
|
|
if outRecipient.ACI == uuid.Nil && aci != uuid.Nil {
|
|
outRecipient.ACI = aci
|
|
changed = true
|
|
}
|
|
if changed || len(entries) == 0 {
|
|
err = s.StoreRecipient(ctx, outRecipient)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to store updated recipient row: %w", err)
|
|
}
|
|
}
|
|
return nil
|
|
} else if outRecipient, err = s.mergeRecipients(ctx, entries[0], entries[1], updater); err != nil {
|
|
return fmt.Errorf("failed to merge recipient rows for ACI %s and PNI %s: %w", aci, pni, err)
|
|
} else {
|
|
return nil
|
|
}
|
|
})
|
|
return
|
|
}
|
|
|
|
func (s *sqlStore) UpdateRecipientE164(ctx context.Context, aci, pni uuid.UUID, e164 string) (*types.Recipient, error) {
|
|
return s.LoadAndUpdateRecipient(ctx, aci, pni, func(recipient *types.Recipient) (bool, error) {
|
|
if recipient.E164 != e164 {
|
|
recipient.E164 = e164
|
|
return true, nil
|
|
}
|
|
return false, nil
|
|
})
|
|
}
|
|
func (s *sqlStore) LoadRecipientByE164(ctx context.Context, e164 string) (*types.Recipient, error) {
|
|
return scanRecipient(s.db.QueryRow(ctx, getRecipientByPhoneQuery, s.AccountID, e164))
|
|
}
|
|
|
|
func (s *sqlStore) LoadAllContacts(ctx context.Context) ([]*types.Recipient, error) {
|
|
rows, err := s.db.Query(ctx, getAllRecipientsWithNameOrPhoneQuery, s.AccountID)
|
|
return dbutil.NewRowIterWithError(rows, scanRecipient, err).AsList()
|
|
}
|
|
|
|
func (s *sqlStore) DeleteRecipientByPNI(ctx context.Context, pni uuid.UUID) error {
|
|
_, err := s.db.Exec(ctx, deleteRecipientByPNIQuery, s.AccountID, pni)
|
|
return err
|
|
}
|
|
|
|
func nullableUUID(u uuid.UUID) uuid.NullUUID {
|
|
return uuid.NullUUID{UUID: u, Valid: u != uuid.Nil}
|
|
}
|
|
|
|
func (s *sqlStore) StoreRecipient(ctx context.Context, recipient *types.Recipient) (err error) {
|
|
if recipient.ACI != uuid.Nil {
|
|
_, err = s.db.Exec(
|
|
ctx,
|
|
upsertACIRecipientQuery,
|
|
s.AccountID,
|
|
recipient.ACI,
|
|
nullableUUID(recipient.PNI),
|
|
recipient.E164,
|
|
recipient.ContactName,
|
|
recipient.ContactAvatar.Hash,
|
|
recipient.Profile.Key.Slice(),
|
|
recipient.Profile.Name,
|
|
recipient.Profile.About,
|
|
recipient.Profile.AboutEmoji,
|
|
recipient.Profile.AvatarPath,
|
|
dbutil.UnixMilliPtr(recipient.Profile.FetchedAt),
|
|
recipient.NeedsPNISignature,
|
|
)
|
|
} else if recipient.PNI != uuid.Nil {
|
|
_, err = s.db.Exec(
|
|
ctx,
|
|
upsertPNIRecipientQuery,
|
|
s.AccountID,
|
|
recipient.PNI,
|
|
recipient.E164,
|
|
recipient.ContactName,
|
|
recipient.ContactAvatar.Hash,
|
|
)
|
|
} else {
|
|
err = fmt.Errorf("no ACI or PNI provided in StoreRecipient call")
|
|
}
|
|
return
|
|
}
|