648 lines
21 KiB
Go
648 lines
21 KiB
Go
// mautrix-signal - A Matrix-signal puppeting bridge.
|
|
// Copyright (C) 2023 Scott Weber
|
|
//
|
|
// This program is free software: you can redistribute it and/or modify
|
|
// it under the terms of the GNU Affero General Public License as published by
|
|
// the Free Software Foundation, either version 3 of the License, or
|
|
// (at your option) any later version.
|
|
//
|
|
// This program is distributed in the hope that it will be useful,
|
|
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
|
// GNU Affero General Public License for more details.
|
|
//
|
|
// You should have received a copy of the GNU Affero General Public License
|
|
// along with this program. If not, see <https://www.gnu.org/licenses/>.
|
|
|
|
package signalmeow
|
|
|
|
import (
|
|
"context"
|
|
"encoding/base64"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"math/rand"
|
|
"net/http"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/rs/zerolog"
|
|
|
|
"go.mau.fi/mautrix-signal/pkg/libsignalgo"
|
|
"go.mau.fi/mautrix-signal/pkg/signalmeow/store"
|
|
"go.mau.fi/mautrix-signal/pkg/signalmeow/web"
|
|
)
|
|
|
|
const PREKEY_BATCH_SIZE = 100
|
|
|
|
type GeneratedPreKeys struct {
|
|
PreKeys []*libsignalgo.PreKeyRecord
|
|
KyberPreKeys []*libsignalgo.KyberPreKeyRecord
|
|
IdentityKey []uint8
|
|
}
|
|
|
|
func (cli *Client) GenerateAndRegisterPreKeys(ctx context.Context, pks store.PreKeyStore) error {
|
|
_, err := cli.GenerateAndSaveNextPreKeyBatch(ctx, pks, 0)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to generate and save next prekey batch: %w", err)
|
|
}
|
|
_, err = cli.GenerateAndSaveNextKyberPreKeyBatch(ctx, pks, 0)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to generate and save next kyber prekey batch: %w", err)
|
|
}
|
|
|
|
// We need to upload all currently valid prekeys, not just the ones we just generated
|
|
err = cli.RegisterAllPreKeys(ctx, pks)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to register prekey batches: %w", err)
|
|
}
|
|
|
|
return err
|
|
}
|
|
|
|
func (cli *Client) RegisterAllPreKeys(ctx context.Context, pks store.PreKeyStore) error {
|
|
var identityKeyPair *libsignalgo.IdentityKeyPair
|
|
var pni bool
|
|
if pks.GetServiceID().Type == libsignalgo.ServiceIDTypePNI {
|
|
pni = true
|
|
identityKeyPair = cli.Store.PNIIdentityKeyPair
|
|
} else {
|
|
identityKeyPair = cli.Store.ACIIdentityKeyPair
|
|
}
|
|
|
|
// Get all prekeys and kyber prekeys from the database
|
|
preKeys, err := pks.AllPreKeys(ctx)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to get all prekeys: %w", err)
|
|
}
|
|
kyberPreKeys, err := pks.AllNormalKyberPreKeys(ctx)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to get all kyber prekeys: %w", err)
|
|
}
|
|
|
|
// We need to have some keys to upload
|
|
if len(preKeys) == 0 && len(kyberPreKeys) == 0 {
|
|
return fmt.Errorf("no prekeys to upload")
|
|
}
|
|
|
|
identityKey, err := identityKeyPair.GetPublicKey().Serialize()
|
|
if err != nil {
|
|
return fmt.Errorf("failed to serialize identity key: %w", err)
|
|
}
|
|
|
|
generatedPreKeys := GeneratedPreKeys{
|
|
PreKeys: preKeys,
|
|
KyberPreKeys: kyberPreKeys,
|
|
IdentityKey: identityKey,
|
|
}
|
|
preKeyUsername := fmt.Sprintf("%s.%d", cli.Store.ACI, cli.Store.DeviceID)
|
|
log := zerolog.Ctx(ctx).With().Str("action", "register prekeys").Logger()
|
|
log.Debug().Int("num_prekeys", len(preKeys)).Int("num_kyber_prekeys", len(kyberPreKeys)).Msg("Registering prekeys")
|
|
err = RegisterPreKeys(ctx, &generatedPreKeys, pni, preKeyUsername, cli.Store.Password)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to register prekeys: %w", err)
|
|
}
|
|
|
|
return err
|
|
}
|
|
|
|
func (cli *Client) GenerateAndSaveNextPreKeyBatch(ctx context.Context, pks store.PreKeyStore, serverCount int) (bool, error) {
|
|
storeCount, nextPreKeyID, err := pks.GetNextPreKeyID(ctx)
|
|
if err != nil {
|
|
return false, fmt.Errorf("failed to get next prekey ID: %w", err)
|
|
}
|
|
if serverCount < PREKEY_BATCH_SIZE/2 {
|
|
if storeCount >= PREKEY_BATCH_SIZE {
|
|
zerolog.Ctx(ctx).Warn().
|
|
Int("server_count", serverCount).
|
|
Uint32("store_count", storeCount).
|
|
Msg("Store is full, but server is not, reuploading EC prekeys without generating more")
|
|
} else {
|
|
zerolog.Ctx(ctx).Info().
|
|
Int("server_count", serverCount).
|
|
Uint32("store_count", storeCount).
|
|
Msg("Generating and uploading EC prekeys")
|
|
}
|
|
} else if uint32(serverCount) > storeCount {
|
|
zerolog.Ctx(ctx).Warn().
|
|
Int("server_count", serverCount).
|
|
Uint32("store_count", storeCount).
|
|
Msg("Server has more EC prekeys than store, reuploading")
|
|
} else {
|
|
zerolog.Ctx(ctx).Debug().
|
|
Int("server_count", serverCount).
|
|
Uint32("store_count", storeCount).
|
|
Msg("EC prekey count is good")
|
|
return false, nil
|
|
}
|
|
if storeCount < PREKEY_BATCH_SIZE {
|
|
preKeys := GeneratePreKeys(nextPreKeyID, PREKEY_BATCH_SIZE-storeCount)
|
|
for _, preKey := range preKeys {
|
|
err = pks.StorePreKey(ctx, 0, preKey)
|
|
if err != nil {
|
|
return false, fmt.Errorf("failed to save prekey: %w", err)
|
|
}
|
|
}
|
|
}
|
|
return true, nil
|
|
}
|
|
|
|
func (cli *Client) GenerateAndSaveNextKyberPreKeyBatch(ctx context.Context, pks store.PreKeyStore, serverCount int) (bool, error) {
|
|
var identityKeyPair *libsignalgo.IdentityKeyPair
|
|
if pks.GetServiceID().Type == libsignalgo.ServiceIDTypePNI {
|
|
identityKeyPair = cli.Store.PNIIdentityKeyPair
|
|
} else {
|
|
identityKeyPair = cli.Store.ACIIdentityKeyPair
|
|
}
|
|
storeCount, nextKyberPreKeyID, err := pks.GetNextKyberPreKeyID(ctx)
|
|
if err != nil {
|
|
return false, fmt.Errorf("failed to get next kyber prekey ID: %w", err)
|
|
}
|
|
if serverCount < PREKEY_BATCH_SIZE/2 {
|
|
if storeCount >= PREKEY_BATCH_SIZE {
|
|
zerolog.Ctx(ctx).Warn().
|
|
Int("server_count", serverCount).
|
|
Uint32("store_count", storeCount).
|
|
Msg("Store is full, but server is not, reuploading kyber prekeys without generating more")
|
|
} else {
|
|
zerolog.Ctx(ctx).Info().
|
|
Int("server_count", serverCount).
|
|
Uint32("store_count", storeCount).
|
|
Msg("Generating and uploading kyber prekeys")
|
|
}
|
|
} else if uint32(serverCount) > storeCount {
|
|
zerolog.Ctx(ctx).Warn().
|
|
Int("server_count", serverCount).
|
|
Uint32("store_count", storeCount).
|
|
Msg("Server has more kyber prekeys than store, reuploading")
|
|
} else {
|
|
zerolog.Ctx(ctx).Debug().
|
|
Int("server_count", serverCount).
|
|
Uint32("store_count", storeCount).
|
|
Msg("Kyber prekey count is good")
|
|
return false, nil
|
|
}
|
|
if storeCount < PREKEY_BATCH_SIZE {
|
|
kyberPreKeys := GenerateKyberPreKeys(nextKyberPreKeyID, PREKEY_BATCH_SIZE-storeCount, identityKeyPair)
|
|
for _, kyberPreKey := range kyberPreKeys {
|
|
err = pks.StoreKyberPreKey(ctx, 0, kyberPreKey)
|
|
if err != nil {
|
|
return false, fmt.Errorf("failed to save kyber prekey: %w", err)
|
|
}
|
|
}
|
|
}
|
|
return true, nil
|
|
}
|
|
|
|
func GeneratePreKeys(startKeyID uint32, count uint32) []*libsignalgo.PreKeyRecord {
|
|
if count > PREKEY_BATCH_SIZE {
|
|
panic("count must be less than or equal to PREKEY_BATCH_SIZE")
|
|
}
|
|
generatedPreKeys := make([]*libsignalgo.PreKeyRecord, 0, count)
|
|
for keyID := startKeyID; keyID < startKeyID+count; keyID++ {
|
|
privateKey, err := libsignalgo.GeneratePrivateKey()
|
|
if err != nil {
|
|
panic(fmt.Errorf("error generating private key: %w", err))
|
|
}
|
|
preKey, err := libsignalgo.NewPreKeyRecordFromPrivateKey(keyID, privateKey)
|
|
if err != nil {
|
|
panic(fmt.Errorf("error creating prekey record: %w", err))
|
|
}
|
|
generatedPreKeys = append(generatedPreKeys, preKey)
|
|
}
|
|
return generatedPreKeys
|
|
}
|
|
|
|
func GenerateKyberPreKeys(startKeyID uint32, count uint32, identityKeyPair *libsignalgo.IdentityKeyPair) []*libsignalgo.KyberPreKeyRecord {
|
|
if count > PREKEY_BATCH_SIZE {
|
|
panic("count must be less than or equal to PREKEY_BATCH_SIZE")
|
|
}
|
|
generatedKyberPreKeys := make([]*libsignalgo.KyberPreKeyRecord, 0, count)
|
|
for keyID := startKeyID; keyID < startKeyID+count; keyID++ {
|
|
kyberPreKeyPair, err := libsignalgo.KyberKeyPairGenerate()
|
|
if err != nil {
|
|
panic(fmt.Errorf("error generating kyber key pair: %w", err))
|
|
}
|
|
publicKey, err := kyberPreKeyPair.GetPublicKey()
|
|
if err != nil {
|
|
panic(fmt.Errorf("error getting kyber public key: %w", err))
|
|
}
|
|
serializedPublicKey, err := publicKey.Serialize()
|
|
if err != nil {
|
|
panic(fmt.Errorf("error serializing kyber public key: %w", err))
|
|
}
|
|
signature, err := identityKeyPair.GetPrivateKey().Sign(serializedPublicKey)
|
|
if err != nil {
|
|
panic(fmt.Errorf("error signing kyber public key: %w", err))
|
|
}
|
|
preKey, err := libsignalgo.NewKyberPreKeyRecord(keyID, time.Now(), kyberPreKeyPair, signature)
|
|
if err != nil {
|
|
panic(fmt.Errorf("error creating kyber prekey record: %w", err))
|
|
|
|
}
|
|
generatedKyberPreKeys = append(generatedKyberPreKeys, preKey)
|
|
}
|
|
return generatedKyberPreKeys
|
|
}
|
|
|
|
func GenerateSignedPreKey(startSignedKeyId uint32, identityKeyPair *libsignalgo.IdentityKeyPair) *libsignalgo.SignedPreKeyRecord {
|
|
// Generate a signed prekey
|
|
privateKey, err := libsignalgo.GeneratePrivateKey()
|
|
if err != nil {
|
|
panic(fmt.Errorf("error generating private key: %w", err))
|
|
}
|
|
timestamp := time.Now()
|
|
publicKey, err := privateKey.GetPublicKey()
|
|
if err != nil {
|
|
panic(fmt.Errorf("error getting public key: %w", err))
|
|
}
|
|
serializedPublicKey, err := publicKey.Serialize()
|
|
if err != nil {
|
|
panic(fmt.Errorf("error serializing public key: %w", err))
|
|
}
|
|
signature, err := identityKeyPair.GetPrivateKey().Sign(serializedPublicKey)
|
|
if err != nil {
|
|
panic(fmt.Errorf("error signing public key: %w", err))
|
|
}
|
|
signedPreKey, err := libsignalgo.NewSignedPreKeyRecordFromPrivateKey(startSignedKeyId, timestamp, privateKey, signature)
|
|
if err != nil {
|
|
panic(fmt.Errorf("error creating signed prekey record: %w", err))
|
|
}
|
|
|
|
return signedPreKey
|
|
}
|
|
|
|
func PreKeyToJSON(preKey *libsignalgo.PreKeyRecord) (map[string]interface{}, error) {
|
|
id, err := preKey.GetID()
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to get ID: %w", err)
|
|
}
|
|
publicKey, err := preKey.GetPublicKey()
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to get public key: %w", err)
|
|
}
|
|
serializedKey, err := publicKey.Serialize()
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to serialize public key: %w", err)
|
|
}
|
|
preKeyJson := map[string]interface{}{
|
|
"keyId": id,
|
|
"publicKey": base64.StdEncoding.EncodeToString(serializedKey),
|
|
}
|
|
return preKeyJson, nil
|
|
}
|
|
|
|
func SignedPreKeyToJSON(signedPreKey *libsignalgo.SignedPreKeyRecord) (map[string]interface{}, error) {
|
|
id, err := signedPreKey.GetID()
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to get ID: %w", err)
|
|
}
|
|
publicKey, err := signedPreKey.GetPublicKey()
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to get public key: %w", err)
|
|
}
|
|
serializedKey, err := publicKey.Serialize()
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to serialize public key: %w", err)
|
|
}
|
|
signature, err := signedPreKey.GetSignature()
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to get signature: %w", err)
|
|
}
|
|
signedPreKeyJson := map[string]interface{}{
|
|
"keyId": id,
|
|
"publicKey": base64.StdEncoding.EncodeToString(serializedKey),
|
|
"signature": base64.StdEncoding.EncodeToString(signature),
|
|
}
|
|
return signedPreKeyJson, nil
|
|
}
|
|
|
|
func KyberPreKeyToJSON(kyberPreKey *libsignalgo.KyberPreKeyRecord) (map[string]interface{}, error) {
|
|
id, err := kyberPreKey.GetID()
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to get ID: %w", err)
|
|
}
|
|
publicKey, err := kyberPreKey.GetPublicKey()
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to get public key: %w", err)
|
|
}
|
|
serializedKey, err := publicKey.Serialize()
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to serialize public key: %w", err)
|
|
}
|
|
signature, err := kyberPreKey.GetSignature()
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to get signature: %w", err)
|
|
}
|
|
kyberPreKeyJson := map[string]interface{}{
|
|
"keyId": id,
|
|
"publicKey": base64.StdEncoding.EncodeToString(serializedKey),
|
|
"signature": base64.StdEncoding.EncodeToString(signature),
|
|
}
|
|
return kyberPreKeyJson, nil
|
|
}
|
|
|
|
var errPrekeyUpload422 = errors.New("http 422 while registering prekeys")
|
|
|
|
func RegisterPreKeys(ctx context.Context, generatedPreKeys *GeneratedPreKeys, pni bool, username string, password string) error {
|
|
log := zerolog.Ctx(ctx).With().Str("action", "register prekeys").Logger()
|
|
// Convert generated prekeys to JSON
|
|
preKeysJson := []map[string]interface{}{}
|
|
kyberPreKeysJson := []map[string]interface{}{}
|
|
for _, preKey := range generatedPreKeys.PreKeys {
|
|
preKeyJson, err := PreKeyToJSON(preKey)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to convert prekey to JSON: %w", err)
|
|
}
|
|
preKeysJson = append(preKeysJson, preKeyJson)
|
|
}
|
|
for _, kyberPreKey := range generatedPreKeys.KyberPreKeys {
|
|
kyberPreKeyJson, err := KyberPreKeyToJSON(kyberPreKey)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to convert kyber prekey to JSON: %w", err)
|
|
}
|
|
kyberPreKeysJson = append(kyberPreKeysJson, kyberPreKeyJson)
|
|
}
|
|
|
|
identityKey := generatedPreKeys.IdentityKey
|
|
register_json := map[string]interface{}{
|
|
"preKeys": preKeysJson,
|
|
"pqPreKeys": kyberPreKeysJson,
|
|
"identityKey": base64.StdEncoding.EncodeToString(identityKey),
|
|
}
|
|
|
|
// Send request
|
|
jsonBytes, err := json.Marshal(register_json)
|
|
if err != nil {
|
|
log.Err(err).Msg("Error marshalling register JSON")
|
|
return err
|
|
}
|
|
opts := &web.HTTPReqOpt{Body: jsonBytes, Username: &username, Password: &password}
|
|
resp, err := web.SendHTTPRequest(ctx, http.MethodPut, keysPath(pni), opts)
|
|
if err != nil {
|
|
log.Err(err).Msg("Error sending request")
|
|
return err
|
|
}
|
|
defer resp.Body.Close()
|
|
// status code not 2xx
|
|
if resp.StatusCode == 422 {
|
|
return errPrekeyUpload422
|
|
} else if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
|
return fmt.Errorf("error registering prekeys: %v", resp.Status)
|
|
}
|
|
return err
|
|
}
|
|
|
|
type prekeyResponse struct {
|
|
IdentityKey string `json:"identityKey"`
|
|
Devices []prekeyDevice `json:"devices"`
|
|
}
|
|
|
|
type preKeyCountResponse struct {
|
|
Count int `json:"count"`
|
|
PQCount int `json:"pqCount"`
|
|
}
|
|
|
|
type prekeyDevice struct {
|
|
DeviceID int `json:"deviceId"`
|
|
RegistrationID int `json:"registrationId"`
|
|
SignedPreKey prekeyDetail `json:"signedPreKey"`
|
|
PreKey *prekeyDetail `json:"preKey"`
|
|
PQPreKey *prekeyDetail `json:"pqPreKey"`
|
|
}
|
|
|
|
type prekeyDetail struct {
|
|
KeyID int `json:"keyId"`
|
|
PublicKey string `json:"publicKey"`
|
|
Signature string `json:"signature,omitempty"` // 'omitempty' since this field isn't always present
|
|
}
|
|
|
|
func addBase64PaddingAndDecode(data string) ([]byte, error) {
|
|
padding := len(data) % 4
|
|
if padding > 0 {
|
|
data += strings.Repeat("=", 4-padding)
|
|
}
|
|
return base64.StdEncoding.DecodeString(data)
|
|
}
|
|
|
|
func (cli *Client) FetchAndProcessPreKey(ctx context.Context, theirServiceID libsignalgo.ServiceID, specificDeviceID int) error {
|
|
// Fetch prekey
|
|
deviceIDPath := "/*"
|
|
if specificDeviceID >= 0 {
|
|
deviceIDPath = "/" + fmt.Sprint(specificDeviceID)
|
|
}
|
|
path := "/v2/keys/" + theirServiceID.String() + deviceIDPath + "?pq=true"
|
|
username, password := cli.Store.BasicAuthCreds()
|
|
resp, err := web.SendHTTPRequest(ctx, http.MethodGet, path, &web.HTTPReqOpt{Username: &username, Password: &password})
|
|
if err != nil {
|
|
return fmt.Errorf("error sending request: %w", err)
|
|
}
|
|
var prekeyResponse prekeyResponse
|
|
err = web.DecodeHTTPResponseBody(ctx, &prekeyResponse, resp)
|
|
if err != nil {
|
|
return fmt.Errorf("error decoding response body: %w", err)
|
|
}
|
|
|
|
rawIdentityKey, err := addBase64PaddingAndDecode(prekeyResponse.IdentityKey)
|
|
if err != nil {
|
|
return fmt.Errorf("error decoding identity key: %w", err)
|
|
}
|
|
identityKey, err := libsignalgo.DeserializeIdentityKey([]byte(rawIdentityKey))
|
|
if err != nil {
|
|
return fmt.Errorf("error deserializing identity key: %w", err)
|
|
}
|
|
if identityKey == nil {
|
|
return fmt.Errorf("deserializing identity key returned nil with no error")
|
|
}
|
|
|
|
// Process each prekey in response (should only be one at the moment)
|
|
for _, d := range prekeyResponse.Devices {
|
|
var publicKey *libsignalgo.PublicKey
|
|
var preKeyID uint32
|
|
if d.PreKey != nil {
|
|
preKeyID = uint32(d.PreKey.KeyID)
|
|
rawPublicKey, err := addBase64PaddingAndDecode(d.PreKey.PublicKey)
|
|
if err != nil {
|
|
return fmt.Errorf("error decoding public key: %w", err)
|
|
}
|
|
publicKey, err = libsignalgo.DeserializePublicKey(rawPublicKey)
|
|
if err != nil {
|
|
return fmt.Errorf("error deserializing public key: %w", err)
|
|
}
|
|
}
|
|
|
|
rawSignedPublicKey, err := addBase64PaddingAndDecode(d.SignedPreKey.PublicKey)
|
|
if err != nil {
|
|
return fmt.Errorf("error decoding signed public key: %w", err)
|
|
}
|
|
signedPublicKey, err := libsignalgo.DeserializePublicKey(rawSignedPublicKey)
|
|
if err != nil {
|
|
return fmt.Errorf("error deserializing signed public key: %w", err)
|
|
}
|
|
|
|
var kyberPublicKey *libsignalgo.KyberPublicKey
|
|
var kyberPreKeyID uint32
|
|
var kyberPreKeySignature []byte
|
|
if d.PQPreKey != nil {
|
|
kyberPreKeyID = uint32(d.PQPreKey.KeyID)
|
|
rawKyberPublicKey, err := addBase64PaddingAndDecode(d.PQPreKey.PublicKey)
|
|
if err != nil {
|
|
return fmt.Errorf("error decoding kyber public key: %w", err)
|
|
}
|
|
kyberPublicKey, err = libsignalgo.DeserializeKyberPublicKey(rawKyberPublicKey)
|
|
if err != nil {
|
|
return fmt.Errorf("error deserializing kyber public key: %w", err)
|
|
}
|
|
kyberPreKeySignature, err = addBase64PaddingAndDecode(d.PQPreKey.Signature)
|
|
if err != nil {
|
|
return fmt.Errorf("error decoding kyber prekey signature: %w", err)
|
|
}
|
|
}
|
|
|
|
rawSignature, err := addBase64PaddingAndDecode(d.SignedPreKey.Signature)
|
|
if err != nil {
|
|
return fmt.Errorf("error decoding signature: %w", err)
|
|
}
|
|
|
|
preKeyBundle, err := libsignalgo.NewPreKeyBundle(
|
|
uint32(d.RegistrationID),
|
|
uint32(d.DeviceID),
|
|
preKeyID,
|
|
publicKey,
|
|
uint32(d.SignedPreKey.KeyID),
|
|
signedPublicKey,
|
|
rawSignature,
|
|
kyberPreKeyID,
|
|
kyberPublicKey,
|
|
kyberPreKeySignature,
|
|
identityKey,
|
|
)
|
|
if err != nil {
|
|
return fmt.Errorf("error creating prekey bundle: %w", err)
|
|
}
|
|
address, err := theirServiceID.Address(uint(d.DeviceID))
|
|
if err != nil {
|
|
return fmt.Errorf("error creating address: %w", err)
|
|
}
|
|
err = libsignalgo.ProcessPreKeyBundle(
|
|
ctx,
|
|
preKeyBundle,
|
|
address,
|
|
cli.Store.ACISessionStore,
|
|
cli.Store.ACIIdentityStore,
|
|
)
|
|
if err != nil {
|
|
return fmt.Errorf("error processing prekey bundle: %w", err)
|
|
}
|
|
}
|
|
|
|
return err
|
|
}
|
|
|
|
const (
|
|
aciKeysPath = "/v2/keys?identity=aci"
|
|
pniKeysPath = "/v2/keys?identity=pni"
|
|
)
|
|
|
|
func keysPath(pni bool) string {
|
|
if pni {
|
|
return pniKeysPath
|
|
}
|
|
return aciKeysPath
|
|
}
|
|
|
|
func (cli *Client) GetMyKeyCounts(ctx context.Context, pni bool) (int, int, error) {
|
|
log := zerolog.Ctx(ctx).With().Str("action", "get my key counts").Logger()
|
|
username, password := cli.Store.BasicAuthCreds()
|
|
resp, err := web.SendHTTPRequest(ctx, http.MethodGet, keysPath(pni), &web.HTTPReqOpt{Username: &username, Password: &password})
|
|
if err != nil {
|
|
log.Err(err).Msg("Error sending request")
|
|
return 0, 0, err
|
|
}
|
|
var preKeyCountResponse preKeyCountResponse
|
|
err = web.DecodeHTTPResponseBody(ctx, &preKeyCountResponse, resp)
|
|
if err != nil {
|
|
log.Err(err).Msg("Fetching prekey counts, error with response body")
|
|
return 0, 0, err
|
|
}
|
|
return preKeyCountResponse.Count, preKeyCountResponse.PQCount, err
|
|
}
|
|
|
|
func (cli *Client) CheckAndUploadNewPreKeys(ctx context.Context, pks store.PreKeyStore) error {
|
|
log := zerolog.Ctx(ctx).With().Str("action", "check and upload new prekeys").Logger()
|
|
// Check if we need to upload prekeys
|
|
preKeyCount, kyberPreKeyCount, err := cli.GetMyKeyCounts(ctx, pks.GetServiceID().Type == libsignalgo.ServiceIDTypePNI)
|
|
if err != nil {
|
|
log.Err(err).Msg("Error getting prekey counts")
|
|
return err
|
|
}
|
|
doECUpload, err := cli.GenerateAndSaveNextPreKeyBatch(ctx, pks, preKeyCount)
|
|
if err != nil {
|
|
log.Err(err).Msg("Error generating and saving next prekey batch")
|
|
return err
|
|
}
|
|
doKyberUpload, err := cli.GenerateAndSaveNextKyberPreKeyBatch(ctx, pks, kyberPreKeyCount)
|
|
if err != nil {
|
|
log.Err(err).Msg("Error generating and saving next kyber prekey batch")
|
|
return err
|
|
}
|
|
if !doECUpload && !doKyberUpload {
|
|
log.Debug().Msg("No new prekeys to upload")
|
|
return nil
|
|
}
|
|
err = cli.RegisterAllPreKeys(ctx, pks)
|
|
if err != nil {
|
|
log.Err(err).Msg("Error registering prekey batches")
|
|
return err
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (cli *Client) StartKeyCheckLoop(ctx context.Context) {
|
|
log := zerolog.Ctx(ctx).With().Str("action", "start key check loop").Logger()
|
|
go func() {
|
|
// Do the initial check in 5-10 minutes after starting the loop
|
|
window_start := 0
|
|
window_size := 1
|
|
for {
|
|
random_minutes_in_window := rand.Intn(window_size) + window_start
|
|
check_time := time.Duration(random_minutes_in_window) * time.Minute
|
|
log.Debug().Dur("check_time", check_time).Msg("Waiting to check for new prekeys")
|
|
|
|
select {
|
|
case <-ctx.Done():
|
|
return
|
|
case <-time.After(check_time):
|
|
err := cli.CheckAndUploadNewPreKeys(ctx, cli.Store.ACIPreKeyStore)
|
|
if err != nil {
|
|
log.Err(err).Msg("Error checking and uploading new prekeys for ACI identity")
|
|
// Retry within half an hour
|
|
window_start = 5
|
|
window_size = 25
|
|
continue
|
|
}
|
|
err = cli.CheckAndUploadNewPreKeys(ctx, cli.Store.PNIPreKeyStore)
|
|
if err != nil {
|
|
if errors.Is(err, errPrekeyUpload422) {
|
|
log.Err(err).Msg("Got 422 error while uploading PNI prekeys, deleting session")
|
|
disconnectErr := cli.ClearKeysAndDisconnect(ctx)
|
|
if disconnectErr != nil {
|
|
log.Err(disconnectErr).Msg("ClearKeysAndDisconnect error")
|
|
}
|
|
return
|
|
}
|
|
log.Err(err).Msg("Error checking and uploading new prekeys for PNI identity")
|
|
// Retry within half an hour
|
|
window_start = 5
|
|
window_size = 25
|
|
continue
|
|
}
|
|
// After a successful check, check again in 36 to 60 hours
|
|
window_start = 36 * 60
|
|
window_size = 24 * 60
|
|
}
|
|
}
|
|
}()
|
|
}
|