mirror of https://github.com/mautrix/go.git
363 lines
12 KiB
Go
363 lines
12 KiB
Go
// Package ratchet provides the ratchet used by the olm protocol
|
|
package ratchet
|
|
|
|
import (
|
|
"crypto/hmac"
|
|
"crypto/sha256"
|
|
"fmt"
|
|
"io"
|
|
|
|
"golang.org/x/crypto/hkdf"
|
|
|
|
"maunium.net/go/mautrix/crypto/goolm/aessha2"
|
|
"maunium.net/go/mautrix/crypto/goolm/crypto"
|
|
"maunium.net/go/mautrix/crypto/goolm/libolmpickle"
|
|
"maunium.net/go/mautrix/crypto/goolm/message"
|
|
"maunium.net/go/mautrix/crypto/olm"
|
|
)
|
|
|
|
const (
|
|
olmPickleVersion uint8 = 1
|
|
)
|
|
|
|
const (
|
|
maxReceiverChains = 5
|
|
maxSkippedMessageKeys = 40
|
|
protocolVersion = 3
|
|
messageKeySeed = 0x01
|
|
|
|
maxMessageGap = 2000
|
|
sharedKeyLength = 32
|
|
)
|
|
|
|
var olmKeysKDFInfo = []byte("OLM_KEYS")
|
|
|
|
// KdfInfo has the infos used for the kdf
|
|
var KdfInfo = struct {
|
|
Root []byte
|
|
Ratchet []byte
|
|
}{
|
|
Root: []byte("OLM_ROOT"),
|
|
Ratchet: []byte("OLM_RATCHET"),
|
|
}
|
|
|
|
// Ratchet represents the olm ratchet as described in
|
|
//
|
|
// https://gitlab.matrix.org/matrix-org/olm/-/blob/master/docs/olm.md
|
|
type Ratchet struct {
|
|
// The root key is used to generate chain keys from the ephemeral keys.
|
|
// A new root_key is derived each time a new chain is started.
|
|
RootKey crypto.Curve25519PublicKey `json:"root_key"`
|
|
|
|
// The sender chain is used to send messages. Each time a new ephemeral
|
|
// key is received from the remote server we generate a new sender chain
|
|
// with a new ephemeral key when we next send a message.
|
|
SenderChains senderChain `json:"sender_chain"`
|
|
|
|
// The receiver chain is used to decrypt received messages. We store the
|
|
// last few chains so we can decrypt any out of order messages we haven't
|
|
// received yet.
|
|
// New chains are prepended for easier access.
|
|
ReceiverChains []receiverChain `json:"receiver_chains"`
|
|
|
|
// Storing the keys of missed messages for future use.
|
|
// The order of the elements is not important.
|
|
SkippedMessageKeys []skippedMessageKey `json:"skipped_message_keys"`
|
|
}
|
|
|
|
// New creates a new ratchet, setting the kdfInfos and cipher.
|
|
func New() *Ratchet {
|
|
return &Ratchet{}
|
|
}
|
|
|
|
// InitializeAsBob initializes this ratchet from a receiving point of view (only first message).
|
|
func (r *Ratchet) InitializeAsBob(sharedSecret []byte, theirRatchetKey crypto.Curve25519PublicKey) error {
|
|
derivedSecretsReader := hkdf.New(sha256.New, sharedSecret, nil, KdfInfo.Root)
|
|
derivedSecrets := make([]byte, 2*sharedKeyLength)
|
|
if _, err := io.ReadFull(derivedSecretsReader, derivedSecrets); err != nil {
|
|
return err
|
|
}
|
|
r.RootKey = derivedSecrets[0:sharedKeyLength]
|
|
newReceiverChain := newReceiverChain(derivedSecrets[sharedKeyLength:], theirRatchetKey)
|
|
r.ReceiverChains = append([]receiverChain{*newReceiverChain}, r.ReceiverChains...)
|
|
return nil
|
|
}
|
|
|
|
// InitializeAsAlice initializes this ratchet from a sending point of view (only first message).
|
|
func (r *Ratchet) InitializeAsAlice(sharedSecret []byte, ourRatchetKey crypto.Curve25519KeyPair) error {
|
|
derivedSecretsReader := hkdf.New(sha256.New, sharedSecret, nil, KdfInfo.Root)
|
|
derivedSecrets := make([]byte, 2*sharedKeyLength)
|
|
if _, err := io.ReadFull(derivedSecretsReader, derivedSecrets); err != nil {
|
|
return err
|
|
}
|
|
r.RootKey = derivedSecrets[0:sharedKeyLength]
|
|
newSenderChain := newSenderChain(derivedSecrets[sharedKeyLength:], ourRatchetKey)
|
|
r.SenderChains = *newSenderChain
|
|
return nil
|
|
}
|
|
|
|
// Encrypt encrypts the message in a message.Message with MAC.
|
|
func (r *Ratchet) Encrypt(plaintext []byte) ([]byte, error) {
|
|
var err error
|
|
if !r.SenderChains.IsSet {
|
|
newRatchetKey, err := crypto.Curve25519GenerateKey()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
newChainKey, err := r.advanceRootKey(newRatchetKey, r.ReceiverChains[0].ratchetKey())
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
newSenderChain := newSenderChain(newChainKey, newRatchetKey)
|
|
r.SenderChains = *newSenderChain
|
|
}
|
|
|
|
messageKey := r.createMessageKeys(r.SenderChains.chainKey())
|
|
r.SenderChains.advance()
|
|
|
|
cipher, err := aessha2.NewAESSHA2(messageKey.Key, olmKeysKDFInfo)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
encryptedText, err := cipher.Encrypt(plaintext)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("cipher encrypt: %w", err)
|
|
}
|
|
|
|
message := &message.Message{}
|
|
message.Version = protocolVersion
|
|
message.Counter = messageKey.Index
|
|
message.RatchetKey = r.SenderChains.ratchetKey().PublicKey
|
|
message.Ciphertext = encryptedText
|
|
//creating the mac is done in encode
|
|
return message.EncodeAndMAC(cipher)
|
|
}
|
|
|
|
// Decrypt decrypts the ciphertext and verifies the MAC.
|
|
func (r *Ratchet) Decrypt(input []byte) ([]byte, error) {
|
|
message := &message.Message{}
|
|
//The mac is not verified here, as we do not know the key yet
|
|
err := message.Decode(input)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if message.Version != protocolVersion {
|
|
return nil, fmt.Errorf("decrypt: %w", olm.ErrWrongProtocolVersion)
|
|
}
|
|
if !message.HasCounter || len(message.RatchetKey) == 0 || len(message.Ciphertext) == 0 {
|
|
return nil, fmt.Errorf("decrypt: %w", olm.ErrBadMessageFormat)
|
|
}
|
|
var receiverChainFromMessage *receiverChain
|
|
for curChainIndex := range r.ReceiverChains {
|
|
if r.ReceiverChains[curChainIndex].ratchetKey().Equal(message.RatchetKey) {
|
|
receiverChainFromMessage = &r.ReceiverChains[curChainIndex]
|
|
break
|
|
}
|
|
}
|
|
if receiverChainFromMessage == nil {
|
|
//Advancing the chain is done in this method
|
|
return r.decryptForNewChain(message, input)
|
|
} else if receiverChainFromMessage.chainKey().Index > message.Counter {
|
|
// No need to advance the chain
|
|
// Chain already advanced beyond the key for this message
|
|
// Check if the message keys are in the skipped key list.
|
|
for curSkippedIndex := range r.SkippedMessageKeys {
|
|
if message.Counter != r.SkippedMessageKeys[curSkippedIndex].MKey.Index {
|
|
continue
|
|
}
|
|
|
|
// Found the key for this message. Check the MAC.
|
|
if cipher, err := aessha2.NewAESSHA2(r.SkippedMessageKeys[curSkippedIndex].MKey.Key, olmKeysKDFInfo); err != nil {
|
|
return nil, err
|
|
} else if verified, err := message.VerifyMACInline(r.SkippedMessageKeys[curSkippedIndex].MKey.Key, cipher, input); err != nil {
|
|
return nil, err
|
|
} else if !verified {
|
|
return nil, fmt.Errorf("decrypt from skipped message keys: %w", olm.ErrBadMAC)
|
|
} else if result, err := cipher.Decrypt(message.Ciphertext); err != nil {
|
|
return nil, fmt.Errorf("cipher decrypt: %w", err)
|
|
} else if len(result) != 0 {
|
|
// Remove the key from the skipped keys now that we've
|
|
// decoded the message it corresponds to.
|
|
r.SkippedMessageKeys[curSkippedIndex] = r.SkippedMessageKeys[len(r.SkippedMessageKeys)-1]
|
|
r.SkippedMessageKeys = r.SkippedMessageKeys[:len(r.SkippedMessageKeys)-1]
|
|
return result, nil
|
|
}
|
|
}
|
|
return nil, fmt.Errorf("decrypt: %w", olm.ErrMessageKeyNotFound)
|
|
} else {
|
|
//Advancing the chain is done in this method
|
|
return r.decryptForExistingChain(receiverChainFromMessage, message, input)
|
|
}
|
|
}
|
|
|
|
// advanceRootKey created the next root key and returns the next chainKey
|
|
func (r *Ratchet) advanceRootKey(newRatchetKey crypto.Curve25519KeyPair, oldRatchetKey crypto.Curve25519PublicKey) (crypto.Curve25519PublicKey, error) {
|
|
sharedSecret, err := newRatchetKey.SharedSecret(oldRatchetKey)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
derivedSecretsReader := hkdf.New(sha256.New, sharedSecret, r.RootKey, KdfInfo.Ratchet)
|
|
derivedSecrets := make([]byte, 2*sharedKeyLength)
|
|
if _, err := io.ReadFull(derivedSecretsReader, derivedSecrets); err != nil {
|
|
return nil, err
|
|
}
|
|
r.RootKey = derivedSecrets[:sharedKeyLength]
|
|
return derivedSecrets[sharedKeyLength:], nil
|
|
}
|
|
|
|
// createMessageKeys returns the messageKey derived from the chainKey
|
|
func (r Ratchet) createMessageKeys(chainKey chainKey) messageKey {
|
|
hash := hmac.New(sha256.New, chainKey.Key)
|
|
hash.Write([]byte{messageKeySeed})
|
|
return messageKey{
|
|
Key: hash.Sum(nil),
|
|
Index: chainKey.Index,
|
|
}
|
|
}
|
|
|
|
// decryptForExistingChain returns the decrypted message by using the chain. The MAC of the rawMessage is verified.
|
|
func (r *Ratchet) decryptForExistingChain(chain *receiverChain, message *message.Message, rawMessage []byte) ([]byte, error) {
|
|
if message.Counter < chain.CKey.Index {
|
|
return nil, fmt.Errorf("decrypt: %w", olm.ErrChainTooHigh)
|
|
}
|
|
// Limit the number of hashes we're prepared to compute
|
|
if message.Counter-chain.CKey.Index > maxMessageGap {
|
|
return nil, fmt.Errorf("decrypt from existing chain: %w", olm.ErrMsgIndexTooHigh)
|
|
}
|
|
for chain.CKey.Index < message.Counter {
|
|
messageKey := r.createMessageKeys(chain.chainKey())
|
|
skippedKey := skippedMessageKey{
|
|
MKey: messageKey,
|
|
RKey: chain.ratchetKey(),
|
|
}
|
|
r.SkippedMessageKeys = append(r.SkippedMessageKeys, skippedKey)
|
|
chain.advance()
|
|
}
|
|
messageKey := r.createMessageKeys(chain.chainKey())
|
|
chain.advance()
|
|
cipher, err := aessha2.NewAESSHA2(messageKey.Key, olmKeysKDFInfo)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
verified, err := message.VerifyMACInline(messageKey.Key, cipher, rawMessage)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if !verified {
|
|
return nil, fmt.Errorf("decrypt from existing chain: %w", olm.ErrBadMAC)
|
|
}
|
|
return cipher.Decrypt(message.Ciphertext)
|
|
}
|
|
|
|
// decryptForNewChain returns the decrypted message by creating a new chain and advancing the root key.
|
|
func (r *Ratchet) decryptForNewChain(message *message.Message, rawMessage []byte) ([]byte, error) {
|
|
// They shouldn't move to a new chain until we've sent them a message
|
|
// acknowledging the last one
|
|
if !r.SenderChains.IsSet {
|
|
return nil, fmt.Errorf("decrypt for new chain: %w", olm.ErrProtocolViolation)
|
|
}
|
|
// Limit the number of hashes we're prepared to compute
|
|
if message.Counter > maxMessageGap {
|
|
return nil, fmt.Errorf("decrypt for new chain: %w", olm.ErrMsgIndexTooHigh)
|
|
}
|
|
|
|
newChainKey, err := r.advanceRootKey(r.SenderChains.ratchetKey(), message.RatchetKey)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
newChain := newReceiverChain(newChainKey, message.RatchetKey)
|
|
r.ReceiverChains = append([]receiverChain{*newChain}, r.ReceiverChains...)
|
|
/*
|
|
They have started using a new ephemeral ratchet key.
|
|
We needed to derive a new set of chain keys.
|
|
We can discard our previous ephemeral ratchet key.
|
|
We will generate a new key when we send the next message.
|
|
*/
|
|
r.SenderChains = senderChain{}
|
|
|
|
return r.decryptForExistingChain(&r.ReceiverChains[0], message, rawMessage)
|
|
}
|
|
|
|
// PickleAsJSON returns a ratchet as a base64 string encrypted using the supplied key. The unencrypted representation of the Account is in JSON format.
|
|
func (r Ratchet) PickleAsJSON(key []byte) ([]byte, error) {
|
|
return libolmpickle.PickleAsJSON(r, olmPickleVersion, key)
|
|
}
|
|
|
|
// UnpickleAsJSON updates a ratchet by a base64 encrypted string using the supplied key. The unencrypted representation has to be in JSON format.
|
|
func (r *Ratchet) UnpickleAsJSON(pickled, key []byte) error {
|
|
return libolmpickle.UnpickleAsJSON(r, pickled, key, olmPickleVersion)
|
|
}
|
|
|
|
// UnpickleLibOlm unpickles the unencryted value and populates the [Ratchet]
|
|
// accordingly.
|
|
func (r *Ratchet) UnpickleLibOlm(decoder *libolmpickle.Decoder, includesChainIndex bool) error {
|
|
if err := r.RootKey.UnpickleLibOlm(decoder); err != nil {
|
|
return err
|
|
}
|
|
senderChainsCount, err := decoder.ReadUInt32()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
for i := uint32(0); i < senderChainsCount; i++ {
|
|
if i == 0 {
|
|
// only the first sender key is stored
|
|
err = r.SenderChains.UnpickleLibOlm(decoder)
|
|
r.SenderChains.IsSet = true
|
|
} else {
|
|
// just eat the values
|
|
err = (&senderChain{}).UnpickleLibOlm(decoder)
|
|
}
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
receiverChainCount, err := decoder.ReadUInt32()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
r.ReceiverChains = make([]receiverChain, receiverChainCount)
|
|
for i := uint32(0); i < receiverChainCount; i++ {
|
|
if err := r.ReceiverChains[i].UnpickleLibOlm(decoder); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
skippedMessageKeysCount, err := decoder.ReadUInt32()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
r.SkippedMessageKeys = make([]skippedMessageKey, skippedMessageKeysCount)
|
|
for i := uint32(0); i < skippedMessageKeysCount; i++ {
|
|
if err := r.SkippedMessageKeys[i].UnpickleLibOlm(decoder); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
// pickle version 0x80000001 includes a chain index; pickle version 1 does not.
|
|
if includesChainIndex {
|
|
_, err = decoder.ReadUInt32()
|
|
return err
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// PickleLibOlm pickles the ratchet into the encoder.
|
|
func (r Ratchet) PickleLibOlm(encoder *libolmpickle.Encoder) {
|
|
r.RootKey.PickleLibOlm(encoder)
|
|
r.SenderChains.PickleLibOlm(encoder)
|
|
|
|
// Receiver Chains
|
|
encoder.WriteUInt32(uint32(len(r.ReceiverChains)))
|
|
for _, curChain := range r.ReceiverChains {
|
|
curChain.PickleLibOlm(encoder)
|
|
}
|
|
|
|
// Skipped Message Keys
|
|
encoder.WriteUInt32(uint32(len(r.SkippedMessageKeys)))
|
|
for _, curChain := range r.SkippedMessageKeys {
|
|
curChain.PickleLibOlm(encoder)
|
|
}
|
|
}
|