mirror of https://github.com/mautrix/go.git
411 lines
12 KiB
Go
411 lines
12 KiB
Go
package libolm
|
|
|
|
// #cgo LDFLAGS: -lolm -lstdc++
|
|
// #include <olm/olm.h>
|
|
import "C"
|
|
|
|
import (
|
|
"crypto/rand"
|
|
"encoding/base64"
|
|
"encoding/json"
|
|
"unsafe"
|
|
|
|
"github.com/tidwall/gjson"
|
|
|
|
"maunium.net/go/mautrix/crypto/olm"
|
|
"maunium.net/go/mautrix/id"
|
|
)
|
|
|
|
// Account stores a device account for end to end encrypted messaging.
|
|
type Account struct {
|
|
int *C.OlmAccount
|
|
mem []byte
|
|
}
|
|
|
|
func init() {
|
|
olm.InitNewAccount = func() (olm.Account, error) {
|
|
return NewAccount()
|
|
}
|
|
olm.InitBlankAccount = func() olm.Account {
|
|
return NewBlankAccount()
|
|
}
|
|
olm.InitNewAccountFromPickled = func(pickled, key []byte) (olm.Account, error) {
|
|
return AccountFromPickled(pickled, key)
|
|
}
|
|
}
|
|
|
|
// Ensure that [Account] implements [olm.Account].
|
|
var _ olm.Account = (*Account)(nil)
|
|
|
|
// AccountFromPickled loads an Account from a pickled base64 string. Decrypts
|
|
// the Account using the supplied key. Returns error on failure. If the key
|
|
// doesn't match the one used to encrypt the Account then the error will be
|
|
// "BAD_ACCOUNT_KEY". If the base64 couldn't be decoded then the error will be
|
|
// "INVALID_BASE64".
|
|
func AccountFromPickled(pickled, key []byte) (*Account, error) {
|
|
if len(pickled) == 0 {
|
|
return nil, olm.EmptyInput
|
|
}
|
|
a := NewBlankAccount()
|
|
return a, a.Unpickle(pickled, key)
|
|
}
|
|
|
|
func NewBlankAccount() *Account {
|
|
memory := make([]byte, accountSize())
|
|
return &Account{
|
|
int: C.olm_account(unsafe.Pointer(&memory[0])),
|
|
mem: memory,
|
|
}
|
|
}
|
|
|
|
// NewAccount creates a new [Account].
|
|
func NewAccount() (*Account, error) {
|
|
a := NewBlankAccount()
|
|
random := make([]byte, a.createRandomLen()+1)
|
|
_, err := rand.Read(random)
|
|
if err != nil {
|
|
panic(olm.NotEnoughGoRandom)
|
|
}
|
|
ret := C.olm_create_account(
|
|
(*C.OlmAccount)(a.int),
|
|
unsafe.Pointer(&random[0]),
|
|
C.size_t(len(random)))
|
|
if ret == errorVal() {
|
|
return nil, a.lastError()
|
|
} else {
|
|
return a, nil
|
|
}
|
|
}
|
|
|
|
// accountSize returns the size of an account object in bytes.
|
|
func accountSize() uint {
|
|
return uint(C.olm_account_size())
|
|
}
|
|
|
|
// lastError returns an error describing the most recent error to happen to an
|
|
// account.
|
|
func (a *Account) lastError() error {
|
|
return convertError(C.GoString(C.olm_account_last_error((*C.OlmAccount)(a.int))))
|
|
}
|
|
|
|
// Clear clears the memory used to back this Account.
|
|
func (a *Account) Clear() error {
|
|
r := C.olm_clear_account((*C.OlmAccount)(a.int))
|
|
if r == errorVal() {
|
|
return a.lastError()
|
|
} else {
|
|
return nil
|
|
}
|
|
}
|
|
|
|
// pickleLen returns the number of bytes needed to store an Account.
|
|
func (a *Account) pickleLen() uint {
|
|
return uint(C.olm_pickle_account_length((*C.OlmAccount)(a.int)))
|
|
}
|
|
|
|
// createRandomLen returns the number of random bytes needed to create an
|
|
// Account.
|
|
func (a *Account) createRandomLen() uint {
|
|
return uint(C.olm_create_account_random_length((*C.OlmAccount)(a.int)))
|
|
}
|
|
|
|
// identityKeysLen returns the size of the output buffer needed to hold the
|
|
// identity keys.
|
|
func (a *Account) identityKeysLen() uint {
|
|
return uint(C.olm_account_identity_keys_length((*C.OlmAccount)(a.int)))
|
|
}
|
|
|
|
// signatureLen returns the length of an ed25519 signature encoded as base64.
|
|
func (a *Account) signatureLen() uint {
|
|
return uint(C.olm_account_signature_length((*C.OlmAccount)(a.int)))
|
|
}
|
|
|
|
// oneTimeKeysLen returns the size of the output buffer needed to hold the one
|
|
// time keys.
|
|
func (a *Account) oneTimeKeysLen() uint {
|
|
return uint(C.olm_account_one_time_keys_length((*C.OlmAccount)(a.int)))
|
|
}
|
|
|
|
// genOneTimeKeysRandomLen returns the number of random bytes needed to
|
|
// generate a given number of new one time keys.
|
|
func (a *Account) genOneTimeKeysRandomLen(num uint) uint {
|
|
return uint(C.olm_account_generate_one_time_keys_random_length(
|
|
(*C.OlmAccount)(a.int),
|
|
C.size_t(num)))
|
|
}
|
|
|
|
// Pickle returns an Account as a base64 string. Encrypts the Account using the
|
|
// supplied key.
|
|
func (a *Account) Pickle(key []byte) ([]byte, error) {
|
|
if len(key) == 0 {
|
|
return nil, olm.NoKeyProvided
|
|
}
|
|
pickled := make([]byte, a.pickleLen())
|
|
r := C.olm_pickle_account(
|
|
(*C.OlmAccount)(a.int),
|
|
unsafe.Pointer(&key[0]),
|
|
C.size_t(len(key)),
|
|
unsafe.Pointer(&pickled[0]),
|
|
C.size_t(len(pickled)))
|
|
if r == errorVal() {
|
|
return nil, a.lastError()
|
|
}
|
|
return pickled[:r], nil
|
|
}
|
|
|
|
func (a *Account) Unpickle(pickled, key []byte) error {
|
|
if len(key) == 0 {
|
|
return olm.NoKeyProvided
|
|
}
|
|
r := C.olm_unpickle_account(
|
|
(*C.OlmAccount)(a.int),
|
|
unsafe.Pointer(&key[0]),
|
|
C.size_t(len(key)),
|
|
unsafe.Pointer(&pickled[0]),
|
|
C.size_t(len(pickled)))
|
|
if r == errorVal() {
|
|
return a.lastError()
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// Deprecated
|
|
func (a *Account) GobEncode() ([]byte, error) {
|
|
pickled, err := a.Pickle(pickleKey)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
length := base64.RawStdEncoding.DecodedLen(len(pickled))
|
|
rawPickled := make([]byte, length)
|
|
_, err = base64.RawStdEncoding.Decode(rawPickled, pickled)
|
|
return rawPickled, err
|
|
}
|
|
|
|
// Deprecated
|
|
func (a *Account) GobDecode(rawPickled []byte) error {
|
|
if a.int == nil {
|
|
*a = *NewBlankAccount()
|
|
}
|
|
length := base64.RawStdEncoding.EncodedLen(len(rawPickled))
|
|
pickled := make([]byte, length)
|
|
base64.RawStdEncoding.Encode(pickled, rawPickled)
|
|
return a.Unpickle(pickled, pickleKey)
|
|
}
|
|
|
|
// Deprecated
|
|
func (a *Account) MarshalJSON() ([]byte, error) {
|
|
pickled, err := a.Pickle(pickleKey)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
quotes := make([]byte, len(pickled)+2)
|
|
quotes[0] = '"'
|
|
quotes[len(quotes)-1] = '"'
|
|
copy(quotes[1:len(quotes)-1], pickled)
|
|
return quotes, nil
|
|
}
|
|
|
|
// Deprecated
|
|
func (a *Account) UnmarshalJSON(data []byte) error {
|
|
if len(data) == 0 || data[0] != '"' || data[len(data)-1] != '"' {
|
|
return olm.InputNotJSONString
|
|
}
|
|
if a.int == nil {
|
|
*a = *NewBlankAccount()
|
|
}
|
|
return a.Unpickle(data[1:len(data)-1], pickleKey)
|
|
}
|
|
|
|
// IdentityKeysJSON returns the public parts of the identity keys for the Account.
|
|
func (a *Account) IdentityKeysJSON() ([]byte, error) {
|
|
identityKeys := make([]byte, a.identityKeysLen())
|
|
r := C.olm_account_identity_keys(
|
|
(*C.OlmAccount)(a.int),
|
|
unsafe.Pointer(&identityKeys[0]),
|
|
C.size_t(len(identityKeys)))
|
|
if r == errorVal() {
|
|
return nil, a.lastError()
|
|
} else {
|
|
return identityKeys, nil
|
|
}
|
|
}
|
|
|
|
// IdentityKeys returns the public parts of the Ed25519 and Curve25519 identity
|
|
// keys for the Account.
|
|
func (a *Account) IdentityKeys() (id.Ed25519, id.Curve25519, error) {
|
|
identityKeysJSON, err := a.IdentityKeysJSON()
|
|
if err != nil {
|
|
return "", "", err
|
|
}
|
|
results := gjson.GetManyBytes(identityKeysJSON, "ed25519", "curve25519")
|
|
return id.Ed25519(results[0].Str), id.Curve25519(results[1].Str), nil
|
|
}
|
|
|
|
// Sign returns the signature of a message using the ed25519 key for this
|
|
// Account.
|
|
func (a *Account) Sign(message []byte) ([]byte, error) {
|
|
if len(message) == 0 {
|
|
panic(olm.EmptyInput)
|
|
}
|
|
signature := make([]byte, a.signatureLen())
|
|
r := C.olm_account_sign(
|
|
(*C.OlmAccount)(a.int),
|
|
unsafe.Pointer(&message[0]),
|
|
C.size_t(len(message)),
|
|
unsafe.Pointer(&signature[0]),
|
|
C.size_t(len(signature)))
|
|
if r == errorVal() {
|
|
panic(a.lastError())
|
|
}
|
|
return signature, nil
|
|
}
|
|
|
|
// OneTimeKeys returns the public parts of the unpublished one time keys for
|
|
// the Account.
|
|
//
|
|
// The returned data is a struct with the single value "Curve25519", which is
|
|
// itself an object mapping key id to base64-encoded Curve25519 key. For
|
|
// example:
|
|
//
|
|
// {
|
|
// Curve25519: {
|
|
// "AAAAAA": "wo76WcYtb0Vk/pBOdmduiGJ0wIEjW4IBMbbQn7aSnTo",
|
|
// "AAAAAB": "LRvjo46L1X2vx69sS9QNFD29HWulxrmW11Up5AfAjgU"
|
|
// }
|
|
// }
|
|
func (a *Account) OneTimeKeys() (map[string]id.Curve25519, error) {
|
|
oneTimeKeysJSON := make([]byte, a.oneTimeKeysLen())
|
|
r := C.olm_account_one_time_keys(
|
|
(*C.OlmAccount)(a.int),
|
|
unsafe.Pointer(&oneTimeKeysJSON[0]),
|
|
C.size_t(len(oneTimeKeysJSON)))
|
|
if r == errorVal() {
|
|
return nil, a.lastError()
|
|
}
|
|
var oneTimeKeys struct {
|
|
Curve25519 map[string]id.Curve25519 `json:"curve25519"`
|
|
}
|
|
return oneTimeKeys.Curve25519, json.Unmarshal(oneTimeKeysJSON, &oneTimeKeys)
|
|
}
|
|
|
|
// MarkKeysAsPublished marks the current set of one time keys as being
|
|
// published.
|
|
func (a *Account) MarkKeysAsPublished() {
|
|
C.olm_account_mark_keys_as_published((*C.OlmAccount)(a.int))
|
|
}
|
|
|
|
// MaxNumberOfOneTimeKeys returns the largest number of one time keys this
|
|
// Account can store.
|
|
func (a *Account) MaxNumberOfOneTimeKeys() uint {
|
|
return uint(C.olm_account_max_number_of_one_time_keys((*C.OlmAccount)(a.int)))
|
|
}
|
|
|
|
// GenOneTimeKeys generates a number of new one time keys. If the total number
|
|
// of keys stored by this Account exceeds MaxNumberOfOneTimeKeys then the old
|
|
// keys are discarded.
|
|
func (a *Account) GenOneTimeKeys(num uint) error {
|
|
random := make([]byte, a.genOneTimeKeysRandomLen(num)+1)
|
|
_, err := rand.Read(random)
|
|
if err != nil {
|
|
return olm.NotEnoughGoRandom
|
|
}
|
|
r := C.olm_account_generate_one_time_keys(
|
|
(*C.OlmAccount)(a.int),
|
|
C.size_t(num),
|
|
unsafe.Pointer(&random[0]),
|
|
C.size_t(len(random)))
|
|
if r == errorVal() {
|
|
return a.lastError()
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// NewOutboundSession creates a new out-bound session for sending messages to a
|
|
// given curve25519 identityKey and oneTimeKey. Returns error on failure. If the
|
|
// keys couldn't be decoded as base64 then the error will be "INVALID_BASE64"
|
|
func (a *Account) NewOutboundSession(theirIdentityKey, theirOneTimeKey id.Curve25519) (olm.Session, error) {
|
|
if len(theirIdentityKey) == 0 || len(theirOneTimeKey) == 0 {
|
|
return nil, olm.EmptyInput
|
|
}
|
|
s := NewBlankSession()
|
|
random := make([]byte, s.createOutboundRandomLen()+1)
|
|
_, err := rand.Read(random)
|
|
if err != nil {
|
|
panic(olm.NotEnoughGoRandom)
|
|
}
|
|
r := C.olm_create_outbound_session(
|
|
(*C.OlmSession)(s.int),
|
|
(*C.OlmAccount)(a.int),
|
|
unsafe.Pointer(&([]byte(theirIdentityKey)[0])),
|
|
C.size_t(len(theirIdentityKey)),
|
|
unsafe.Pointer(&([]byte(theirOneTimeKey)[0])),
|
|
C.size_t(len(theirOneTimeKey)),
|
|
unsafe.Pointer(&random[0]),
|
|
C.size_t(len(random)))
|
|
if r == errorVal() {
|
|
return nil, s.lastError()
|
|
}
|
|
return s, nil
|
|
}
|
|
|
|
// NewInboundSession creates a new in-bound session for sending/receiving
|
|
// messages from an incoming PRE_KEY message. Returns error on failure. If
|
|
// the base64 couldn't be decoded then the error will be "INVALID_BASE64". If
|
|
// the message was for an unsupported protocol version then the error will be
|
|
// "BAD_MESSAGE_VERSION". If the message couldn't be decoded then then the
|
|
// error will be "BAD_MESSAGE_FORMAT". If the message refers to an unknown one
|
|
// time key then the error will be "BAD_MESSAGE_KEY_ID".
|
|
func (a *Account) NewInboundSession(oneTimeKeyMsg string) (olm.Session, error) {
|
|
if len(oneTimeKeyMsg) == 0 {
|
|
return nil, olm.EmptyInput
|
|
}
|
|
s := NewBlankSession()
|
|
r := C.olm_create_inbound_session(
|
|
(*C.OlmSession)(s.int),
|
|
(*C.OlmAccount)(a.int),
|
|
unsafe.Pointer(&([]byte(oneTimeKeyMsg)[0])),
|
|
C.size_t(len(oneTimeKeyMsg)))
|
|
if r == errorVal() {
|
|
return nil, s.lastError()
|
|
}
|
|
return s, nil
|
|
}
|
|
|
|
// NewInboundSessionFrom creates a new in-bound session for sending/receiving
|
|
// messages from an incoming PRE_KEY message. Returns error on failure. If
|
|
// the base64 couldn't be decoded then the error will be "INVALID_BASE64". If
|
|
// the message was for an unsupported protocol version then the error will be
|
|
// "BAD_MESSAGE_VERSION". If the message couldn't be decoded then then the
|
|
// error will be "BAD_MESSAGE_FORMAT". If the message refers to an unknown one
|
|
// time key then the error will be "BAD_MESSAGE_KEY_ID".
|
|
func (a *Account) NewInboundSessionFrom(theirIdentityKey *id.Curve25519, oneTimeKeyMsg string) (olm.Session, error) {
|
|
if theirIdentityKey == nil || len(oneTimeKeyMsg) == 0 {
|
|
return nil, olm.EmptyInput
|
|
}
|
|
s := NewBlankSession()
|
|
r := C.olm_create_inbound_session_from(
|
|
(*C.OlmSession)(s.int),
|
|
(*C.OlmAccount)(a.int),
|
|
unsafe.Pointer(&([]byte(*theirIdentityKey)[0])),
|
|
C.size_t(len(*theirIdentityKey)),
|
|
unsafe.Pointer(&([]byte(oneTimeKeyMsg)[0])),
|
|
C.size_t(len(oneTimeKeyMsg)))
|
|
if r == errorVal() {
|
|
return nil, s.lastError()
|
|
}
|
|
return s, nil
|
|
}
|
|
|
|
// RemoveOneTimeKeys removes the one time keys that the session used from the
|
|
// Account. Returns error on failure. If the Account doesn't have any
|
|
// matching one time keys then the error will be "BAD_MESSAGE_KEY_ID".
|
|
func (a *Account) RemoveOneTimeKeys(s olm.Session) error {
|
|
r := C.olm_remove_one_time_keys(
|
|
(*C.OlmAccount)(a.int),
|
|
(*C.OlmSession)(s.(*Session).int))
|
|
if r == errorVal() {
|
|
return a.lastError()
|
|
}
|
|
return nil
|
|
}
|