mautrix-go/crypto/libolm/account.go

411 lines
12 KiB
Go

package libolm
// #cgo LDFLAGS: -lolm -lstdc++
// #include <olm/olm.h>
import "C"
import (
"crypto/rand"
"encoding/base64"
"encoding/json"
"unsafe"
"github.com/tidwall/gjson"
"maunium.net/go/mautrix/crypto/olm"
"maunium.net/go/mautrix/id"
)
// Account stores a device account for end to end encrypted messaging.
type Account struct {
int *C.OlmAccount
mem []byte
}
func init() {
olm.InitNewAccount = func() (olm.Account, error) {
return NewAccount()
}
olm.InitBlankAccount = func() olm.Account {
return NewBlankAccount()
}
olm.InitNewAccountFromPickled = func(pickled, key []byte) (olm.Account, error) {
return AccountFromPickled(pickled, key)
}
}
// Ensure that [Account] implements [olm.Account].
var _ olm.Account = (*Account)(nil)
// AccountFromPickled loads an Account from a pickled base64 string. Decrypts
// the Account using the supplied key. Returns error on failure. If the key
// doesn't match the one used to encrypt the Account then the error will be
// "BAD_ACCOUNT_KEY". If the base64 couldn't be decoded then the error will be
// "INVALID_BASE64".
func AccountFromPickled(pickled, key []byte) (*Account, error) {
if len(pickled) == 0 {
return nil, olm.EmptyInput
}
a := NewBlankAccount()
return a, a.Unpickle(pickled, key)
}
func NewBlankAccount() *Account {
memory := make([]byte, accountSize())
return &Account{
int: C.olm_account(unsafe.Pointer(&memory[0])),
mem: memory,
}
}
// NewAccount creates a new [Account].
func NewAccount() (*Account, error) {
a := NewBlankAccount()
random := make([]byte, a.createRandomLen()+1)
_, err := rand.Read(random)
if err != nil {
panic(olm.NotEnoughGoRandom)
}
ret := C.olm_create_account(
(*C.OlmAccount)(a.int),
unsafe.Pointer(&random[0]),
C.size_t(len(random)))
if ret == errorVal() {
return nil, a.lastError()
} else {
return a, nil
}
}
// accountSize returns the size of an account object in bytes.
func accountSize() uint {
return uint(C.olm_account_size())
}
// lastError returns an error describing the most recent error to happen to an
// account.
func (a *Account) lastError() error {
return convertError(C.GoString(C.olm_account_last_error((*C.OlmAccount)(a.int))))
}
// Clear clears the memory used to back this Account.
func (a *Account) Clear() error {
r := C.olm_clear_account((*C.OlmAccount)(a.int))
if r == errorVal() {
return a.lastError()
} else {
return nil
}
}
// pickleLen returns the number of bytes needed to store an Account.
func (a *Account) pickleLen() uint {
return uint(C.olm_pickle_account_length((*C.OlmAccount)(a.int)))
}
// createRandomLen returns the number of random bytes needed to create an
// Account.
func (a *Account) createRandomLen() uint {
return uint(C.olm_create_account_random_length((*C.OlmAccount)(a.int)))
}
// identityKeysLen returns the size of the output buffer needed to hold the
// identity keys.
func (a *Account) identityKeysLen() uint {
return uint(C.olm_account_identity_keys_length((*C.OlmAccount)(a.int)))
}
// signatureLen returns the length of an ed25519 signature encoded as base64.
func (a *Account) signatureLen() uint {
return uint(C.olm_account_signature_length((*C.OlmAccount)(a.int)))
}
// oneTimeKeysLen returns the size of the output buffer needed to hold the one
// time keys.
func (a *Account) oneTimeKeysLen() uint {
return uint(C.olm_account_one_time_keys_length((*C.OlmAccount)(a.int)))
}
// genOneTimeKeysRandomLen returns the number of random bytes needed to
// generate a given number of new one time keys.
func (a *Account) genOneTimeKeysRandomLen(num uint) uint {
return uint(C.olm_account_generate_one_time_keys_random_length(
(*C.OlmAccount)(a.int),
C.size_t(num)))
}
// Pickle returns an Account as a base64 string. Encrypts the Account using the
// supplied key.
func (a *Account) Pickle(key []byte) ([]byte, error) {
if len(key) == 0 {
return nil, olm.NoKeyProvided
}
pickled := make([]byte, a.pickleLen())
r := C.olm_pickle_account(
(*C.OlmAccount)(a.int),
unsafe.Pointer(&key[0]),
C.size_t(len(key)),
unsafe.Pointer(&pickled[0]),
C.size_t(len(pickled)))
if r == errorVal() {
return nil, a.lastError()
}
return pickled[:r], nil
}
func (a *Account) Unpickle(pickled, key []byte) error {
if len(key) == 0 {
return olm.NoKeyProvided
}
r := C.olm_unpickle_account(
(*C.OlmAccount)(a.int),
unsafe.Pointer(&key[0]),
C.size_t(len(key)),
unsafe.Pointer(&pickled[0]),
C.size_t(len(pickled)))
if r == errorVal() {
return a.lastError()
}
return nil
}
// Deprecated
func (a *Account) GobEncode() ([]byte, error) {
pickled, err := a.Pickle(pickleKey)
if err != nil {
return nil, err
}
length := base64.RawStdEncoding.DecodedLen(len(pickled))
rawPickled := make([]byte, length)
_, err = base64.RawStdEncoding.Decode(rawPickled, pickled)
return rawPickled, err
}
// Deprecated
func (a *Account) GobDecode(rawPickled []byte) error {
if a.int == nil {
*a = *NewBlankAccount()
}
length := base64.RawStdEncoding.EncodedLen(len(rawPickled))
pickled := make([]byte, length)
base64.RawStdEncoding.Encode(pickled, rawPickled)
return a.Unpickle(pickled, pickleKey)
}
// Deprecated
func (a *Account) MarshalJSON() ([]byte, error) {
pickled, err := a.Pickle(pickleKey)
if err != nil {
return nil, err
}
quotes := make([]byte, len(pickled)+2)
quotes[0] = '"'
quotes[len(quotes)-1] = '"'
copy(quotes[1:len(quotes)-1], pickled)
return quotes, nil
}
// Deprecated
func (a *Account) UnmarshalJSON(data []byte) error {
if len(data) == 0 || data[0] != '"' || data[len(data)-1] != '"' {
return olm.InputNotJSONString
}
if a.int == nil {
*a = *NewBlankAccount()
}
return a.Unpickle(data[1:len(data)-1], pickleKey)
}
// IdentityKeysJSON returns the public parts of the identity keys for the Account.
func (a *Account) IdentityKeysJSON() ([]byte, error) {
identityKeys := make([]byte, a.identityKeysLen())
r := C.olm_account_identity_keys(
(*C.OlmAccount)(a.int),
unsafe.Pointer(&identityKeys[0]),
C.size_t(len(identityKeys)))
if r == errorVal() {
return nil, a.lastError()
} else {
return identityKeys, nil
}
}
// IdentityKeys returns the public parts of the Ed25519 and Curve25519 identity
// keys for the Account.
func (a *Account) IdentityKeys() (id.Ed25519, id.Curve25519, error) {
identityKeysJSON, err := a.IdentityKeysJSON()
if err != nil {
return "", "", err
}
results := gjson.GetManyBytes(identityKeysJSON, "ed25519", "curve25519")
return id.Ed25519(results[0].Str), id.Curve25519(results[1].Str), nil
}
// Sign returns the signature of a message using the ed25519 key for this
// Account.
func (a *Account) Sign(message []byte) ([]byte, error) {
if len(message) == 0 {
panic(olm.EmptyInput)
}
signature := make([]byte, a.signatureLen())
r := C.olm_account_sign(
(*C.OlmAccount)(a.int),
unsafe.Pointer(&message[0]),
C.size_t(len(message)),
unsafe.Pointer(&signature[0]),
C.size_t(len(signature)))
if r == errorVal() {
panic(a.lastError())
}
return signature, nil
}
// OneTimeKeys returns the public parts of the unpublished one time keys for
// the Account.
//
// The returned data is a struct with the single value "Curve25519", which is
// itself an object mapping key id to base64-encoded Curve25519 key. For
// example:
//
// {
// Curve25519: {
// "AAAAAA": "wo76WcYtb0Vk/pBOdmduiGJ0wIEjW4IBMbbQn7aSnTo",
// "AAAAAB": "LRvjo46L1X2vx69sS9QNFD29HWulxrmW11Up5AfAjgU"
// }
// }
func (a *Account) OneTimeKeys() (map[string]id.Curve25519, error) {
oneTimeKeysJSON := make([]byte, a.oneTimeKeysLen())
r := C.olm_account_one_time_keys(
(*C.OlmAccount)(a.int),
unsafe.Pointer(&oneTimeKeysJSON[0]),
C.size_t(len(oneTimeKeysJSON)))
if r == errorVal() {
return nil, a.lastError()
}
var oneTimeKeys struct {
Curve25519 map[string]id.Curve25519 `json:"curve25519"`
}
return oneTimeKeys.Curve25519, json.Unmarshal(oneTimeKeysJSON, &oneTimeKeys)
}
// MarkKeysAsPublished marks the current set of one time keys as being
// published.
func (a *Account) MarkKeysAsPublished() {
C.olm_account_mark_keys_as_published((*C.OlmAccount)(a.int))
}
// MaxNumberOfOneTimeKeys returns the largest number of one time keys this
// Account can store.
func (a *Account) MaxNumberOfOneTimeKeys() uint {
return uint(C.olm_account_max_number_of_one_time_keys((*C.OlmAccount)(a.int)))
}
// GenOneTimeKeys generates a number of new one time keys. If the total number
// of keys stored by this Account exceeds MaxNumberOfOneTimeKeys then the old
// keys are discarded.
func (a *Account) GenOneTimeKeys(num uint) error {
random := make([]byte, a.genOneTimeKeysRandomLen(num)+1)
_, err := rand.Read(random)
if err != nil {
return olm.NotEnoughGoRandom
}
r := C.olm_account_generate_one_time_keys(
(*C.OlmAccount)(a.int),
C.size_t(num),
unsafe.Pointer(&random[0]),
C.size_t(len(random)))
if r == errorVal() {
return a.lastError()
}
return nil
}
// NewOutboundSession creates a new out-bound session for sending messages to a
// given curve25519 identityKey and oneTimeKey. Returns error on failure. If the
// keys couldn't be decoded as base64 then the error will be "INVALID_BASE64"
func (a *Account) NewOutboundSession(theirIdentityKey, theirOneTimeKey id.Curve25519) (olm.Session, error) {
if len(theirIdentityKey) == 0 || len(theirOneTimeKey) == 0 {
return nil, olm.EmptyInput
}
s := NewBlankSession()
random := make([]byte, s.createOutboundRandomLen()+1)
_, err := rand.Read(random)
if err != nil {
panic(olm.NotEnoughGoRandom)
}
r := C.olm_create_outbound_session(
(*C.OlmSession)(s.int),
(*C.OlmAccount)(a.int),
unsafe.Pointer(&([]byte(theirIdentityKey)[0])),
C.size_t(len(theirIdentityKey)),
unsafe.Pointer(&([]byte(theirOneTimeKey)[0])),
C.size_t(len(theirOneTimeKey)),
unsafe.Pointer(&random[0]),
C.size_t(len(random)))
if r == errorVal() {
return nil, s.lastError()
}
return s, nil
}
// NewInboundSession creates a new in-bound session for sending/receiving
// messages from an incoming PRE_KEY message. Returns error on failure. If
// the base64 couldn't be decoded then the error will be "INVALID_BASE64". If
// the message was for an unsupported protocol version then the error will be
// "BAD_MESSAGE_VERSION". If the message couldn't be decoded then then the
// error will be "BAD_MESSAGE_FORMAT". If the message refers to an unknown one
// time key then the error will be "BAD_MESSAGE_KEY_ID".
func (a *Account) NewInboundSession(oneTimeKeyMsg string) (olm.Session, error) {
if len(oneTimeKeyMsg) == 0 {
return nil, olm.EmptyInput
}
s := NewBlankSession()
r := C.olm_create_inbound_session(
(*C.OlmSession)(s.int),
(*C.OlmAccount)(a.int),
unsafe.Pointer(&([]byte(oneTimeKeyMsg)[0])),
C.size_t(len(oneTimeKeyMsg)))
if r == errorVal() {
return nil, s.lastError()
}
return s, nil
}
// NewInboundSessionFrom creates a new in-bound session for sending/receiving
// messages from an incoming PRE_KEY message. Returns error on failure. If
// the base64 couldn't be decoded then the error will be "INVALID_BASE64". If
// the message was for an unsupported protocol version then the error will be
// "BAD_MESSAGE_VERSION". If the message couldn't be decoded then then the
// error will be "BAD_MESSAGE_FORMAT". If the message refers to an unknown one
// time key then the error will be "BAD_MESSAGE_KEY_ID".
func (a *Account) NewInboundSessionFrom(theirIdentityKey *id.Curve25519, oneTimeKeyMsg string) (olm.Session, error) {
if theirIdentityKey == nil || len(oneTimeKeyMsg) == 0 {
return nil, olm.EmptyInput
}
s := NewBlankSession()
r := C.olm_create_inbound_session_from(
(*C.OlmSession)(s.int),
(*C.OlmAccount)(a.int),
unsafe.Pointer(&([]byte(*theirIdentityKey)[0])),
C.size_t(len(*theirIdentityKey)),
unsafe.Pointer(&([]byte(oneTimeKeyMsg)[0])),
C.size_t(len(oneTimeKeyMsg)))
if r == errorVal() {
return nil, s.lastError()
}
return s, nil
}
// RemoveOneTimeKeys removes the one time keys that the session used from the
// Account. Returns error on failure. If the Account doesn't have any
// matching one time keys then the error will be "BAD_MESSAGE_KEY_ID".
func (a *Account) RemoveOneTimeKeys(s olm.Session) error {
r := C.olm_remove_one_time_keys(
(*C.OlmAccount)(a.int),
(*C.OlmSession)(s.(*Session).int))
if r == errorVal() {
return a.lastError()
}
return nil
}