mirror of https://github.com/mautrix/go.git
780 lines
28 KiB
Go
780 lines
28 KiB
Go
// Copyright (c) 2024 Tulir Asokan
|
|
//
|
|
// This Source Code Form is subject to the terms of the Mozilla Public
|
|
// License, v. 2.0. If a copy of the MPL was not distributed with this
|
|
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
|
|
|
package crypto
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"sync"
|
|
"sync/atomic"
|
|
"time"
|
|
|
|
"github.com/rs/zerolog"
|
|
|
|
"go.mau.fi/util/exzerolog"
|
|
|
|
"maunium.net/go/mautrix"
|
|
"maunium.net/go/mautrix/crypto/ssss"
|
|
"maunium.net/go/mautrix/event"
|
|
"maunium.net/go/mautrix/id"
|
|
)
|
|
|
|
// OlmMachine is the main struct for handling Matrix end-to-end encryption.
|
|
type OlmMachine struct {
|
|
Client *mautrix.Client
|
|
SSSS *ssss.Machine
|
|
Log *zerolog.Logger
|
|
|
|
CryptoStore Store
|
|
StateStore StateStore
|
|
|
|
BackgroundCtx context.Context
|
|
|
|
PlaintextMentions bool
|
|
|
|
// Never ask the server for keys automatically as a side effect during Megolm decryption.
|
|
DisableDecryptKeyFetching bool
|
|
|
|
// Don't mark outbound Olm sessions as shared for devices they were initially sent to.
|
|
DisableSharedGroupSessionTracking bool
|
|
|
|
SendKeysMinTrust id.TrustState
|
|
ShareKeysMinTrust id.TrustState
|
|
|
|
AllowKeyShare func(context.Context, *id.Device, event.RequestedKeyInfo) *KeyShareRejection
|
|
|
|
account *OlmAccount
|
|
|
|
roomKeyRequestFilled *sync.Map
|
|
keyVerificationTransactionState *sync.Map
|
|
|
|
keyWaiters map[id.SessionID]chan struct{}
|
|
keyWaitersLock sync.Mutex
|
|
|
|
// Optional callback which is called when we save a session to store
|
|
SessionReceived func(context.Context, id.RoomID, id.SessionID, uint32)
|
|
|
|
devicesToUnwedge map[id.IdentityKey]bool
|
|
devicesToUnwedgeLock sync.Mutex
|
|
recentlyUnwedged map[id.IdentityKey]time.Time
|
|
recentlyUnwedgedLock sync.Mutex
|
|
olmHashSavePoints []time.Time
|
|
lastHashDelete time.Time
|
|
olmHashSavePointLock sync.Mutex
|
|
|
|
olmLock sync.Mutex
|
|
megolmEncryptLock sync.Mutex
|
|
megolmDecryptLock sync.Mutex
|
|
|
|
otkUploadLock sync.Mutex
|
|
lastOTKUpload time.Time
|
|
receivedOTKsForSelf atomic.Bool
|
|
|
|
CrossSigningKeys *CrossSigningKeysCache
|
|
crossSigningPubkeys *CrossSigningPublicKeysCache
|
|
|
|
crossSigningPubkeysFetched bool
|
|
|
|
DeleteOutboundKeysOnAck bool
|
|
DontStoreOutboundKeys bool
|
|
DeletePreviousKeysOnReceive bool
|
|
RatchetKeysOnDecrypt bool
|
|
DeleteFullyUsedKeysOnDecrypt bool
|
|
DeleteKeysOnDeviceDelete bool
|
|
DisableRatchetTracking bool
|
|
|
|
DisableDeviceChangeKeyRotation bool
|
|
|
|
secretLock sync.Mutex
|
|
secretListeners map[string]chan<- string
|
|
}
|
|
|
|
// StateStore is used by OlmMachine to get room state information that's needed for encryption.
|
|
type StateStore interface {
|
|
// IsEncrypted returns whether a room is encrypted.
|
|
IsEncrypted(context.Context, id.RoomID) (bool, error)
|
|
// GetEncryptionEvent returns the encryption event's content for an encrypted room.
|
|
GetEncryptionEvent(context.Context, id.RoomID) (*event.EncryptionEventContent, error)
|
|
// FindSharedRooms returns the encrypted rooms that another user is also in for a user ID.
|
|
FindSharedRooms(context.Context, id.UserID) ([]id.RoomID, error)
|
|
}
|
|
|
|
// NewOlmMachine creates an OlmMachine with the given client, logger and stores.
|
|
func NewOlmMachine(client *mautrix.Client, log *zerolog.Logger, cryptoStore Store, stateStore StateStore) *OlmMachine {
|
|
if log == nil {
|
|
logPtr := zerolog.Nop()
|
|
log = &logPtr
|
|
}
|
|
mach := &OlmMachine{
|
|
Client: client,
|
|
SSSS: ssss.NewSSSSMachine(client),
|
|
Log: log,
|
|
CryptoStore: cryptoStore,
|
|
StateStore: stateStore,
|
|
|
|
BackgroundCtx: context.Background(),
|
|
|
|
SendKeysMinTrust: id.TrustStateUnset,
|
|
ShareKeysMinTrust: id.TrustStateCrossSignedTOFU,
|
|
|
|
roomKeyRequestFilled: &sync.Map{},
|
|
keyVerificationTransactionState: &sync.Map{},
|
|
|
|
keyWaiters: make(map[id.SessionID]chan struct{}),
|
|
|
|
devicesToUnwedge: make(map[id.IdentityKey]bool),
|
|
recentlyUnwedged: make(map[id.IdentityKey]time.Time),
|
|
secretListeners: make(map[string]chan<- string),
|
|
}
|
|
mach.AllowKeyShare = mach.defaultAllowKeyShare
|
|
return mach
|
|
}
|
|
|
|
func (mach *OlmMachine) machOrContextLog(ctx context.Context) *zerolog.Logger {
|
|
log := zerolog.Ctx(ctx)
|
|
if log.GetLevel() == zerolog.Disabled || log == zerolog.DefaultContextLogger {
|
|
return mach.Log
|
|
}
|
|
return log
|
|
}
|
|
|
|
// Load loads the Olm account information from the crypto store. If there's no olm account, a new one is created.
|
|
// This must be called before using the machine.
|
|
func (mach *OlmMachine) Load(ctx context.Context) (err error) {
|
|
mach.account, err = mach.CryptoStore.GetAccount(ctx)
|
|
if err != nil {
|
|
return
|
|
}
|
|
if mach.account == nil {
|
|
mach.account = NewOlmAccount()
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (mach *OlmMachine) saveAccount(ctx context.Context) error {
|
|
err := mach.CryptoStore.PutAccount(ctx, mach.account)
|
|
if err != nil {
|
|
mach.Log.Error().Err(err).Msg("Failed to save account")
|
|
}
|
|
return err
|
|
}
|
|
|
|
func (mach *OlmMachine) KeyBackupVersion() id.KeyBackupVersion {
|
|
return mach.account.KeyBackupVersion
|
|
}
|
|
|
|
func (mach *OlmMachine) SetKeyBackupVersion(ctx context.Context, version id.KeyBackupVersion) error {
|
|
mach.account.KeyBackupVersion = version
|
|
return mach.saveAccount(ctx)
|
|
}
|
|
|
|
// FlushStore calls the Flush method of the CryptoStore.
|
|
func (mach *OlmMachine) FlushStore(ctx context.Context) error {
|
|
return mach.CryptoStore.Flush(ctx)
|
|
}
|
|
|
|
func (mach *OlmMachine) timeTrace(ctx context.Context, thing string, expectedDuration time.Duration) func() {
|
|
start := time.Now()
|
|
return func() {
|
|
duration := time.Now().Sub(start)
|
|
if duration > expectedDuration {
|
|
zerolog.Ctx(ctx).Warn().
|
|
Str("action", thing).
|
|
Dur("duration", duration).
|
|
Msg("Executing encryption function took longer than expected")
|
|
}
|
|
}
|
|
}
|
|
|
|
// Deprecated: moved to SigningKey.Fingerprint
|
|
func Fingerprint(key id.SigningKey) string {
|
|
return key.Fingerprint()
|
|
}
|
|
|
|
// Fingerprint returns the fingerprint of the Olm account that can be used for non-interactive verification.
|
|
func (mach *OlmMachine) Fingerprint() string {
|
|
return mach.account.SigningKey().Fingerprint()
|
|
}
|
|
|
|
func (mach *OlmMachine) GetAccount() *OlmAccount {
|
|
return mach.account
|
|
}
|
|
|
|
// OwnIdentity returns this device's id.Device struct
|
|
func (mach *OlmMachine) OwnIdentity() *id.Device {
|
|
return &id.Device{
|
|
UserID: mach.Client.UserID,
|
|
DeviceID: mach.Client.DeviceID,
|
|
IdentityKey: mach.account.IdentityKey(),
|
|
SigningKey: mach.account.SigningKey(),
|
|
Trust: id.TrustStateVerified,
|
|
Deleted: false,
|
|
}
|
|
}
|
|
|
|
type ASEventProcessor interface {
|
|
On(evtType event.Type, handler func(ctx context.Context, evt *event.Event))
|
|
OnOTK(func(ctx context.Context, otk *mautrix.OTKCount))
|
|
OnDeviceList(func(ctx context.Context, lists *mautrix.DeviceLists, since string))
|
|
Dispatch(ctx context.Context, evt *event.Event)
|
|
}
|
|
|
|
func (mach *OlmMachine) AddAppserviceListener(ep ASEventProcessor) {
|
|
// ToDeviceForwardedRoomKey and ToDeviceRoomKey should only be present inside encrypted to-device events
|
|
ep.On(event.ToDeviceEncrypted, mach.HandleToDeviceEvent)
|
|
ep.On(event.ToDeviceRoomKeyRequest, mach.HandleToDeviceEvent)
|
|
ep.On(event.ToDeviceRoomKeyWithheld, mach.HandleToDeviceEvent)
|
|
ep.On(event.ToDeviceBeeperRoomKeyAck, mach.HandleToDeviceEvent)
|
|
ep.On(event.ToDeviceOrgMatrixRoomKeyWithheld, mach.HandleToDeviceEvent)
|
|
ep.On(event.ToDeviceVerificationRequest, mach.HandleToDeviceEvent)
|
|
ep.On(event.ToDeviceVerificationStart, mach.HandleToDeviceEvent)
|
|
ep.On(event.ToDeviceVerificationAccept, mach.HandleToDeviceEvent)
|
|
ep.On(event.ToDeviceVerificationKey, mach.HandleToDeviceEvent)
|
|
ep.On(event.ToDeviceVerificationMAC, mach.HandleToDeviceEvent)
|
|
ep.On(event.ToDeviceVerificationCancel, mach.HandleToDeviceEvent)
|
|
ep.OnOTK(mach.HandleOTKCounts)
|
|
ep.OnDeviceList(mach.HandleDeviceLists)
|
|
mach.Log.Debug().Msg("Added listeners for encryption data coming from appservice transactions")
|
|
}
|
|
|
|
func (mach *OlmMachine) HandleDeviceLists(ctx context.Context, dl *mautrix.DeviceLists, since string) {
|
|
if len(dl.Changed) > 0 {
|
|
traceID := time.Now().Format("15:04:05.000000")
|
|
mach.Log.Debug().
|
|
Str("trace_id", traceID).
|
|
Interface("changes", dl.Changed).
|
|
Msg("Device list changes in /sync")
|
|
mach.FetchKeys(ctx, dl.Changed, false)
|
|
mach.Log.Debug().Str("trace_id", traceID).Msg("Finished handling device list changes")
|
|
}
|
|
}
|
|
|
|
func (mach *OlmMachine) otkCountIsForCrossSigningKey(otkCount *mautrix.OTKCount) bool {
|
|
if mach.crossSigningPubkeys == nil || otkCount.UserID != mach.Client.UserID {
|
|
return false
|
|
}
|
|
switch id.Ed25519(otkCount.DeviceID) {
|
|
case mach.crossSigningPubkeys.MasterKey, mach.crossSigningPubkeys.UserSigningKey, mach.crossSigningPubkeys.SelfSigningKey:
|
|
return true
|
|
}
|
|
return false
|
|
}
|
|
|
|
func (mach *OlmMachine) HandleOTKCounts(ctx context.Context, otkCount *mautrix.OTKCount) {
|
|
receivedOTKsForSelf := mach.receivedOTKsForSelf.Load()
|
|
if (len(otkCount.UserID) > 0 && otkCount.UserID != mach.Client.UserID) || (len(otkCount.DeviceID) > 0 && otkCount.DeviceID != mach.Client.DeviceID) {
|
|
if otkCount.UserID != mach.Client.UserID || (!receivedOTKsForSelf && !mach.otkCountIsForCrossSigningKey(otkCount)) {
|
|
mach.Log.Warn().
|
|
Str("target_user_id", otkCount.UserID.String()).
|
|
Str("target_device_id", otkCount.DeviceID.String()).
|
|
Msg("Dropping OTK counts targeted to someone else")
|
|
}
|
|
return
|
|
} else if !receivedOTKsForSelf {
|
|
mach.receivedOTKsForSelf.Store(true)
|
|
}
|
|
|
|
minCount := mach.account.InternalLibolm.MaxNumberOfOneTimeKeys() / 2
|
|
if otkCount.SignedCurve25519 < int(minCount) {
|
|
traceID := time.Now().Format("15:04:05.000000")
|
|
log := mach.Log.With().Str("trace_id", traceID).Logger()
|
|
ctx = log.WithContext(ctx)
|
|
log.Debug().
|
|
Int("keys_left", otkCount.SignedCurve25519).
|
|
Msg("Sync response said we have less than 50 signed curve25519 keys left, sharing new ones...")
|
|
err := mach.ShareKeys(ctx, otkCount.SignedCurve25519)
|
|
if err != nil {
|
|
log.Error().Err(err).Msg("Failed to share keys")
|
|
} else {
|
|
log.Debug().Msg("Successfully shared keys")
|
|
}
|
|
}
|
|
}
|
|
|
|
// ProcessSyncResponse processes a single /sync response.
|
|
//
|
|
// This can be easily registered into a mautrix client using .OnSync():
|
|
//
|
|
// client.Syncer.(mautrix.ExtensibleSyncer).OnSync(c.crypto.ProcessSyncResponse)
|
|
func (mach *OlmMachine) ProcessSyncResponse(ctx context.Context, resp *mautrix.RespSync, since string) bool {
|
|
mach.HandleDeviceLists(ctx, &resp.DeviceLists, since)
|
|
|
|
for _, evt := range resp.ToDevice.Events {
|
|
evt.Type.Class = event.ToDeviceEventType
|
|
err := evt.Content.ParseRaw(evt.Type)
|
|
if err != nil {
|
|
mach.Log.Warn().Str("event_type", evt.Type.Type).Err(err).Msg("Failed to parse to-device event")
|
|
continue
|
|
}
|
|
mach.HandleToDeviceEvent(ctx, evt)
|
|
}
|
|
|
|
mach.HandleOTKCounts(ctx, &resp.DeviceOTKCount)
|
|
mach.MarkOlmHashSavePoint(ctx)
|
|
return true
|
|
}
|
|
|
|
// HandleMemberEvent handles a single membership event.
|
|
//
|
|
// Currently this is not automatically called, so you must add a listener yourself:
|
|
//
|
|
// client.Syncer.(mautrix.ExtensibleSyncer).OnEventType(event.StateMember, c.crypto.HandleMemberEvent)
|
|
func (mach *OlmMachine) HandleMemberEvent(ctx context.Context, evt *event.Event) {
|
|
if isEncrypted, err := mach.StateStore.IsEncrypted(ctx, evt.RoomID); err != nil {
|
|
mach.machOrContextLog(ctx).Err(err).Stringer("room_id", evt.RoomID).
|
|
Msg("Failed to check if room is encrypted to handle member event")
|
|
return
|
|
} else if !isEncrypted {
|
|
return
|
|
}
|
|
content := evt.Content.AsMember()
|
|
if content == nil {
|
|
return
|
|
}
|
|
var prevContent *event.MemberEventContent
|
|
if evt.Unsigned.PrevContent != nil {
|
|
_ = evt.Unsigned.PrevContent.ParseRaw(evt.Type)
|
|
prevContent = evt.Unsigned.PrevContent.AsMember()
|
|
}
|
|
if prevContent == nil {
|
|
prevContent = &event.MemberEventContent{Membership: "unknown"}
|
|
}
|
|
if prevContent.Membership == content.Membership ||
|
|
(prevContent.Membership == event.MembershipInvite && content.Membership == event.MembershipJoin) ||
|
|
(prevContent.Membership == event.MembershipBan && content.Membership == event.MembershipLeave) ||
|
|
(prevContent.Membership == event.MembershipLeave && content.Membership == event.MembershipBan) {
|
|
return
|
|
}
|
|
mach.Log.Trace().
|
|
Str("room_id", evt.RoomID.String()).
|
|
Str("user_id", evt.GetStateKey()).
|
|
Str("prev_membership", string(prevContent.Membership)).
|
|
Str("new_membership", string(content.Membership)).
|
|
Msg("Got membership state change, invalidating group session in room")
|
|
err := mach.CryptoStore.RemoveOutboundGroupSession(ctx, evt.RoomID)
|
|
if err != nil {
|
|
mach.Log.Warn().Str("room_id", evt.RoomID.String()).Msg("Failed to invalidate outbound group session")
|
|
}
|
|
}
|
|
|
|
func (mach *OlmMachine) HandleEncryptedEvent(ctx context.Context, evt *event.Event) {
|
|
if _, ok := evt.Content.Parsed.(*event.EncryptedEventContent); !ok {
|
|
mach.machOrContextLog(ctx).Warn().Msg("Passed invalid event to encrypted handler")
|
|
return
|
|
}
|
|
|
|
decryptedEvt, err := mach.decryptOlmEvent(ctx, evt)
|
|
if err != nil {
|
|
mach.machOrContextLog(ctx).Error().Err(err).Msg("Failed to decrypt to-device event")
|
|
return
|
|
}
|
|
|
|
log := mach.machOrContextLog(ctx).With().
|
|
Str("decrypted_type", decryptedEvt.Type.Type).
|
|
Str("sender_device", decryptedEvt.SenderDevice.String()).
|
|
Str("sender_signing_key", decryptedEvt.Keys.Ed25519.String()).
|
|
Logger()
|
|
log.Trace().Msg("Successfully decrypted to-device event")
|
|
|
|
switch decryptedContent := decryptedEvt.Content.Parsed.(type) {
|
|
case *event.RoomKeyEventContent:
|
|
mach.receiveRoomKey(ctx, decryptedEvt, decryptedContent)
|
|
log.Trace().Msg("Handled room key event")
|
|
case *event.ForwardedRoomKeyEventContent:
|
|
if mach.importForwardedRoomKey(ctx, decryptedEvt, decryptedContent) {
|
|
if ch, ok := mach.roomKeyRequestFilled.Load(decryptedContent.SessionID); ok {
|
|
// close channel to notify listener that the key was received
|
|
close(ch.(chan struct{}))
|
|
}
|
|
}
|
|
log.Trace().Msg("Handled forwarded room key event")
|
|
case *event.DummyEventContent:
|
|
log.Debug().Msg("Received encrypted dummy event")
|
|
case *event.SecretSendEventContent:
|
|
mach.receiveSecret(ctx, decryptedEvt, decryptedContent)
|
|
log.Trace().Msg("Handled secret send event")
|
|
default:
|
|
log.Debug().Msg("Unhandled encrypted to-device event")
|
|
}
|
|
}
|
|
|
|
const olmHashSavePointCount = 5
|
|
const olmHashDeleteMinInterval = 10 * time.Minute
|
|
const minSavePointInterval = 1 * time.Minute
|
|
|
|
// MarkOlmHashSavePoint marks the current time as a save point for olm hashes and deletes old hashes if needed.
|
|
//
|
|
// This should be called after all to-device events in a sync have been processed.
|
|
// The function will then delete old olm hashes after enough syncs have happened
|
|
// (such that it's unlikely for the olm messages to repeat).
|
|
func (mach *OlmMachine) MarkOlmHashSavePoint(ctx context.Context) {
|
|
mach.olmHashSavePointLock.Lock()
|
|
defer mach.olmHashSavePointLock.Unlock()
|
|
if len(mach.olmHashSavePoints) > 0 && time.Since(mach.olmHashSavePoints[len(mach.olmHashSavePoints)-1]) < minSavePointInterval {
|
|
return
|
|
}
|
|
mach.olmHashSavePoints = append(mach.olmHashSavePoints, time.Now())
|
|
if len(mach.olmHashSavePoints) > olmHashSavePointCount {
|
|
sp := mach.olmHashSavePoints[0]
|
|
mach.olmHashSavePoints = mach.olmHashSavePoints[1:]
|
|
if time.Since(mach.lastHashDelete) > olmHashDeleteMinInterval {
|
|
err := mach.CryptoStore.DeleteOldOlmHashes(ctx, sp)
|
|
mach.lastHashDelete = time.Now()
|
|
if err != nil {
|
|
zerolog.Ctx(ctx).Err(err).Msg("Failed to delete old olm hashes")
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// HandleToDeviceEvent handles a single to-device event. This is automatically called by ProcessSyncResponse, so you
|
|
// don't need to add any custom handlers if you use that method.
|
|
func (mach *OlmMachine) HandleToDeviceEvent(ctx context.Context, evt *event.Event) {
|
|
if len(evt.ToUserID) > 0 && (evt.ToUserID != mach.Client.UserID || evt.ToDeviceID != mach.Client.DeviceID) {
|
|
// TODO This log probably needs to be silence-able if someone wants to use encrypted appservices with multiple e2ee sessions
|
|
mach.Log.Debug().
|
|
Str("target_user_id", evt.ToUserID.String()).
|
|
Str("target_device_id", evt.ToDeviceID.String()).
|
|
Msg("Dropping to-device event targeted to someone else")
|
|
return
|
|
}
|
|
traceID := time.Now().Format("15:04:05.000000")
|
|
// TODO use context log?
|
|
log := mach.Log.With().
|
|
Str("trace_id", traceID).
|
|
Str("sender", evt.Sender.String()).
|
|
Str("type", evt.Type.Type).
|
|
Logger()
|
|
ctx = log.WithContext(ctx)
|
|
if evt.Type != event.ToDeviceEncrypted {
|
|
log.Debug().Msg("Starting handling to-device event")
|
|
}
|
|
switch content := evt.Content.Parsed.(type) {
|
|
case *event.EncryptedEventContent:
|
|
mach.HandleEncryptedEvent(ctx, evt)
|
|
return
|
|
case *event.RoomKeyRequestEventContent:
|
|
go mach.HandleRoomKeyRequest(ctx, evt.Sender, content)
|
|
case *event.BeeperRoomKeyAckEventContent:
|
|
mach.HandleBeeperRoomKeyAck(ctx, evt.Sender, content)
|
|
case *event.RoomKeyWithheldEventContent:
|
|
mach.HandleRoomKeyWithheld(ctx, content)
|
|
case *event.SecretRequestEventContent:
|
|
if content.Action == event.SecretRequestRequest {
|
|
mach.HandleSecretRequest(ctx, evt.Sender, content)
|
|
log.Trace().Msg("Handled secret request event")
|
|
}
|
|
default:
|
|
deviceID, _ := evt.Content.Raw["device_id"].(string)
|
|
log.Debug().Str("maybe_device_id", deviceID).Msg("Unhandled to-device event")
|
|
return
|
|
}
|
|
log.Debug().Msg("Finished handling to-device event")
|
|
}
|
|
|
|
// GetOrFetchDevice attempts to retrieve the device identity for the given device from the store
|
|
// and if it's not found it asks the server for it.
|
|
func (mach *OlmMachine) GetOrFetchDevice(ctx context.Context, userID id.UserID, deviceID id.DeviceID) (*id.Device, error) {
|
|
device, err := mach.CryptoStore.GetDevice(ctx, userID, deviceID)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to get sender device from store: %w", err)
|
|
} else if device != nil {
|
|
return device, nil
|
|
}
|
|
if usersToDevices, err := mach.FetchKeys(ctx, []id.UserID{userID}, true); err != nil {
|
|
return nil, fmt.Errorf("failed to fetch keys: %w", err)
|
|
} else if devices, ok := usersToDevices[userID]; ok {
|
|
if device, ok = devices[deviceID]; ok {
|
|
return device, nil
|
|
}
|
|
return nil, fmt.Errorf("didn't get identity for device %s of %s", deviceID, userID)
|
|
}
|
|
return nil, fmt.Errorf("didn't get any devices for %s", userID)
|
|
}
|
|
|
|
// GetOrFetchDeviceByKey attempts to retrieve the device identity for the device with the given identity key from the
|
|
// store and if it's not found it asks the server for it. This returns nil if the server doesn't return a device with
|
|
// the given identity key.
|
|
func (mach *OlmMachine) GetOrFetchDeviceByKey(ctx context.Context, userID id.UserID, identityKey id.IdentityKey) (*id.Device, error) {
|
|
deviceIdentity, err := mach.CryptoStore.FindDeviceByKey(ctx, userID, identityKey)
|
|
if err != nil || deviceIdentity != nil {
|
|
return deviceIdentity, err
|
|
}
|
|
mach.machOrContextLog(ctx).Debug().
|
|
Str("user_id", userID.String()).
|
|
Str("identity_key", identityKey.String()).
|
|
Msg("Didn't find identity in crypto store, fetching from server")
|
|
devices := mach.LoadDevices(ctx, userID)
|
|
for _, device := range devices {
|
|
if device.IdentityKey == identityKey {
|
|
return device, nil
|
|
}
|
|
}
|
|
return nil, nil
|
|
}
|
|
|
|
// SendEncryptedToDevice sends an Olm-encrypted event to the given user device.
|
|
func (mach *OlmMachine) SendEncryptedToDevice(ctx context.Context, device *id.Device, evtType event.Type, content event.Content) error {
|
|
if err := mach.createOutboundSessions(ctx, map[id.UserID]map[id.DeviceID]*id.Device{
|
|
device.UserID: {
|
|
device.DeviceID: device,
|
|
},
|
|
}); err != nil {
|
|
return err
|
|
}
|
|
|
|
mach.olmLock.Lock()
|
|
defer mach.olmLock.Unlock()
|
|
|
|
olmSess, err := mach.CryptoStore.GetLatestSession(ctx, device.IdentityKey)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if olmSess == nil {
|
|
return fmt.Errorf("didn't find created outbound session for device %s of %s", device.DeviceID, device.UserID)
|
|
}
|
|
|
|
encrypted := mach.encryptOlmEvent(ctx, olmSess, device, evtType, content)
|
|
encryptedContent := &event.Content{Parsed: &encrypted}
|
|
|
|
mach.machOrContextLog(ctx).Debug().
|
|
Str("decrypted_type", evtType.Type).
|
|
Str("to_user_id", device.UserID.String()).
|
|
Str("to_device_id", device.DeviceID.String()).
|
|
Str("to_identity_key", device.IdentityKey.String()).
|
|
Str("olm_session_id", olmSess.ID().String()).
|
|
Msg("Sending encrypted to-device event")
|
|
_, err = mach.Client.SendToDevice(ctx, event.ToDeviceEncrypted,
|
|
&mautrix.ReqSendToDevice{
|
|
Messages: map[id.UserID]map[id.DeviceID]*event.Content{
|
|
device.UserID: {
|
|
device.DeviceID: encryptedContent,
|
|
},
|
|
},
|
|
},
|
|
)
|
|
|
|
return err
|
|
}
|
|
|
|
func (mach *OlmMachine) createGroupSession(ctx context.Context, senderKey id.SenderKey, signingKey id.Ed25519, roomID id.RoomID, sessionID id.SessionID, sessionKey string, maxAge time.Duration, maxMessages int, isScheduled bool) error {
|
|
log := zerolog.Ctx(ctx)
|
|
igs, err := NewInboundGroupSession(senderKey, signingKey, roomID, sessionKey, maxAge, maxMessages, isScheduled)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to create inbound group session: %w", err)
|
|
} else if igs.ID() != sessionID {
|
|
log.Warn().
|
|
Str("expected_session_id", sessionID.String()).
|
|
Str("actual_session_id", igs.ID().String()).
|
|
Msg("Mismatched session ID while creating inbound group session")
|
|
return fmt.Errorf("mismatched session ID while creating inbound group session")
|
|
}
|
|
err = mach.CryptoStore.PutGroupSession(ctx, igs)
|
|
if err != nil {
|
|
log.Err(err).Str("session_id", sessionID.String()).Msg("Failed to store new inbound group session")
|
|
return fmt.Errorf("failed to store new inbound group session: %w", err)
|
|
}
|
|
if igs.InternalLibolm.FirstKnownIndex() != igs.InternalGoolm.FirstKnownIndex() {
|
|
panic("different index")
|
|
}
|
|
mach.markSessionReceived(ctx, roomID, sessionID, igs.InternalLibolm.FirstKnownIndex())
|
|
log.Debug().
|
|
Str("session_id", sessionID.String()).
|
|
Str("sender_key", senderKey.String()).
|
|
Str("max_age", maxAge.String()).
|
|
Int("max_messages", maxMessages).
|
|
Bool("is_scheduled", isScheduled).
|
|
Msg("Received inbound group session")
|
|
return nil
|
|
}
|
|
|
|
func (mach *OlmMachine) markSessionReceived(ctx context.Context, roomID id.RoomID, id id.SessionID, firstKnownIndex uint32) {
|
|
if mach.SessionReceived != nil {
|
|
mach.SessionReceived(ctx, roomID, id, firstKnownIndex)
|
|
}
|
|
|
|
mach.keyWaitersLock.Lock()
|
|
ch, ok := mach.keyWaiters[id]
|
|
if ok {
|
|
close(ch)
|
|
delete(mach.keyWaiters, id)
|
|
}
|
|
mach.keyWaitersLock.Unlock()
|
|
}
|
|
|
|
// WaitForSession waits for the given Megolm session to arrive.
|
|
func (mach *OlmMachine) WaitForSession(ctx context.Context, roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID, timeout time.Duration) bool {
|
|
mach.keyWaitersLock.Lock()
|
|
ch, ok := mach.keyWaiters[sessionID]
|
|
if !ok {
|
|
ch = make(chan struct{})
|
|
mach.keyWaiters[sessionID] = ch
|
|
}
|
|
mach.keyWaitersLock.Unlock()
|
|
// Handle race conditions where a session appears between the failed decryption and WaitForSession call.
|
|
sess, err := mach.CryptoStore.GetGroupSession(ctx, roomID, sessionID)
|
|
if sess != nil || errors.Is(err, ErrGroupSessionWithheld) {
|
|
return true
|
|
}
|
|
select {
|
|
case <-ch:
|
|
return true
|
|
case <-time.After(timeout):
|
|
sess, err = mach.CryptoStore.GetGroupSession(ctx, roomID, sessionID)
|
|
// Check if the session somehow appeared in the store without telling us
|
|
// We accept withheld sessions as received, as then the decryption attempt will show the error.
|
|
return sess != nil || errors.Is(err, ErrGroupSessionWithheld)
|
|
case <-ctx.Done():
|
|
return false
|
|
}
|
|
}
|
|
|
|
func (mach *OlmMachine) receiveRoomKey(ctx context.Context, evt *DecryptedOlmEvent, content *event.RoomKeyEventContent) {
|
|
log := zerolog.Ctx(ctx).With().
|
|
Str("algorithm", string(content.Algorithm)).
|
|
Str("session_id", content.SessionID.String()).
|
|
Str("room_id", content.RoomID.String()).
|
|
Logger()
|
|
if content.Algorithm != id.AlgorithmMegolmV1 || evt.Keys.Ed25519 == "" {
|
|
log.Debug().Msg("Ignoring weird room key")
|
|
return
|
|
}
|
|
|
|
config, err := mach.StateStore.GetEncryptionEvent(ctx, content.RoomID)
|
|
if err != nil {
|
|
log.Error().Err(err).Msg("Failed to get encryption event for room")
|
|
}
|
|
var maxAge time.Duration
|
|
var maxMessages int
|
|
if config != nil {
|
|
maxAge = time.Duration(config.RotationPeriodMillis) * time.Millisecond
|
|
if maxAge == 0 {
|
|
maxAge = 7 * 24 * time.Hour
|
|
}
|
|
maxMessages = config.RotationPeriodMessages
|
|
if maxMessages == 0 {
|
|
maxMessages = 100
|
|
}
|
|
}
|
|
if content.MaxAge != 0 {
|
|
maxAge = time.Duration(content.MaxAge) * time.Millisecond
|
|
}
|
|
if content.MaxMessages != 0 {
|
|
maxMessages = content.MaxMessages
|
|
}
|
|
if mach.DeletePreviousKeysOnReceive && !content.IsScheduled {
|
|
log.Debug().Msg("Redacting previous megolm sessions from sender in room")
|
|
sessionIDs, err := mach.CryptoStore.RedactGroupSessions(ctx, content.RoomID, evt.SenderKey, "received new key from device")
|
|
if err != nil {
|
|
log.Err(err).Msg("Failed to redact previous megolm sessions")
|
|
} else {
|
|
log.Info().
|
|
Array("session_ids", exzerolog.ArrayOfStrs(sessionIDs)).
|
|
Msg("Redacted previous megolm sessions")
|
|
}
|
|
}
|
|
err = mach.createGroupSession(ctx, evt.SenderKey, evt.Keys.Ed25519, content.RoomID, content.SessionID, content.SessionKey, maxAge, maxMessages, content.IsScheduled)
|
|
if err != nil {
|
|
log.Err(err).Msg("Failed to create inbound group session")
|
|
}
|
|
}
|
|
|
|
func (mach *OlmMachine) HandleRoomKeyWithheld(ctx context.Context, content *event.RoomKeyWithheldEventContent) {
|
|
if content.Algorithm != id.AlgorithmMegolmV1 {
|
|
zerolog.Ctx(ctx).Debug().Interface("content", content).Msg("Non-megolm room key withheld event")
|
|
return
|
|
}
|
|
// TODO log if there's a conflict? (currently ignored)
|
|
err := mach.CryptoStore.PutWithheldGroupSession(ctx, *content)
|
|
if err != nil {
|
|
zerolog.Ctx(ctx).Error().Err(err).Msg("Failed to save room key withheld event")
|
|
}
|
|
}
|
|
|
|
// ShareKeys uploads necessary keys to the server.
|
|
//
|
|
// If the Olm account hasn't been shared, the account keys will be uploaded.
|
|
// If currentOTKCount is less than half of the limit (100 / 2 = 50), enough one-time keys will be uploaded so exactly
|
|
// half of the limit is filled.
|
|
func (mach *OlmMachine) ShareKeys(ctx context.Context, currentOTKCount int) error {
|
|
log := mach.machOrContextLog(ctx)
|
|
start := time.Now()
|
|
mach.otkUploadLock.Lock()
|
|
defer mach.otkUploadLock.Unlock()
|
|
if mach.lastOTKUpload.Add(1*time.Minute).After(start) || currentOTKCount < 0 {
|
|
log.Debug().Msg("Checking OTK count from server due to suspiciously close share keys requests or negative OTK count")
|
|
resp, err := mach.Client.UploadKeys(ctx, &mautrix.ReqUploadKeys{})
|
|
if err != nil {
|
|
return fmt.Errorf("failed to check current OTK counts: %w", err)
|
|
}
|
|
log.Debug().
|
|
Int("input_count", currentOTKCount).
|
|
Int("server_count", resp.OneTimeKeyCounts.SignedCurve25519).
|
|
Msg("Fetched current OTK count from server")
|
|
currentOTKCount = resp.OneTimeKeyCounts.SignedCurve25519
|
|
}
|
|
var deviceKeys *mautrix.DeviceKeys
|
|
if !mach.account.Shared {
|
|
deviceKeys = mach.account.getInitialKeys(mach.Client.UserID, mach.Client.DeviceID)
|
|
err := mach.CryptoStore.PutDevice(ctx, mach.Client.UserID, &id.Device{
|
|
UserID: mach.Client.UserID,
|
|
DeviceID: mach.Client.DeviceID,
|
|
IdentityKey: deviceKeys.Keys.GetCurve25519(mach.Client.DeviceID),
|
|
SigningKey: deviceKeys.Keys.GetEd25519(mach.Client.DeviceID),
|
|
})
|
|
if err != nil {
|
|
return fmt.Errorf("failed to save initial keys: %w", err)
|
|
}
|
|
log.Debug().Msg("Going to upload initial account keys")
|
|
}
|
|
oneTimeKeys := mach.account.getOneTimeKeys(mach.Client.UserID, mach.Client.DeviceID, currentOTKCount)
|
|
if len(oneTimeKeys) == 0 && deviceKeys == nil {
|
|
log.Debug().Msg("No one-time keys nor device keys got when trying to share keys")
|
|
return nil
|
|
}
|
|
// Save the keys before sending the upload request in case there is a
|
|
// network failure.
|
|
if err := mach.saveAccount(ctx); err != nil {
|
|
return err
|
|
}
|
|
req := &mautrix.ReqUploadKeys{
|
|
DeviceKeys: deviceKeys,
|
|
OneTimeKeys: oneTimeKeys,
|
|
}
|
|
log.Debug().Int("count", len(oneTimeKeys)).Msg("Uploading one-time keys")
|
|
_, err := mach.Client.UploadKeys(ctx, req)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
mach.lastOTKUpload = time.Now()
|
|
mach.account.InternalLibolm.MarkKeysAsPublished()
|
|
mach.account.InternalGoolm.MarkKeysAsPublished()
|
|
mach.account.Shared = true
|
|
return mach.saveAccount(ctx)
|
|
}
|
|
|
|
func (mach *OlmMachine) ExpiredKeyDeleteLoop(ctx context.Context) {
|
|
log := mach.Log.With().Str("action", "redact expired sessions").Logger()
|
|
for {
|
|
sessionIDs, err := mach.CryptoStore.RedactExpiredGroupSessions(ctx)
|
|
if err != nil {
|
|
log.Err(err).Msg("Failed to redact expired megolm sessions")
|
|
} else if len(sessionIDs) > 0 {
|
|
log.Info().Array("session_ids", exzerolog.ArrayOfStrs(sessionIDs)).Msg("Redacted expired megolm sessions")
|
|
} else {
|
|
log.Debug().Msg("Didn't find any expired megolm sessions")
|
|
}
|
|
select {
|
|
case <-ctx.Done():
|
|
log.Debug().Msg("Loop stopped")
|
|
return
|
|
case <-time.After(24 * time.Hour):
|
|
}
|
|
}
|
|
}
|