mirror of https://github.com/mautrix/go.git
226 lines
9.6 KiB
Go
226 lines
9.6 KiB
Go
// Copyright (c) 2024 Tulir Asokan
|
|
//
|
|
// This Source Code Form is subject to the terms of the Mozilla Public
|
|
// License, v. 2.0. If a copy of the MPL was not distributed with this
|
|
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
|
|
|
package database
|
|
|
|
import (
|
|
"context"
|
|
"crypto/sha256"
|
|
"database/sql"
|
|
"encoding/base64"
|
|
"strings"
|
|
"time"
|
|
|
|
"go.mau.fi/util/dbutil"
|
|
|
|
"maunium.net/go/mautrix/bridgev2/networkid"
|
|
"maunium.net/go/mautrix/id"
|
|
)
|
|
|
|
type MessageQuery struct {
|
|
BridgeID networkid.BridgeID
|
|
MetaType MetaTypeCreator
|
|
*dbutil.QueryHelper[*Message]
|
|
}
|
|
|
|
type Message struct {
|
|
RowID int64
|
|
BridgeID networkid.BridgeID
|
|
ID networkid.MessageID
|
|
PartID networkid.PartID
|
|
MXID id.EventID
|
|
|
|
Room networkid.PortalKey
|
|
SenderID networkid.UserID
|
|
SenderMXID id.UserID
|
|
Timestamp time.Time
|
|
EditCount int
|
|
IsDoublePuppeted bool
|
|
|
|
ThreadRoot networkid.MessageID
|
|
ReplyTo networkid.MessageOptionalPartID
|
|
|
|
Metadata any
|
|
}
|
|
|
|
const (
|
|
getMessageBaseQuery = `
|
|
SELECT rowid, bridge_id, id, part_id, mxid, room_id, room_receiver, sender_id, sender_mxid,
|
|
timestamp, edit_count, double_puppeted, thread_root_id, reply_to_id, reply_to_part_id, metadata
|
|
FROM message
|
|
`
|
|
getAllMessagePartsByIDQuery = getMessageBaseQuery + `WHERE bridge_id=$1 AND (room_receiver=$2 OR room_receiver='') AND id=$3`
|
|
getMessagePartByIDQuery = getMessageBaseQuery + `WHERE bridge_id=$1 AND (room_receiver=$2 OR room_receiver='') AND id=$3 AND part_id=$4`
|
|
getMessagePartByRowIDQuery = getMessageBaseQuery + `WHERE bridge_id=$1 AND rowid=$2`
|
|
getMessageByMXIDQuery = getMessageBaseQuery + `WHERE bridge_id=$1 AND mxid=$2`
|
|
getLastMessagePartByIDQuery = getMessageBaseQuery + `WHERE bridge_id=$1 AND (room_receiver=$2 OR room_receiver='') AND id=$3 ORDER BY part_id DESC LIMIT 1`
|
|
getFirstMessagePartByIDQuery = getMessageBaseQuery + `WHERE bridge_id=$1 AND (room_receiver=$2 OR room_receiver='') AND id=$3 ORDER BY part_id ASC LIMIT 1`
|
|
getMessagesBetweenTimeQuery = getMessageBaseQuery + `WHERE bridge_id=$1 AND room_id=$2 AND room_receiver=$3 AND timestamp>$4 AND timestamp<=$5`
|
|
getOldestMessageInPortal = getMessageBaseQuery + `WHERE bridge_id=$1 AND room_id=$2 AND room_receiver=$3 ORDER BY timestamp ASC, part_id ASC LIMIT 1`
|
|
getFirstMessageInThread = getMessageBaseQuery + `WHERE bridge_id=$1 AND room_id=$2 AND room_receiver=$3 AND (id=$4 OR thread_root_id=$4) ORDER BY timestamp ASC, part_id ASC LIMIT 1`
|
|
getLastMessageInThread = getMessageBaseQuery + `WHERE bridge_id=$1 AND room_id=$2 AND room_receiver=$3 AND (id=$4 OR thread_root_id=$4) ORDER BY timestamp DESC, part_id DESC LIMIT 1`
|
|
getLastNInPortal = getMessageBaseQuery + `WHERE bridge_id=$1 AND room_id=$2 AND room_receiver=$3 ORDER BY timestamp DESC, part_id DESC LIMIT $4`
|
|
|
|
getLastMessagePartAtOrBeforeTimeQuery = getMessageBaseQuery + `WHERE bridge_id = $1 AND room_id=$2 AND room_receiver=$3 AND timestamp<=$4 ORDER BY timestamp DESC, part_id DESC LIMIT 1`
|
|
|
|
countMessagesInPortalQuery = `
|
|
SELECT COUNT(*) FROM message WHERE bridge_id=$1 AND room_id=$2 AND room_receiver=$3
|
|
`
|
|
|
|
insertMessageQuery = `
|
|
INSERT INTO message (
|
|
bridge_id, id, part_id, mxid, room_id, room_receiver, sender_id, sender_mxid,
|
|
timestamp, edit_count, double_puppeted, thread_root_id, reply_to_id, reply_to_part_id, metadata
|
|
)
|
|
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15)
|
|
RETURNING rowid
|
|
`
|
|
updateMessageQuery = `
|
|
UPDATE message SET id=$2, part_id=$3, mxid=$4, room_id=$5, room_receiver=$6, sender_id=$7, sender_mxid=$8,
|
|
timestamp=$9, edit_count=$10, double_puppeted=$11, thread_root_id=$12, reply_to_id=$13,
|
|
reply_to_part_id=$14, metadata=$15
|
|
WHERE bridge_id=$1 AND rowid=$16
|
|
`
|
|
deleteAllMessagePartsByIDQuery = `
|
|
DELETE FROM message WHERE bridge_id=$1 AND (room_receiver=$2 OR room_receiver='') AND id=$3
|
|
`
|
|
deleteMessagePartByRowIDQuery = `
|
|
DELETE FROM message WHERE bridge_id=$1 AND rowid=$2
|
|
`
|
|
)
|
|
|
|
func (mq *MessageQuery) GetAllPartsByID(ctx context.Context, receiver networkid.UserLoginID, id networkid.MessageID) ([]*Message, error) {
|
|
return mq.QueryMany(ctx, getAllMessagePartsByIDQuery, mq.BridgeID, receiver, id)
|
|
}
|
|
|
|
func (mq *MessageQuery) GetPartByID(ctx context.Context, receiver networkid.UserLoginID, id networkid.MessageID, partID networkid.PartID) (*Message, error) {
|
|
return mq.QueryOne(ctx, getMessagePartByIDQuery, mq.BridgeID, receiver, id, partID)
|
|
}
|
|
|
|
func (mq *MessageQuery) GetPartByMXID(ctx context.Context, mxid id.EventID) (*Message, error) {
|
|
return mq.QueryOne(ctx, getMessageByMXIDQuery, mq.BridgeID, mxid)
|
|
}
|
|
|
|
func (mq *MessageQuery) GetLastPartByID(ctx context.Context, receiver networkid.UserLoginID, id networkid.MessageID) (*Message, error) {
|
|
return mq.QueryOne(ctx, getLastMessagePartByIDQuery, mq.BridgeID, receiver, id)
|
|
}
|
|
|
|
func (mq *MessageQuery) GetFirstPartByID(ctx context.Context, receiver networkid.UserLoginID, id networkid.MessageID) (*Message, error) {
|
|
return mq.QueryOne(ctx, getFirstMessagePartByIDQuery, mq.BridgeID, receiver, id)
|
|
}
|
|
|
|
func (mq *MessageQuery) GetByRowID(ctx context.Context, rowID int64) (*Message, error) {
|
|
return mq.QueryOne(ctx, getMessagePartByRowIDQuery, mq.BridgeID, rowID)
|
|
}
|
|
|
|
func (mq *MessageQuery) GetFirstOrSpecificPartByID(ctx context.Context, receiver networkid.UserLoginID, id networkid.MessageOptionalPartID) (*Message, error) {
|
|
if id.PartID == nil {
|
|
return mq.GetFirstPartByID(ctx, receiver, id.MessageID)
|
|
} else {
|
|
return mq.GetPartByID(ctx, receiver, id.MessageID, *id.PartID)
|
|
}
|
|
}
|
|
|
|
func (mq *MessageQuery) GetLastPartAtOrBeforeTime(ctx context.Context, portal networkid.PortalKey, maxTS time.Time) (*Message, error) {
|
|
return mq.QueryOne(ctx, getLastMessagePartAtOrBeforeTimeQuery, mq.BridgeID, portal.ID, portal.Receiver, maxTS.UnixNano())
|
|
}
|
|
|
|
func (mq *MessageQuery) GetMessagesBetweenTimeQuery(ctx context.Context, portal networkid.PortalKey, start, end time.Time) ([]*Message, error) {
|
|
return mq.QueryMany(ctx, getMessagesBetweenTimeQuery, mq.BridgeID, portal.ID, portal.Receiver, start.UnixNano(), end.UnixNano())
|
|
}
|
|
|
|
func (mq *MessageQuery) GetFirstPortalMessage(ctx context.Context, portal networkid.PortalKey) (*Message, error) {
|
|
return mq.QueryOne(ctx, getOldestMessageInPortal, mq.BridgeID, portal.ID, portal.Receiver)
|
|
}
|
|
|
|
func (mq *MessageQuery) GetFirstThreadMessage(ctx context.Context, portal networkid.PortalKey, threadRoot networkid.MessageID) (*Message, error) {
|
|
return mq.QueryOne(ctx, getFirstMessageInThread, mq.BridgeID, portal.ID, portal.Receiver, threadRoot)
|
|
}
|
|
|
|
func (mq *MessageQuery) GetLastThreadMessage(ctx context.Context, portal networkid.PortalKey, threadRoot networkid.MessageID) (*Message, error) {
|
|
return mq.QueryOne(ctx, getLastMessageInThread, mq.BridgeID, portal.ID, portal.Receiver, threadRoot)
|
|
}
|
|
|
|
func (mq *MessageQuery) GetLastNInPortal(ctx context.Context, portal networkid.PortalKey, n int) ([]*Message, error) {
|
|
return mq.QueryMany(ctx, getLastNInPortal, mq.BridgeID, portal.ID, portal.Receiver, n)
|
|
}
|
|
|
|
func (mq *MessageQuery) Insert(ctx context.Context, msg *Message) error {
|
|
ensureBridgeIDMatches(&msg.BridgeID, mq.BridgeID)
|
|
return mq.GetDB().QueryRow(ctx, insertMessageQuery, msg.ensureHasMetadata(mq.MetaType).sqlVariables()...).Scan(&msg.RowID)
|
|
}
|
|
|
|
func (mq *MessageQuery) Update(ctx context.Context, msg *Message) error {
|
|
ensureBridgeIDMatches(&msg.BridgeID, mq.BridgeID)
|
|
return mq.Exec(ctx, updateMessageQuery, msg.ensureHasMetadata(mq.MetaType).updateSQLVariables()...)
|
|
}
|
|
|
|
func (mq *MessageQuery) DeleteAllParts(ctx context.Context, receiver networkid.UserLoginID, id networkid.MessageID) error {
|
|
return mq.Exec(ctx, deleteAllMessagePartsByIDQuery, mq.BridgeID, receiver, id)
|
|
}
|
|
|
|
func (mq *MessageQuery) Delete(ctx context.Context, rowID int64) error {
|
|
return mq.Exec(ctx, deleteMessagePartByRowIDQuery, mq.BridgeID, rowID)
|
|
}
|
|
|
|
func (mq *MessageQuery) CountMessagesInPortal(ctx context.Context, key networkid.PortalKey) (count int, err error) {
|
|
err = mq.GetDB().QueryRow(ctx, countMessagesInPortalQuery, mq.BridgeID, key.ID, key.Receiver).Scan(&count)
|
|
return
|
|
}
|
|
|
|
func (m *Message) Scan(row dbutil.Scannable) (*Message, error) {
|
|
var timestamp int64
|
|
var threadRootID, replyToID, replyToPartID sql.NullString
|
|
var doublePuppeted sql.NullBool
|
|
err := row.Scan(
|
|
&m.RowID, &m.BridgeID, &m.ID, &m.PartID, &m.MXID, &m.Room.ID, &m.Room.Receiver, &m.SenderID, &m.SenderMXID,
|
|
×tamp, &m.EditCount, &doublePuppeted, &threadRootID, &replyToID, &replyToPartID, dbutil.JSON{Data: m.Metadata},
|
|
)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
m.Timestamp = time.Unix(0, timestamp)
|
|
m.ThreadRoot = networkid.MessageID(threadRootID.String)
|
|
m.IsDoublePuppeted = doublePuppeted.Valid
|
|
if replyToID.Valid {
|
|
m.ReplyTo.MessageID = networkid.MessageID(replyToID.String)
|
|
if replyToPartID.Valid {
|
|
m.ReplyTo.PartID = (*networkid.PartID)(&replyToPartID.String)
|
|
}
|
|
}
|
|
return m, nil
|
|
}
|
|
|
|
func (m *Message) ensureHasMetadata(metaType MetaTypeCreator) *Message {
|
|
if m.Metadata == nil {
|
|
m.Metadata = metaType()
|
|
}
|
|
return m
|
|
}
|
|
|
|
func (m *Message) sqlVariables() []any {
|
|
return []any{
|
|
m.BridgeID, m.ID, m.PartID, m.MXID, m.Room.ID, m.Room.Receiver, m.SenderID, m.SenderMXID,
|
|
m.Timestamp.UnixNano(), m.EditCount, m.IsDoublePuppeted, dbutil.StrPtr(m.ThreadRoot),
|
|
dbutil.StrPtr(m.ReplyTo.MessageID), m.ReplyTo.PartID, dbutil.JSON{Data: m.Metadata},
|
|
}
|
|
}
|
|
|
|
func (m *Message) updateSQLVariables() []any {
|
|
return append(m.sqlVariables(), m.RowID)
|
|
}
|
|
|
|
const FakeMXIDPrefix = "~fake:"
|
|
|
|
func (m *Message) SetFakeMXID() {
|
|
hash := sha256.Sum256([]byte(m.ID))
|
|
m.MXID = id.EventID(FakeMXIDPrefix + base64.RawURLEncoding.EncodeToString(hash[:]))
|
|
}
|
|
|
|
func (m *Message) HasFakeMXID() bool {
|
|
return strings.HasPrefix(m.MXID.String(), FakeMXIDPrefix)
|
|
}
|