authelia/internal/storage/sql_provider_encryption.go

570 lines
16 KiB
Go

package storage
import (
"bytes"
"context"
"crypto/hmac"
"crypto/rand"
"crypto/sha256"
"crypto/sha512"
"database/sql"
"errors"
"fmt"
"github.com/google/uuid"
"github.com/jmoiron/sqlx"
"github.com/authelia/authelia/v4/internal/utils"
)
// SchemaEncryptionChangeKey uses the currently configured key to decrypt values in the storage provider and the key
// provided by this command to encrypt the values again and update them using a transaction.
func (p *SQLProvider) SchemaEncryptionChangeKey(ctx context.Context, rawKey string) (err error) {
key := sha256.Sum256([]byte(rawKey))
if bytes.Equal(key[:], p.keys.encryption[:]) {
return fmt.Errorf("error changing the storage encryption key: the old key and the new key are the same")
}
if _, err = p.SchemaEncryptionCheckKey(ctx, false); err != nil {
return fmt.Errorf("error changing the storage encryption key: %w", err)
}
tx, err := p.db.Beginx()
if err != nil {
return fmt.Errorf("error beginning transaction to change encryption key: %w", err)
}
encChangeFuncs := []EncryptionChangeKeyFunc{
schemaEncryptionChangeKeyOneTimeCode,
schemaEncryptionChangeKeyTOTP,
schemaEncryptionChangeKeyWebAuthn,
}
for i := 0; true; i++ {
typeOAuth2Session := OAuth2SessionType(i)
if typeOAuth2Session.Table() == "" {
break
}
encChangeFuncs = append(encChangeFuncs, schemaEncryptionChangeKeyOpenIDConnect(typeOAuth2Session))
}
encChangeFuncs = append(encChangeFuncs, schemaEncryptionChangeKeyEncryption)
for _, encChangeFunc := range encChangeFuncs {
if err = encChangeFunc(ctx, p, tx, key); err != nil {
if rerr := tx.Rollback(); rerr != nil {
return fmt.Errorf("rollback error %v: rollback due to error: %w", rerr, err)
}
return fmt.Errorf("rollback due to error: %w", err)
}
}
return tx.Commit()
}
// SchemaEncryptionCheckKey checks the encryption key configured is valid for the database.
func (p *SQLProvider) SchemaEncryptionCheckKey(ctx context.Context, verbose bool) (result EncryptionValidationResult, err error) {
version, err := p.SchemaVersion(ctx)
if err != nil {
return result, err
}
if version < 1 {
return result, ErrSchemaEncryptionVersionUnsupported
}
result = EncryptionValidationResult{
Tables: map[string]EncryptionValidationTableResult{},
}
if _, err = p.getEncryptionValue(ctx, encryptionNameCheck); err != nil {
result.InvalidCheckValue = true
}
if verbose {
encCheckFuncs := []EncryptionCheckKeyFunc{
schemaEncryptionCheckKeyOneTimeCode,
schemaEncryptionCheckKeyTOTP,
schemaEncryptionCheckKeyWebAuthn,
}
for i := 0; true; i++ {
typeOAuth2Session := OAuth2SessionType(i)
if typeOAuth2Session.Table() == "" {
break
}
encCheckFuncs = append(encCheckFuncs, schemaEncryptionCheckKeyOpenIDConnect(typeOAuth2Session))
}
encCheckFuncs = append(encCheckFuncs, schemaEncryptionCheckKeyEncryption)
for _, encCheckFunc := range encCheckFuncs {
table, tableResult := encCheckFunc(ctx, p)
result.Tables[table] = tableResult
}
}
return result, nil
}
func schemaEncryptionChangeKeyOneTimeCode(ctx context.Context, provider *SQLProvider, tx *sqlx.Tx, key [32]byte) (err error) {
var count int
if err = tx.GetContext(ctx, &count, fmt.Sprintf(queryFmtSelectRowCount, tableOneTimeCode)); err != nil {
return err
}
if count == 0 {
return nil
}
configs := make([]encOneTimeCode, 0, count)
if err = tx.SelectContext(ctx, &configs, fmt.Sprintf(queryFmtSelectOTCEncryptedData, tableOneTimeCode)); err != nil {
if errors.Is(err, sql.ErrNoRows) {
return nil
}
return fmt.Errorf("error selecting one-time codes: %w", err)
}
query := provider.db.Rebind(fmt.Sprintf(queryFmtUpdateOTCEncryptedData, tableOneTimeCode))
for _, c := range configs {
if c.Code, err = provider.decrypt(c.Code); err != nil {
return fmt.Errorf("error decrypting one-time code with id '%d': %w", c.ID, err)
}
if c.Code, err = utils.Encrypt(c.Code, &key); err != nil {
return fmt.Errorf("error encrypting one-time code with id '%d': %w", c.ID, err)
}
if _, err = tx.ExecContext(ctx, query, c.Code, c.ID); err != nil {
return fmt.Errorf("error updating one-time code with id '%d': %w", c.ID, err)
}
}
return nil
}
func schemaEncryptionChangeKeyTOTP(ctx context.Context, provider *SQLProvider, tx *sqlx.Tx, key [32]byte) (err error) {
var count int
if err = tx.GetContext(ctx, &count, fmt.Sprintf(queryFmtSelectRowCount, tableTOTPConfigurations)); err != nil {
return err
}
if count == 0 {
return nil
}
configs := make([]encTOTPConfiguration, 0, count)
if err = tx.SelectContext(ctx, &configs, fmt.Sprintf(queryFmtSelectTOTPConfigurationsEncryptedData, tableTOTPConfigurations)); err != nil {
if errors.Is(err, sql.ErrNoRows) {
return nil
}
return fmt.Errorf("error selecting TOTP configurations: %w", err)
}
query := provider.db.Rebind(fmt.Sprintf(queryFmtUpdateTOTPConfigurationEncryptedData, tableTOTPConfigurations))
for _, c := range configs {
if c.Secret, err = provider.decrypt(c.Secret); err != nil {
return fmt.Errorf("error decrypting TOTP configuration secret with id '%d': %w", c.ID, err)
}
if c.Secret, err = utils.Encrypt(c.Secret, &key); err != nil {
return fmt.Errorf("error encrypting TOTP configuration secret with id '%d': %w", c.ID, err)
}
if _, err = tx.ExecContext(ctx, query, c.Secret, c.ID); err != nil {
return fmt.Errorf("error updating TOTP configuration secret with id '%d': %w", c.ID, err)
}
}
return nil
}
func schemaEncryptionChangeKeyWebAuthn(ctx context.Context, provider *SQLProvider, tx *sqlx.Tx, key [32]byte) (err error) {
var count int
if err = tx.GetContext(ctx, &count, fmt.Sprintf(queryFmtSelectRowCount, tableWebAuthnCredentials)); err != nil {
return err
}
if count == 0 {
return nil
}
credentials := make([]encWebAuthnCredential, 0, count)
if err = tx.SelectContext(ctx, &credentials, fmt.Sprintf(queryFmtSelectWebAuthnCredentialsEncryptedData, tableWebAuthnCredentials)); err != nil {
if errors.Is(err, sql.ErrNoRows) {
return nil
}
return fmt.Errorf("error selecting WebAuthn credentials: %w", err)
}
query := provider.db.Rebind(fmt.Sprintf(queryFmtUpdateWebAuthnCredentialsEncryptedData, tableWebAuthnCredentials))
for _, d := range credentials {
if d.PublicKey, err = provider.decrypt(d.PublicKey); err != nil {
return fmt.Errorf("error decrypting WebAuthn credential public key with id '%d': %w", d.ID, err)
}
if d.PublicKey, err = utils.Encrypt(d.PublicKey, &key); err != nil {
return fmt.Errorf("error encrypting WebAuthn credential public key with id '%d': %w", d.ID, err)
}
if _, err = tx.ExecContext(ctx, query, d.PublicKey, d.ID); err != nil {
return fmt.Errorf("error updating WebAuthn credential public key with id '%d': %w", d.ID, err)
}
}
return nil
}
func schemaEncryptionChangeKeyOpenIDConnect(typeOAuth2Session OAuth2SessionType) EncryptionChangeKeyFunc {
return func(ctx context.Context, provider *SQLProvider, tx *sqlx.Tx, key [32]byte) (err error) {
var count int
if err = tx.GetContext(ctx, &count, fmt.Sprintf(queryFmtSelectRowCount, typeOAuth2Session.Table())); err != nil {
return err
}
if count == 0 {
return nil
}
sessions := make([]encOAuth2Session, 0, count)
if err = tx.SelectContext(ctx, &sessions, fmt.Sprintf(queryFmtSelectOAuth2SessionEncryptedData, typeOAuth2Session.Table())); err != nil {
return fmt.Errorf("error selecting oauth2 %s sessions: %w", typeOAuth2Session.String(), err)
}
query := provider.db.Rebind(fmt.Sprintf(queryFmtUpdateOAuth2ConsentSessionEncryptedData, typeOAuth2Session.Table()))
for _, s := range sessions {
if s.Session, err = provider.decrypt(s.Session); err != nil {
return fmt.Errorf("error decrypting oauth2 %s session data with id '%d': %w", typeOAuth2Session.String(), s.ID, err)
}
if s.Session, err = utils.Encrypt(s.Session, &key); err != nil {
return fmt.Errorf("error encrypting oauth2 %s session data with id '%d': %w", typeOAuth2Session.String(), s.ID, err)
}
if _, err = tx.ExecContext(ctx, query, s.Session, s.ID); err != nil {
return fmt.Errorf("error updating oauth2 %s session data with id '%d': %w", typeOAuth2Session.String(), s.ID, err)
}
}
return nil
}
}
func schemaEncryptionChangeKeyEncryption(ctx context.Context, provider *SQLProvider, tx *sqlx.Tx, key [32]byte) (err error) {
var count int
if err = tx.GetContext(ctx, &count, fmt.Sprintf(queryFmtSelectRowCount, tableEncryption)); err != nil {
return err
}
if count == 0 {
return nil
}
configs := make([]encEncryption, 0, count)
if err = tx.SelectContext(ctx, &configs, fmt.Sprintf(queryFmtSelectEncryptionEncryptedData, tableEncryption)); err != nil {
if errors.Is(err, sql.ErrNoRows) {
return nil
}
return fmt.Errorf("error selecting encyption value: %w", err)
}
query := provider.db.Rebind(fmt.Sprintf(queryFmtUpdateEncryptionEncryptedData, tableEncryption))
for _, c := range configs {
if c.Value, err = provider.decrypt(c.Value); err != nil {
return fmt.Errorf("error decrypting encyption value with id '%d': %w", c.ID, err)
}
if c.Value, err = utils.Encrypt(c.Value, &key); err != nil {
return fmt.Errorf("error encrypting encyption value with id '%d': %w", c.ID, err)
}
if _, err = tx.ExecContext(ctx, query, c.Value, c.ID); err != nil {
return fmt.Errorf("error updating encyption value with id '%d': %w", c.ID, err)
}
}
return nil
}
func schemaEncryptionCheckKeyOneTimeCode(ctx context.Context, provider *SQLProvider) (table string, result EncryptionValidationTableResult) {
var (
rows *sqlx.Rows
err error
)
if rows, err = provider.db.QueryxContext(ctx, fmt.Sprintf(queryFmtSelectOTCEncryptedData, tableOneTimeCode)); err != nil {
return tableOneTimeCode, EncryptionValidationTableResult{Error: fmt.Errorf("error selecting one time-codes: %w", err)}
}
var config encOneTimeCode
for rows.Next() {
result.Total++
if err = rows.StructScan(&config); err != nil {
_ = rows.Close()
return tableOneTimeCode, EncryptionValidationTableResult{Error: fmt.Errorf("error scanning one time-code to struct: %w", err)}
}
if _, err = provider.decrypt(config.Code); err != nil {
result.Invalid++
}
}
_ = rows.Close()
return tableOneTimeCode, result
}
func schemaEncryptionCheckKeyTOTP(ctx context.Context, provider *SQLProvider) (table string, result EncryptionValidationTableResult) {
var (
rows *sqlx.Rows
err error
)
if rows, err = provider.db.QueryxContext(ctx, fmt.Sprintf(queryFmtSelectTOTPConfigurationsEncryptedData, tableTOTPConfigurations)); err != nil {
return tableTOTPConfigurations, EncryptionValidationTableResult{Error: fmt.Errorf("error selecting TOTP configurations: %w", err)}
}
var config encTOTPConfiguration
for rows.Next() {
result.Total++
if err = rows.StructScan(&config); err != nil {
_ = rows.Close()
return tableTOTPConfigurations, EncryptionValidationTableResult{Error: fmt.Errorf("error scanning TOTP configuration to struct: %w", err)}
}
if _, err = provider.decrypt(config.Secret); err != nil {
result.Invalid++
}
}
_ = rows.Close()
return tableTOTPConfigurations, result
}
func schemaEncryptionCheckKeyWebAuthn(ctx context.Context, provider *SQLProvider) (table string, result EncryptionValidationTableResult) {
var (
rows *sqlx.Rows
err error
)
if rows, err = provider.db.QueryxContext(ctx, fmt.Sprintf(queryFmtSelectWebAuthnCredentialsEncryptedData, tableWebAuthnCredentials)); err != nil {
return tableWebAuthnCredentials, EncryptionValidationTableResult{Error: fmt.Errorf("error selecting WebAuthn credentials: %w", err)}
}
var credential encWebAuthnCredential
for rows.Next() {
result.Total++
if err = rows.StructScan(&credential); err != nil {
_ = rows.Close()
return tableWebAuthnCredentials, EncryptionValidationTableResult{Error: fmt.Errorf("error scanning WebAuthn credential to struct: %w", err)}
}
if _, err = provider.decrypt(credential.PublicKey); err != nil {
result.Invalid++
}
}
_ = rows.Close()
return tableWebAuthnCredentials, result
}
func schemaEncryptionCheckKeyOpenIDConnect(typeOAuth2Session OAuth2SessionType) EncryptionCheckKeyFunc {
return func(ctx context.Context, provider *SQLProvider) (table string, result EncryptionValidationTableResult) {
var (
rows *sqlx.Rows
err error
)
if rows, err = provider.db.QueryxContext(ctx, fmt.Sprintf(queryFmtSelectOAuth2SessionEncryptedData, typeOAuth2Session.Table())); err != nil {
return typeOAuth2Session.Table(), EncryptionValidationTableResult{Error: fmt.Errorf("error selecting oauth2 %s sessions: %w", typeOAuth2Session.String(), err)}
}
var session encOAuth2Session
for rows.Next() {
result.Total++
if err = rows.StructScan(&session); err != nil {
_ = rows.Close()
return typeOAuth2Session.Table(), EncryptionValidationTableResult{Error: fmt.Errorf("error scanning oauth2 %s session to struct: %w", typeOAuth2Session.String(), err)}
}
if _, err = provider.decrypt(session.Session); err != nil {
result.Invalid++
}
}
_ = rows.Close()
return typeOAuth2Session.Table(), result
}
}
func schemaEncryptionCheckKeyEncryption(ctx context.Context, provider *SQLProvider) (table string, result EncryptionValidationTableResult) {
var (
rows *sqlx.Rows
err error
)
if rows, err = provider.db.QueryxContext(ctx, fmt.Sprintf(queryFmtSelectEncryptionEncryptedData, tableEncryption)); err != nil {
return tableEncryption, EncryptionValidationTableResult{Error: fmt.Errorf("error selecting encryption values: %w", err)}
}
var config encEncryption
for rows.Next() {
result.Total++
if err = rows.StructScan(&config); err != nil {
_ = rows.Close()
return tableEncryption, EncryptionValidationTableResult{Error: fmt.Errorf("error scanning encryption value to struct: %w", err)}
}
if _, err = provider.decrypt(config.Value); err != nil {
result.Invalid++
}
}
_ = rows.Close()
return tableEncryption, result
}
func (p *SQLProvider) encrypt(clearText []byte) (cipherText []byte, err error) {
return utils.Encrypt(clearText, &p.keys.encryption)
}
func (p *SQLProvider) decrypt(cipherText []byte) (clearText []byte, err error) {
return utils.Decrypt(cipherText, &p.keys.encryption)
}
func (p *SQLProvider) otcHMACSignature(values ...[]byte) string {
h := hmac.New(sha512.New, p.keys.otcHMAC)
for i := 0; i < len(values); i++ {
h.Write(values[i])
}
return fmt.Sprintf("%x", h.Sum(nil))
}
func (p *SQLProvider) otpHMACSignature(values ...[]byte) string {
h := hmac.New(sha256.New, p.keys.otpHMAC)
for i := 0; i < len(values); i++ {
h.Write(values[i])
}
return fmt.Sprintf("%x", h.Sum(nil))
}
func (p *SQLProvider) getHMACOneTimeCode(ctx context.Context) (key []byte, err error) {
return p.getHMACKey(ctx, "hmac_key_otc", sha512.BlockSize)
}
func (p *SQLProvider) getHMACOneTimePassword(ctx context.Context) (key []byte, err error) {
return p.getHMACKey(ctx, "hmac_key_otp", sha256.BlockSize)
}
func (p *SQLProvider) getHMACKey(ctx context.Context, name string, size int) (key []byte, err error) {
if key, err = p.getEncryptionValue(ctx, name); err != nil {
if errors.Is(err, sql.ErrNoRows) {
key = make([]byte, size)
_, err = rand.Read(key)
if err != nil {
return nil, fmt.Errorf("failed to generate hmac key: %w", err)
}
if err = p.setEncryptionValue(ctx, name, key); err != nil {
return nil, err
}
return key, nil
}
return nil, err
}
return key, nil
}
func (p *SQLProvider) getEncryptionValue(ctx context.Context, name string) (value []byte, err error) {
var encryptedValue []byte
err = p.db.GetContext(ctx, &encryptedValue, p.sqlSelectEncryptionValue, name)
if err != nil {
return nil, err
}
return p.decrypt(encryptedValue)
}
func (p *SQLProvider) setEncryptionValue(ctx context.Context, name string, value []byte) (err error) {
if value, err = p.encrypt(value); err != nil {
return err
}
if _, err = p.db.ExecContext(ctx, p.sqlUpsertEncryptionValue, name, value); err != nil {
return err
}
return nil
}
func (p *SQLProvider) setNewEncryptionCheckValue(ctx context.Context, conn SQLXConnection, key *[32]byte) (err error) {
valueClearText, err := uuid.NewRandom()
if err != nil {
return err
}
value, err := utils.Encrypt([]byte(valueClearText.String()), key)
if err != nil {
return err
}
_, err = conn.ExecContext(ctx, p.sqlUpsertEncryptionValue, encryptionNameCheck, value)
return err
}