435 lines
14 KiB
Go
435 lines
14 KiB
Go
// mautrix-signal - A Matrix-signal puppeting bridge.
|
|
// Copyright (C) 2023 Scott Weber
|
|
//
|
|
// This program is free software: you can redistribute it and/or modify
|
|
// it under the terms of the GNU Affero General Public License as published by
|
|
// the Free Software Foundation, either version 3 of the License, or
|
|
// (at your option) any later version.
|
|
//
|
|
// This program is distributed in the hope that it will be useful,
|
|
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
|
// GNU Affero General Public License for more details.
|
|
//
|
|
// You should have received a copy of the GNU Affero General Public License
|
|
// along with this program. If not, see <https://www.gnu.org/licenses/>.
|
|
|
|
package signalmeow
|
|
|
|
import (
|
|
"context"
|
|
"encoding/base64"
|
|
"encoding/json"
|
|
"fmt"
|
|
mrand "math/rand"
|
|
"net/http"
|
|
"net/url"
|
|
"time"
|
|
|
|
"github.com/google/uuid"
|
|
"github.com/rs/zerolog"
|
|
"go.mau.fi/util/exerrors"
|
|
"go.mau.fi/util/random"
|
|
"google.golang.org/protobuf/proto"
|
|
"nhooyr.io/websocket"
|
|
|
|
"go.mau.fi/mautrix-signal/pkg/libsignalgo"
|
|
signalpb "go.mau.fi/mautrix-signal/pkg/signalmeow/protobuf"
|
|
"go.mau.fi/mautrix-signal/pkg/signalmeow/store"
|
|
"go.mau.fi/mautrix-signal/pkg/signalmeow/types"
|
|
"go.mau.fi/mautrix-signal/pkg/signalmeow/web"
|
|
"go.mau.fi/mautrix-signal/pkg/signalmeow/wspb"
|
|
)
|
|
|
|
type ConfirmDeviceResponse struct {
|
|
ACI uuid.UUID `json:"uuid"`
|
|
PNI uuid.UUID `json:"pni,omitempty"`
|
|
DeviceID int `json:"deviceId"`
|
|
}
|
|
|
|
type ProvisioningState int
|
|
|
|
const (
|
|
StateProvisioningError ProvisioningState = iota
|
|
StateProvisioningURLReceived
|
|
StateProvisioningDataReceived
|
|
StateProvisioningPreKeysRegistered
|
|
)
|
|
|
|
func (s ProvisioningState) String() string {
|
|
switch s {
|
|
case StateProvisioningError:
|
|
return "StateProvisioningError"
|
|
case StateProvisioningURLReceived:
|
|
return "StateProvisioningURLReceived"
|
|
case StateProvisioningDataReceived:
|
|
return "StateProvisioningDataReceived"
|
|
case StateProvisioningPreKeysRegistered:
|
|
return "StateProvisioningPreKeysRegistered"
|
|
default:
|
|
return fmt.Sprintf("ProvisioningState(%d)", s)
|
|
}
|
|
}
|
|
|
|
// Enum for the provisioningUrl, ProvisioningMessage, and error
|
|
type ProvisioningResponse struct {
|
|
State ProvisioningState
|
|
ProvisioningURL string
|
|
ProvisioningData *store.DeviceData
|
|
Err error
|
|
}
|
|
|
|
func PerformProvisioning(ctx context.Context, deviceStore store.DeviceStore, deviceName string) chan ProvisioningResponse {
|
|
log := zerolog.Ctx(ctx).With().Str("action", "perform provisioning").Logger()
|
|
c := make(chan ProvisioningResponse, 4)
|
|
go func() {
|
|
defer close(c)
|
|
|
|
ctx, cancel := context.WithTimeout(ctx, 2*time.Minute)
|
|
defer cancel()
|
|
ws, resp, err := web.OpenWebsocket(ctx, web.WebsocketProvisioningPath)
|
|
if err != nil {
|
|
log.Err(err).Any("resp", resp).Msg("error opening provisioning websocket")
|
|
c <- ProvisioningResponse{State: StateProvisioningError, Err: err}
|
|
return
|
|
}
|
|
defer ws.Close(websocket.StatusInternalError, "Websocket StatusInternalError")
|
|
provisioningCipher := NewProvisioningCipher()
|
|
|
|
provisioningURL, err := startProvisioning(ctx, ws, provisioningCipher)
|
|
if err != nil {
|
|
log.Err(err).Msg("startProvisioning error")
|
|
c <- ProvisioningResponse{State: StateProvisioningError, Err: err}
|
|
return
|
|
}
|
|
c <- ProvisioningResponse{State: StateProvisioningURLReceived, ProvisioningURL: provisioningURL, Err: err}
|
|
|
|
provisioningMessage, err := continueProvisioning(ctx, ws, provisioningCipher)
|
|
if err != nil {
|
|
log.Err(err).Msg("continueProvisioning error")
|
|
c <- ProvisioningResponse{State: StateProvisioningError, Err: err}
|
|
return
|
|
}
|
|
ws.Close(websocket.StatusNormalClosure, "")
|
|
|
|
aciPublicKey := exerrors.Must(libsignalgo.DeserializePublicKey(provisioningMessage.GetAciIdentityKeyPublic()))
|
|
aciPrivateKey := exerrors.Must(libsignalgo.DeserializePrivateKey(provisioningMessage.GetAciIdentityKeyPrivate()))
|
|
aciIdentityKeyPair := exerrors.Must(libsignalgo.NewIdentityKeyPair(aciPublicKey, aciPrivateKey))
|
|
pniPublicKey := exerrors.Must(libsignalgo.DeserializePublicKey(provisioningMessage.GetPniIdentityKeyPublic()))
|
|
pniPrivateKey := exerrors.Must(libsignalgo.DeserializePrivateKey(provisioningMessage.GetPniIdentityKeyPrivate()))
|
|
pniIdentityKeyPair := exerrors.Must(libsignalgo.NewIdentityKeyPair(pniPublicKey, pniPrivateKey))
|
|
profileKey := libsignalgo.ProfileKey(provisioningMessage.GetProfileKey())
|
|
|
|
username := *provisioningMessage.Number
|
|
password := random.String(22)
|
|
code := provisioningMessage.ProvisioningCode
|
|
aciRegistrationID := mrand.Intn(16383) + 1
|
|
pniRegistrationID := mrand.Intn(16383) + 1
|
|
aciSignedPreKey := GenerateSignedPreKey(1, aciIdentityKeyPair)
|
|
pniSignedPreKey := GenerateSignedPreKey(1, pniIdentityKeyPair)
|
|
aciPQLastResortPreKey := GenerateKyberPreKeys(1, 1, aciIdentityKeyPair)[0]
|
|
pniPQLastResortPreKey := GenerateKyberPreKeys(1, 1, pniIdentityKeyPair)[0]
|
|
deviceResponse, err := confirmDevice(
|
|
ctx,
|
|
username,
|
|
password,
|
|
*code,
|
|
aciRegistrationID,
|
|
pniRegistrationID,
|
|
aciSignedPreKey,
|
|
pniSignedPreKey,
|
|
aciPQLastResortPreKey,
|
|
pniPQLastResortPreKey,
|
|
aciIdentityKeyPair,
|
|
deviceName,
|
|
)
|
|
if err != nil {
|
|
log.Err(err).Msg("confirmDevice error")
|
|
c <- ProvisioningResponse{State: StateProvisioningError, Err: err}
|
|
return
|
|
}
|
|
|
|
deviceId := 1
|
|
if deviceResponse.DeviceID != 0 {
|
|
deviceId = deviceResponse.DeviceID
|
|
}
|
|
|
|
data := &store.DeviceData{
|
|
ACIIdentityKeyPair: aciIdentityKeyPair,
|
|
PNIIdentityKeyPair: pniIdentityKeyPair,
|
|
ACIRegistrationID: aciRegistrationID,
|
|
PNIRegistrationID: pniRegistrationID,
|
|
ACI: deviceResponse.ACI,
|
|
PNI: deviceResponse.PNI,
|
|
DeviceID: deviceId,
|
|
Number: *provisioningMessage.Number,
|
|
Password: password,
|
|
}
|
|
|
|
// Store the provisioning data
|
|
err = deviceStore.PutDevice(ctx, data)
|
|
if err != nil {
|
|
log.Err(err).Msg("error storing new device")
|
|
c <- ProvisioningResponse{State: StateProvisioningError, Err: err}
|
|
return
|
|
}
|
|
|
|
device, err := deviceStore.DeviceByACI(ctx, data.ACI)
|
|
if err != nil {
|
|
log.Err(err).Msg("error retrieving new device")
|
|
c <- ProvisioningResponse{State: StateProvisioningError, Err: err}
|
|
return
|
|
}
|
|
|
|
// In case this is an existing device, we gotta clear out keys
|
|
device.ClearDeviceKeys(ctx)
|
|
|
|
// Store identity keys?
|
|
_, err = device.IdentityKeyStore.SaveIdentityKey(ctx, device.ACIServiceID(), device.ACIIdentityKeyPair.GetIdentityKey())
|
|
if err != nil {
|
|
c <- ProvisioningResponse{
|
|
State: StateProvisioningError,
|
|
Err: fmt.Errorf("error saving identity key: %w", err),
|
|
}
|
|
return
|
|
}
|
|
_, err = device.IdentityKeyStore.SaveIdentityKey(ctx, device.PNIServiceID(), device.PNIIdentityKeyPair.GetIdentityKey())
|
|
if err != nil {
|
|
c <- ProvisioningResponse{
|
|
State: StateProvisioningError,
|
|
Err: fmt.Errorf("error saving identity key: %w", err),
|
|
}
|
|
return
|
|
}
|
|
|
|
// Store signed prekeys (now that we have a device)
|
|
device.ACIPreKeyStore.StoreSignedPreKey(ctx, 1, aciSignedPreKey)
|
|
device.PNIPreKeyStore.StoreSignedPreKey(ctx, 1, pniSignedPreKey)
|
|
device.ACIPreKeyStore.StoreLastResortKyberPreKey(ctx, 1, aciPQLastResortPreKey)
|
|
device.PNIPreKeyStore.StoreLastResortKyberPreKey(ctx, 1, pniPQLastResortPreKey)
|
|
|
|
// Store our profile key
|
|
err = device.RecipientStore.StoreRecipient(ctx, &types.Recipient{
|
|
ACI: data.ACI,
|
|
PNI: data.PNI,
|
|
E164: data.Number,
|
|
Profile: types.Profile{
|
|
Key: profileKey,
|
|
},
|
|
})
|
|
if err != nil {
|
|
c <- ProvisioningResponse{
|
|
State: StateProvisioningError,
|
|
Err: fmt.Errorf("error storing profile key: %w", err),
|
|
}
|
|
return
|
|
}
|
|
|
|
// Return the provisioning data
|
|
c <- ProvisioningResponse{State: StateProvisioningDataReceived, ProvisioningData: data}
|
|
|
|
// Generate, store, and register prekeys
|
|
// TODO hacky client construction
|
|
cli := &Client{Store: device}
|
|
err = cli.GenerateAndRegisterPreKeys(ctx, device.ACIPreKeyStore)
|
|
if err != nil {
|
|
c <- ProvisioningResponse{
|
|
State: StateProvisioningError,
|
|
Err: fmt.Errorf("error generating and registering ACI prekeys: %w", err),
|
|
}
|
|
return
|
|
}
|
|
err = cli.GenerateAndRegisterPreKeys(ctx, device.PNIPreKeyStore)
|
|
if err != nil {
|
|
c <- ProvisioningResponse{
|
|
State: StateProvisioningError,
|
|
Err: fmt.Errorf("error generating and registering PNI prekeys: %w", err),
|
|
}
|
|
return
|
|
}
|
|
|
|
c <- ProvisioningResponse{State: StateProvisioningPreKeysRegistered}
|
|
}()
|
|
return c
|
|
}
|
|
|
|
// Returns the provisioningUrl and an error
|
|
func startProvisioning(ctx context.Context, ws *websocket.Conn, provisioningCipher *ProvisioningCipher) (string, error) {
|
|
log := zerolog.Ctx(ctx).With().Str("action", "start provisioning").Logger()
|
|
pubKey := provisioningCipher.GetPublicKey()
|
|
|
|
msg := &signalpb.WebSocketMessage{}
|
|
err := wspb.Read(ctx, ws, msg)
|
|
if err != nil {
|
|
log.Err(err).Msg("error reading websocket message")
|
|
return "", err
|
|
}
|
|
|
|
// Ensure the message is a request and has a valid verb and path
|
|
if msg.GetType() != signalpb.WebSocketMessage_REQUEST || msg.GetRequest().GetVerb() != http.MethodPut || msg.GetRequest().GetPath() != "/v1/address" {
|
|
return "", fmt.Errorf("unexpected websocket message: %v", msg)
|
|
}
|
|
|
|
var provisioningBody signalpb.ProvisioningUuid
|
|
err = proto.Unmarshal(msg.GetRequest().GetBody(), &provisioningBody)
|
|
if err != nil {
|
|
return "", fmt.Errorf("failed to unmarshal provisioning UUID: %w", err)
|
|
}
|
|
|
|
provisioningURL := (&url.URL{
|
|
Scheme: "sgnl",
|
|
Host: "linkdevice",
|
|
RawQuery: url.Values{
|
|
"uuid": []string{provisioningBody.GetUuid()},
|
|
"pub_key": []string{base64.StdEncoding.EncodeToString(exerrors.Must(pubKey.Serialize()))},
|
|
}.Encode(),
|
|
}).String()
|
|
|
|
// Create and send response
|
|
response := web.CreateWSResponse(ctx, msg.GetRequest().GetId(), 200)
|
|
err = wspb.Write(ctx, ws, response)
|
|
if err != nil {
|
|
log.Err(err).Msg("error writing websocket message")
|
|
return "", err
|
|
}
|
|
return provisioningURL, nil
|
|
}
|
|
|
|
func continueProvisioning(ctx context.Context, ws *websocket.Conn, provisioningCipher *ProvisioningCipher) (*signalpb.ProvisionMessage, error) {
|
|
log := zerolog.Ctx(ctx).With().Str("action", "continue provisioning").Logger()
|
|
envelope := &signalpb.ProvisionEnvelope{}
|
|
msg := &signalpb.WebSocketMessage{}
|
|
err := wspb.Read(ctx, ws, msg)
|
|
if err != nil {
|
|
log.Err(err).Msg("error reading websocket message")
|
|
return nil, err
|
|
}
|
|
|
|
// Wait for provisioning message in a request, then send a response
|
|
if *msg.Type == signalpb.WebSocketMessage_REQUEST &&
|
|
*msg.Request.Verb == http.MethodPut &&
|
|
*msg.Request.Path == "/v1/message" {
|
|
|
|
err = proto.Unmarshal(msg.Request.Body, envelope)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
response := web.CreateWSResponse(ctx, *msg.Request.Id, 200)
|
|
err = wspb.Write(ctx, ws, response)
|
|
if err != nil {
|
|
log.Err(err).Msg("error writing websocket message")
|
|
return nil, err
|
|
}
|
|
} else {
|
|
err = fmt.Errorf("invalid provisioning message, type: %v, verb: %v, path: %v", *msg.Type, *msg.Request.Verb, *msg.Request.Path)
|
|
log.Err(err).Msg("problem reading websocket message")
|
|
return nil, err
|
|
}
|
|
provisioningMessage, err := provisioningCipher.Decrypt(envelope)
|
|
return provisioningMessage, err
|
|
}
|
|
|
|
func confirmDevice(
|
|
ctx context.Context,
|
|
username string,
|
|
password string,
|
|
code string,
|
|
aciRegistrationID int,
|
|
pniRegistrationID int,
|
|
aciSignedPreKey *libsignalgo.SignedPreKeyRecord,
|
|
pniSignedPreKey *libsignalgo.SignedPreKeyRecord,
|
|
aciPQLastResortPreKey *libsignalgo.KyberPreKeyRecord,
|
|
pniPQLastResortPreKey *libsignalgo.KyberPreKeyRecord,
|
|
aciIdentityKeyPair *libsignalgo.IdentityKeyPair,
|
|
deviceName string,
|
|
) (*ConfirmDeviceResponse, error) {
|
|
log := zerolog.Ctx(ctx).With().Str("action", "confirm device").Logger()
|
|
ctx = log.WithContext(ctx)
|
|
encryptedDeviceName, err := EncryptDeviceName(deviceName, aciIdentityKeyPair.GetPublicKey())
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to encrypt device name: %w", err)
|
|
}
|
|
|
|
ws, resp, err := web.OpenWebsocket(ctx, web.WebsocketPath)
|
|
if err != nil {
|
|
log.Err(err).Any("resp", resp).Msg("error opening websocket")
|
|
return nil, err
|
|
}
|
|
defer ws.Close(websocket.StatusInternalError, "Websocket StatusInternalError")
|
|
|
|
aciSignedPreKeyJson, err := SignedPreKeyToJSON(aciSignedPreKey)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to convert signed ACI prekey to JSON: %w", err)
|
|
}
|
|
pniSignedPreKeyJson, err := SignedPreKeyToJSON(pniSignedPreKey)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to convert signed PNI prekey to JSON: %w", err)
|
|
}
|
|
|
|
aciPQLastResortPreKeyJson, err := KyberPreKeyToJSON(aciPQLastResortPreKey)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to convert ACI kyber last resort prekey to JSON: %w", err)
|
|
}
|
|
pniPQLastResortPreKeyJson, err := KyberPreKeyToJSON(pniPQLastResortPreKey)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to convert PNI kyber last resort prekey to JSON: %w", err)
|
|
}
|
|
|
|
data := map[string]any{
|
|
"verificationCode": code,
|
|
"accountAttributes": map[string]any{
|
|
"fetchesMessages": true,
|
|
"name": encryptedDeviceName,
|
|
"registrationId": aciRegistrationID,
|
|
"pniRegistrationId": pniRegistrationID,
|
|
"capabilities": map[string]any{
|
|
"pni": true,
|
|
},
|
|
},
|
|
"aciSignedPreKey": aciSignedPreKeyJson,
|
|
"pniSignedPreKey": pniSignedPreKeyJson,
|
|
"aciPqLastResortPreKey": aciPQLastResortPreKeyJson,
|
|
"pniPqLastResortPreKey": pniPQLastResortPreKeyJson,
|
|
}
|
|
|
|
jsonBytes, err := json.Marshal(data)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to marshal JSON: %w", err)
|
|
}
|
|
|
|
// Create and send request TODO: Use SignalWebsocket
|
|
request := web.CreateWSRequest(http.MethodPut, "/v1/devices/link", jsonBytes, &username, &password)
|
|
one := uint64(1)
|
|
request.Id = &one
|
|
msg_type := signalpb.WebSocketMessage_REQUEST
|
|
message := &signalpb.WebSocketMessage{
|
|
Type: &msg_type,
|
|
Request: request,
|
|
}
|
|
err = wspb.Write(ctx, ws, message)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed on write protobuf data to websocket: %w", err)
|
|
}
|
|
|
|
receivedMsg := &signalpb.WebSocketMessage{}
|
|
err = wspb.Read(ctx, ws, receivedMsg)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to read from websocket after devices call: %w", err)
|
|
}
|
|
|
|
status := int(*receivedMsg.Response.Status)
|
|
if status < 200 || status >= 300 {
|
|
return nil, fmt.Errorf("non-200 status code (%d) from devices response: %s", status, *receivedMsg.Response.Message)
|
|
}
|
|
|
|
// unmarshal JSON response into ConfirmDeviceResponse
|
|
deviceResp := ConfirmDeviceResponse{}
|
|
err = json.Unmarshal(receivedMsg.Response.Body, &deviceResp)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to unmarshal JSON: %w", err)
|
|
}
|
|
|
|
return &deviceResp, nil
|
|
}
|