mirror of https://github.com/mautrix/go.git
319 lines
14 KiB
Go
319 lines
14 KiB
Go
// Copyright (c) 2024 Tulir Asokan
|
|
//
|
|
// This Source Code Form is subject to the terms of the Mozilla Public
|
|
// License, v. 2.0. If a copy of the MPL was not distributed with this
|
|
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
|
|
|
package crypto
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"strings"
|
|
|
|
"github.com/rs/zerolog"
|
|
"github.com/tidwall/gjson"
|
|
"github.com/tidwall/sjson"
|
|
"go.mau.fi/util/exgjson"
|
|
|
|
"maunium.net/go/mautrix/crypto/olm"
|
|
"maunium.net/go/mautrix/event"
|
|
"maunium.net/go/mautrix/id"
|
|
)
|
|
|
|
var (
|
|
IncorrectEncryptedContentType = errors.New("event content is not instance of *event.EncryptedEventContent")
|
|
NoSessionFound = errors.New("failed to decrypt megolm event: no session with given ID found")
|
|
DuplicateMessageIndex = errors.New("duplicate megolm message index")
|
|
WrongRoom = errors.New("encrypted megolm event is not intended for this room")
|
|
DeviceKeyMismatch = errors.New("device keys in event and verified device info do not match")
|
|
SenderKeyMismatch = errors.New("sender keys in content and megolm session do not match")
|
|
RatchetError = errors.New("failed to ratchet session after use")
|
|
)
|
|
|
|
type megolmEvent struct {
|
|
RoomID id.RoomID `json:"room_id"`
|
|
Type event.Type `json:"type"`
|
|
StateKey *string `json:"state_key"`
|
|
Content event.Content `json:"content"`
|
|
}
|
|
|
|
var (
|
|
relatesToContentPath = exgjson.Path("m.relates_to")
|
|
relatesToTopLevelPath = exgjson.Path("content", "m.relates_to")
|
|
)
|
|
|
|
// DecryptMegolmEvent decrypts an m.room.encrypted event where the algorithm is m.megolm.v1.aes-sha2
|
|
func (mach *OlmMachine) DecryptMegolmEvent(ctx context.Context, evt *event.Event) (*event.Event, error) {
|
|
content, ok := evt.Content.Parsed.(*event.EncryptedEventContent)
|
|
if !ok {
|
|
return nil, IncorrectEncryptedContentType
|
|
} else if content.Algorithm != id.AlgorithmMegolmV1 {
|
|
return nil, UnsupportedAlgorithm
|
|
}
|
|
log := mach.machOrContextLog(ctx).With().
|
|
Str("action", "decrypt megolm event").
|
|
Str("event_id", evt.ID.String()).
|
|
Str("sender", evt.Sender.String()).
|
|
Str("sender_key", content.SenderKey.String()).
|
|
Str("session_id", content.SessionID.String()).
|
|
Logger()
|
|
ctx = log.WithContext(ctx)
|
|
encryptionRoomID := evt.RoomID
|
|
// Allow the server to move encrypted events between rooms if both the real room and target room are on a non-federatable .local domain.
|
|
// The message index checks to prevent replay attacks still apply and aren't based on the room ID,
|
|
// so the event ID and timestamp must remain the same when the event is moved to a different room.
|
|
if origRoomID, ok := evt.Content.Raw["com.beeper.original_room_id"].(string); ok && strings.HasSuffix(origRoomID, ".local") && strings.HasSuffix(evt.RoomID.String(), ".local") {
|
|
encryptionRoomID = id.RoomID(origRoomID)
|
|
}
|
|
sess, plaintext, messageIndex, err := mach.actuallyDecryptMegolmEvent(ctx, evt, encryptionRoomID, content)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
log = log.With().Uint("message_index", messageIndex).Logger()
|
|
|
|
var trustLevel id.TrustState
|
|
var forwardedKeys bool
|
|
var device *id.Device
|
|
ownSigningKey, ownIdentityKey := mach.account.Keys()
|
|
if sess.SigningKey == ownSigningKey && sess.SenderKey == ownIdentityKey && len(sess.ForwardingChains) == 0 {
|
|
trustLevel = id.TrustStateVerified
|
|
} else {
|
|
if mach.DisableDecryptKeyFetching {
|
|
device, err = mach.CryptoStore.FindDeviceByKey(ctx, evt.Sender, sess.SenderKey)
|
|
} else {
|
|
device, err = mach.GetOrFetchDeviceByKey(ctx, evt.Sender, sess.SenderKey)
|
|
}
|
|
if err != nil {
|
|
// We don't want to throw these errors as the message can still be decrypted.
|
|
log.Debug().Err(err).Msg("Failed to get device to verify session")
|
|
trustLevel = id.TrustStateUnknownDevice
|
|
} else if len(sess.ForwardingChains) == 0 || (len(sess.ForwardingChains) == 1 && sess.ForwardingChains[0] == sess.SenderKey.String()) {
|
|
if device == nil {
|
|
log.Debug().
|
|
Str("session_sender_key", sess.SenderKey.String()).
|
|
Msg("Couldn't resolve trust level of session: sent by unknown device")
|
|
trustLevel = id.TrustStateUnknownDevice
|
|
} else if device.SigningKey != sess.SigningKey || device.IdentityKey != sess.SenderKey {
|
|
return nil, DeviceKeyMismatch
|
|
} else {
|
|
trustLevel, err = mach.ResolveTrustContext(ctx, device)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
} else {
|
|
forwardedKeys = true
|
|
lastChainItem := sess.ForwardingChains[len(sess.ForwardingChains)-1]
|
|
device, _ = mach.CryptoStore.FindDeviceByKey(ctx, evt.Sender, id.IdentityKey(lastChainItem))
|
|
if device != nil {
|
|
trustLevel, err = mach.ResolveTrustContext(ctx, device)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
} else {
|
|
log.Debug().
|
|
Str("forward_last_sender_key", lastChainItem).
|
|
Msg("Couldn't resolve trust level of session: forwarding chain ends with unknown device")
|
|
trustLevel = id.TrustStateForwarded
|
|
}
|
|
}
|
|
}
|
|
|
|
if content.RelatesTo != nil {
|
|
relation := gjson.GetBytes(evt.Content.VeryRaw, relatesToContentPath)
|
|
if relation.Exists() && !gjson.GetBytes(plaintext, relatesToTopLevelPath).IsObject() {
|
|
var raw []byte
|
|
if relation.Index > 0 {
|
|
raw = evt.Content.VeryRaw[relation.Index : relation.Index+len(relation.Raw)]
|
|
} else {
|
|
raw = []byte(relation.Raw)
|
|
}
|
|
updatedPlaintext, err := sjson.SetRawBytes(plaintext, relatesToTopLevelPath, raw)
|
|
if err != nil {
|
|
log.Warn().Msg("Failed to copy m.relates_to to decrypted payload")
|
|
} else if updatedPlaintext != nil {
|
|
plaintext = updatedPlaintext
|
|
}
|
|
} else if !relation.Exists() {
|
|
log.Warn().Msg("Failed to find m.relates_to in raw encrypted event even though it was present in parsed content")
|
|
}
|
|
}
|
|
|
|
megolmEvt := &megolmEvent{}
|
|
err = json.Unmarshal(plaintext, &megolmEvt)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to parse megolm payload: %w", err)
|
|
} else if megolmEvt.RoomID != encryptionRoomID {
|
|
return nil, WrongRoom
|
|
}
|
|
if evt.StateKey != nil && megolmEvt.StateKey != nil {
|
|
megolmEvt.Type.Class = event.StateEventType
|
|
} else {
|
|
megolmEvt.Type.Class = evt.Type.Class
|
|
megolmEvt.StateKey = nil
|
|
}
|
|
log = log.With().Str("decrypted_event_type", megolmEvt.Type.Repr()).Logger()
|
|
err = megolmEvt.Content.ParseRaw(megolmEvt.Type)
|
|
if err != nil {
|
|
if errors.Is(err, event.ErrUnsupportedContentType) {
|
|
log.Warn().Msg("Unsupported event type in encrypted event")
|
|
} else {
|
|
return nil, fmt.Errorf("failed to parse content of megolm payload event: %w", err)
|
|
}
|
|
}
|
|
log.Debug().Msg("Event decrypted successfully")
|
|
megolmEvt.Type.Class = evt.Type.Class
|
|
return &event.Event{
|
|
Sender: evt.Sender,
|
|
Type: megolmEvt.Type,
|
|
StateKey: megolmEvt.StateKey,
|
|
Timestamp: evt.Timestamp,
|
|
ID: evt.ID,
|
|
RoomID: evt.RoomID,
|
|
Content: megolmEvt.Content,
|
|
Unsigned: evt.Unsigned,
|
|
Mautrix: event.MautrixInfo{
|
|
TrustState: trustLevel,
|
|
TrustSource: device,
|
|
ForwardedKeys: forwardedKeys,
|
|
WasEncrypted: true,
|
|
ReceivedAt: evt.Mautrix.ReceivedAt,
|
|
},
|
|
}, nil
|
|
}
|
|
|
|
func removeItem(slice []uint, item uint) ([]uint, bool) {
|
|
for i, s := range slice {
|
|
if s == item {
|
|
return append(slice[:i], slice[i+1:]...), true
|
|
}
|
|
}
|
|
return slice, false
|
|
}
|
|
|
|
const missedIndexCutoff = 10
|
|
|
|
func (mach *OlmMachine) checkUndecryptableMessageIndexDuplication(ctx context.Context, sess *InboundGroupSession, evt *event.Event, content *event.EncryptedEventContent) (uint, error) {
|
|
log := *zerolog.Ctx(ctx)
|
|
messageIndex, decodeErr := ParseMegolmMessageIndex(content.MegolmCiphertext)
|
|
if decodeErr != nil {
|
|
log.Warn().Err(decodeErr).Msg("Failed to parse message index to check if it's a duplicate for message that failed to decrypt")
|
|
return 0, fmt.Errorf("%w (also failed to parse message index)", olm.UnknownMessageIndex)
|
|
}
|
|
firstKnown := sess.Internal.FirstKnownIndex()
|
|
log = log.With().Uint("message_index", messageIndex).Uint32("first_known_index", firstKnown).Logger()
|
|
if ok, err := mach.CryptoStore.ValidateMessageIndex(ctx, sess.SenderKey, content.SessionID, evt.ID, messageIndex, evt.Timestamp); err != nil {
|
|
log.Debug().Err(err).Msg("Failed to check if message index is duplicate")
|
|
return messageIndex, fmt.Errorf("%w (failed to check if index is duplicate; received: %d, earliest known: %d)", olm.UnknownMessageIndex, messageIndex, firstKnown)
|
|
} else if !ok {
|
|
log.Debug().Msg("Failed to decrypt message due to unknown index and found duplicate")
|
|
return messageIndex, fmt.Errorf("%w %d (also failed to decrypt because earliest known index is %d)", DuplicateMessageIndex, messageIndex, firstKnown)
|
|
}
|
|
log.Debug().Msg("Failed to decrypt message due to unknown index, but index is not duplicate")
|
|
return messageIndex, fmt.Errorf("%w (not duplicate index; received: %d, earliest known: %d)", olm.UnknownMessageIndex, messageIndex, firstKnown)
|
|
}
|
|
|
|
func (mach *OlmMachine) actuallyDecryptMegolmEvent(ctx context.Context, evt *event.Event, encryptionRoomID id.RoomID, content *event.EncryptedEventContent) (*InboundGroupSession, []byte, uint, error) {
|
|
mach.megolmDecryptLock.Lock()
|
|
defer mach.megolmDecryptLock.Unlock()
|
|
|
|
sess, err := mach.CryptoStore.GetGroupSession(ctx, encryptionRoomID, content.SessionID)
|
|
if err != nil {
|
|
return nil, nil, 0, fmt.Errorf("failed to get group session: %w", err)
|
|
} else if sess == nil {
|
|
return nil, nil, 0, fmt.Errorf("%w (ID %s)", NoSessionFound, content.SessionID)
|
|
} else if content.SenderKey != "" && content.SenderKey != sess.SenderKey {
|
|
return sess, nil, 0, SenderKeyMismatch
|
|
}
|
|
plaintext, messageIndex, err := sess.Internal.Decrypt(content.MegolmCiphertext)
|
|
if err != nil {
|
|
if errors.Is(err, olm.UnknownMessageIndex) && mach.RatchetKeysOnDecrypt {
|
|
messageIndex, err = mach.checkUndecryptableMessageIndexDuplication(ctx, sess, evt, content)
|
|
return sess, nil, messageIndex, fmt.Errorf("failed to decrypt megolm event: %w", err)
|
|
}
|
|
return sess, nil, 0, fmt.Errorf("failed to decrypt megolm event: %w", err)
|
|
} else if ok, err := mach.CryptoStore.ValidateMessageIndex(ctx, sess.SenderKey, content.SessionID, evt.ID, messageIndex, evt.Timestamp); err != nil {
|
|
return sess, nil, messageIndex, fmt.Errorf("failed to check if message index is duplicate: %w", err)
|
|
} else if !ok {
|
|
return sess, nil, messageIndex, fmt.Errorf("%w %d", DuplicateMessageIndex, messageIndex)
|
|
}
|
|
|
|
// Normal clients don't care about tracking the ratchet state, so let them bypass the rest of the function
|
|
if mach.DisableRatchetTracking {
|
|
return sess, plaintext, messageIndex, nil
|
|
}
|
|
|
|
expectedMessageIndex := sess.RatchetSafety.NextIndex
|
|
didModify := false
|
|
switch {
|
|
case messageIndex > expectedMessageIndex:
|
|
// When the index jumps, add indices in between to the missed indices list.
|
|
for i := expectedMessageIndex; i < messageIndex; i++ {
|
|
sess.RatchetSafety.MissedIndices = append(sess.RatchetSafety.MissedIndices, i)
|
|
}
|
|
fallthrough
|
|
case messageIndex == expectedMessageIndex:
|
|
// When the index moves forward (to the next one or jumping ahead), update the last received index.
|
|
sess.RatchetSafety.NextIndex = messageIndex + 1
|
|
didModify = true
|
|
default:
|
|
sess.RatchetSafety.MissedIndices, didModify = removeItem(sess.RatchetSafety.MissedIndices, messageIndex)
|
|
}
|
|
// Use presence of ReceivedAt as a sign that this is a recent megolm session,
|
|
// and therefore it's safe to drop missed indices entirely.
|
|
if !sess.ReceivedAt.IsZero() && len(sess.RatchetSafety.MissedIndices) > 0 && int(sess.RatchetSafety.MissedIndices[0]) < int(sess.RatchetSafety.NextIndex)-missedIndexCutoff {
|
|
limit := sess.RatchetSafety.NextIndex - missedIndexCutoff
|
|
var cutoff int
|
|
for ; cutoff < len(sess.RatchetSafety.MissedIndices) && sess.RatchetSafety.MissedIndices[cutoff] < limit; cutoff++ {
|
|
}
|
|
sess.RatchetSafety.LostIndices = append(sess.RatchetSafety.LostIndices, sess.RatchetSafety.MissedIndices[:cutoff]...)
|
|
sess.RatchetSafety.MissedIndices = sess.RatchetSafety.MissedIndices[cutoff:]
|
|
didModify = true
|
|
}
|
|
ratchetTargetIndex := uint32(sess.RatchetSafety.NextIndex)
|
|
if len(sess.RatchetSafety.MissedIndices) > 0 {
|
|
ratchetTargetIndex = uint32(sess.RatchetSafety.MissedIndices[0])
|
|
}
|
|
ratchetCurrentIndex := sess.Internal.FirstKnownIndex()
|
|
log := zerolog.Ctx(ctx).With().
|
|
Uint32("prev_ratchet_index", ratchetCurrentIndex).
|
|
Uint32("new_ratchet_index", ratchetTargetIndex).
|
|
Uint("next_new_index", sess.RatchetSafety.NextIndex).
|
|
Uints("missed_indices", sess.RatchetSafety.MissedIndices).
|
|
Uints("lost_indices", sess.RatchetSafety.LostIndices).
|
|
Int("max_messages", sess.MaxMessages).
|
|
Logger()
|
|
if sess.MaxMessages > 0 && int(ratchetTargetIndex) >= sess.MaxMessages && len(sess.RatchetSafety.MissedIndices) == 0 && mach.DeleteFullyUsedKeysOnDecrypt {
|
|
err = mach.CryptoStore.RedactGroupSession(ctx, sess.RoomID, sess.ID(), "maximum messages reached")
|
|
if err != nil {
|
|
log.Err(err).Msg("Failed to delete fully used session")
|
|
return sess, plaintext, messageIndex, RatchetError
|
|
} else {
|
|
log.Info().Msg("Deleted fully used session")
|
|
}
|
|
} else if ratchetCurrentIndex < ratchetTargetIndex && mach.RatchetKeysOnDecrypt {
|
|
if err = sess.RatchetTo(ratchetTargetIndex); err != nil {
|
|
log.Err(err).Msg("Failed to ratchet session")
|
|
return sess, plaintext, messageIndex, RatchetError
|
|
} else if err = mach.CryptoStore.PutGroupSession(ctx, sess); err != nil {
|
|
log.Err(err).Msg("Failed to store ratcheted session")
|
|
return sess, plaintext, messageIndex, RatchetError
|
|
} else {
|
|
log.Info().Msg("Ratcheted session forward")
|
|
}
|
|
} else if didModify {
|
|
if err = mach.CryptoStore.PutGroupSession(ctx, sess); err != nil {
|
|
log.Err(err).Msg("Failed to store updated ratchet safety data")
|
|
return sess, plaintext, messageIndex, RatchetError
|
|
} else {
|
|
log.Debug().Msg("Ratchet safety data changed (ratchet state didn't change)")
|
|
}
|
|
} else {
|
|
log.Debug().Msg("Ratchet safety data didn't change")
|
|
}
|
|
return sess, plaintext, messageIndex, nil
|
|
}
|