315 lines
8.2 KiB
Go
315 lines
8.2 KiB
Go
package repository
|
|
|
|
import (
|
|
"context"
|
|
"time"
|
|
|
|
"go.opentelemetry.io/otel"
|
|
"go.opentelemetry.io/otel/attribute"
|
|
"go.opentelemetry.io/otel/codes"
|
|
"go.opentelemetry.io/otel/trace"
|
|
|
|
"github.com/christianselig/apollo-backend/internal/domain"
|
|
)
|
|
|
|
type postgresAccountRepository struct {
|
|
conn Connection
|
|
tracer trace.Tracer
|
|
}
|
|
|
|
func NewPostgresAccount(conn Connection) domain.AccountRepository {
|
|
tracer := otel.Tracer("db:postgres:accounts")
|
|
return &postgresAccountRepository{conn: conn, tracer: tracer}
|
|
}
|
|
|
|
func (p *postgresAccountRepository) fetch(ctx context.Context, query string, args ...interface{}) ([]domain.Account, error) {
|
|
ctx, span := spanWithQuery(ctx, p.tracer, query)
|
|
defer span.End()
|
|
|
|
rows, err := p.conn.Query(ctx, query, args...)
|
|
if err != nil {
|
|
span.SetStatus(codes.Error, "failed querying accounts")
|
|
span.RecordError(err)
|
|
return nil, err
|
|
}
|
|
defer rows.Close()
|
|
|
|
var accs []domain.Account
|
|
for rows.Next() {
|
|
var acc domain.Account
|
|
if err := rows.Scan(
|
|
&acc.ID,
|
|
&acc.Username,
|
|
&acc.AccountID,
|
|
&acc.AccessToken,
|
|
&acc.RefreshToken,
|
|
&acc.TokenExpiresAt,
|
|
&acc.LastMessageID,
|
|
&acc.NextNotificationCheckAt,
|
|
&acc.NextStuckNotificationCheckAt,
|
|
&acc.CheckCount,
|
|
&acc.Development,
|
|
); err != nil {
|
|
return nil, err
|
|
}
|
|
accs = append(accs, acc)
|
|
}
|
|
return accs, nil
|
|
}
|
|
|
|
func (p *postgresAccountRepository) GetByID(ctx context.Context, id int64) (domain.Account, error) {
|
|
query := `
|
|
SELECT id, username, reddit_account_id, access_token, refresh_token, token_expires_at,
|
|
last_message_id, next_notification_check_at, next_stuck_notification_check_at,
|
|
check_count, development
|
|
FROM accounts
|
|
WHERE id = $1 AND is_deleted IS FALSE`
|
|
|
|
accs, err := p.fetch(ctx, query, id)
|
|
if err != nil {
|
|
return domain.Account{}, err
|
|
}
|
|
|
|
if len(accs) == 0 {
|
|
return domain.Account{}, domain.ErrNotFound
|
|
}
|
|
return accs[0], nil
|
|
}
|
|
|
|
func (p *postgresAccountRepository) GetByRedditID(ctx context.Context, id string) (domain.Account, error) {
|
|
query := `
|
|
SELECT id, username, reddit_account_id, access_token, refresh_token, token_expires_at,
|
|
last_message_id, next_notification_check_at, next_stuck_notification_check_at,
|
|
check_count, development
|
|
FROM accounts
|
|
WHERE reddit_account_id = $1 AND is_deleted IS FALSE`
|
|
|
|
accs, err := p.fetch(ctx, query, id)
|
|
if err != nil {
|
|
return domain.Account{}, err
|
|
}
|
|
|
|
if len(accs) == 0 {
|
|
return domain.Account{}, domain.ErrNotFound
|
|
}
|
|
|
|
return accs[0], nil
|
|
}
|
|
func (p *postgresAccountRepository) CreateOrUpdate(ctx context.Context, acc *domain.Account) error {
|
|
query := `
|
|
INSERT INTO accounts (username, reddit_account_id, access_token, refresh_token, token_expires_at,
|
|
last_message_id, next_notification_check_at, next_stuck_notification_check_at, is_deleted, development)
|
|
VALUES ($1, $2, $3, $4, $5, $6, NOW(), NOW(), FALSE, $7)
|
|
ON CONFLICT(username) DO
|
|
UPDATE SET access_token = $3,
|
|
refresh_token = $4,
|
|
token_expires_at = $5,
|
|
last_message_id = $6,
|
|
is_deleted = FALSE
|
|
RETURNING id`
|
|
|
|
ctx, span := spanWithQuery(ctx, p.tracer, query)
|
|
defer span.End()
|
|
|
|
if err := p.conn.QueryRow(
|
|
ctx,
|
|
query,
|
|
acc.Username,
|
|
acc.AccountID,
|
|
acc.AccessToken,
|
|
acc.RefreshToken,
|
|
acc.TokenExpiresAt,
|
|
acc.LastMessageID,
|
|
acc.Development,
|
|
).Scan(&acc.ID); err != nil {
|
|
span.SetStatus(codes.Error, "failed upserting account")
|
|
span.RecordError(err)
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (p *postgresAccountRepository) Create(ctx context.Context, acc *domain.Account) error {
|
|
query := `
|
|
INSERT INTO accounts
|
|
(username, reddit_account_id, access_token, refresh_token, token_expires_at,
|
|
last_message_id, next_notification_check_at, next_stuck_notification_check_at, is_deleted, development)
|
|
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, FALSE, $9)
|
|
RETURNING id`
|
|
|
|
ctx, span := spanWithQuery(ctx, p.tracer, query)
|
|
defer span.End()
|
|
|
|
if err := p.conn.QueryRow(
|
|
ctx,
|
|
query,
|
|
acc.Username,
|
|
acc.AccountID,
|
|
acc.AccessToken,
|
|
acc.RefreshToken,
|
|
acc.TokenExpiresAt,
|
|
acc.LastMessageID,
|
|
acc.NextNotificationCheckAt,
|
|
acc.NextStuckNotificationCheckAt,
|
|
acc.Development,
|
|
).Scan(&acc.ID); err != nil {
|
|
span.SetStatus(codes.Error, "failed inserting account")
|
|
span.RecordError(err)
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (p *postgresAccountRepository) Update(ctx context.Context, acc *domain.Account) error {
|
|
query := `
|
|
UPDATE accounts
|
|
SET username = $2,
|
|
reddit_account_id = $3,
|
|
access_token = $4,
|
|
refresh_token = $5,
|
|
token_expires_at = $6,
|
|
last_message_id = $7,
|
|
next_notification_check_at = $8,
|
|
next_stuck_notification_check_at = $9,
|
|
check_count = $10,
|
|
development = $11
|
|
WHERE id = $1`
|
|
|
|
ctx, span := spanWithQuery(ctx, p.tracer, query)
|
|
defer span.End()
|
|
|
|
if _, err := p.conn.Exec(
|
|
ctx,
|
|
query,
|
|
acc.ID,
|
|
acc.Username,
|
|
acc.AccountID,
|
|
acc.AccessToken,
|
|
acc.RefreshToken,
|
|
acc.TokenExpiresAt,
|
|
acc.LastMessageID,
|
|
acc.NextNotificationCheckAt,
|
|
acc.NextStuckNotificationCheckAt,
|
|
acc.CheckCount,
|
|
acc.Development,
|
|
); err != nil {
|
|
span.SetStatus(codes.Error, "failed to update account")
|
|
span.RecordError(err)
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (p *postgresAccountRepository) Delete(ctx context.Context, id int64) error {
|
|
query := `UPDATE accounts SET is_deleted = TRUE WHERE id = $1`
|
|
|
|
ctx, span := spanWithQuery(ctx, p.tracer, query)
|
|
defer span.End()
|
|
|
|
if _, err := p.conn.Exec(ctx, query, id); err != nil {
|
|
span.SetStatus(codes.Error, "failed to delete account")
|
|
span.RecordError(err)
|
|
return err
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (p *postgresAccountRepository) Associate(ctx context.Context, acc *domain.Account, dev *domain.Device) error {
|
|
query := `
|
|
INSERT INTO devices_accounts
|
|
(account_id, device_id)
|
|
VALUES ($1, $2)
|
|
ON CONFLICT(account_id, device_id) DO NOTHING`
|
|
|
|
ctx, span := spanWithQuery(ctx, p.tracer, query)
|
|
defer span.End()
|
|
|
|
if _, err := p.conn.Exec(ctx, query, acc.ID, dev.ID); err != nil {
|
|
span.SetStatus(codes.Error, "failed to associate account to device")
|
|
span.RecordError(err)
|
|
return err
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (p *postgresAccountRepository) Disassociate(ctx context.Context, acc *domain.Account, dev *domain.Device) error {
|
|
query := `DELETE FROM devices_accounts WHERE account_id = $1 AND device_id = $2`
|
|
|
|
ctx, span := spanWithQuery(ctx, p.tracer, query)
|
|
defer span.End()
|
|
|
|
if _, err := p.conn.Exec(ctx, query, acc.ID, dev.ID); err != nil {
|
|
span.SetStatus(codes.Error, "failed to disassociate account from device")
|
|
span.RecordError(err)
|
|
return err
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (p *postgresAccountRepository) GetByAPNSToken(ctx context.Context, token string) ([]domain.Account, error) {
|
|
query := `
|
|
SELECT accounts.id, username, accounts.reddit_account_id, access_token, refresh_token, token_expires_at,
|
|
last_message_id, next_notification_check_at, next_stuck_notification_check_at,
|
|
check_count, development
|
|
FROM accounts
|
|
INNER JOIN devices_accounts ON accounts.id = devices_accounts.account_id
|
|
INNER JOIN devices ON devices.id = devices_accounts.device_id
|
|
WHERE devices.apns_token = $1
|
|
AND accounts.is_deleted IS FALSE`
|
|
|
|
return p.fetch(ctx, query, token)
|
|
}
|
|
|
|
func (p *postgresAccountRepository) PruneStale(ctx context.Context, expiry time.Time) (int64, error) {
|
|
query := `
|
|
UPDATE accounts
|
|
SET is_deleted = TRUE
|
|
WHERE token_expires_at < $1`
|
|
|
|
ctx, span := spanWithQuery(ctx, p.tracer, query)
|
|
defer span.End()
|
|
|
|
res, err := p.conn.Exec(ctx, query, expiry)
|
|
if err != nil {
|
|
span.SetStatus(codes.Error, "failed to prune stale accounts")
|
|
span.RecordError(err)
|
|
}
|
|
|
|
span.SetAttributes(attribute.Int64("db.result.rows_affected", res.RowsAffected()))
|
|
|
|
return res.RowsAffected(), err
|
|
}
|
|
|
|
func (p *postgresAccountRepository) PruneOrphaned(ctx context.Context) (int64, error) {
|
|
query := `
|
|
WITH accounts_with_device_count AS (
|
|
SELECT accounts.id, COUNT(device_id) AS device_count
|
|
FROM accounts
|
|
LEFT JOIN devices_accounts ON accounts.id = devices_accounts.account_id
|
|
GROUP BY accounts.id
|
|
)
|
|
UPDATE accounts
|
|
SET is_deleted = TRUE
|
|
WHERE id IN (
|
|
SELECT id
|
|
FROM accounts_with_device_count
|
|
WHERE device_count = 0
|
|
)`
|
|
|
|
ctx, span := spanWithQuery(ctx, p.tracer, query)
|
|
defer span.End()
|
|
|
|
res, err := p.conn.Exec(ctx, query)
|
|
if err != nil {
|
|
span.SetStatus(codes.Error, "failed to prune orphaned accounts")
|
|
span.RecordError(err)
|
|
}
|
|
|
|
span.SetAttributes(attribute.Int64("db.result.rows_affected", res.RowsAffected()))
|
|
|
|
return res.RowsAffected(), err
|
|
}
|