mirror of https://github.com/mautrix/go.git
450 lines
16 KiB
Go
450 lines
16 KiB
Go
// Copyright (c) 2024 Tulir Asokan
|
|
//
|
|
// This Source Code Form is subject to the terms of the Mozilla Public
|
|
// License, v. 2.0. If a copy of the MPL was not distributed with this
|
|
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
|
|
|
package crypto
|
|
|
|
import (
|
|
"context"
|
|
"encoding/base64"
|
|
"encoding/binary"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
|
|
"github.com/rs/zerolog"
|
|
"github.com/tidwall/gjson"
|
|
"go.mau.fi/util/exgjson"
|
|
"go.mau.fi/util/exzerolog"
|
|
|
|
"maunium.net/go/mautrix"
|
|
"maunium.net/go/mautrix/event"
|
|
"maunium.net/go/mautrix/id"
|
|
)
|
|
|
|
var (
|
|
AlreadyShared = errors.New("group session already shared")
|
|
NoGroupSession = errors.New("no group session created")
|
|
)
|
|
|
|
func getRawJSON[T any](content json.RawMessage, path ...string) *T {
|
|
value := gjson.GetBytes(content, exgjson.Path(path...))
|
|
if !value.IsObject() {
|
|
return nil
|
|
}
|
|
var result T
|
|
err := json.Unmarshal([]byte(value.Raw), &result)
|
|
if err != nil {
|
|
return nil
|
|
}
|
|
return &result
|
|
}
|
|
|
|
func getRelatesTo(content any) *event.RelatesTo {
|
|
contentJSON, ok := content.(json.RawMessage)
|
|
if ok {
|
|
return getRawJSON[event.RelatesTo](contentJSON, "m.relates_to")
|
|
}
|
|
contentStruct, ok := content.(*event.Content)
|
|
if ok {
|
|
content = contentStruct.Parsed
|
|
}
|
|
relatable, ok := content.(event.Relatable)
|
|
if ok {
|
|
return relatable.OptionalGetRelatesTo()
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func getMentions(content any) *event.Mentions {
|
|
contentJSON, ok := content.(json.RawMessage)
|
|
if ok {
|
|
return getRawJSON[event.Mentions](contentJSON, "m.mentions")
|
|
}
|
|
contentStruct, ok := content.(*event.Content)
|
|
if ok {
|
|
content = contentStruct.Parsed
|
|
}
|
|
message, ok := content.(*event.MessageEventContent)
|
|
if ok {
|
|
return message.Mentions
|
|
}
|
|
return nil
|
|
}
|
|
|
|
type rawMegolmEvent struct {
|
|
RoomID id.RoomID `json:"room_id"`
|
|
Type event.Type `json:"type"`
|
|
StateKey *string `json:"state_key,omitempty"`
|
|
Content interface{} `json:"content"`
|
|
}
|
|
|
|
// IsShareError returns true if the error is caused by the lack of an outgoing megolm session and can be solved with OlmMachine.ShareGroupSession
|
|
func IsShareError(err error) bool {
|
|
return err == SessionExpired || err == SessionNotShared || err == NoGroupSession
|
|
}
|
|
|
|
func ParseMegolmMessageIndex(ciphertext []byte) (uint, error) {
|
|
decoded := make([]byte, base64.RawStdEncoding.DecodedLen(len(ciphertext)))
|
|
var err error
|
|
_, err = base64.RawStdEncoding.Decode(decoded, ciphertext)
|
|
if err != nil {
|
|
return 0, err
|
|
} else if decoded[0] != 3 || decoded[1] != 8 {
|
|
return 0, fmt.Errorf("unexpected initial bytes %d and %d", decoded[0], decoded[1])
|
|
}
|
|
index, read := binary.Uvarint(decoded[2 : 2+binary.MaxVarintLen64])
|
|
if read <= 0 {
|
|
return 0, fmt.Errorf("failed to decode varint, read value %d", read)
|
|
}
|
|
return uint(index), nil
|
|
}
|
|
|
|
// EncryptMegolmEvent encrypts data with the m.megolm.v1.aes-sha2 algorithm.
|
|
//
|
|
// If you use the event.Content struct, make sure you pass a pointer to the struct,
|
|
// as JSON serialization will not work correctly otherwise.
|
|
func (mach *OlmMachine) EncryptMegolmEvent(ctx context.Context, roomID id.RoomID, evtType event.Type, content interface{}) (*event.EncryptedEventContent, error) {
|
|
return mach.EncryptMegolmEventWithStateKey(ctx, roomID, evtType, nil, content)
|
|
}
|
|
|
|
// EncryptMegolmEventWithStateKey encrypts data with the m.megolm.v1.aes-sha2 algorithm.
|
|
//
|
|
// If you use the event.Content struct, make sure you pass a pointer to the struct,
|
|
// as JSON serialization will not work correctly otherwise.
|
|
func (mach *OlmMachine) EncryptMegolmEventWithStateKey(ctx context.Context, roomID id.RoomID, evtType event.Type, stateKey *string, content interface{}) (*event.EncryptedEventContent, error) {
|
|
mach.megolmEncryptLock.Lock()
|
|
defer mach.megolmEncryptLock.Unlock()
|
|
session, err := mach.CryptoStore.GetOutboundGroupSession(ctx, roomID)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to get outbound group session: %w", err)
|
|
} else if session == nil {
|
|
return nil, NoGroupSession
|
|
}
|
|
plaintext, err := json.Marshal(&rawMegolmEvent{
|
|
RoomID: roomID,
|
|
Type: evtType,
|
|
StateKey: stateKey,
|
|
Content: content,
|
|
})
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
log := mach.machOrContextLog(ctx).With().
|
|
Str("event_type", evtType.Type).
|
|
Any("state_key", stateKey).
|
|
Str("room_id", roomID.String()).
|
|
Str("session_id", session.ID().String()).
|
|
Uint("expected_index", session.Internal.MessageIndex()).
|
|
Logger()
|
|
log.Trace().Msg("Encrypting event...")
|
|
ciphertext, err := session.Encrypt(plaintext)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
idx, err := ParseMegolmMessageIndex(ciphertext)
|
|
if err != nil {
|
|
log.Warn().Err(err).Msg("Failed to get megolm message index of encrypted event")
|
|
} else {
|
|
log = log.With().Uint("message_index", idx).Logger()
|
|
}
|
|
log.Debug().Msg("Encrypted event successfully")
|
|
err = mach.CryptoStore.UpdateOutboundGroupSession(ctx, session)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to update outbound group session after encrypting: %w", err)
|
|
}
|
|
encrypted := &event.EncryptedEventContent{
|
|
Algorithm: id.AlgorithmMegolmV1,
|
|
SessionID: session.ID(),
|
|
MegolmCiphertext: ciphertext,
|
|
RelatesTo: getRelatesTo(content),
|
|
|
|
// These are deprecated
|
|
SenderKey: mach.account.IdentityKey(),
|
|
DeviceID: mach.Client.DeviceID,
|
|
}
|
|
if mach.PlaintextMentions {
|
|
encrypted.Mentions = getMentions(content)
|
|
}
|
|
return encrypted, nil
|
|
}
|
|
|
|
func (mach *OlmMachine) newOutboundGroupSession(ctx context.Context, roomID id.RoomID) (*OutboundGroupSession, error) {
|
|
encryptionEvent, err := mach.StateStore.GetEncryptionEvent(ctx, roomID)
|
|
if err != nil {
|
|
mach.machOrContextLog(ctx).Err(err).
|
|
Stringer("room_id", roomID).
|
|
Msg("Failed to get encryption event in room")
|
|
return nil, fmt.Errorf("failed to get encryption event in room %s: %w", roomID, err)
|
|
}
|
|
session, err := NewOutboundGroupSession(roomID, encryptionEvent)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if !mach.DontStoreOutboundKeys {
|
|
signingKey, idKey := mach.account.Keys()
|
|
err := mach.createGroupSession(ctx, idKey, signingKey, roomID, session.ID(), session.Internal.Key(), session.MaxAge, session.MaxMessages, false)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
return session, err
|
|
}
|
|
|
|
type deviceSessionWrapper struct {
|
|
session *OlmSession
|
|
identity *id.Device
|
|
}
|
|
|
|
// ShareGroupSession shares a group session for a specific room with all the devices of the given user list.
|
|
//
|
|
// For devices with TrustStateBlacklisted, a m.room_key.withheld event with code=m.blacklisted is sent.
|
|
// If AllowUnverifiedDevices is false, a similar event with code=m.unverified is sent to devices with TrustStateUnset
|
|
func (mach *OlmMachine) ShareGroupSession(ctx context.Context, roomID id.RoomID, users []id.UserID) error {
|
|
mach.megolmEncryptLock.Lock()
|
|
defer mach.megolmEncryptLock.Unlock()
|
|
session, err := mach.CryptoStore.GetOutboundGroupSession(ctx, roomID)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to get previous outbound group session: %w", err)
|
|
} else if session != nil && session.Shared && !session.Expired() {
|
|
return AlreadyShared
|
|
}
|
|
log := mach.machOrContextLog(ctx).With().
|
|
Str("room_id", roomID.String()).
|
|
Str("action", "share megolm session").
|
|
Logger()
|
|
ctx = log.WithContext(ctx)
|
|
if session == nil || session.Expired() {
|
|
if session, err = mach.newOutboundGroupSession(ctx, roomID); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
log = log.With().Str("session_id", session.ID().String()).Logger()
|
|
ctx = log.WithContext(ctx)
|
|
log.Debug().Array("users", exzerolog.ArrayOfStrs(users)).Msg("Sharing group session for room")
|
|
|
|
withheldCount := 0
|
|
toDeviceWithheld := &mautrix.ReqSendToDevice{Messages: make(map[id.UserID]map[id.DeviceID]*event.Content)}
|
|
olmSessions := make(map[id.UserID]map[id.DeviceID]deviceSessionWrapper)
|
|
missingSessions := make(map[id.UserID]map[id.DeviceID]*id.Device)
|
|
missingUserSessions := make(map[id.DeviceID]*id.Device)
|
|
var fetchKeysForUsers []id.UserID
|
|
|
|
for _, userID := range users {
|
|
log := log.With().Str("target_user_id", userID.String()).Logger()
|
|
devices, err := mach.CryptoStore.GetDevices(ctx, userID)
|
|
if err != nil {
|
|
log.Err(err).Msg("Failed to get devices of user")
|
|
return fmt.Errorf("failed to get devices of user %s: %w", userID, err)
|
|
} else if devices == nil {
|
|
log.Debug().Msg("GetDevices returned nil, will fetch keys and retry")
|
|
fetchKeysForUsers = append(fetchKeysForUsers, userID)
|
|
} else if len(devices) == 0 {
|
|
log.Trace().Msg("User has no devices, skipping")
|
|
} else {
|
|
log.Trace().Msg("Trying to find olm session to encrypt megolm session for user")
|
|
toDeviceWithheld.Messages[userID] = make(map[id.DeviceID]*event.Content)
|
|
olmSessions[userID] = make(map[id.DeviceID]deviceSessionWrapper)
|
|
mach.findOlmSessionsForUser(ctx, session, userID, devices, olmSessions[userID], toDeviceWithheld.Messages[userID], missingUserSessions)
|
|
log.Debug().
|
|
Int("olm_session_count", len(olmSessions[userID])).
|
|
Int("withheld_count", len(toDeviceWithheld.Messages[userID])).
|
|
Int("missing_count", len(missingUserSessions)).
|
|
Msg("Completed first pass of finding olm sessions")
|
|
withheldCount += len(toDeviceWithheld.Messages[userID])
|
|
if len(missingUserSessions) > 0 {
|
|
missingSessions[userID] = missingUserSessions
|
|
missingUserSessions = make(map[id.DeviceID]*id.Device)
|
|
}
|
|
if len(toDeviceWithheld.Messages[userID]) == 0 {
|
|
delete(toDeviceWithheld.Messages, userID)
|
|
}
|
|
}
|
|
}
|
|
|
|
if len(fetchKeysForUsers) > 0 {
|
|
log.Debug().Array("users", exzerolog.ArrayOfStrs(fetchKeysForUsers)).Msg("Fetching missing keys")
|
|
keys, err := mach.FetchKeys(ctx, fetchKeysForUsers, true)
|
|
if err != nil {
|
|
log.Err(err).Array("users", exzerolog.ArrayOfStrs(fetchKeysForUsers)).Msg("Failed to fetch missing keys")
|
|
return fmt.Errorf("failed to fetch missing keys: %w", err)
|
|
}
|
|
for userID, devices := range keys {
|
|
log.Debug().
|
|
Int("device_count", len(devices)).
|
|
Str("target_user_id", userID.String()).
|
|
Msg("Got device keys for user")
|
|
missingSessions[userID] = devices
|
|
}
|
|
}
|
|
|
|
if len(missingSessions) > 0 {
|
|
log.Debug().Msg("Creating missing olm sessions")
|
|
err = mach.createOutboundSessions(ctx, missingSessions)
|
|
if err != nil {
|
|
log.Err(err).Msg("Failed to create missing olm sessions")
|
|
return fmt.Errorf("failed to create missing olm sessions: %w", err)
|
|
}
|
|
}
|
|
|
|
for userID, devices := range missingSessions {
|
|
if len(devices) == 0 {
|
|
// No missing sessions
|
|
continue
|
|
}
|
|
output, ok := olmSessions[userID]
|
|
if !ok {
|
|
output = make(map[id.DeviceID]deviceSessionWrapper)
|
|
olmSessions[userID] = output
|
|
}
|
|
withheld, ok := toDeviceWithheld.Messages[userID]
|
|
if !ok {
|
|
withheld = make(map[id.DeviceID]*event.Content)
|
|
toDeviceWithheld.Messages[userID] = withheld
|
|
}
|
|
|
|
log := log.With().Str("target_user_id", userID.String()).Logger()
|
|
log.Trace().Msg("Trying to find olm session to encrypt megolm session for user (post-fetch retry)")
|
|
mach.findOlmSessionsForUser(ctx, session, userID, devices, output, withheld, nil)
|
|
log.Debug().
|
|
Int("olm_session_count", len(output)).
|
|
Int("withheld_count", len(withheld)).
|
|
Msg("Completed post-fetch retry of finding olm sessions")
|
|
withheldCount += len(toDeviceWithheld.Messages[userID])
|
|
if len(toDeviceWithheld.Messages[userID]) == 0 {
|
|
delete(toDeviceWithheld.Messages, userID)
|
|
}
|
|
}
|
|
|
|
err = mach.encryptAndSendGroupSession(ctx, session, olmSessions)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to share group session: %w", err)
|
|
}
|
|
|
|
if len(toDeviceWithheld.Messages) > 0 {
|
|
log.Debug().
|
|
Int("device_count", withheldCount).
|
|
Int("user_count", len(toDeviceWithheld.Messages)).
|
|
Msg("Sending to-device messages to report withheld key")
|
|
// TODO remove the next 4 lines once clients support m.room_key.withheld
|
|
_, err = mach.Client.SendToDevice(ctx, event.ToDeviceOrgMatrixRoomKeyWithheld, toDeviceWithheld)
|
|
if err != nil {
|
|
log.Warn().Err(err).Msg("Failed to report withheld keys (legacy event type)")
|
|
}
|
|
_, err = mach.Client.SendToDevice(ctx, event.ToDeviceRoomKeyWithheld, toDeviceWithheld)
|
|
if err != nil {
|
|
log.Warn().Err(err).Msg("Failed to report withheld keys")
|
|
}
|
|
}
|
|
|
|
log.Debug().Msg("Group session successfully shared")
|
|
session.Shared = true
|
|
return mach.CryptoStore.AddOutboundGroupSession(ctx, session)
|
|
}
|
|
|
|
func (mach *OlmMachine) encryptAndSendGroupSession(ctx context.Context, session *OutboundGroupSession, olmSessions map[id.UserID]map[id.DeviceID]deviceSessionWrapper) error {
|
|
mach.olmLock.Lock()
|
|
defer mach.olmLock.Unlock()
|
|
log := zerolog.Ctx(ctx)
|
|
log.Trace().Msg("Encrypting group session for all found devices")
|
|
deviceCount := 0
|
|
toDevice := &mautrix.ReqSendToDevice{Messages: make(map[id.UserID]map[id.DeviceID]*event.Content)}
|
|
for userID, sessions := range olmSessions {
|
|
if len(sessions) == 0 {
|
|
continue
|
|
}
|
|
output := make(map[id.DeviceID]*event.Content)
|
|
toDevice.Messages[userID] = output
|
|
for deviceID, device := range sessions {
|
|
log.Trace().
|
|
Stringer("target_user_id", userID).
|
|
Stringer("target_device_id", deviceID).
|
|
Stringer("target_identity_key", device.identity.IdentityKey).
|
|
Msg("Encrypting group session for device")
|
|
content := mach.encryptOlmEvent(ctx, device.session, device.identity, event.ToDeviceRoomKey, session.ShareContent())
|
|
output[deviceID] = &event.Content{Parsed: content}
|
|
deviceCount++
|
|
log.Debug().
|
|
Stringer("target_user_id", userID).
|
|
Stringer("target_device_id", deviceID).
|
|
Stringer("target_identity_key", device.identity.IdentityKey).
|
|
Msg("Encrypted group session for device")
|
|
if !mach.DisableSharedGroupSessionTracking {
|
|
err := mach.CryptoStore.MarkOutboundGroupSessionShared(ctx, userID, device.identity.IdentityKey, session.id)
|
|
if err != nil {
|
|
log.Warn().
|
|
Err(err).
|
|
Stringer("target_user_id", userID).
|
|
Stringer("target_device_id", deviceID).
|
|
Stringer("target_identity_key", device.identity.IdentityKey).
|
|
Stringer("target_session_id", session.id).
|
|
Msg("Failed to mark outbound group session shared")
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
log.Debug().
|
|
Int("device_count", deviceCount).
|
|
Int("user_count", len(toDevice.Messages)).
|
|
Msg("Sending to-device messages to share group session")
|
|
_, err := mach.Client.SendToDevice(ctx, event.ToDeviceEncrypted, toDevice)
|
|
return err
|
|
}
|
|
|
|
func (mach *OlmMachine) findOlmSessionsForUser(ctx context.Context, session *OutboundGroupSession, userID id.UserID, devices map[id.DeviceID]*id.Device, output map[id.DeviceID]deviceSessionWrapper, withheld map[id.DeviceID]*event.Content, missingOutput map[id.DeviceID]*id.Device) {
|
|
for deviceID, device := range devices {
|
|
log := zerolog.Ctx(ctx).With().
|
|
Stringer("target_user_id", userID).
|
|
Stringer("target_device_id", deviceID).
|
|
Stringer("target_identity_key", device.IdentityKey).
|
|
Logger()
|
|
userKey := UserDevice{UserID: userID, DeviceID: deviceID}
|
|
if state := session.Users[userKey]; state != OGSNotShared {
|
|
continue
|
|
} else if userID == mach.Client.UserID && deviceID == mach.Client.DeviceID {
|
|
session.Users[userKey] = OGSIgnored
|
|
} else if device.Trust == id.TrustStateBlacklisted {
|
|
log.Debug().Msg("Not encrypting group session for device: device is blacklisted")
|
|
withheld[deviceID] = &event.Content{Parsed: &event.RoomKeyWithheldEventContent{
|
|
RoomID: session.RoomID,
|
|
Algorithm: id.AlgorithmMegolmV1,
|
|
SessionID: session.ID(),
|
|
SenderKey: mach.account.IdentityKey(),
|
|
Code: event.RoomKeyWithheldBlacklisted,
|
|
Reason: "Device is blacklisted",
|
|
}}
|
|
session.Users[userKey] = OGSIgnored
|
|
} else if trustState := mach.ResolveTrust(device); trustState < mach.SendKeysMinTrust {
|
|
log.Debug().
|
|
Str("min_trust", mach.SendKeysMinTrust.String()).
|
|
Str("device_trust", trustState.String()).
|
|
Msg("Not encrypting group session for device: device is not trusted")
|
|
withheld[deviceID] = &event.Content{Parsed: &event.RoomKeyWithheldEventContent{
|
|
RoomID: session.RoomID,
|
|
Algorithm: id.AlgorithmMegolmV1,
|
|
SessionID: session.ID(),
|
|
SenderKey: mach.account.IdentityKey(),
|
|
Code: event.RoomKeyWithheldUnverified,
|
|
Reason: "This device does not encrypt messages for unverified devices",
|
|
}}
|
|
session.Users[userKey] = OGSIgnored
|
|
} else if deviceSession, err := mach.CryptoStore.GetLatestSession(ctx, device.IdentityKey); err != nil {
|
|
log.Error().Err(err).Msg("Failed to get olm session to encrypt group session")
|
|
} else if deviceSession == nil {
|
|
log.Warn().Err(err).Msg("Didn't find olm session to encrypt group session")
|
|
if missingOutput != nil {
|
|
missingOutput[deviceID] = device
|
|
}
|
|
} else {
|
|
output[deviceID] = deviceSessionWrapper{
|
|
session: deviceSession,
|
|
identity: device,
|
|
}
|
|
session.Users[userKey] = OGSAlreadyShared
|
|
}
|
|
}
|
|
}
|