332 lines
12 KiB
Go
332 lines
12 KiB
Go
// mautrix-signal - A Matrix-signal puppeting bridge.
|
|
// Copyright (C) 2025 Tulir Asokan
|
|
//
|
|
// This program is free software: you can redistribute it and/or modify
|
|
// it under the terms of the GNU Affero General Public License as published by
|
|
// the Free Software Foundation, either version 3 of the License, or
|
|
// (at your option) any later version.
|
|
//
|
|
// This program is distributed in the hope that it will be useful,
|
|
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
|
// GNU Affero General Public License for more details.
|
|
//
|
|
// You should have received a copy of the GNU Affero General Public License
|
|
// along with this program. If not, see <https://www.gnu.org/licenses/>.
|
|
|
|
package store
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"encoding/base64"
|
|
"errors"
|
|
"fmt"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/google/uuid"
|
|
"github.com/rs/zerolog"
|
|
"go.mau.fi/util/dbutil"
|
|
"go.mau.fi/util/ptr"
|
|
"google.golang.org/protobuf/proto"
|
|
|
|
"go.mau.fi/mautrix-signal/pkg/libsignalgo"
|
|
"go.mau.fi/mautrix-signal/pkg/signalmeow/protobuf/backuppb"
|
|
"go.mau.fi/mautrix-signal/pkg/signalmeow/types"
|
|
)
|
|
|
|
type BackupChat struct {
|
|
*backuppb.Chat
|
|
TotalMessages int
|
|
LatestMessageID uint64
|
|
}
|
|
|
|
type BackupStore interface {
|
|
AddBackupRecipient(ctx context.Context, recipient *backuppb.Recipient) error
|
|
AddBackupChat(ctx context.Context, chat *backuppb.Chat) error
|
|
AddBackupChatItem(ctx context.Context, item *backuppb.ChatItem) error
|
|
RecalculateChatCounts(ctx context.Context) error
|
|
ClearBackup(ctx context.Context) error
|
|
|
|
GetBackupRecipient(ctx context.Context, recipientID uint64) (*backuppb.Recipient, error)
|
|
GetBackupChatByUserID(ctx context.Context, userID libsignalgo.ServiceID) (*BackupChat, error)
|
|
GetBackupChatByGroupID(ctx context.Context, groupID types.GroupIdentifier) (*BackupChat, error)
|
|
GetBackupChats(ctx context.Context) ([]*BackupChat, error)
|
|
GetBackupChatItems(ctx context.Context, chatID uint64, anchor time.Time, forward bool, limit int) ([]*backuppb.ChatItem, error)
|
|
DeleteBackupChat(ctx context.Context, chatID uint64) error
|
|
DeleteBackupChatItems(ctx context.Context, chatID uint64, minTime time.Time) error
|
|
}
|
|
|
|
var _ BackupStore = (*sqlStore)(nil)
|
|
|
|
const (
|
|
addBackupRecipientQuery = `
|
|
INSERT INTO signalmeow_backup_recipient (account_id, recipient_id, aci_uuid, pni_uuid, group_master_key, data)
|
|
VALUES ($1, $2, $3, $4, $5, $6)
|
|
`
|
|
addBackupChatQuery = `
|
|
INSERT INTO signalmeow_backup_chat (account_id, chat_id, recipient_id, data)
|
|
VALUES ($1, $2, $3, $4)
|
|
`
|
|
addBackupChatItemQuery = `
|
|
INSERT INTO signalmeow_backup_message (account_id, chat_id, sender_id, message_id, data)
|
|
VALUES ($1, $2, $3, $4, $5)
|
|
ON CONFLICT DO NOTHING
|
|
`
|
|
|
|
getBackupRecipientQuery = `
|
|
SELECT data FROM signalmeow_backup_recipient WHERE account_id=$1 AND recipient_id=$2
|
|
`
|
|
getBackupChatByACIQuery = `
|
|
SELECT chat.data, chat.latest_message_id, chat.total_message_count FROM signalmeow_backup_recipient rcp
|
|
INNER JOIN signalmeow_backup_chat chat ON rcp.account_id=chat.account_id AND rcp.recipient_id=chat.recipient_id
|
|
WHERE rcp.account_id=$1 AND rcp.aci_uuid=$2
|
|
`
|
|
getBackupChatByPNIQuery = `
|
|
SELECT chat.data, chat.latest_message_id, chat.total_message_count FROM signalmeow_backup_recipient rcp
|
|
INNER JOIN signalmeow_backup_chat chat ON rcp.account_id=chat.account_id AND rcp.recipient_id=chat.recipient_id
|
|
WHERE rcp.account_id=$1 AND rcp.pni_uuid=$2
|
|
`
|
|
getBackupChatByGroupIDQuery = `
|
|
SELECT chat.data, chat.latest_message_id, chat.total_message_count FROM signalmeow_groups g
|
|
INNER JOIN signalmeow_backup_recipient rcp ON g.account_id=rcp.account_id AND g.master_key=rcp.group_master_key
|
|
INNER JOIN signalmeow_backup_chat chat ON rcp.account_id=chat.account_id AND rcp.recipient_id=chat.recipient_id
|
|
WHERE g.account_id=$1 AND g.group_identifier=$2
|
|
`
|
|
getAllBackupChatsQuery = `
|
|
SELECT data, latest_message_id, total_message_count
|
|
FROM signalmeow_backup_chat
|
|
WHERE account_id=$1
|
|
`
|
|
getBackupChatItemsQuery = `
|
|
SELECT data
|
|
FROM signalmeow_backup_message
|
|
WHERE account_id=$1 AND chat_id=$2 AND message_id > $3 AND message_id < $4
|
|
ORDER BY message_id DESC
|
|
LIMIT $5
|
|
`
|
|
deleteBackupChatQuery = `
|
|
DELETE FROM signalmeow_backup_chat WHERE account_id=$1 AND chat_id=$2
|
|
`
|
|
deleteBackupChatItemsQuery = `
|
|
DELETE FROM signalmeow_backup_message WHERE account_id=$1 AND chat_id=$2 AND message_id >= $3
|
|
`
|
|
recalculateChatCountsQuery = `
|
|
UPDATE signalmeow_backup_chat
|
|
SET latest_message_id = (
|
|
SELECT message_id
|
|
FROM signalmeow_backup_message
|
|
WHERE account_id=signalmeow_backup_chat.account_id AND chat_id=signalmeow_backup_chat.chat_id
|
|
ORDER BY message_id DESC
|
|
LIMIT 1
|
|
),
|
|
total_message_count = (
|
|
SELECT COUNT(*)
|
|
FROM signalmeow_backup_message
|
|
WHERE account_id=signalmeow_backup_chat.account_id AND chat_id=signalmeow_backup_chat.chat_id
|
|
)
|
|
WHERE account_id=$1
|
|
`
|
|
)
|
|
|
|
func tryCastUUID(b []byte) uuid.UUID {
|
|
if len(b) == 16 {
|
|
return uuid.UUID(b)
|
|
}
|
|
return uuid.Nil
|
|
}
|
|
|
|
func (s *sqlStore) AddBackupRecipient(ctx context.Context, recipient *backuppb.Recipient) error {
|
|
recipientData, err := proto.Marshal(recipient)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to marshal recipient %d: %w", recipient.Id, err)
|
|
}
|
|
var aci, pni uuid.UUID
|
|
var groupMasterKey types.SerializedGroupMasterKey
|
|
switch dest := recipient.Destination.(type) {
|
|
case *backuppb.Recipient_Contact:
|
|
aci = tryCastUUID(dest.Contact.Aci)
|
|
pni = tryCastUUID(dest.Contact.Pni)
|
|
// TODO save identity key + trust level
|
|
if aci != uuid.Nil || pni != uuid.Nil {
|
|
_, err := s.LoadAndUpdateRecipient(ctx, aci, pni, func(recipient *types.Recipient) (changed bool, err error) {
|
|
oldRecipient := ptr.Clone(recipient)
|
|
if dest.Contact.E164 != nil {
|
|
recipient.E164 = fmt.Sprintf("+%d", *dest.Contact.E164)
|
|
}
|
|
if len(dest.Contact.ProfileKey) == libsignalgo.ProfileKeyLength {
|
|
recipient.Profile.Key = libsignalgo.ProfileKey(dest.Contact.ProfileKey)
|
|
}
|
|
if dest.Contact.ProfileGivenName != nil || dest.Contact.ProfileFamilyName != nil {
|
|
recipient.Profile.Name = strings.TrimSpace(fmt.Sprintf("%s %s", dest.Contact.GetProfileGivenName(), dest.Contact.GetProfileFamilyName()))
|
|
}
|
|
changed = oldRecipient.E164 != recipient.E164 ||
|
|
oldRecipient.Profile.Key != recipient.Profile.Key ||
|
|
oldRecipient.Profile.Name != recipient.Profile.Name
|
|
return
|
|
})
|
|
if err != nil {
|
|
return fmt.Errorf("failed to save info for recipient %d: %w", recipient.Id, err)
|
|
}
|
|
} else if dest.Contact.GetRegistered() != nil {
|
|
zerolog.Ctx(ctx).Warn().
|
|
Uint64("recipient_id", recipient.Id).
|
|
Any("entry", dest.Contact).
|
|
Msg("Both ACI and PNI are invalid for registered contact recipient")
|
|
}
|
|
case *backuppb.Recipient_Group:
|
|
groupMasterKey = types.SerializedGroupMasterKey(base64.StdEncoding.EncodeToString(dest.Group.MasterKey))
|
|
if len(dest.Group.MasterKey) == libsignalgo.GroupMasterKeyLength {
|
|
gid, err := libsignalgo.GroupMasterKey(dest.Group.MasterKey).GroupIdentifier()
|
|
if err != nil {
|
|
zerolog.Ctx(ctx).Err(err).
|
|
Uint64("recipient_id", recipient.Id).
|
|
Msg("Failed to get group identifier from master key")
|
|
} else if err = s.StoreMasterKey(ctx, types.GroupIdentifier(base64.StdEncoding.EncodeToString(gid[:])), groupMasterKey); err != nil {
|
|
return fmt.Errorf("failed to save group master key for recipient %d: %w", recipient.Id, err)
|
|
}
|
|
} else {
|
|
zerolog.Ctx(ctx).Warn().
|
|
Uint64("recipient_id", recipient.Id).
|
|
Msg("Invalid group master key length")
|
|
}
|
|
case *backuppb.Recipient_Self:
|
|
aci = s.AccountID
|
|
default:
|
|
}
|
|
_, err = s.db.Exec(ctx, addBackupRecipientQuery, s.AccountID, recipient.Id, ptr.NonZero(aci), ptr.NonZero(pni), ptr.NonZero(groupMasterKey), recipientData)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to add backup recipient %d: %w", recipient.Id, err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (s *sqlStore) AddBackupChat(ctx context.Context, chat *backuppb.Chat) error {
|
|
chatData, err := proto.Marshal(chat)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to marshal chat %d: %w", chat.Id, err)
|
|
}
|
|
_, err = s.db.Exec(ctx, addBackupChatQuery, s.AccountID, chat.Id, chat.RecipientId, chatData)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to add backup chat %d: %w", chat.Id, err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (s *sqlStore) AddBackupChatItem(ctx context.Context, item *backuppb.ChatItem) error {
|
|
itemData, err := proto.Marshal(item)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to marshal chat item %d: %w", item.DateSent, err)
|
|
}
|
|
_, err = s.db.Exec(ctx, addBackupChatItemQuery, s.AccountID, item.ChatId, item.AuthorId, item.DateSent, itemData)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to add backup chat item %d: %w", item.DateSent, err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (s *sqlStore) ClearBackup(ctx context.Context) error {
|
|
_, err := s.db.Exec(ctx, "DELETE FROM signalmeow_backup_message WHERE account_id=$1", s.AccountID)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
_, err = s.db.Exec(ctx, "DELETE FROM signalmeow_backup_chat WHERE account_id=$1", s.AccountID)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
_, err = s.db.Exec(ctx, "DELETE FROM signalmeow_backup_recipient WHERE account_id=$1", s.AccountID)
|
|
return err
|
|
}
|
|
|
|
func scanProto[T proto.Message](row dbutil.Scannable) (val T, err error) {
|
|
var data []byte
|
|
err = row.Scan(&data)
|
|
if err != nil {
|
|
if errors.Is(err, sql.ErrNoRows) {
|
|
err = nil
|
|
}
|
|
return
|
|
}
|
|
val = val.ProtoReflect().New().Interface().(T)
|
|
err = proto.Unmarshal(data, val)
|
|
return
|
|
}
|
|
|
|
func scanChat(row dbutil.Scannable) (*BackupChat, error) {
|
|
var data []byte
|
|
var latestMessageID, totalMessageCount sql.NullInt64
|
|
err := row.Scan(&data, &latestMessageID, &totalMessageCount)
|
|
if err != nil {
|
|
if errors.Is(err, sql.ErrNoRows) {
|
|
return nil, nil
|
|
}
|
|
return nil, err
|
|
}
|
|
var chat backuppb.Chat
|
|
err = proto.Unmarshal(data, &chat)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return &BackupChat{
|
|
Chat: &chat,
|
|
TotalMessages: int(totalMessageCount.Int64),
|
|
LatestMessageID: uint64(totalMessageCount.Int64),
|
|
}, nil
|
|
}
|
|
|
|
var chatScanner = dbutil.ConvertRowFn[*BackupChat](scanChat)
|
|
var messageScanner = dbutil.ConvertRowFn[*backuppb.ChatItem](scanProto[*backuppb.ChatItem])
|
|
|
|
func (s *sqlStore) GetBackupRecipient(ctx context.Context, recipientID uint64) (*backuppb.Recipient, error) {
|
|
return scanProto[*backuppb.Recipient](s.db.QueryRow(ctx, getBackupRecipientQuery, s.AccountID, recipientID))
|
|
}
|
|
|
|
func (s *sqlStore) GetBackupChatByUserID(ctx context.Context, userID libsignalgo.ServiceID) (*BackupChat, error) {
|
|
query := getBackupChatByACIQuery
|
|
if userID.Type == libsignalgo.ServiceIDTypePNI {
|
|
query = getBackupChatByPNIQuery
|
|
}
|
|
return scanChat(s.db.QueryRow(ctx, query, s.AccountID, userID.UUID))
|
|
}
|
|
|
|
func (s *sqlStore) GetBackupChatByGroupID(ctx context.Context, groupID types.GroupIdentifier) (*BackupChat, error) {
|
|
return scanChat(s.db.QueryRow(ctx, getBackupChatByGroupIDQuery, s.AccountID, groupID))
|
|
}
|
|
|
|
func (s *sqlStore) GetBackupChats(ctx context.Context) ([]*BackupChat, error) {
|
|
return chatScanner.NewRowIter(s.db.Query(ctx, getAllBackupChatsQuery, s.AccountID)).AsList()
|
|
}
|
|
|
|
func (s *sqlStore) GetBackupChatItems(ctx context.Context, chatID uint64, anchor time.Time, forward bool, limit int) ([]*backuppb.ChatItem, error) {
|
|
var minTS, maxTS int64
|
|
if anchor.IsZero() {
|
|
maxTS = time.Now().Add(24 * time.Hour).UnixMilli()
|
|
} else if forward {
|
|
minTS = anchor.UnixMilli()
|
|
maxTS = time.Now().Add(24 * time.Hour).UnixMilli()
|
|
} else {
|
|
maxTS = anchor.UnixMilli()
|
|
}
|
|
return messageScanner.NewRowIter(s.db.Query(ctx, getBackupChatItemsQuery, s.AccountID, chatID, minTS, maxTS, limit)).AsList()
|
|
}
|
|
|
|
func (s *sqlStore) DeleteBackupChatItems(ctx context.Context, chatID uint64, minTime time.Time) error {
|
|
anchorTS := minTime.UnixMilli()
|
|
if minTime.IsZero() {
|
|
anchorTS = 0
|
|
}
|
|
_, err := s.db.Exec(ctx, deleteBackupChatItemsQuery, s.AccountID, chatID, anchorTS)
|
|
return err
|
|
}
|
|
|
|
func (s *sqlStore) DeleteBackupChat(ctx context.Context, chatID uint64) error {
|
|
_, err := s.db.Exec(ctx, deleteBackupChatQuery, s.AccountID, chatID)
|
|
return err
|
|
}
|
|
|
|
func (s *sqlStore) RecalculateChatCounts(ctx context.Context) error {
|
|
_, err := s.db.Exec(ctx, recalculateChatCountsQuery, s.AccountID)
|
|
return err
|
|
}
|