mirror of https://github.com/mautrix/go.git
400 lines
11 KiB
Go
400 lines
11 KiB
Go
// Copyright (c) 2024 Tulir Asokan
|
|
//
|
|
// This Source Code Form is subject to the terms of the Mozilla Public
|
|
// License, v. 2.0. If a copy of the MPL was not distributed with this
|
|
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
|
|
|
package federation
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"encoding/base64"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"net/url"
|
|
"strconv"
|
|
"time"
|
|
|
|
"go.mau.fi/util/exslices"
|
|
"go.mau.fi/util/jsontime"
|
|
|
|
"maunium.net/go/mautrix"
|
|
"maunium.net/go/mautrix/id"
|
|
)
|
|
|
|
type Client struct {
|
|
HTTP *http.Client
|
|
ServerName string
|
|
UserAgent string
|
|
Key *SigningKey
|
|
}
|
|
|
|
func NewClient(serverName string, key *SigningKey) *Client {
|
|
return &Client{
|
|
HTTP: &http.Client{
|
|
Transport: NewServerResolvingTransport(),
|
|
Timeout: 120 * time.Second,
|
|
},
|
|
UserAgent: mautrix.DefaultUserAgent,
|
|
ServerName: serverName,
|
|
Key: key,
|
|
}
|
|
}
|
|
|
|
func (c *Client) Version(ctx context.Context, serverName string) (resp *RespServerVersion, err error) {
|
|
err = c.MakeRequest(ctx, serverName, false, http.MethodGet, URLPath{"v1", "version"}, nil, &resp)
|
|
return
|
|
}
|
|
|
|
func (c *Client) ServerKeys(ctx context.Context, serverName string) (resp *ServerKeyResponse, err error) {
|
|
err = c.MakeRequest(ctx, serverName, false, http.MethodGet, KeyURLPath{"v2", "server"}, nil, &resp)
|
|
return
|
|
}
|
|
|
|
func (c *Client) QueryKeys(ctx context.Context, serverName string, req *ReqQueryKeys) (resp *ServerKeyResponse, err error) {
|
|
err = c.MakeRequest(ctx, serverName, false, http.MethodPost, KeyURLPath{"v2", "query"}, req, &resp)
|
|
return
|
|
}
|
|
|
|
type PDU = json.RawMessage
|
|
type EDU = json.RawMessage
|
|
|
|
type ReqSendTransaction struct {
|
|
Destination string `json:"destination"`
|
|
TxnID string `json:"-"`
|
|
|
|
Origin string `json:"origin"`
|
|
OriginServerTS jsontime.UnixMilli `json:"origin_server_ts"`
|
|
PDUs []PDU `json:"pdus"`
|
|
EDUs []EDU `json:"edus,omitempty"`
|
|
}
|
|
|
|
type PDUProcessingResult struct {
|
|
Error string `json:"error,omitempty"`
|
|
}
|
|
|
|
type RespSendTransaction struct {
|
|
PDUs map[id.EventID]PDUProcessingResult `json:"pdus"`
|
|
}
|
|
|
|
func (c *Client) SendTransaction(ctx context.Context, req *ReqSendTransaction) (resp *RespSendTransaction, err error) {
|
|
err = c.MakeRequest(ctx, req.Destination, true, http.MethodPost, URLPath{"v1", "send", req.TxnID}, req, &resp)
|
|
return
|
|
}
|
|
|
|
type RespGetEventAuthChain struct {
|
|
AuthChain []PDU `json:"auth_chain"`
|
|
}
|
|
|
|
func (c *Client) GetEventAuthChain(ctx context.Context, serverName string, roomID id.RoomID, eventID id.EventID) (resp *RespGetEventAuthChain, err error) {
|
|
err = c.MakeRequest(ctx, serverName, true, http.MethodGet, URLPath{"v1", "event_auth", roomID, eventID}, nil, &resp)
|
|
return
|
|
}
|
|
|
|
type ReqBackfill struct {
|
|
ServerName string
|
|
RoomID id.RoomID
|
|
Limit int
|
|
BackfillFrom []id.EventID
|
|
}
|
|
|
|
type RespBackfill struct {
|
|
Origin string `json:"origin"`
|
|
OriginServerTS jsontime.UnixMilli `json:"origin_server_ts"`
|
|
PDUs []PDU `json:"pdus"`
|
|
}
|
|
|
|
func (c *Client) Backfill(ctx context.Context, req *ReqBackfill) (resp *RespBackfill, err error) {
|
|
_, _, err = c.MakeFullRequest(ctx, RequestParams{
|
|
ServerName: req.ServerName,
|
|
Method: http.MethodGet,
|
|
Path: URLPath{"v1", "backfill", req.RoomID},
|
|
Query: url.Values{
|
|
"limit": {strconv.Itoa(req.Limit)},
|
|
"v": exslices.CastToString[string](req.BackfillFrom),
|
|
},
|
|
Authenticate: true,
|
|
ResponseJSON: &resp,
|
|
})
|
|
return
|
|
}
|
|
|
|
type ReqGetMissingEvents struct {
|
|
ServerName string `json:"-"`
|
|
RoomID id.RoomID `json:"-"`
|
|
EarliestEvents []id.EventID `json:"earliest_events"`
|
|
LatestEvents []id.EventID `json:"latest_events"`
|
|
Limit int `json:"limit,omitempty"`
|
|
MinDepth int `json:"min_depth,omitempty"`
|
|
}
|
|
|
|
type RespGetMissingEvents struct {
|
|
Events []PDU `json:"events"`
|
|
}
|
|
|
|
func (c *Client) GetMissingEvents(ctx context.Context, req *ReqGetMissingEvents) (resp *RespGetMissingEvents, err error) {
|
|
err = c.MakeRequest(ctx, req.ServerName, true, http.MethodPost, URLPath{"v1", "get_missing_events", req.RoomID}, req, &resp)
|
|
return
|
|
}
|
|
|
|
func (c *Client) GetEvent(ctx context.Context, serverName string, eventID id.EventID) (resp *RespBackfill, err error) {
|
|
err = c.MakeRequest(ctx, serverName, true, http.MethodGet, URLPath{"v1", "event", eventID}, nil, &resp)
|
|
return
|
|
}
|
|
|
|
type RespGetState struct {
|
|
AuthChain []PDU `json:"auth_chain"`
|
|
PDUs []PDU `json:"pdus"`
|
|
}
|
|
|
|
func (c *Client) GetState(ctx context.Context, serverName string, roomID id.RoomID, eventID id.EventID) (resp *RespGetState, err error) {
|
|
_, _, err = c.MakeFullRequest(ctx, RequestParams{
|
|
ServerName: serverName,
|
|
Method: http.MethodGet,
|
|
Path: URLPath{"v1", "state", roomID},
|
|
Query: url.Values{
|
|
"event_id": {string(eventID)},
|
|
},
|
|
Authenticate: true,
|
|
ResponseJSON: &resp,
|
|
})
|
|
return
|
|
}
|
|
|
|
type RespGetStateIDs struct {
|
|
AuthChain []id.EventID `json:"auth_chain_ids"`
|
|
PDUs []id.EventID `json:"pdu_ids"`
|
|
}
|
|
|
|
func (c *Client) GetStateIDs(ctx context.Context, serverName string, roomID id.RoomID, eventID id.EventID) (resp *RespGetStateIDs, err error) {
|
|
_, _, err = c.MakeFullRequest(ctx, RequestParams{
|
|
ServerName: serverName,
|
|
Method: http.MethodGet,
|
|
Path: URLPath{"v1", "state_ids", roomID},
|
|
Query: url.Values{
|
|
"event_id": {string(eventID)},
|
|
},
|
|
Authenticate: true,
|
|
ResponseJSON: &resp,
|
|
})
|
|
return
|
|
}
|
|
|
|
func (c *Client) TimestampToEvent(ctx context.Context, serverName string, roomID id.RoomID, timestamp time.Time, dir mautrix.Direction) (resp *mautrix.RespTimestampToEvent, err error) {
|
|
_, _, err = c.MakeFullRequest(ctx, RequestParams{
|
|
ServerName: serverName,
|
|
Method: http.MethodGet,
|
|
Path: URLPath{"v1", "timestamp_to_event", roomID},
|
|
Query: url.Values{
|
|
"dir": {string(dir)},
|
|
"ts": {strconv.FormatInt(timestamp.UnixMilli(), 10)},
|
|
},
|
|
Authenticate: true,
|
|
ResponseJSON: &resp,
|
|
})
|
|
return
|
|
}
|
|
|
|
func (c *Client) QueryProfile(ctx context.Context, serverName string, userID id.UserID) (resp *mautrix.RespUserProfile, err error) {
|
|
err = c.Query(ctx, serverName, "profile", url.Values{"user_id": {userID.String()}}, &resp)
|
|
return
|
|
}
|
|
|
|
func (c *Client) QueryDirectory(ctx context.Context, serverName string, roomAlias id.RoomAlias) (resp *mautrix.RespAliasResolve, err error) {
|
|
err = c.Query(ctx, serverName, "directory", url.Values{"room_alias": {roomAlias.String()}}, &resp)
|
|
return
|
|
}
|
|
|
|
func (c *Client) Query(ctx context.Context, serverName, queryType string, queryParams url.Values, respStruct any) (err error) {
|
|
_, _, err = c.MakeFullRequest(ctx, RequestParams{
|
|
ServerName: serverName,
|
|
Method: http.MethodGet,
|
|
Path: URLPath{"v1", "query", queryType},
|
|
Query: queryParams,
|
|
Authenticate: true,
|
|
ResponseJSON: respStruct,
|
|
})
|
|
return
|
|
}
|
|
|
|
type RespOpenIDUserInfo struct {
|
|
Sub id.UserID `json:"sub"`
|
|
}
|
|
|
|
func (c *Client) GetOpenIDUserInfo(ctx context.Context, serverName, accessToken string) (resp *RespOpenIDUserInfo, err error) {
|
|
_, _, err = c.MakeFullRequest(ctx, RequestParams{
|
|
ServerName: serverName,
|
|
Method: http.MethodGet,
|
|
Path: URLPath{"v1", "openid", "userinfo"},
|
|
Query: url.Values{"access_token": {accessToken}},
|
|
ResponseJSON: &resp,
|
|
})
|
|
return
|
|
}
|
|
|
|
type URLPath []any
|
|
|
|
func (fup URLPath) FullPath() []any {
|
|
return append([]any{"_matrix", "federation"}, []any(fup)...)
|
|
}
|
|
|
|
type KeyURLPath []any
|
|
|
|
func (fkup KeyURLPath) FullPath() []any {
|
|
return append([]any{"_matrix", "key"}, []any(fkup)...)
|
|
}
|
|
|
|
type RequestParams struct {
|
|
ServerName string
|
|
Method string
|
|
Path mautrix.PrefixableURLPath
|
|
Query url.Values
|
|
Authenticate bool
|
|
RequestJSON any
|
|
|
|
ResponseJSON any
|
|
DontReadBody bool
|
|
}
|
|
|
|
func (c *Client) MakeRequest(ctx context.Context, serverName string, authenticate bool, method string, path mautrix.PrefixableURLPath, reqJSON, respJSON any) error {
|
|
_, _, err := c.MakeFullRequest(ctx, RequestParams{
|
|
ServerName: serverName,
|
|
Method: method,
|
|
Path: path,
|
|
Authenticate: authenticate,
|
|
RequestJSON: reqJSON,
|
|
ResponseJSON: respJSON,
|
|
})
|
|
return err
|
|
}
|
|
|
|
func (c *Client) MakeFullRequest(ctx context.Context, params RequestParams) ([]byte, *http.Response, error) {
|
|
req, err := c.compileRequest(ctx, params)
|
|
if err != nil {
|
|
return nil, nil, err
|
|
}
|
|
resp, err := c.HTTP.Do(req)
|
|
if err != nil {
|
|
return nil, nil, mautrix.HTTPError{
|
|
Request: req,
|
|
Response: resp,
|
|
|
|
Message: "request error",
|
|
WrappedError: err,
|
|
}
|
|
}
|
|
defer func() {
|
|
_ = resp.Body.Close()
|
|
}()
|
|
var body []byte
|
|
if resp.StatusCode >= 400 {
|
|
body, err = mautrix.ParseErrorResponse(req, resp)
|
|
return body, resp, err
|
|
} else if params.ResponseJSON != nil || !params.DontReadBody {
|
|
body, err = io.ReadAll(resp.Body)
|
|
if err != nil {
|
|
return body, resp, mautrix.HTTPError{
|
|
Request: req,
|
|
Response: resp,
|
|
|
|
Message: "failed to read response body",
|
|
WrappedError: err,
|
|
}
|
|
}
|
|
if params.ResponseJSON != nil {
|
|
err = json.Unmarshal(body, params.ResponseJSON)
|
|
if err != nil {
|
|
return body, resp, mautrix.HTTPError{
|
|
Request: req,
|
|
Response: resp,
|
|
|
|
Message: "failed to unmarshal response JSON",
|
|
ResponseBody: string(body),
|
|
WrappedError: err,
|
|
}
|
|
}
|
|
}
|
|
}
|
|
return body, resp, nil
|
|
}
|
|
|
|
func (c *Client) compileRequest(ctx context.Context, params RequestParams) (*http.Request, error) {
|
|
reqURL := mautrix.BuildURL(&url.URL{
|
|
Scheme: "matrix-federation",
|
|
Host: params.ServerName,
|
|
}, params.Path.FullPath()...)
|
|
reqURL.RawQuery = params.Query.Encode()
|
|
var reqJSON json.RawMessage
|
|
var reqBody io.Reader
|
|
if params.RequestJSON != nil {
|
|
var err error
|
|
reqJSON, err = json.Marshal(params.RequestJSON)
|
|
if err != nil {
|
|
return nil, mautrix.HTTPError{
|
|
Message: "failed to marshal JSON",
|
|
WrappedError: err,
|
|
}
|
|
}
|
|
reqBody = bytes.NewReader(reqJSON)
|
|
}
|
|
req, err := http.NewRequestWithContext(ctx, params.Method, reqURL.String(), reqBody)
|
|
if err != nil {
|
|
return nil, mautrix.HTTPError{
|
|
Message: "failed to create request",
|
|
WrappedError: err,
|
|
}
|
|
}
|
|
req.Header.Set("User-Agent", c.UserAgent)
|
|
if params.Authenticate {
|
|
if c.ServerName == "" || c.Key == nil {
|
|
return nil, mautrix.HTTPError{
|
|
Message: "client not configured for authentication",
|
|
}
|
|
}
|
|
var contentAny any
|
|
if reqJSON != nil {
|
|
contentAny = reqJSON
|
|
}
|
|
auth, err := (&signableRequest{
|
|
Method: req.Method,
|
|
URI: reqURL.RequestURI(),
|
|
Origin: c.ServerName,
|
|
Destination: params.ServerName,
|
|
Content: contentAny,
|
|
}).Sign(c.Key)
|
|
if err != nil {
|
|
return nil, mautrix.HTTPError{
|
|
Message: "failed to sign request",
|
|
WrappedError: err,
|
|
}
|
|
}
|
|
req.Header.Set("Authorization", auth)
|
|
}
|
|
return req, nil
|
|
}
|
|
|
|
type signableRequest struct {
|
|
Method string `json:"method"`
|
|
URI string `json:"uri"`
|
|
Origin string `json:"origin"`
|
|
Destination string `json:"destination"`
|
|
Content any `json:"content,omitempty"`
|
|
}
|
|
|
|
func (r *signableRequest) Sign(key *SigningKey) (string, error) {
|
|
sig, err := key.SignJSON(r)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
return fmt.Sprintf(
|
|
`X-Matrix origin="%s",destination="%s",key="%s",sig="%s"`,
|
|
r.Origin,
|
|
r.Destination,
|
|
key.ID,
|
|
base64.RawURLEncoding.EncodeToString(sig),
|
|
), nil
|
|
}
|