mirror of https://github.com/mautrix/go.git
502 lines
15 KiB
Go
502 lines
15 KiB
Go
// Copyright (c) 2024 Tulir Asokan
|
|
//
|
|
// This Source Code Form is subject to the terms of the Mozilla Public
|
|
// License, v. 2.0. If a copy of the MPL was not distributed with this
|
|
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
|
|
|
package mediaproxy
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"mime"
|
|
"mime/multipart"
|
|
"net/http"
|
|
"net/textproto"
|
|
"net/url"
|
|
"os"
|
|
"strconv"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/gorilla/mux"
|
|
"github.com/rs/zerolog"
|
|
|
|
"maunium.net/go/mautrix"
|
|
"maunium.net/go/mautrix/federation"
|
|
)
|
|
|
|
type GetMediaResponse interface {
|
|
isGetMediaResponse()
|
|
}
|
|
|
|
func (*GetMediaResponseURL) isGetMediaResponse() {}
|
|
func (*GetMediaResponseData) isGetMediaResponse() {}
|
|
func (*GetMediaResponseCallback) isGetMediaResponse() {}
|
|
func (*GetMediaResponseFile) isGetMediaResponse() {}
|
|
|
|
type GetMediaResponseURL struct {
|
|
URL string
|
|
ExpiresAt time.Time
|
|
}
|
|
|
|
type GetMediaResponseWriter interface {
|
|
GetMediaResponse
|
|
io.WriterTo
|
|
GetContentType() string
|
|
GetContentLength() int64
|
|
}
|
|
|
|
var (
|
|
_ GetMediaResponseWriter = (*GetMediaResponseCallback)(nil)
|
|
_ GetMediaResponseWriter = (*GetMediaResponseData)(nil)
|
|
)
|
|
|
|
type GetMediaResponseData struct {
|
|
Reader io.ReadCloser
|
|
ContentType string
|
|
ContentLength int64
|
|
}
|
|
|
|
func (d *GetMediaResponseData) WriteTo(w io.Writer) (int64, error) {
|
|
return io.Copy(w, d.Reader)
|
|
}
|
|
|
|
func (d *GetMediaResponseData) GetContentType() string {
|
|
return d.ContentType
|
|
}
|
|
|
|
func (d *GetMediaResponseData) GetContentLength() int64 {
|
|
return d.ContentLength
|
|
}
|
|
|
|
type GetMediaResponseCallback struct {
|
|
Callback func(w io.Writer) (int64, error)
|
|
ContentType string
|
|
ContentLength int64
|
|
}
|
|
|
|
func (d *GetMediaResponseCallback) WriteTo(w io.Writer) (int64, error) {
|
|
return d.Callback(w)
|
|
}
|
|
|
|
func (d *GetMediaResponseCallback) GetContentLength() int64 {
|
|
return d.ContentLength
|
|
}
|
|
|
|
func (d *GetMediaResponseCallback) GetContentType() string {
|
|
return d.ContentType
|
|
}
|
|
|
|
type GetMediaResponseFile struct {
|
|
Callback func(w *os.File) error
|
|
ContentType string
|
|
}
|
|
|
|
type GetMediaFunc = func(ctx context.Context, mediaID string, params map[string]string) (response GetMediaResponse, err error)
|
|
|
|
type MediaProxy struct {
|
|
KeyServer *federation.KeyServer
|
|
|
|
ForceProxyLegacyFederation bool
|
|
|
|
GetMedia GetMediaFunc
|
|
PrepareProxyRequest func(*http.Request)
|
|
|
|
serverName string
|
|
serverKey *federation.SigningKey
|
|
|
|
FederationRouter *mux.Router
|
|
ClientMediaRouter *mux.Router
|
|
}
|
|
|
|
func New(serverName string, serverKey string, getMedia GetMediaFunc) (*MediaProxy, error) {
|
|
parsed, err := federation.ParseSynapseKey(serverKey)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return &MediaProxy{
|
|
serverName: serverName,
|
|
serverKey: parsed,
|
|
GetMedia: getMedia,
|
|
KeyServer: &federation.KeyServer{
|
|
KeyProvider: &federation.StaticServerKey{
|
|
ServerName: serverName,
|
|
Key: parsed,
|
|
},
|
|
WellKnownTarget: fmt.Sprintf("%s:443", serverName),
|
|
Version: federation.ServerVersion{
|
|
Name: "mautrix-go media proxy",
|
|
Version: strings.TrimPrefix(mautrix.VersionWithCommit, "v"),
|
|
},
|
|
},
|
|
}, nil
|
|
}
|
|
|
|
type BasicConfig struct {
|
|
ServerName string `yaml:"server_name" json:"server_name"`
|
|
ServerKey string `yaml:"server_key" json:"server_key"`
|
|
WellKnownResponse string `yaml:"well_known_response" json:"well_known_response"`
|
|
}
|
|
|
|
func NewFromConfig(cfg BasicConfig, getMedia GetMediaFunc) (*MediaProxy, error) {
|
|
mp, err := New(cfg.ServerName, cfg.ServerKey, getMedia)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if cfg.WellKnownResponse != "" {
|
|
mp.KeyServer.WellKnownTarget = cfg.WellKnownResponse
|
|
}
|
|
return mp, nil
|
|
}
|
|
|
|
type ServerConfig struct {
|
|
Hostname string `yaml:"hostname" json:"hostname"`
|
|
Port uint16 `yaml:"port" json:"port"`
|
|
}
|
|
|
|
func (mp *MediaProxy) Listen(cfg ServerConfig) error {
|
|
router := mux.NewRouter()
|
|
mp.RegisterRoutes(router)
|
|
return http.ListenAndServe(fmt.Sprintf("%s:%d", cfg.Hostname, cfg.Port), router)
|
|
}
|
|
|
|
func (mp *MediaProxy) GetServerName() string {
|
|
return mp.serverName
|
|
}
|
|
|
|
func (mp *MediaProxy) GetServerKey() *federation.SigningKey {
|
|
return mp.serverKey
|
|
}
|
|
|
|
func (mp *MediaProxy) RegisterRoutes(router *mux.Router) {
|
|
if mp.FederationRouter == nil {
|
|
mp.FederationRouter = router.PathPrefix("/_matrix/federation").Subrouter()
|
|
}
|
|
if mp.ClientMediaRouter == nil {
|
|
mp.ClientMediaRouter = router.PathPrefix("/_matrix/client/v1/media").Subrouter()
|
|
}
|
|
|
|
mp.FederationRouter.HandleFunc("/v1/media/download/{mediaID}", mp.DownloadMediaFederation).Methods(http.MethodGet)
|
|
mp.FederationRouter.HandleFunc("/v1/version", mp.KeyServer.GetServerVersion).Methods(http.MethodGet)
|
|
mp.ClientMediaRouter.HandleFunc("/download/{serverName}/{mediaID}", mp.DownloadMedia).Methods(http.MethodGet)
|
|
mp.ClientMediaRouter.HandleFunc("/download/{serverName}/{mediaID}/{fileName}", mp.DownloadMedia).Methods(http.MethodGet)
|
|
mp.ClientMediaRouter.HandleFunc("/thumbnail/{serverName}/{mediaID}", mp.DownloadMedia).Methods(http.MethodGet)
|
|
mp.ClientMediaRouter.HandleFunc("/upload/{serverName}/{mediaID}", mp.UploadNotSupported).Methods(http.MethodPut)
|
|
mp.ClientMediaRouter.HandleFunc("/upload", mp.UploadNotSupported).Methods(http.MethodPost)
|
|
mp.ClientMediaRouter.HandleFunc("/create", mp.UploadNotSupported).Methods(http.MethodPost)
|
|
mp.ClientMediaRouter.HandleFunc("/config", mp.UploadNotSupported).Methods(http.MethodGet)
|
|
mp.ClientMediaRouter.HandleFunc("/preview_url", mp.PreviewURLNotSupported).Methods(http.MethodGet)
|
|
mp.FederationRouter.NotFoundHandler = http.HandlerFunc(mp.UnknownEndpoint)
|
|
mp.FederationRouter.MethodNotAllowedHandler = http.HandlerFunc(mp.UnsupportedMethod)
|
|
mp.ClientMediaRouter.NotFoundHandler = http.HandlerFunc(mp.UnknownEndpoint)
|
|
mp.ClientMediaRouter.MethodNotAllowedHandler = http.HandlerFunc(mp.UnsupportedMethod)
|
|
corsMiddleware := func(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.Header().Set("Access-Control-Allow-Origin", "*")
|
|
w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS")
|
|
w.Header().Set("Access-Control-Allow-Headers", "X-Requested-With, Content-Type, Authorization")
|
|
w.Header().Set("Content-Security-Policy", "sandbox; default-src 'none'; script-src 'none'; plugin-types application/pdf; style-src 'unsafe-inline'; object-src 'self';")
|
|
next.ServeHTTP(w, r)
|
|
})
|
|
}
|
|
mp.ClientMediaRouter.Use(corsMiddleware)
|
|
mp.KeyServer.Register(router)
|
|
}
|
|
|
|
// Deprecated: use mautrix.RespError instead
|
|
type ResponseError struct {
|
|
Status int
|
|
Data any
|
|
}
|
|
|
|
func (err *ResponseError) Error() string {
|
|
return fmt.Sprintf("HTTP %d: %v", err.Status, err.Data)
|
|
}
|
|
|
|
var ErrInvalidMediaIDSyntax = errors.New("invalid media ID syntax")
|
|
|
|
func queryToMap(vals url.Values) map[string]string {
|
|
m := make(map[string]string, len(vals))
|
|
for k, v := range vals {
|
|
m[k] = v[0]
|
|
}
|
|
return m
|
|
}
|
|
|
|
func (mp *MediaProxy) getMedia(w http.ResponseWriter, r *http.Request) GetMediaResponse {
|
|
mediaID := mux.Vars(r)["mediaID"]
|
|
resp, err := mp.GetMedia(r.Context(), mediaID, queryToMap(r.URL.Query()))
|
|
if err != nil {
|
|
//lint:ignore SA1019 deprecated types need to be supported until they're removed
|
|
var respError *ResponseError
|
|
var mautrixRespError mautrix.RespError
|
|
if errors.Is(err, ErrInvalidMediaIDSyntax) {
|
|
mautrix.MNotFound.WithMessage("This is a media proxy at %q, other media downloads are not available here", mp.serverName).Write(w)
|
|
} else if errors.As(err, &mautrixRespError) {
|
|
mautrixRespError.Write(w)
|
|
} else if errors.As(err, &respError) {
|
|
w.Header().Add("Content-Type", "application/json")
|
|
w.WriteHeader(respError.Status)
|
|
_ = json.NewEncoder(w).Encode(respError.Data)
|
|
} else {
|
|
zerolog.Ctx(r.Context()).Err(err).Str("media_id", mediaID).Msg("Failed to get media URL")
|
|
mautrix.MNotFound.WithMessage("Media not found").Write(w)
|
|
}
|
|
return nil
|
|
}
|
|
return resp
|
|
}
|
|
|
|
func startMultipart(ctx context.Context, w http.ResponseWriter) *multipart.Writer {
|
|
mpw := multipart.NewWriter(w)
|
|
w.Header().Set("Content-Type", strings.Replace(mpw.FormDataContentType(), "form-data", "mixed", 1))
|
|
w.WriteHeader(http.StatusOK)
|
|
metaPart, err := mpw.CreatePart(textproto.MIMEHeader{
|
|
"Content-Type": {"application/json"},
|
|
})
|
|
if err != nil {
|
|
zerolog.Ctx(ctx).Err(err).Msg("Failed to create multipart metadata field")
|
|
return nil
|
|
}
|
|
_, err = metaPart.Write([]byte(`{}`))
|
|
if err != nil {
|
|
zerolog.Ctx(ctx).Err(err).Msg("Failed to write multipart metadata field")
|
|
return nil
|
|
}
|
|
return mpw
|
|
}
|
|
|
|
func (mp *MediaProxy) DownloadMediaFederation(w http.ResponseWriter, r *http.Request) {
|
|
ctx := r.Context()
|
|
log := zerolog.Ctx(ctx)
|
|
// TODO check destination header in X-Matrix auth
|
|
|
|
resp := mp.getMedia(w, r)
|
|
if resp == nil {
|
|
return
|
|
}
|
|
|
|
var mpw *multipart.Writer
|
|
if urlResp, ok := resp.(*GetMediaResponseURL); ok {
|
|
mpw = startMultipart(ctx, w)
|
|
if mpw == nil {
|
|
return
|
|
}
|
|
_, err := mpw.CreatePart(textproto.MIMEHeader{
|
|
"Location": {urlResp.URL},
|
|
})
|
|
if err != nil {
|
|
log.Err(err).Msg("Failed to create multipart redirect field")
|
|
return
|
|
}
|
|
} else if fileResp, ok := resp.(*GetMediaResponseFile); ok {
|
|
responseStarted, err := doTempFileDownload(fileResp, func(wt io.WriterTo, size int64, mimeType string) error {
|
|
mpw = startMultipart(ctx, w)
|
|
if mpw == nil {
|
|
return fmt.Errorf("failed to start multipart writer")
|
|
}
|
|
dataPart, err := mpw.CreatePart(textproto.MIMEHeader{
|
|
"Content-Type": {mimeType},
|
|
})
|
|
if err != nil {
|
|
return fmt.Errorf("failed to create multipart data field: %w", err)
|
|
}
|
|
_, err = wt.WriteTo(dataPart)
|
|
return err
|
|
})
|
|
if err != nil {
|
|
log.Err(err).Msg("Failed to do media proxy with temp file")
|
|
if !responseStarted {
|
|
var mautrixRespError mautrix.RespError
|
|
if errors.As(err, &mautrixRespError) {
|
|
mautrixRespError.Write(w)
|
|
} else {
|
|
mautrix.MUnknown.WithMessage("Internal error proxying media").Write(w)
|
|
}
|
|
}
|
|
return
|
|
}
|
|
} else if dataResp, ok := resp.(GetMediaResponseWriter); ok {
|
|
mpw = startMultipart(ctx, w)
|
|
if mpw == nil {
|
|
return
|
|
}
|
|
dataPart, err := mpw.CreatePart(textproto.MIMEHeader{
|
|
"Content-Type": {dataResp.GetContentType()},
|
|
})
|
|
if err != nil {
|
|
log.Err(err).Msg("Failed to create multipart data field")
|
|
return
|
|
}
|
|
_, err = dataResp.WriteTo(dataPart)
|
|
if err != nil {
|
|
log.Err(err).Msg("Failed to write multipart data field")
|
|
return
|
|
}
|
|
} else {
|
|
panic(fmt.Errorf("unknown GetMediaResponse type %T", resp))
|
|
}
|
|
err := mpw.Close()
|
|
if err != nil {
|
|
log.Err(err).Msg("Failed to close multipart writer")
|
|
return
|
|
}
|
|
}
|
|
|
|
func (mp *MediaProxy) addHeaders(w http.ResponseWriter, mimeType, fileName string) {
|
|
w.Header().Set("Cache-Control", "public, max-age=31536000, immutable")
|
|
contentDisposition := "attachment"
|
|
switch mimeType {
|
|
case "text/css", "text/plain", "text/csv", "application/json", "application/ld+json", "image/jpeg", "image/gif",
|
|
"image/png", "image/apng", "image/webp", "image/avif", "video/mp4", "video/webm", "video/ogg", "video/quicktime",
|
|
"audio/mp4", "audio/webm", "audio/aac", "audio/mpeg", "audio/ogg", "audio/wave", "audio/wav", "audio/x-wav",
|
|
"audio/x-pn-wav", "audio/flac", "audio/x-flac", "application/pdf":
|
|
contentDisposition = "inline"
|
|
}
|
|
if fileName != "" {
|
|
contentDisposition = mime.FormatMediaType(contentDisposition, map[string]string{
|
|
"filename": fileName,
|
|
})
|
|
}
|
|
w.Header().Set("Content-Disposition", contentDisposition)
|
|
w.Header().Set("Content-Type", mimeType)
|
|
}
|
|
|
|
func (mp *MediaProxy) DownloadMedia(w http.ResponseWriter, r *http.Request) {
|
|
ctx := r.Context()
|
|
log := zerolog.Ctx(ctx)
|
|
vars := mux.Vars(r)
|
|
if vars["serverName"] != mp.serverName {
|
|
mautrix.MNotFound.WithMessage("This is a media proxy at %q, other media downloads are not available here", mp.serverName).Write(w)
|
|
return
|
|
}
|
|
resp := mp.getMedia(w, r)
|
|
if resp == nil {
|
|
return
|
|
}
|
|
|
|
if urlResp, ok := resp.(*GetMediaResponseURL); ok {
|
|
w.Header().Set("Location", urlResp.URL)
|
|
expirySeconds := (time.Until(urlResp.ExpiresAt) - 5*time.Minute).Seconds()
|
|
if urlResp.ExpiresAt.IsZero() {
|
|
w.Header().Set("Cache-Control", "public, max-age=31536000, immutable")
|
|
} else if expirySeconds > 0 {
|
|
cacheControl := fmt.Sprintf("public, max-age=%d, immutable", int(expirySeconds))
|
|
w.Header().Set("Cache-Control", cacheControl)
|
|
} else {
|
|
w.Header().Set("Cache-Control", "no-store")
|
|
}
|
|
w.WriteHeader(http.StatusTemporaryRedirect)
|
|
} else if fileResp, ok := resp.(*GetMediaResponseFile); ok {
|
|
responseStarted, err := doTempFileDownload(fileResp, func(wt io.WriterTo, size int64, mimeType string) error {
|
|
mp.addHeaders(w, mimeType, vars["fileName"])
|
|
w.Header().Set("Content-Length", strconv.FormatInt(size, 10))
|
|
w.WriteHeader(http.StatusOK)
|
|
_, err := wt.WriteTo(w)
|
|
return err
|
|
})
|
|
if err != nil {
|
|
log.Err(err).Msg("Failed to do media proxy with temp file")
|
|
if !responseStarted {
|
|
var mautrixRespError mautrix.RespError
|
|
if errors.As(err, &mautrixRespError) {
|
|
mautrixRespError.Write(w)
|
|
} else {
|
|
mautrix.MUnknown.WithMessage("Internal error proxying media").Write(w)
|
|
}
|
|
}
|
|
}
|
|
} else if dataResp, ok := resp.(GetMediaResponseWriter); ok {
|
|
mp.addHeaders(w, dataResp.GetContentType(), vars["fileName"])
|
|
if dataResp.GetContentLength() != 0 {
|
|
w.Header().Set("Content-Length", strconv.FormatInt(dataResp.GetContentLength(), 10))
|
|
}
|
|
w.WriteHeader(http.StatusOK)
|
|
_, err := dataResp.WriteTo(w)
|
|
if err != nil {
|
|
log.Err(err).Msg("Failed to write media data")
|
|
}
|
|
} else {
|
|
panic(fmt.Errorf("unknown GetMediaResponse type %T", resp))
|
|
}
|
|
}
|
|
|
|
func doTempFileDownload(
|
|
data *GetMediaResponseFile,
|
|
respond func(w io.WriterTo, size int64, mimeType string) error,
|
|
) (bool, error) {
|
|
tempFile, err := os.CreateTemp("", "mautrix-mediaproxy-*")
|
|
if err != nil {
|
|
return false, fmt.Errorf("failed to create temp file: %w", err)
|
|
}
|
|
defer func() {
|
|
_ = tempFile.Close()
|
|
_ = os.Remove(tempFile.Name())
|
|
}()
|
|
err = data.Callback(tempFile)
|
|
if err != nil {
|
|
return false, err
|
|
}
|
|
_, err = tempFile.Seek(0, io.SeekStart)
|
|
if err != nil {
|
|
return false, fmt.Errorf("failed to seek to start of temp file: %w", err)
|
|
}
|
|
fileInfo, err := tempFile.Stat()
|
|
if err != nil {
|
|
return false, fmt.Errorf("failed to stat temp file: %w", err)
|
|
}
|
|
mimeType := data.ContentType
|
|
if mimeType == "" {
|
|
buf := make([]byte, 512)
|
|
n, err := tempFile.Read(buf)
|
|
if err != nil {
|
|
return false, fmt.Errorf("failed to read temp file to detect mime: %w", err)
|
|
}
|
|
buf = buf[:n]
|
|
_, err = tempFile.Seek(0, io.SeekStart)
|
|
if err != nil {
|
|
return false, fmt.Errorf("failed to seek to start of temp file: %w", err)
|
|
}
|
|
mimeType = http.DetectContentType(buf)
|
|
}
|
|
err = respond(tempFile, fileInfo.Size(), mimeType)
|
|
if err != nil {
|
|
return true, err
|
|
}
|
|
return true, nil
|
|
}
|
|
|
|
var (
|
|
ErrUploadNotSupported = mautrix.MUnrecognized.
|
|
WithMessage("This is a media proxy and does not support media uploads.").
|
|
WithStatus(http.StatusNotImplemented)
|
|
ErrPreviewURLNotSupported = mautrix.MUnrecognized.
|
|
WithMessage("This is a media proxy and does not support URL previews.").
|
|
WithStatus(http.StatusNotImplemented)
|
|
ErrUnknownEndpoint = mautrix.MUnrecognized.
|
|
WithMessage("Unrecognized endpoint")
|
|
ErrUnsupportedMethod = mautrix.MUnrecognized.
|
|
WithMessage("Invalid method for endpoint").
|
|
WithStatus(http.StatusMethodNotAllowed)
|
|
)
|
|
|
|
func (mp *MediaProxy) UploadNotSupported(w http.ResponseWriter, r *http.Request) {
|
|
ErrUploadNotSupported.Write(w)
|
|
}
|
|
|
|
func (mp *MediaProxy) PreviewURLNotSupported(w http.ResponseWriter, r *http.Request) {
|
|
ErrPreviewURLNotSupported.Write(w)
|
|
}
|
|
|
|
func (mp *MediaProxy) UnknownEndpoint(w http.ResponseWriter, r *http.Request) {
|
|
ErrUnknownEndpoint.Write(w)
|
|
}
|
|
|
|
func (mp *MediaProxy) UnsupportedMethod(w http.ResponseWriter, r *http.Request) {
|
|
ErrUnsupportedMethod.Write(w)
|
|
}
|