mirror of https://github.com/mautrix/go.git
336 lines
8.9 KiB
Go
336 lines
8.9 KiB
Go
// Copyright (c) 2024 Tulir Asokan
|
|
//
|
|
// This Source Code Form is subject to the terms of the Mozilla Public
|
|
// License, v. 2.0. If a copy of the MPL was not distributed with this
|
|
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
|
|
|
package bridgev2
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/rs/zerolog"
|
|
"go.mau.fi/util/dbutil"
|
|
|
|
"maunium.net/go/mautrix/bridge/status"
|
|
"maunium.net/go/mautrix/bridgev2/bridgeconfig"
|
|
"maunium.net/go/mautrix/bridgev2/database"
|
|
"maunium.net/go/mautrix/bridgev2/networkid"
|
|
"maunium.net/go/mautrix/id"
|
|
)
|
|
|
|
type CommandProcessor interface {
|
|
Handle(ctx context.Context, roomID id.RoomID, eventID id.EventID, user *User, message string, replyTo id.EventID)
|
|
}
|
|
|
|
type Bridge struct {
|
|
ID networkid.BridgeID
|
|
DB *database.Database
|
|
Log zerolog.Logger
|
|
|
|
Matrix MatrixConnector
|
|
Bot MatrixAPI
|
|
Network NetworkConnector
|
|
Commands CommandProcessor
|
|
Config *bridgeconfig.BridgeConfig
|
|
|
|
DisappearLoop *DisappearLoop
|
|
|
|
usersByMXID map[id.UserID]*User
|
|
userLoginsByID map[networkid.UserLoginID]*UserLogin
|
|
portalsByKey map[networkid.PortalKey]*Portal
|
|
portalsByMXID map[id.RoomID]*Portal
|
|
ghostsByID map[networkid.UserID]*Ghost
|
|
cacheLock sync.Mutex
|
|
|
|
didSplitPortals bool
|
|
|
|
Background bool
|
|
|
|
wakeupBackfillQueue chan struct{}
|
|
stopBackfillQueue chan struct{}
|
|
}
|
|
|
|
func NewBridge(
|
|
bridgeID networkid.BridgeID,
|
|
db *dbutil.Database,
|
|
log zerolog.Logger,
|
|
cfg *bridgeconfig.BridgeConfig,
|
|
matrix MatrixConnector,
|
|
network NetworkConnector,
|
|
newCommandProcessor func(*Bridge) CommandProcessor,
|
|
) *Bridge {
|
|
br := &Bridge{
|
|
ID: bridgeID,
|
|
DB: database.New(bridgeID, network.GetDBMetaTypes(), db),
|
|
Log: log,
|
|
|
|
Matrix: matrix,
|
|
Network: network,
|
|
Config: cfg,
|
|
|
|
usersByMXID: make(map[id.UserID]*User),
|
|
userLoginsByID: make(map[networkid.UserLoginID]*UserLogin),
|
|
portalsByKey: make(map[networkid.PortalKey]*Portal),
|
|
portalsByMXID: make(map[id.RoomID]*Portal),
|
|
ghostsByID: make(map[networkid.UserID]*Ghost),
|
|
|
|
wakeupBackfillQueue: make(chan struct{}),
|
|
stopBackfillQueue: make(chan struct{}),
|
|
}
|
|
if br.Config == nil {
|
|
br.Config = &bridgeconfig.BridgeConfig{CommandPrefix: "!bridge"}
|
|
}
|
|
br.Commands = newCommandProcessor(br)
|
|
br.Matrix.Init(br)
|
|
br.Bot = br.Matrix.BotIntent()
|
|
br.Network.Init(br)
|
|
br.DisappearLoop = &DisappearLoop{br: br}
|
|
return br
|
|
}
|
|
|
|
type DBUpgradeError struct {
|
|
Err error
|
|
Section string
|
|
}
|
|
|
|
func (e DBUpgradeError) Error() string {
|
|
return e.Err.Error()
|
|
}
|
|
|
|
func (e DBUpgradeError) Unwrap() error {
|
|
return e.Err
|
|
}
|
|
|
|
func (br *Bridge) Start() error {
|
|
ctx := br.Log.WithContext(context.Background())
|
|
err := br.StartConnectors(ctx)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
err = br.StartLogins(ctx)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
br.PostStart(ctx)
|
|
return nil
|
|
}
|
|
|
|
func (br *Bridge) RunOnce(ctx context.Context, loginID networkid.UserLoginID) error {
|
|
br.Background = true
|
|
err := br.StartConnectors(ctx)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if loginID == "" {
|
|
br.Log.Info().Msg("No login ID provided to RunOnce, running all logins for 20 seconds")
|
|
err = br.StartLogins(ctx)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer br.Stop()
|
|
select {
|
|
case <-time.After(20 * time.Second):
|
|
case <-ctx.Done():
|
|
}
|
|
return nil
|
|
}
|
|
|
|
defer br.stop(true)
|
|
login, err := br.GetExistingUserLoginByID(ctx, loginID)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to get user login: %w", err)
|
|
} else if login == nil {
|
|
return ErrNotLoggedIn
|
|
}
|
|
syncClient, ok := login.Client.(BackgroundSyncingNetworkAPI)
|
|
if !ok {
|
|
br.Log.Warn().Msg("Network connector doesn't implement background mode, using fallback mechanism for RunOnce")
|
|
login.Client.Connect(ctx)
|
|
defer login.Disconnect(nil)
|
|
select {
|
|
case <-time.After(20 * time.Second):
|
|
case <-ctx.Done():
|
|
}
|
|
return nil
|
|
} else {
|
|
br.Log.Info().Str("user_login_id", string(login.ID)).Msg("Starting individual user login in background mode")
|
|
return syncClient.ConnectBackground(login.Log.WithContext(ctx))
|
|
}
|
|
}
|
|
|
|
func (br *Bridge) StartConnectors(ctx context.Context) error {
|
|
br.Log.Info().Msg("Starting bridge")
|
|
|
|
err := br.DB.Upgrade(ctx)
|
|
if err != nil {
|
|
return DBUpgradeError{Err: err, Section: "main"}
|
|
}
|
|
if !br.Background {
|
|
br.didSplitPortals = br.MigrateToSplitPortals(ctx)
|
|
}
|
|
br.Log.Info().Msg("Starting Matrix connector")
|
|
err = br.Matrix.Start(ctx)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to start Matrix connector: %w", err)
|
|
}
|
|
br.Log.Info().Msg("Starting network connector")
|
|
err = br.Network.Start(ctx)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to start network connector: %w", err)
|
|
}
|
|
if br.Network.GetCapabilities().DisappearingMessages && !br.Background {
|
|
go br.DisappearLoop.Start()
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (br *Bridge) PostStart(ctx context.Context) {
|
|
if br.Background {
|
|
return
|
|
}
|
|
rawBridgeInfoVer := br.DB.KV.Get(ctx, database.KeyBridgeInfoVersion)
|
|
bridgeInfoVer, capVer, err := parseBridgeInfoVersion(rawBridgeInfoVer)
|
|
if err != nil {
|
|
br.Log.Err(err).Str("db_bridge_info_version", rawBridgeInfoVer).Msg("Failed to parse bridge info version")
|
|
return
|
|
}
|
|
expectedBridgeInfoVer, expectedCapVer := br.Network.GetBridgeInfoVersion()
|
|
doResendBridgeInfo := bridgeInfoVer != expectedBridgeInfoVer || br.didSplitPortals || br.Config.ResendBridgeInfo
|
|
doResendCapabilities := capVer != expectedCapVer || br.didSplitPortals
|
|
if doResendBridgeInfo || doResendCapabilities {
|
|
br.ResendBridgeInfo(ctx, doResendBridgeInfo, doResendCapabilities)
|
|
}
|
|
br.DB.KV.Set(ctx, database.KeyBridgeInfoVersion, fmt.Sprintf("%d,%d", expectedBridgeInfoVer, expectedCapVer))
|
|
}
|
|
|
|
func parseBridgeInfoVersion(version string) (info, capabilities int, err error) {
|
|
_, err = fmt.Sscanf(version, "%d,%d", &info, &capabilities)
|
|
if version == "" {
|
|
err = nil
|
|
}
|
|
return
|
|
}
|
|
|
|
func (br *Bridge) ResendBridgeInfo(ctx context.Context, resendInfo, resendCaps bool) {
|
|
log := zerolog.Ctx(ctx).With().Str("action", "resend bridge info").Logger()
|
|
portals, err := br.GetAllPortalsWithMXID(ctx)
|
|
if err != nil {
|
|
log.Err(err).Msg("Failed to get portals")
|
|
return
|
|
}
|
|
for _, portal := range portals {
|
|
if resendInfo {
|
|
portal.UpdateBridgeInfo(ctx)
|
|
}
|
|
if resendCaps {
|
|
logins, err := br.GetUserLoginsInPortal(ctx, portal.PortalKey)
|
|
if err != nil {
|
|
log.Err(err).
|
|
Stringer("room_id", portal.MXID).
|
|
Object("portal_key", portal.PortalKey).
|
|
Msg("Failed to get user logins in portal")
|
|
} else {
|
|
found := false
|
|
for _, login := range logins {
|
|
if portal.CapState.ID == "" || login.ID == portal.CapState.Source {
|
|
portal.UpdateCapabilities(ctx, login, true)
|
|
found = true
|
|
}
|
|
}
|
|
if !found && len(logins) > 0 {
|
|
portal.CapState.Source = ""
|
|
portal.UpdateCapabilities(ctx, logins[0], true)
|
|
} else if !found {
|
|
log.Warn().
|
|
Stringer("room_id", portal.MXID).
|
|
Object("portal_key", portal.PortalKey).
|
|
Msg("No user login found to update capabilities")
|
|
}
|
|
}
|
|
}
|
|
}
|
|
log.Info().
|
|
Bool("capabilities", resendCaps).
|
|
Bool("info", resendInfo).
|
|
Msg("Resent bridge info to all portals")
|
|
}
|
|
|
|
func (br *Bridge) MigrateToSplitPortals(ctx context.Context) bool {
|
|
log := zerolog.Ctx(ctx).With().Str("action", "migrate to split portals").Logger()
|
|
ctx = log.WithContext(ctx)
|
|
if !br.Config.SplitPortals || br.DB.KV.Get(ctx, database.KeySplitPortalsEnabled) == "true" {
|
|
return false
|
|
}
|
|
affected, err := br.DB.Portal.MigrateToSplitPortals(ctx)
|
|
if err != nil {
|
|
log.Err(err).Msg("Failed to migrate portals")
|
|
return false
|
|
}
|
|
log.Info().Int64("rows_affected", affected).Msg("Migrated to split portals")
|
|
br.DB.KV.Set(ctx, database.KeySplitPortalsEnabled, "true")
|
|
return affected > 0
|
|
}
|
|
|
|
func (br *Bridge) StartLogins(ctx context.Context) error {
|
|
userIDs, err := br.DB.UserLogin.GetAllUserIDsWithLogins(ctx)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to get users with logins: %w", err)
|
|
}
|
|
startedAny := false
|
|
for _, userID := range userIDs {
|
|
br.Log.Info().Stringer("user_id", userID).Msg("Loading user")
|
|
var user *User
|
|
user, err = br.GetUserByMXID(ctx, userID)
|
|
if err != nil {
|
|
br.Log.Err(err).Stringer("user_id", userID).Msg("Failed to load user")
|
|
} else {
|
|
for _, login := range user.GetUserLogins() {
|
|
startedAny = true
|
|
br.Log.Info().Str("id", string(login.ID)).Msg("Starting user login")
|
|
login.Client.Connect(login.Log.WithContext(ctx))
|
|
}
|
|
}
|
|
}
|
|
if !startedAny {
|
|
br.Log.Info().Msg("No user logins found")
|
|
br.SendGlobalBridgeState(status.BridgeState{StateEvent: status.StateUnconfigured})
|
|
}
|
|
go br.RunBackfillQueue()
|
|
|
|
br.Log.Info().Msg("Bridge started")
|
|
return nil
|
|
}
|
|
|
|
func (br *Bridge) Stop() {
|
|
br.stop(false)
|
|
}
|
|
|
|
func (br *Bridge) stop(isRunOnce bool) {
|
|
br.Log.Info().Msg("Shutting down bridge")
|
|
close(br.stopBackfillQueue)
|
|
br.Matrix.Stop()
|
|
if !isRunOnce {
|
|
br.cacheLock.Lock()
|
|
var wg sync.WaitGroup
|
|
wg.Add(len(br.userLoginsByID))
|
|
for _, login := range br.userLoginsByID {
|
|
go login.Disconnect(wg.Done)
|
|
}
|
|
wg.Wait()
|
|
br.cacheLock.Unlock()
|
|
}
|
|
if stopNet, ok := br.Network.(StoppableNetwork); ok {
|
|
stopNet.Stop()
|
|
}
|
|
err := br.DB.Close()
|
|
if err != nil {
|
|
br.Log.Warn().Err(err).Msg("Failed to close database")
|
|
}
|
|
br.Log.Info().Msg("Shutdown complete")
|
|
}
|