mautrix-go/mediaproxy/mediaproxy.go

502 lines
15 KiB
Go

// Copyright (c) 2024 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
package mediaproxy
import (
"context"
"encoding/json"
"errors"
"fmt"
"io"
"mime"
"mime/multipart"
"net/http"
"net/textproto"
"net/url"
"os"
"strconv"
"strings"
"time"
"github.com/gorilla/mux"
"github.com/rs/zerolog"
"maunium.net/go/mautrix"
"maunium.net/go/mautrix/federation"
)
type GetMediaResponse interface {
isGetMediaResponse()
}
func (*GetMediaResponseURL) isGetMediaResponse() {}
func (*GetMediaResponseData) isGetMediaResponse() {}
func (*GetMediaResponseCallback) isGetMediaResponse() {}
func (*GetMediaResponseFile) isGetMediaResponse() {}
type GetMediaResponseURL struct {
URL string
ExpiresAt time.Time
}
type GetMediaResponseWriter interface {
GetMediaResponse
io.WriterTo
GetContentType() string
GetContentLength() int64
}
var (
_ GetMediaResponseWriter = (*GetMediaResponseCallback)(nil)
_ GetMediaResponseWriter = (*GetMediaResponseData)(nil)
)
type GetMediaResponseData struct {
Reader io.ReadCloser
ContentType string
ContentLength int64
}
func (d *GetMediaResponseData) WriteTo(w io.Writer) (int64, error) {
return io.Copy(w, d.Reader)
}
func (d *GetMediaResponseData) GetContentType() string {
return d.ContentType
}
func (d *GetMediaResponseData) GetContentLength() int64 {
return d.ContentLength
}
type GetMediaResponseCallback struct {
Callback func(w io.Writer) (int64, error)
ContentType string
ContentLength int64
}
func (d *GetMediaResponseCallback) WriteTo(w io.Writer) (int64, error) {
return d.Callback(w)
}
func (d *GetMediaResponseCallback) GetContentLength() int64 {
return d.ContentLength
}
func (d *GetMediaResponseCallback) GetContentType() string {
return d.ContentType
}
type GetMediaResponseFile struct {
Callback func(w *os.File) error
ContentType string
}
type GetMediaFunc = func(ctx context.Context, mediaID string, params map[string]string) (response GetMediaResponse, err error)
type MediaProxy struct {
KeyServer *federation.KeyServer
ForceProxyLegacyFederation bool
GetMedia GetMediaFunc
PrepareProxyRequest func(*http.Request)
serverName string
serverKey *federation.SigningKey
FederationRouter *mux.Router
ClientMediaRouter *mux.Router
}
func New(serverName string, serverKey string, getMedia GetMediaFunc) (*MediaProxy, error) {
parsed, err := federation.ParseSynapseKey(serverKey)
if err != nil {
return nil, err
}
return &MediaProxy{
serverName: serverName,
serverKey: parsed,
GetMedia: getMedia,
KeyServer: &federation.KeyServer{
KeyProvider: &federation.StaticServerKey{
ServerName: serverName,
Key: parsed,
},
WellKnownTarget: fmt.Sprintf("%s:443", serverName),
Version: federation.ServerVersion{
Name: "mautrix-go media proxy",
Version: strings.TrimPrefix(mautrix.VersionWithCommit, "v"),
},
},
}, nil
}
type BasicConfig struct {
ServerName string `yaml:"server_name" json:"server_name"`
ServerKey string `yaml:"server_key" json:"server_key"`
WellKnownResponse string `yaml:"well_known_response" json:"well_known_response"`
}
func NewFromConfig(cfg BasicConfig, getMedia GetMediaFunc) (*MediaProxy, error) {
mp, err := New(cfg.ServerName, cfg.ServerKey, getMedia)
if err != nil {
return nil, err
}
if cfg.WellKnownResponse != "" {
mp.KeyServer.WellKnownTarget = cfg.WellKnownResponse
}
return mp, nil
}
type ServerConfig struct {
Hostname string `yaml:"hostname" json:"hostname"`
Port uint16 `yaml:"port" json:"port"`
}
func (mp *MediaProxy) Listen(cfg ServerConfig) error {
router := mux.NewRouter()
mp.RegisterRoutes(router)
return http.ListenAndServe(fmt.Sprintf("%s:%d", cfg.Hostname, cfg.Port), router)
}
func (mp *MediaProxy) GetServerName() string {
return mp.serverName
}
func (mp *MediaProxy) GetServerKey() *federation.SigningKey {
return mp.serverKey
}
func (mp *MediaProxy) RegisterRoutes(router *mux.Router) {
if mp.FederationRouter == nil {
mp.FederationRouter = router.PathPrefix("/_matrix/federation").Subrouter()
}
if mp.ClientMediaRouter == nil {
mp.ClientMediaRouter = router.PathPrefix("/_matrix/client/v1/media").Subrouter()
}
mp.FederationRouter.HandleFunc("/v1/media/download/{mediaID}", mp.DownloadMediaFederation).Methods(http.MethodGet)
mp.FederationRouter.HandleFunc("/v1/version", mp.KeyServer.GetServerVersion).Methods(http.MethodGet)
mp.ClientMediaRouter.HandleFunc("/download/{serverName}/{mediaID}", mp.DownloadMedia).Methods(http.MethodGet)
mp.ClientMediaRouter.HandleFunc("/download/{serverName}/{mediaID}/{fileName}", mp.DownloadMedia).Methods(http.MethodGet)
mp.ClientMediaRouter.HandleFunc("/thumbnail/{serverName}/{mediaID}", mp.DownloadMedia).Methods(http.MethodGet)
mp.ClientMediaRouter.HandleFunc("/upload/{serverName}/{mediaID}", mp.UploadNotSupported).Methods(http.MethodPut)
mp.ClientMediaRouter.HandleFunc("/upload", mp.UploadNotSupported).Methods(http.MethodPost)
mp.ClientMediaRouter.HandleFunc("/create", mp.UploadNotSupported).Methods(http.MethodPost)
mp.ClientMediaRouter.HandleFunc("/config", mp.UploadNotSupported).Methods(http.MethodGet)
mp.ClientMediaRouter.HandleFunc("/preview_url", mp.PreviewURLNotSupported).Methods(http.MethodGet)
mp.FederationRouter.NotFoundHandler = http.HandlerFunc(mp.UnknownEndpoint)
mp.FederationRouter.MethodNotAllowedHandler = http.HandlerFunc(mp.UnsupportedMethod)
mp.ClientMediaRouter.NotFoundHandler = http.HandlerFunc(mp.UnknownEndpoint)
mp.ClientMediaRouter.MethodNotAllowedHandler = http.HandlerFunc(mp.UnsupportedMethod)
corsMiddleware := func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Access-Control-Allow-Origin", "*")
w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS")
w.Header().Set("Access-Control-Allow-Headers", "X-Requested-With, Content-Type, Authorization")
w.Header().Set("Content-Security-Policy", "sandbox; default-src 'none'; script-src 'none'; plugin-types application/pdf; style-src 'unsafe-inline'; object-src 'self';")
next.ServeHTTP(w, r)
})
}
mp.ClientMediaRouter.Use(corsMiddleware)
mp.KeyServer.Register(router)
}
// Deprecated: use mautrix.RespError instead
type ResponseError struct {
Status int
Data any
}
func (err *ResponseError) Error() string {
return fmt.Sprintf("HTTP %d: %v", err.Status, err.Data)
}
var ErrInvalidMediaIDSyntax = errors.New("invalid media ID syntax")
func queryToMap(vals url.Values) map[string]string {
m := make(map[string]string, len(vals))
for k, v := range vals {
m[k] = v[0]
}
return m
}
func (mp *MediaProxy) getMedia(w http.ResponseWriter, r *http.Request) GetMediaResponse {
mediaID := mux.Vars(r)["mediaID"]
resp, err := mp.GetMedia(r.Context(), mediaID, queryToMap(r.URL.Query()))
if err != nil {
//lint:ignore SA1019 deprecated types need to be supported until they're removed
var respError *ResponseError
var mautrixRespError mautrix.RespError
if errors.Is(err, ErrInvalidMediaIDSyntax) {
mautrix.MNotFound.WithMessage("This is a media proxy at %q, other media downloads are not available here", mp.serverName).Write(w)
} else if errors.As(err, &mautrixRespError) {
mautrixRespError.Write(w)
} else if errors.As(err, &respError) {
w.Header().Add("Content-Type", "application/json")
w.WriteHeader(respError.Status)
_ = json.NewEncoder(w).Encode(respError.Data)
} else {
zerolog.Ctx(r.Context()).Err(err).Str("media_id", mediaID).Msg("Failed to get media URL")
mautrix.MNotFound.WithMessage("Media not found").Write(w)
}
return nil
}
return resp
}
func startMultipart(ctx context.Context, w http.ResponseWriter) *multipart.Writer {
mpw := multipart.NewWriter(w)
w.Header().Set("Content-Type", strings.Replace(mpw.FormDataContentType(), "form-data", "mixed", 1))
w.WriteHeader(http.StatusOK)
metaPart, err := mpw.CreatePart(textproto.MIMEHeader{
"Content-Type": {"application/json"},
})
if err != nil {
zerolog.Ctx(ctx).Err(err).Msg("Failed to create multipart metadata field")
return nil
}
_, err = metaPart.Write([]byte(`{}`))
if err != nil {
zerolog.Ctx(ctx).Err(err).Msg("Failed to write multipart metadata field")
return nil
}
return mpw
}
func (mp *MediaProxy) DownloadMediaFederation(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
log := zerolog.Ctx(ctx)
// TODO check destination header in X-Matrix auth
resp := mp.getMedia(w, r)
if resp == nil {
return
}
var mpw *multipart.Writer
if urlResp, ok := resp.(*GetMediaResponseURL); ok {
mpw = startMultipart(ctx, w)
if mpw == nil {
return
}
_, err := mpw.CreatePart(textproto.MIMEHeader{
"Location": {urlResp.URL},
})
if err != nil {
log.Err(err).Msg("Failed to create multipart redirect field")
return
}
} else if fileResp, ok := resp.(*GetMediaResponseFile); ok {
responseStarted, err := doTempFileDownload(fileResp, func(wt io.WriterTo, size int64, mimeType string) error {
mpw = startMultipart(ctx, w)
if mpw == nil {
return fmt.Errorf("failed to start multipart writer")
}
dataPart, err := mpw.CreatePart(textproto.MIMEHeader{
"Content-Type": {mimeType},
})
if err != nil {
return fmt.Errorf("failed to create multipart data field: %w", err)
}
_, err = wt.WriteTo(dataPart)
return err
})
if err != nil {
log.Err(err).Msg("Failed to do media proxy with temp file")
if !responseStarted {
var mautrixRespError mautrix.RespError
if errors.As(err, &mautrixRespError) {
mautrixRespError.Write(w)
} else {
mautrix.MUnknown.WithMessage("Internal error proxying media").Write(w)
}
}
return
}
} else if dataResp, ok := resp.(GetMediaResponseWriter); ok {
mpw = startMultipart(ctx, w)
if mpw == nil {
return
}
dataPart, err := mpw.CreatePart(textproto.MIMEHeader{
"Content-Type": {dataResp.GetContentType()},
})
if err != nil {
log.Err(err).Msg("Failed to create multipart data field")
return
}
_, err = dataResp.WriteTo(dataPart)
if err != nil {
log.Err(err).Msg("Failed to write multipart data field")
return
}
} else {
panic(fmt.Errorf("unknown GetMediaResponse type %T", resp))
}
err := mpw.Close()
if err != nil {
log.Err(err).Msg("Failed to close multipart writer")
return
}
}
func (mp *MediaProxy) addHeaders(w http.ResponseWriter, mimeType, fileName string) {
w.Header().Set("Cache-Control", "public, max-age=31536000, immutable")
contentDisposition := "attachment"
switch mimeType {
case "text/css", "text/plain", "text/csv", "application/json", "application/ld+json", "image/jpeg", "image/gif",
"image/png", "image/apng", "image/webp", "image/avif", "video/mp4", "video/webm", "video/ogg", "video/quicktime",
"audio/mp4", "audio/webm", "audio/aac", "audio/mpeg", "audio/ogg", "audio/wave", "audio/wav", "audio/x-wav",
"audio/x-pn-wav", "audio/flac", "audio/x-flac", "application/pdf":
contentDisposition = "inline"
}
if fileName != "" {
contentDisposition = mime.FormatMediaType(contentDisposition, map[string]string{
"filename": fileName,
})
}
w.Header().Set("Content-Disposition", contentDisposition)
w.Header().Set("Content-Type", mimeType)
}
func (mp *MediaProxy) DownloadMedia(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
log := zerolog.Ctx(ctx)
vars := mux.Vars(r)
if vars["serverName"] != mp.serverName {
mautrix.MNotFound.WithMessage("This is a media proxy at %q, other media downloads are not available here", mp.serverName).Write(w)
return
}
resp := mp.getMedia(w, r)
if resp == nil {
return
}
if urlResp, ok := resp.(*GetMediaResponseURL); ok {
w.Header().Set("Location", urlResp.URL)
expirySeconds := (time.Until(urlResp.ExpiresAt) - 5*time.Minute).Seconds()
if urlResp.ExpiresAt.IsZero() {
w.Header().Set("Cache-Control", "public, max-age=31536000, immutable")
} else if expirySeconds > 0 {
cacheControl := fmt.Sprintf("public, max-age=%d, immutable", int(expirySeconds))
w.Header().Set("Cache-Control", cacheControl)
} else {
w.Header().Set("Cache-Control", "no-store")
}
w.WriteHeader(http.StatusTemporaryRedirect)
} else if fileResp, ok := resp.(*GetMediaResponseFile); ok {
responseStarted, err := doTempFileDownload(fileResp, func(wt io.WriterTo, size int64, mimeType string) error {
mp.addHeaders(w, mimeType, vars["fileName"])
w.Header().Set("Content-Length", strconv.FormatInt(size, 10))
w.WriteHeader(http.StatusOK)
_, err := wt.WriteTo(w)
return err
})
if err != nil {
log.Err(err).Msg("Failed to do media proxy with temp file")
if !responseStarted {
var mautrixRespError mautrix.RespError
if errors.As(err, &mautrixRespError) {
mautrixRespError.Write(w)
} else {
mautrix.MUnknown.WithMessage("Internal error proxying media").Write(w)
}
}
}
} else if dataResp, ok := resp.(GetMediaResponseWriter); ok {
mp.addHeaders(w, dataResp.GetContentType(), vars["fileName"])
if dataResp.GetContentLength() != 0 {
w.Header().Set("Content-Length", strconv.FormatInt(dataResp.GetContentLength(), 10))
}
w.WriteHeader(http.StatusOK)
_, err := dataResp.WriteTo(w)
if err != nil {
log.Err(err).Msg("Failed to write media data")
}
} else {
panic(fmt.Errorf("unknown GetMediaResponse type %T", resp))
}
}
func doTempFileDownload(
data *GetMediaResponseFile,
respond func(w io.WriterTo, size int64, mimeType string) error,
) (bool, error) {
tempFile, err := os.CreateTemp("", "mautrix-mediaproxy-*")
if err != nil {
return false, fmt.Errorf("failed to create temp file: %w", err)
}
defer func() {
_ = tempFile.Close()
_ = os.Remove(tempFile.Name())
}()
err = data.Callback(tempFile)
if err != nil {
return false, err
}
_, err = tempFile.Seek(0, io.SeekStart)
if err != nil {
return false, fmt.Errorf("failed to seek to start of temp file: %w", err)
}
fileInfo, err := tempFile.Stat()
if err != nil {
return false, fmt.Errorf("failed to stat temp file: %w", err)
}
mimeType := data.ContentType
if mimeType == "" {
buf := make([]byte, 512)
n, err := tempFile.Read(buf)
if err != nil {
return false, fmt.Errorf("failed to read temp file to detect mime: %w", err)
}
buf = buf[:n]
_, err = tempFile.Seek(0, io.SeekStart)
if err != nil {
return false, fmt.Errorf("failed to seek to start of temp file: %w", err)
}
mimeType = http.DetectContentType(buf)
}
err = respond(tempFile, fileInfo.Size(), mimeType)
if err != nil {
return true, err
}
return true, nil
}
var (
ErrUploadNotSupported = mautrix.MUnrecognized.
WithMessage("This is a media proxy and does not support media uploads.").
WithStatus(http.StatusNotImplemented)
ErrPreviewURLNotSupported = mautrix.MUnrecognized.
WithMessage("This is a media proxy and does not support URL previews.").
WithStatus(http.StatusNotImplemented)
ErrUnknownEndpoint = mautrix.MUnrecognized.
WithMessage("Unrecognized endpoint")
ErrUnsupportedMethod = mautrix.MUnrecognized.
WithMessage("Invalid method for endpoint").
WithStatus(http.StatusMethodNotAllowed)
)
func (mp *MediaProxy) UploadNotSupported(w http.ResponseWriter, r *http.Request) {
ErrUploadNotSupported.Write(w)
}
func (mp *MediaProxy) PreviewURLNotSupported(w http.ResponseWriter, r *http.Request) {
ErrPreviewURLNotSupported.Write(w)
}
func (mp *MediaProxy) UnknownEndpoint(w http.ResponseWriter, r *http.Request) {
ErrUnknownEndpoint.Write(w)
}
func (mp *MediaProxy) UnsupportedMethod(w http.ResponseWriter, r *http.Request) {
ErrUnsupportedMethod.Write(w)
}