mirror of https://github.com/mautrix/go.git
153 lines
4.8 KiB
Go
153 lines
4.8 KiB
Go
// Copyright (c) 2024 Tulir Asokan
|
|
//
|
|
// This Source Code Form is subject to the terms of the Mozilla Public
|
|
// License, v. 2.0. If a copy of the MPL was not distributed with this
|
|
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
|
|
|
package federation
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"net"
|
|
"net/http"
|
|
"net/url"
|
|
"strconv"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/rs/zerolog"
|
|
)
|
|
|
|
type ResolvedServerName struct {
|
|
ServerName string `json:"server_name"`
|
|
HostHeader string `json:"host_header"`
|
|
IPPort []string `json:"ip_port"`
|
|
Expires time.Time `json:"expires"`
|
|
}
|
|
|
|
type ResolveServerNameOpts struct {
|
|
HTTPClient *http.Client
|
|
DNSClient *net.Resolver
|
|
}
|
|
|
|
var (
|
|
ErrInvalidServerName = errors.New("invalid server name")
|
|
)
|
|
|
|
// ResolveServerName implements the full server discovery algorithm as specified in https://spec.matrix.org/v1.11/server-server-api/#resolving-server-names
|
|
func ResolveServerName(ctx context.Context, serverName string, opts ...*ResolveServerNameOpts) (*ResolvedServerName, error) {
|
|
var opt ResolveServerNameOpts
|
|
if len(opts) > 0 && opts[0] != nil {
|
|
opt = *opts[0]
|
|
}
|
|
if opt.HTTPClient == nil {
|
|
opt.HTTPClient = http.DefaultClient
|
|
}
|
|
if opt.DNSClient == nil {
|
|
opt.DNSClient = net.DefaultResolver
|
|
}
|
|
output := ResolvedServerName{
|
|
ServerName: serverName,
|
|
HostHeader: serverName,
|
|
IPPort: []string{serverName},
|
|
Expires: time.Now().Add(24 * time.Hour),
|
|
}
|
|
hostname, port, ok := ParseServerName(serverName)
|
|
if !ok {
|
|
return nil, ErrInvalidServerName
|
|
}
|
|
// Steps 1 and 2: handle IP literals and hostnames with port
|
|
if net.ParseIP(hostname) != nil || port != 0 {
|
|
if port == 0 {
|
|
port = 8448
|
|
}
|
|
output.IPPort = []string{net.JoinHostPort(hostname, strconv.Itoa(int(port)))}
|
|
return &output, nil
|
|
}
|
|
// Step 3: resolve .well-known
|
|
wellKnown, expiry, err := RequestWellKnown(ctx, opt.HTTPClient, hostname)
|
|
if err != nil {
|
|
zerolog.Ctx(ctx).Trace().
|
|
Str("server_name", serverName).
|
|
Err(err).
|
|
Msg("Failed to get well-known data")
|
|
} else if wellKnown != nil {
|
|
output.Expires = expiry
|
|
output.HostHeader = wellKnown.Server
|
|
hostname, port, ok = ParseServerName(wellKnown.Server)
|
|
// Step 3.1 and 3.2: IP literals and hostnames with port inside .well-known
|
|
if net.ParseIP(hostname) != nil || port != 0 {
|
|
if port == 0 {
|
|
port = 8448
|
|
}
|
|
output.IPPort = []string{net.JoinHostPort(hostname, strconv.Itoa(int(port)))}
|
|
return &output, nil
|
|
}
|
|
}
|
|
// Step 3.3, 3.4, 4 and 5: resolve SRV records
|
|
srv, err := RequestSRV(ctx, opt.DNSClient, hostname)
|
|
if err != nil {
|
|
// TODO log more noisily for abnormal errors?
|
|
zerolog.Ctx(ctx).Trace().
|
|
Str("server_name", serverName).
|
|
Str("hostname", hostname).
|
|
Err(err).
|
|
Msg("Failed to get SRV record")
|
|
} else if len(srv) > 0 {
|
|
output.IPPort = make([]string, len(srv))
|
|
for i, record := range srv {
|
|
output.IPPort[i] = net.JoinHostPort(strings.TrimRight(record.Target, "."), strconv.Itoa(int(record.Port)))
|
|
}
|
|
return &output, nil
|
|
}
|
|
// Step 6 or 3.5: no SRV records were found, so default to port 8448
|
|
output.IPPort = []string{net.JoinHostPort(hostname, "8448")}
|
|
return &output, nil
|
|
}
|
|
|
|
// RequestSRV resolves the `_matrix-fed._tcp` SRV record for the given hostname.
|
|
// If the new matrix-fed record is not found, it falls back to the old `_matrix._tcp` record.
|
|
func RequestSRV(ctx context.Context, cli *net.Resolver, hostname string) ([]*net.SRV, error) {
|
|
_, target, err := cli.LookupSRV(ctx, "matrix-fed", "tcp", hostname)
|
|
var dnsErr *net.DNSError
|
|
if err != nil && errors.As(err, &dnsErr) && dnsErr.IsNotFound {
|
|
_, target, err = cli.LookupSRV(ctx, "matrix", "tcp", hostname)
|
|
}
|
|
return target, err
|
|
}
|
|
|
|
// RequestWellKnown sends a request to the well-known endpoint of a server and returns the response,
|
|
// plus the time when the cache should expire.
|
|
func RequestWellKnown(ctx context.Context, cli *http.Client, hostname string) (*RespWellKnown, time.Time, error) {
|
|
wellKnownURL := url.URL{
|
|
Scheme: "https",
|
|
Host: hostname,
|
|
Path: "/.well-known/matrix/server",
|
|
}
|
|
req, err := http.NewRequestWithContext(ctx, http.MethodGet, wellKnownURL.String(), nil)
|
|
if err != nil {
|
|
return nil, time.Time{}, fmt.Errorf("failed to prepare request: %w", err)
|
|
}
|
|
resp, err := cli.Do(req)
|
|
if err != nil {
|
|
return nil, time.Time{}, fmt.Errorf("failed to send request: %w", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
if resp.StatusCode != http.StatusOK {
|
|
return nil, time.Time{}, fmt.Errorf("unexpected status code %d", resp.StatusCode)
|
|
}
|
|
var respData RespWellKnown
|
|
err = json.NewDecoder(io.LimitReader(resp.Body, 50*1024)).Decode(&respData)
|
|
if err != nil {
|
|
return nil, time.Time{}, fmt.Errorf("failed to decode response: %w", err)
|
|
} else if respData.Server == "" {
|
|
return nil, time.Time{}, errors.New("server name not found in response")
|
|
}
|
|
// TODO parse cache-control header
|
|
return &respData, time.Now().Add(24 * time.Hour), nil
|
|
}
|