mautrix-signal/pkg/signalmeow/storageservice.go

306 lines
11 KiB
Go

// mautrix-signal - A Matrix-signal puppeting bridge.
// Copyright (C) 2024 Tulir Asokan
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <https://www.gnu.org/licenses/>.
package signalmeow
import (
"context"
"crypto/hmac"
"crypto/sha256"
"encoding/base64"
"fmt"
"io"
"net/http"
"strings"
"github.com/google/uuid"
"github.com/rs/zerolog"
"go.mau.fi/util/exerrors"
"golang.org/x/exp/maps"
"golang.org/x/exp/slices"
"google.golang.org/protobuf/proto"
"go.mau.fi/mautrix-signal/pkg/libsignalgo"
signalpb "go.mau.fi/mautrix-signal/pkg/signalmeow/protobuf"
"go.mau.fi/mautrix-signal/pkg/signalmeow/types"
"go.mau.fi/mautrix-signal/pkg/signalmeow/web"
)
func (cli *Client) SyncStorage(ctx context.Context) {
log := zerolog.Ctx(ctx).With().Str("action", "sync storage").Logger()
// TODO only fetch changed entries
update, err := cli.FetchStorage(ctx, cli.Store.MasterKey, 0, nil)
if err != nil {
log.Err(err).Msg("Failed to fetch storage")
return
}
for _, record := range update.NewRecords {
switch data := record.StorageRecord.GetRecord().(type) {
case *signalpb.StorageRecord_Contact:
log.Trace().Any("contact_record", data.Contact).Msg("Handling contact record")
aci, _ := uuid.Parse(data.Contact.Aci)
pni, _ := uuid.Parse(data.Contact.Pni)
if aci == uuid.Nil && pni == uuid.Nil {
log.Warn().
Str("raw_aci", data.Contact.Aci).
Str("raw_pni", data.Contact.Pni).
Str("raw_e164", data.Contact.E164).
Msg("Storage service has contact record with no ACI or PNI")
continue
}
contact := data.Contact
_, err = cli.Store.RecipientStore.LoadAndUpdateRecipient(ctx, aci, pni, func(recipient *types.Recipient) (changed bool, err error) {
if len(contact.ProfileKey) == libsignalgo.ProfileKeyLength {
newProfileKey := libsignalgo.ProfileKey(contact.ProfileKey)
changed = changed || recipient.Profile.Key != newProfileKey
recipient.Profile.Key = newProfileKey
}
if recipient.Profile.Name == "" && (contact.GivenName != "" || contact.FamilyName != "") {
changed = true
recipient.Profile.Name = strings.TrimSpace(fmt.Sprintf("%s %s", contact.GivenName, contact.FamilyName))
}
if contact.SystemGivenName != "" || contact.SystemFamilyName != "" {
changed = true
recipient.ContactName = strings.TrimSpace(fmt.Sprintf("%s %s", contact.SystemGivenName, contact.SystemFamilyName))
}
if contact.E164 != "" {
changed = changed || recipient.E164 != contact.E164
recipient.E164 = contact.E164
}
return
})
if err != nil {
log.Err(err).
Stringer("aci", aci).
Stringer("pni", pni).
Msg("Failed to update contact")
}
case *signalpb.StorageRecord_GroupV2:
if len(data.GroupV2.MasterKey) != libsignalgo.GroupMasterKeyLength {
log.Warn().Msg("Invalid group master key length")
continue
}
masterKey := libsignalgo.GroupMasterKey(data.GroupV2.MasterKey)
groupID, err := cli.StoreMasterKey(ctx, masterKeyFromBytes(masterKey))
if err != nil {
log.Err(err).Msg("Failed to store group master key from storage service")
} else {
log.Debug().Stringer("group_id", groupID).Msg("Stored group master key from storage service")
}
case *signalpb.StorageRecord_Account:
log.Trace().Any("account_record", data.Account).Msg("Found account record")
// There's probably some useful data here
case *signalpb.StorageRecord_GroupV1, *signalpb.StorageRecord_StoryDistributionList:
// irrelevant data
default:
log.Warn().Str("type", fmt.Sprintf("%T", data)).Msg("Unknown storage record type")
}
}
}
type StorageUpdate struct {
Version uint64
NewRecords []*DecryptedStorageRecord
RemovedRecords []string
MissingRecords []string
}
func (cli *Client) FetchStorage(ctx context.Context, masterKey []byte, currentVersion uint64, existingKeys []string) (*StorageUpdate, error) {
storageKey := deriveStorageServiceKey(masterKey)
manifest, err := cli.fetchStorageManifest(ctx, storageKey, currentVersion)
if err != nil {
return nil, err
} else if manifest == nil {
return nil, nil
}
removedKeys := make([]string, 0)
newKeys := manifestRecordToMap(manifest.GetIdentifiers())
slices.Sort(existingKeys)
existingKeys = slices.Compact(existingKeys)
for _, key := range existingKeys {
_, isStillThere := newKeys[key]
if isStillThere {
delete(newKeys, key)
} else {
removedKeys = append(removedKeys, key)
}
delete(newKeys, key)
}
newRecords, missingKeys, err := cli.fetchStorageRecords(ctx, storageKey, newKeys)
if err != nil {
return nil, err
}
return &StorageUpdate{
Version: manifest.GetVersion(),
NewRecords: newRecords,
RemovedRecords: removedKeys,
MissingRecords: missingKeys,
}, nil
}
func manifestRecordToMap(manifest []*signalpb.ManifestRecord_Identifier) map[string]signalpb.ManifestRecord_Identifier_Type {
manifestMap := make(map[string]signalpb.ManifestRecord_Identifier_Type, len(manifest))
for _, item := range manifest {
manifestMap[base64.StdEncoding.EncodeToString(item.GetRaw())] = item.GetType()
}
return manifestMap
}
func deriveStorageServiceKey(masterKey []byte) []byte {
h := hmac.New(sha256.New, masterKey)
h.Write([]byte("Storage Service Encryption"))
return h.Sum(nil)
}
func deriveStorageManifestKey(storageKey []byte, version uint64) []byte {
h := hmac.New(sha256.New, storageKey)
exerrors.Must(fmt.Fprintf(h, "Manifest_%d", version))
return h.Sum(nil)
}
func deriveStorageItemKey(storageKey []byte, itemID string) []byte {
h := hmac.New(sha256.New, storageKey)
exerrors.Must(fmt.Fprintf(h, "Item_%s", itemID))
return h.Sum(nil)
}
// MaxReadStorageRecords is the maximum number of storage records to fetch at once
// from https://github.com/signalapp/Signal-Desktop/blob/v6.44.0/ts/services/storageConstants.ts
const MaxReadStorageRecords = 2500
type DecryptedStorageRecord struct {
ItemType signalpb.ManifestRecord_Identifier_Type
StorageID string
StorageRecord *signalpb.StorageRecord
}
func (cli *Client) fetchStorageManifest(ctx context.Context, storageKey []byte, greaterThanVersion uint64) (*signalpb.ManifestRecord, error) {
storageCreds, err := cli.getStorageCredentials(ctx)
if err != nil {
return nil, fmt.Errorf("failed to fetch credentials: %w", err)
}
path := "/v1/storage/manifest"
if greaterThanVersion > 0 {
path += fmt.Sprintf("/version/%d", greaterThanVersion)
}
var encryptedManifest signalpb.StorageManifest
var manifestRecord signalpb.ManifestRecord
resp, err := web.SendHTTPRequest(ctx, http.MethodGet, path, &web.HTTPReqOpt{
Username: &storageCreds.Username,
Password: &storageCreds.Password,
ContentType: web.ContentTypeProtobuf,
Host: web.StorageHostname,
})
if err != nil {
return nil, fmt.Errorf("failed to fetch storage manifest: %w", err)
} else if resp.StatusCode == http.StatusNoContent {
// Already up to date
return nil, nil
} else if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("unexpected status code %d fetching storage manifest", resp.StatusCode)
} else if body, err := io.ReadAll(resp.Body); err != nil {
return nil, fmt.Errorf("failed to read storage manifest response: %w", err)
} else if err = proto.Unmarshal(body, &encryptedManifest); err != nil {
return nil, fmt.Errorf("failed to unmarshal encrypted storage manifest: %w", err)
} else if decryptedManifestBytes, err := decryptBytes(deriveStorageManifestKey(storageKey, encryptedManifest.GetVersion()), encryptedManifest.GetValue()); err != nil {
return nil, fmt.Errorf("failed to decrypt storage manifest: %w", err)
} else if err = proto.Unmarshal(decryptedManifestBytes, &manifestRecord); err != nil {
return nil, fmt.Errorf("failed to unmarshal decrypted manifest record: %w", err)
} else {
return &manifestRecord, nil
}
}
func (cli *Client) fetchStorageRecords(ctx context.Context, storageKey []byte, inputRecords map[string]signalpb.ManifestRecord_Identifier_Type) ([]*DecryptedStorageRecord, []string, error) {
recordKeys := make([][]byte, 0, len(inputRecords))
for key := range inputRecords {
decoded, err := base64.StdEncoding.DecodeString(key)
if err != nil {
return nil, nil, fmt.Errorf("failed to decode storage key %s: %w", key, err)
}
recordKeys = append(recordKeys, decoded)
}
items := make([]*signalpb.StorageItem, 0, len(inputRecords))
for i := 0; i < len(recordKeys); i += MaxReadStorageRecords {
end := i + MaxReadStorageRecords
if len(recordKeys) < end {
end = len(recordKeys)
}
keyChunk := recordKeys[i:end]
itemChunk, err := cli.fetchStorageItemsChunk(ctx, keyChunk)
if err != nil {
return nil, nil, err
}
items = append(items, itemChunk...)
}
records := make([]*DecryptedStorageRecord, len(items))
for i, encryptedItem := range items {
base64Key := base64.StdEncoding.EncodeToString(encryptedItem.GetKey())
itemKey := deriveStorageItemKey(storageKey, base64Key)
decryptedItemBytes, err := decryptBytes(itemKey, encryptedItem.GetValue())
if err != nil {
return nil, nil, fmt.Errorf("failed to decrypt storage item #%d (%s): %w", i+1, base64Key, err)
}
var decryptedItem signalpb.StorageRecord
err = proto.Unmarshal(decryptedItemBytes, &decryptedItem)
if err != nil {
return nil, nil, fmt.Errorf("failed to unmarshal decrypted storage item #%d (%s): %w", i+1, base64Key, err)
}
itemType, ok := inputRecords[base64Key]
if !ok {
return nil, nil, fmt.Errorf("received unexpected storage item at index #%d: %s", i+1, base64Key)
}
delete(inputRecords, base64Key)
records[i] = &DecryptedStorageRecord{
ItemType: itemType,
StorageID: base64Key,
StorageRecord: &decryptedItem,
}
}
missingKeys := maps.Keys(inputRecords)
return records, missingKeys, nil
}
func (cli *Client) fetchStorageItemsChunk(ctx context.Context, recordKeys [][]byte) ([]*signalpb.StorageItem, error) {
storageCreds, err := cli.getStorageCredentials(ctx)
if err != nil {
return nil, fmt.Errorf("failed to fetch credentials: %w", err)
}
body, err := proto.Marshal(&signalpb.ReadOperation{ReadKey: recordKeys})
if err != nil {
return nil, fmt.Errorf("failed to marshal read operation: %w", err)
}
var storageItems signalpb.StorageItems
resp, err := web.SendHTTPRequest(ctx, http.MethodPut, "/v1/storage/read", &web.HTTPReqOpt{
Username: &storageCreds.Username,
Password: &storageCreds.Password,
Body: body,
ContentType: web.ContentTypeProtobuf,
Host: web.StorageHostname,
})
if err != nil {
return nil, fmt.Errorf("failed to fetch storage records: %w", err)
} else if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("unexpected status code %d fetching storage records", resp.StatusCode)
} else if body, err := io.ReadAll(resp.Body); err != nil {
return nil, fmt.Errorf("failed to read storage manifest response: %w", err)
} else if err = proto.Unmarshal(body, &storageItems); err != nil {
return nil, fmt.Errorf("failed to unmarshal encrypted storage manifest: %w", err)
} else {
return storageItems.GetItems(), nil
}
}