mautrix-signal/pkg/signalmeow/keys.go

648 lines
21 KiB
Go

// mautrix-signal - A Matrix-signal puppeting bridge.
// Copyright (C) 2023 Scott Weber
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <https://www.gnu.org/licenses/>.
package signalmeow
import (
"context"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"math/rand"
"net/http"
"strings"
"time"
"github.com/rs/zerolog"
"go.mau.fi/mautrix-signal/pkg/libsignalgo"
"go.mau.fi/mautrix-signal/pkg/signalmeow/store"
"go.mau.fi/mautrix-signal/pkg/signalmeow/web"
)
const PREKEY_BATCH_SIZE = 100
type GeneratedPreKeys struct {
PreKeys []*libsignalgo.PreKeyRecord
KyberPreKeys []*libsignalgo.KyberPreKeyRecord
IdentityKey []uint8
}
func (cli *Client) GenerateAndRegisterPreKeys(ctx context.Context, pks store.PreKeyStore) error {
_, err := cli.GenerateAndSaveNextPreKeyBatch(ctx, pks, 0)
if err != nil {
return fmt.Errorf("failed to generate and save next prekey batch: %w", err)
}
_, err = cli.GenerateAndSaveNextKyberPreKeyBatch(ctx, pks, 0)
if err != nil {
return fmt.Errorf("failed to generate and save next kyber prekey batch: %w", err)
}
// We need to upload all currently valid prekeys, not just the ones we just generated
err = cli.RegisterAllPreKeys(ctx, pks)
if err != nil {
return fmt.Errorf("failed to register prekey batches: %w", err)
}
return err
}
func (cli *Client) RegisterAllPreKeys(ctx context.Context, pks store.PreKeyStore) error {
var identityKeyPair *libsignalgo.IdentityKeyPair
var pni bool
if pks.GetServiceID().Type == libsignalgo.ServiceIDTypePNI {
pni = true
identityKeyPair = cli.Store.PNIIdentityKeyPair
} else {
identityKeyPair = cli.Store.ACIIdentityKeyPair
}
// Get all prekeys and kyber prekeys from the database
preKeys, err := pks.AllPreKeys(ctx)
if err != nil {
return fmt.Errorf("failed to get all prekeys: %w", err)
}
kyberPreKeys, err := pks.AllNormalKyberPreKeys(ctx)
if err != nil {
return fmt.Errorf("failed to get all kyber prekeys: %w", err)
}
// We need to have some keys to upload
if len(preKeys) == 0 && len(kyberPreKeys) == 0 {
return fmt.Errorf("no prekeys to upload")
}
identityKey, err := identityKeyPair.GetPublicKey().Serialize()
if err != nil {
return fmt.Errorf("failed to serialize identity key: %w", err)
}
generatedPreKeys := GeneratedPreKeys{
PreKeys: preKeys,
KyberPreKeys: kyberPreKeys,
IdentityKey: identityKey,
}
preKeyUsername := fmt.Sprintf("%s.%d", cli.Store.ACI, cli.Store.DeviceID)
log := zerolog.Ctx(ctx).With().Str("action", "register prekeys").Logger()
log.Debug().Int("num_prekeys", len(preKeys)).Int("num_kyber_prekeys", len(kyberPreKeys)).Msg("Registering prekeys")
err = RegisterPreKeys(ctx, &generatedPreKeys, pni, preKeyUsername, cli.Store.Password)
if err != nil {
return fmt.Errorf("failed to register prekeys: %w", err)
}
return err
}
func (cli *Client) GenerateAndSaveNextPreKeyBatch(ctx context.Context, pks store.PreKeyStore, serverCount int) (bool, error) {
storeCount, nextPreKeyID, err := pks.GetNextPreKeyID(ctx)
if err != nil {
return false, fmt.Errorf("failed to get next prekey ID: %w", err)
}
if serverCount < PREKEY_BATCH_SIZE/2 {
if storeCount >= PREKEY_BATCH_SIZE {
zerolog.Ctx(ctx).Warn().
Int("server_count", serverCount).
Uint32("store_count", storeCount).
Msg("Store is full, but server is not, reuploading EC prekeys without generating more")
} else {
zerolog.Ctx(ctx).Info().
Int("server_count", serverCount).
Uint32("store_count", storeCount).
Msg("Generating and uploading EC prekeys")
}
} else if uint32(serverCount) > storeCount {
zerolog.Ctx(ctx).Warn().
Int("server_count", serverCount).
Uint32("store_count", storeCount).
Msg("Server has more EC prekeys than store, reuploading")
} else {
zerolog.Ctx(ctx).Debug().
Int("server_count", serverCount).
Uint32("store_count", storeCount).
Msg("EC prekey count is good")
return false, nil
}
if storeCount < PREKEY_BATCH_SIZE {
preKeys := GeneratePreKeys(nextPreKeyID, PREKEY_BATCH_SIZE-storeCount)
for _, preKey := range preKeys {
err = pks.StorePreKey(ctx, 0, preKey)
if err != nil {
return false, fmt.Errorf("failed to save prekey: %w", err)
}
}
}
return true, nil
}
func (cli *Client) GenerateAndSaveNextKyberPreKeyBatch(ctx context.Context, pks store.PreKeyStore, serverCount int) (bool, error) {
var identityKeyPair *libsignalgo.IdentityKeyPair
if pks.GetServiceID().Type == libsignalgo.ServiceIDTypePNI {
identityKeyPair = cli.Store.PNIIdentityKeyPair
} else {
identityKeyPair = cli.Store.ACIIdentityKeyPair
}
storeCount, nextKyberPreKeyID, err := pks.GetNextKyberPreKeyID(ctx)
if err != nil {
return false, fmt.Errorf("failed to get next kyber prekey ID: %w", err)
}
if serverCount < PREKEY_BATCH_SIZE/2 {
if storeCount >= PREKEY_BATCH_SIZE {
zerolog.Ctx(ctx).Warn().
Int("server_count", serverCount).
Uint32("store_count", storeCount).
Msg("Store is full, but server is not, reuploading kyber prekeys without generating more")
} else {
zerolog.Ctx(ctx).Info().
Int("server_count", serverCount).
Uint32("store_count", storeCount).
Msg("Generating and uploading kyber prekeys")
}
} else if uint32(serverCount) > storeCount {
zerolog.Ctx(ctx).Warn().
Int("server_count", serverCount).
Uint32("store_count", storeCount).
Msg("Server has more kyber prekeys than store, reuploading")
} else {
zerolog.Ctx(ctx).Debug().
Int("server_count", serverCount).
Uint32("store_count", storeCount).
Msg("Kyber prekey count is good")
return false, nil
}
if storeCount < PREKEY_BATCH_SIZE {
kyberPreKeys := GenerateKyberPreKeys(nextKyberPreKeyID, PREKEY_BATCH_SIZE-storeCount, identityKeyPair)
for _, kyberPreKey := range kyberPreKeys {
err = pks.StoreKyberPreKey(ctx, 0, kyberPreKey)
if err != nil {
return false, fmt.Errorf("failed to save kyber prekey: %w", err)
}
}
}
return true, nil
}
func GeneratePreKeys(startKeyID uint32, count uint32) []*libsignalgo.PreKeyRecord {
if count > PREKEY_BATCH_SIZE {
panic("count must be less than or equal to PREKEY_BATCH_SIZE")
}
generatedPreKeys := make([]*libsignalgo.PreKeyRecord, 0, count)
for keyID := startKeyID; keyID < startKeyID+count; keyID++ {
privateKey, err := libsignalgo.GeneratePrivateKey()
if err != nil {
panic(fmt.Errorf("error generating private key: %w", err))
}
preKey, err := libsignalgo.NewPreKeyRecordFromPrivateKey(keyID, privateKey)
if err != nil {
panic(fmt.Errorf("error creating prekey record: %w", err))
}
generatedPreKeys = append(generatedPreKeys, preKey)
}
return generatedPreKeys
}
func GenerateKyberPreKeys(startKeyID uint32, count uint32, identityKeyPair *libsignalgo.IdentityKeyPair) []*libsignalgo.KyberPreKeyRecord {
if count > PREKEY_BATCH_SIZE {
panic("count must be less than or equal to PREKEY_BATCH_SIZE")
}
generatedKyberPreKeys := make([]*libsignalgo.KyberPreKeyRecord, 0, count)
for keyID := startKeyID; keyID < startKeyID+count; keyID++ {
kyberPreKeyPair, err := libsignalgo.KyberKeyPairGenerate()
if err != nil {
panic(fmt.Errorf("error generating kyber key pair: %w", err))
}
publicKey, err := kyberPreKeyPair.GetPublicKey()
if err != nil {
panic(fmt.Errorf("error getting kyber public key: %w", err))
}
serializedPublicKey, err := publicKey.Serialize()
if err != nil {
panic(fmt.Errorf("error serializing kyber public key: %w", err))
}
signature, err := identityKeyPair.GetPrivateKey().Sign(serializedPublicKey)
if err != nil {
panic(fmt.Errorf("error signing kyber public key: %w", err))
}
preKey, err := libsignalgo.NewKyberPreKeyRecord(keyID, time.Now(), kyberPreKeyPair, signature)
if err != nil {
panic(fmt.Errorf("error creating kyber prekey record: %w", err))
}
generatedKyberPreKeys = append(generatedKyberPreKeys, preKey)
}
return generatedKyberPreKeys
}
func GenerateSignedPreKey(startSignedKeyId uint32, identityKeyPair *libsignalgo.IdentityKeyPair) *libsignalgo.SignedPreKeyRecord {
// Generate a signed prekey
privateKey, err := libsignalgo.GeneratePrivateKey()
if err != nil {
panic(fmt.Errorf("error generating private key: %w", err))
}
timestamp := time.Now()
publicKey, err := privateKey.GetPublicKey()
if err != nil {
panic(fmt.Errorf("error getting public key: %w", err))
}
serializedPublicKey, err := publicKey.Serialize()
if err != nil {
panic(fmt.Errorf("error serializing public key: %w", err))
}
signature, err := identityKeyPair.GetPrivateKey().Sign(serializedPublicKey)
if err != nil {
panic(fmt.Errorf("error signing public key: %w", err))
}
signedPreKey, err := libsignalgo.NewSignedPreKeyRecordFromPrivateKey(startSignedKeyId, timestamp, privateKey, signature)
if err != nil {
panic(fmt.Errorf("error creating signed prekey record: %w", err))
}
return signedPreKey
}
func PreKeyToJSON(preKey *libsignalgo.PreKeyRecord) (map[string]interface{}, error) {
id, err := preKey.GetID()
if err != nil {
return nil, fmt.Errorf("failed to get ID: %w", err)
}
publicKey, err := preKey.GetPublicKey()
if err != nil {
return nil, fmt.Errorf("failed to get public key: %w", err)
}
serializedKey, err := publicKey.Serialize()
if err != nil {
return nil, fmt.Errorf("failed to serialize public key: %w", err)
}
preKeyJson := map[string]interface{}{
"keyId": id,
"publicKey": base64.StdEncoding.EncodeToString(serializedKey),
}
return preKeyJson, nil
}
func SignedPreKeyToJSON(signedPreKey *libsignalgo.SignedPreKeyRecord) (map[string]interface{}, error) {
id, err := signedPreKey.GetID()
if err != nil {
return nil, fmt.Errorf("failed to get ID: %w", err)
}
publicKey, err := signedPreKey.GetPublicKey()
if err != nil {
return nil, fmt.Errorf("failed to get public key: %w", err)
}
serializedKey, err := publicKey.Serialize()
if err != nil {
return nil, fmt.Errorf("failed to serialize public key: %w", err)
}
signature, err := signedPreKey.GetSignature()
if err != nil {
return nil, fmt.Errorf("failed to get signature: %w", err)
}
signedPreKeyJson := map[string]interface{}{
"keyId": id,
"publicKey": base64.StdEncoding.EncodeToString(serializedKey),
"signature": base64.StdEncoding.EncodeToString(signature),
}
return signedPreKeyJson, nil
}
func KyberPreKeyToJSON(kyberPreKey *libsignalgo.KyberPreKeyRecord) (map[string]interface{}, error) {
id, err := kyberPreKey.GetID()
if err != nil {
return nil, fmt.Errorf("failed to get ID: %w", err)
}
publicKey, err := kyberPreKey.GetPublicKey()
if err != nil {
return nil, fmt.Errorf("failed to get public key: %w", err)
}
serializedKey, err := publicKey.Serialize()
if err != nil {
return nil, fmt.Errorf("failed to serialize public key: %w", err)
}
signature, err := kyberPreKey.GetSignature()
if err != nil {
return nil, fmt.Errorf("failed to get signature: %w", err)
}
kyberPreKeyJson := map[string]interface{}{
"keyId": id,
"publicKey": base64.StdEncoding.EncodeToString(serializedKey),
"signature": base64.StdEncoding.EncodeToString(signature),
}
return kyberPreKeyJson, nil
}
var errPrekeyUpload422 = errors.New("http 422 while registering prekeys")
func RegisterPreKeys(ctx context.Context, generatedPreKeys *GeneratedPreKeys, pni bool, username string, password string) error {
log := zerolog.Ctx(ctx).With().Str("action", "register prekeys").Logger()
// Convert generated prekeys to JSON
preKeysJson := []map[string]interface{}{}
kyberPreKeysJson := []map[string]interface{}{}
for _, preKey := range generatedPreKeys.PreKeys {
preKeyJson, err := PreKeyToJSON(preKey)
if err != nil {
return fmt.Errorf("failed to convert prekey to JSON: %w", err)
}
preKeysJson = append(preKeysJson, preKeyJson)
}
for _, kyberPreKey := range generatedPreKeys.KyberPreKeys {
kyberPreKeyJson, err := KyberPreKeyToJSON(kyberPreKey)
if err != nil {
return fmt.Errorf("failed to convert kyber prekey to JSON: %w", err)
}
kyberPreKeysJson = append(kyberPreKeysJson, kyberPreKeyJson)
}
identityKey := generatedPreKeys.IdentityKey
register_json := map[string]interface{}{
"preKeys": preKeysJson,
"pqPreKeys": kyberPreKeysJson,
"identityKey": base64.StdEncoding.EncodeToString(identityKey),
}
// Send request
jsonBytes, err := json.Marshal(register_json)
if err != nil {
log.Err(err).Msg("Error marshalling register JSON")
return err
}
opts := &web.HTTPReqOpt{Body: jsonBytes, Username: &username, Password: &password}
resp, err := web.SendHTTPRequest(ctx, http.MethodPut, keysPath(pni), opts)
if err != nil {
log.Err(err).Msg("Error sending request")
return err
}
defer resp.Body.Close()
// status code not 2xx
if resp.StatusCode == 422 {
return errPrekeyUpload422
} else if resp.StatusCode < 200 || resp.StatusCode >= 300 {
return fmt.Errorf("error registering prekeys: %v", resp.Status)
}
return err
}
type prekeyResponse struct {
IdentityKey string `json:"identityKey"`
Devices []prekeyDevice `json:"devices"`
}
type preKeyCountResponse struct {
Count int `json:"count"`
PQCount int `json:"pqCount"`
}
type prekeyDevice struct {
DeviceID int `json:"deviceId"`
RegistrationID int `json:"registrationId"`
SignedPreKey prekeyDetail `json:"signedPreKey"`
PreKey *prekeyDetail `json:"preKey"`
PQPreKey *prekeyDetail `json:"pqPreKey"`
}
type prekeyDetail struct {
KeyID int `json:"keyId"`
PublicKey string `json:"publicKey"`
Signature string `json:"signature,omitempty"` // 'omitempty' since this field isn't always present
}
func addBase64PaddingAndDecode(data string) ([]byte, error) {
padding := len(data) % 4
if padding > 0 {
data += strings.Repeat("=", 4-padding)
}
return base64.StdEncoding.DecodeString(data)
}
func (cli *Client) FetchAndProcessPreKey(ctx context.Context, theirServiceID libsignalgo.ServiceID, specificDeviceID int) error {
// Fetch prekey
deviceIDPath := "/*"
if specificDeviceID >= 0 {
deviceIDPath = "/" + fmt.Sprint(specificDeviceID)
}
path := "/v2/keys/" + theirServiceID.String() + deviceIDPath + "?pq=true"
username, password := cli.Store.BasicAuthCreds()
resp, err := web.SendHTTPRequest(ctx, http.MethodGet, path, &web.HTTPReqOpt{Username: &username, Password: &password})
if err != nil {
return fmt.Errorf("error sending request: %w", err)
}
var prekeyResponse prekeyResponse
err = web.DecodeHTTPResponseBody(ctx, &prekeyResponse, resp)
if err != nil {
return fmt.Errorf("error decoding response body: %w", err)
}
rawIdentityKey, err := addBase64PaddingAndDecode(prekeyResponse.IdentityKey)
if err != nil {
return fmt.Errorf("error decoding identity key: %w", err)
}
identityKey, err := libsignalgo.DeserializeIdentityKey([]byte(rawIdentityKey))
if err != nil {
return fmt.Errorf("error deserializing identity key: %w", err)
}
if identityKey == nil {
return fmt.Errorf("deserializing identity key returned nil with no error")
}
// Process each prekey in response (should only be one at the moment)
for _, d := range prekeyResponse.Devices {
var publicKey *libsignalgo.PublicKey
var preKeyID uint32
if d.PreKey != nil {
preKeyID = uint32(d.PreKey.KeyID)
rawPublicKey, err := addBase64PaddingAndDecode(d.PreKey.PublicKey)
if err != nil {
return fmt.Errorf("error decoding public key: %w", err)
}
publicKey, err = libsignalgo.DeserializePublicKey(rawPublicKey)
if err != nil {
return fmt.Errorf("error deserializing public key: %w", err)
}
}
rawSignedPublicKey, err := addBase64PaddingAndDecode(d.SignedPreKey.PublicKey)
if err != nil {
return fmt.Errorf("error decoding signed public key: %w", err)
}
signedPublicKey, err := libsignalgo.DeserializePublicKey(rawSignedPublicKey)
if err != nil {
return fmt.Errorf("error deserializing signed public key: %w", err)
}
var kyberPublicKey *libsignalgo.KyberPublicKey
var kyberPreKeyID uint32
var kyberPreKeySignature []byte
if d.PQPreKey != nil {
kyberPreKeyID = uint32(d.PQPreKey.KeyID)
rawKyberPublicKey, err := addBase64PaddingAndDecode(d.PQPreKey.PublicKey)
if err != nil {
return fmt.Errorf("error decoding kyber public key: %w", err)
}
kyberPublicKey, err = libsignalgo.DeserializeKyberPublicKey(rawKyberPublicKey)
if err != nil {
return fmt.Errorf("error deserializing kyber public key: %w", err)
}
kyberPreKeySignature, err = addBase64PaddingAndDecode(d.PQPreKey.Signature)
if err != nil {
return fmt.Errorf("error decoding kyber prekey signature: %w", err)
}
}
rawSignature, err := addBase64PaddingAndDecode(d.SignedPreKey.Signature)
if err != nil {
return fmt.Errorf("error decoding signature: %w", err)
}
preKeyBundle, err := libsignalgo.NewPreKeyBundle(
uint32(d.RegistrationID),
uint32(d.DeviceID),
preKeyID,
publicKey,
uint32(d.SignedPreKey.KeyID),
signedPublicKey,
rawSignature,
kyberPreKeyID,
kyberPublicKey,
kyberPreKeySignature,
identityKey,
)
if err != nil {
return fmt.Errorf("error creating prekey bundle: %w", err)
}
address, err := theirServiceID.Address(uint(d.DeviceID))
if err != nil {
return fmt.Errorf("error creating address: %w", err)
}
err = libsignalgo.ProcessPreKeyBundle(
ctx,
preKeyBundle,
address,
cli.Store.ACISessionStore,
cli.Store.ACIIdentityStore,
)
if err != nil {
return fmt.Errorf("error processing prekey bundle: %w", err)
}
}
return err
}
const (
aciKeysPath = "/v2/keys?identity=aci"
pniKeysPath = "/v2/keys?identity=pni"
)
func keysPath(pni bool) string {
if pni {
return pniKeysPath
}
return aciKeysPath
}
func (cli *Client) GetMyKeyCounts(ctx context.Context, pni bool) (int, int, error) {
log := zerolog.Ctx(ctx).With().Str("action", "get my key counts").Logger()
username, password := cli.Store.BasicAuthCreds()
resp, err := web.SendHTTPRequest(ctx, http.MethodGet, keysPath(pni), &web.HTTPReqOpt{Username: &username, Password: &password})
if err != nil {
log.Err(err).Msg("Error sending request")
return 0, 0, err
}
var preKeyCountResponse preKeyCountResponse
err = web.DecodeHTTPResponseBody(ctx, &preKeyCountResponse, resp)
if err != nil {
log.Err(err).Msg("Fetching prekey counts, error with response body")
return 0, 0, err
}
return preKeyCountResponse.Count, preKeyCountResponse.PQCount, err
}
func (cli *Client) CheckAndUploadNewPreKeys(ctx context.Context, pks store.PreKeyStore) error {
log := zerolog.Ctx(ctx).With().Str("action", "check and upload new prekeys").Logger()
// Check if we need to upload prekeys
preKeyCount, kyberPreKeyCount, err := cli.GetMyKeyCounts(ctx, pks.GetServiceID().Type == libsignalgo.ServiceIDTypePNI)
if err != nil {
log.Err(err).Msg("Error getting prekey counts")
return err
}
doECUpload, err := cli.GenerateAndSaveNextPreKeyBatch(ctx, pks, preKeyCount)
if err != nil {
log.Err(err).Msg("Error generating and saving next prekey batch")
return err
}
doKyberUpload, err := cli.GenerateAndSaveNextKyberPreKeyBatch(ctx, pks, kyberPreKeyCount)
if err != nil {
log.Err(err).Msg("Error generating and saving next kyber prekey batch")
return err
}
if !doECUpload && !doKyberUpload {
log.Debug().Msg("No new prekeys to upload")
return nil
}
err = cli.RegisterAllPreKeys(ctx, pks)
if err != nil {
log.Err(err).Msg("Error registering prekey batches")
return err
}
return nil
}
func (cli *Client) StartKeyCheckLoop(ctx context.Context) {
log := zerolog.Ctx(ctx).With().Str("action", "start key check loop").Logger()
go func() {
// Do the initial check in 5-10 minutes after starting the loop
window_start := 0
window_size := 1
for {
random_minutes_in_window := rand.Intn(window_size) + window_start
check_time := time.Duration(random_minutes_in_window) * time.Minute
log.Debug().Dur("check_time", check_time).Msg("Waiting to check for new prekeys")
select {
case <-ctx.Done():
return
case <-time.After(check_time):
err := cli.CheckAndUploadNewPreKeys(ctx, cli.Store.ACIPreKeyStore)
if err != nil {
log.Err(err).Msg("Error checking and uploading new prekeys for ACI identity")
// Retry within half an hour
window_start = 5
window_size = 25
continue
}
err = cli.CheckAndUploadNewPreKeys(ctx, cli.Store.PNIPreKeyStore)
if err != nil {
if errors.Is(err, errPrekeyUpload422) {
log.Err(err).Msg("Got 422 error while uploading PNI prekeys, deleting session")
disconnectErr := cli.ClearKeysAndDisconnect(ctx)
if disconnectErr != nil {
log.Err(disconnectErr).Msg("ClearKeysAndDisconnect error")
}
return
}
log.Err(err).Msg("Error checking and uploading new prekeys for PNI identity")
// Retry within half an hour
window_start = 5
window_size = 25
continue
}
// After a successful check, check again in 36 to 60 hours
window_start = 36 * 60
window_size = 24 * 60
}
}
}()
}