mirror of https://github.com/authelia/authelia.git
570 lines
16 KiB
Go
570 lines
16 KiB
Go
package storage
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"crypto/hmac"
|
|
"crypto/rand"
|
|
"crypto/sha256"
|
|
"crypto/sha512"
|
|
"database/sql"
|
|
"errors"
|
|
"fmt"
|
|
|
|
"github.com/google/uuid"
|
|
"github.com/jmoiron/sqlx"
|
|
|
|
"github.com/authelia/authelia/v4/internal/utils"
|
|
)
|
|
|
|
// SchemaEncryptionChangeKey uses the currently configured key to decrypt values in the storage provider and the key
|
|
// provided by this command to encrypt the values again and update them using a transaction.
|
|
func (p *SQLProvider) SchemaEncryptionChangeKey(ctx context.Context, rawKey string) (err error) {
|
|
key := sha256.Sum256([]byte(rawKey))
|
|
|
|
if bytes.Equal(key[:], p.keys.encryption[:]) {
|
|
return fmt.Errorf("error changing the storage encryption key: the old key and the new key are the same")
|
|
}
|
|
|
|
if _, err = p.SchemaEncryptionCheckKey(ctx, false); err != nil {
|
|
return fmt.Errorf("error changing the storage encryption key: %w", err)
|
|
}
|
|
|
|
tx, err := p.db.Beginx()
|
|
if err != nil {
|
|
return fmt.Errorf("error beginning transaction to change encryption key: %w", err)
|
|
}
|
|
|
|
encChangeFuncs := []EncryptionChangeKeyFunc{
|
|
schemaEncryptionChangeKeyOneTimeCode,
|
|
schemaEncryptionChangeKeyTOTP,
|
|
schemaEncryptionChangeKeyWebAuthn,
|
|
}
|
|
|
|
for i := 0; true; i++ {
|
|
typeOAuth2Session := OAuth2SessionType(i)
|
|
|
|
if typeOAuth2Session.Table() == "" {
|
|
break
|
|
}
|
|
|
|
encChangeFuncs = append(encChangeFuncs, schemaEncryptionChangeKeyOpenIDConnect(typeOAuth2Session))
|
|
}
|
|
|
|
encChangeFuncs = append(encChangeFuncs, schemaEncryptionChangeKeyEncryption)
|
|
|
|
for _, encChangeFunc := range encChangeFuncs {
|
|
if err = encChangeFunc(ctx, p, tx, key); err != nil {
|
|
if rerr := tx.Rollback(); rerr != nil {
|
|
return fmt.Errorf("rollback error %v: rollback due to error: %w", rerr, err)
|
|
}
|
|
|
|
return fmt.Errorf("rollback due to error: %w", err)
|
|
}
|
|
}
|
|
|
|
return tx.Commit()
|
|
}
|
|
|
|
// SchemaEncryptionCheckKey checks the encryption key configured is valid for the database.
|
|
func (p *SQLProvider) SchemaEncryptionCheckKey(ctx context.Context, verbose bool) (result EncryptionValidationResult, err error) {
|
|
version, err := p.SchemaVersion(ctx)
|
|
if err != nil {
|
|
return result, err
|
|
}
|
|
|
|
if version < 1 {
|
|
return result, ErrSchemaEncryptionVersionUnsupported
|
|
}
|
|
|
|
result = EncryptionValidationResult{
|
|
Tables: map[string]EncryptionValidationTableResult{},
|
|
}
|
|
|
|
if _, err = p.getEncryptionValue(ctx, encryptionNameCheck); err != nil {
|
|
result.InvalidCheckValue = true
|
|
}
|
|
|
|
if verbose {
|
|
encCheckFuncs := []EncryptionCheckKeyFunc{
|
|
schemaEncryptionCheckKeyOneTimeCode,
|
|
schemaEncryptionCheckKeyTOTP,
|
|
schemaEncryptionCheckKeyWebAuthn,
|
|
}
|
|
|
|
for i := 0; true; i++ {
|
|
typeOAuth2Session := OAuth2SessionType(i)
|
|
|
|
if typeOAuth2Session.Table() == "" {
|
|
break
|
|
}
|
|
|
|
encCheckFuncs = append(encCheckFuncs, schemaEncryptionCheckKeyOpenIDConnect(typeOAuth2Session))
|
|
}
|
|
|
|
encCheckFuncs = append(encCheckFuncs, schemaEncryptionCheckKeyEncryption)
|
|
|
|
for _, encCheckFunc := range encCheckFuncs {
|
|
table, tableResult := encCheckFunc(ctx, p)
|
|
|
|
result.Tables[table] = tableResult
|
|
}
|
|
}
|
|
|
|
return result, nil
|
|
}
|
|
|
|
func schemaEncryptionChangeKeyOneTimeCode(ctx context.Context, provider *SQLProvider, tx *sqlx.Tx, key [32]byte) (err error) {
|
|
var count int
|
|
|
|
if err = tx.GetContext(ctx, &count, fmt.Sprintf(queryFmtSelectRowCount, tableOneTimeCode)); err != nil {
|
|
return err
|
|
}
|
|
|
|
if count == 0 {
|
|
return nil
|
|
}
|
|
|
|
configs := make([]encOneTimeCode, 0, count)
|
|
|
|
if err = tx.SelectContext(ctx, &configs, fmt.Sprintf(queryFmtSelectOTCEncryptedData, tableOneTimeCode)); err != nil {
|
|
if errors.Is(err, sql.ErrNoRows) {
|
|
return nil
|
|
}
|
|
|
|
return fmt.Errorf("error selecting one-time codes: %w", err)
|
|
}
|
|
|
|
query := provider.db.Rebind(fmt.Sprintf(queryFmtUpdateOTCEncryptedData, tableOneTimeCode))
|
|
|
|
for _, c := range configs {
|
|
if c.Code, err = provider.decrypt(c.Code); err != nil {
|
|
return fmt.Errorf("error decrypting one-time code with id '%d': %w", c.ID, err)
|
|
}
|
|
|
|
if c.Code, err = utils.Encrypt(c.Code, &key); err != nil {
|
|
return fmt.Errorf("error encrypting one-time code with id '%d': %w", c.ID, err)
|
|
}
|
|
|
|
if _, err = tx.ExecContext(ctx, query, c.Code, c.ID); err != nil {
|
|
return fmt.Errorf("error updating one-time code with id '%d': %w", c.ID, err)
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func schemaEncryptionChangeKeyTOTP(ctx context.Context, provider *SQLProvider, tx *sqlx.Tx, key [32]byte) (err error) {
|
|
var count int
|
|
|
|
if err = tx.GetContext(ctx, &count, fmt.Sprintf(queryFmtSelectRowCount, tableTOTPConfigurations)); err != nil {
|
|
return err
|
|
}
|
|
|
|
if count == 0 {
|
|
return nil
|
|
}
|
|
|
|
configs := make([]encTOTPConfiguration, 0, count)
|
|
|
|
if err = tx.SelectContext(ctx, &configs, fmt.Sprintf(queryFmtSelectTOTPConfigurationsEncryptedData, tableTOTPConfigurations)); err != nil {
|
|
if errors.Is(err, sql.ErrNoRows) {
|
|
return nil
|
|
}
|
|
|
|
return fmt.Errorf("error selecting TOTP configurations: %w", err)
|
|
}
|
|
|
|
query := provider.db.Rebind(fmt.Sprintf(queryFmtUpdateTOTPConfigurationEncryptedData, tableTOTPConfigurations))
|
|
|
|
for _, c := range configs {
|
|
if c.Secret, err = provider.decrypt(c.Secret); err != nil {
|
|
return fmt.Errorf("error decrypting TOTP configuration secret with id '%d': %w", c.ID, err)
|
|
}
|
|
|
|
if c.Secret, err = utils.Encrypt(c.Secret, &key); err != nil {
|
|
return fmt.Errorf("error encrypting TOTP configuration secret with id '%d': %w", c.ID, err)
|
|
}
|
|
|
|
if _, err = tx.ExecContext(ctx, query, c.Secret, c.ID); err != nil {
|
|
return fmt.Errorf("error updating TOTP configuration secret with id '%d': %w", c.ID, err)
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func schemaEncryptionChangeKeyWebAuthn(ctx context.Context, provider *SQLProvider, tx *sqlx.Tx, key [32]byte) (err error) {
|
|
var count int
|
|
|
|
if err = tx.GetContext(ctx, &count, fmt.Sprintf(queryFmtSelectRowCount, tableWebAuthnCredentials)); err != nil {
|
|
return err
|
|
}
|
|
|
|
if count == 0 {
|
|
return nil
|
|
}
|
|
|
|
credentials := make([]encWebAuthnCredential, 0, count)
|
|
|
|
if err = tx.SelectContext(ctx, &credentials, fmt.Sprintf(queryFmtSelectWebAuthnCredentialsEncryptedData, tableWebAuthnCredentials)); err != nil {
|
|
if errors.Is(err, sql.ErrNoRows) {
|
|
return nil
|
|
}
|
|
|
|
return fmt.Errorf("error selecting WebAuthn credentials: %w", err)
|
|
}
|
|
|
|
query := provider.db.Rebind(fmt.Sprintf(queryFmtUpdateWebAuthnCredentialsEncryptedData, tableWebAuthnCredentials))
|
|
|
|
for _, d := range credentials {
|
|
if d.PublicKey, err = provider.decrypt(d.PublicKey); err != nil {
|
|
return fmt.Errorf("error decrypting WebAuthn credential public key with id '%d': %w", d.ID, err)
|
|
}
|
|
|
|
if d.PublicKey, err = utils.Encrypt(d.PublicKey, &key); err != nil {
|
|
return fmt.Errorf("error encrypting WebAuthn credential public key with id '%d': %w", d.ID, err)
|
|
}
|
|
|
|
if _, err = tx.ExecContext(ctx, query, d.PublicKey, d.ID); err != nil {
|
|
return fmt.Errorf("error updating WebAuthn credential public key with id '%d': %w", d.ID, err)
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func schemaEncryptionChangeKeyOpenIDConnect(typeOAuth2Session OAuth2SessionType) EncryptionChangeKeyFunc {
|
|
return func(ctx context.Context, provider *SQLProvider, tx *sqlx.Tx, key [32]byte) (err error) {
|
|
var count int
|
|
|
|
if err = tx.GetContext(ctx, &count, fmt.Sprintf(queryFmtSelectRowCount, typeOAuth2Session.Table())); err != nil {
|
|
return err
|
|
}
|
|
|
|
if count == 0 {
|
|
return nil
|
|
}
|
|
|
|
sessions := make([]encOAuth2Session, 0, count)
|
|
|
|
if err = tx.SelectContext(ctx, &sessions, fmt.Sprintf(queryFmtSelectOAuth2SessionEncryptedData, typeOAuth2Session.Table())); err != nil {
|
|
return fmt.Errorf("error selecting oauth2 %s sessions: %w", typeOAuth2Session.String(), err)
|
|
}
|
|
|
|
query := provider.db.Rebind(fmt.Sprintf(queryFmtUpdateOAuth2ConsentSessionEncryptedData, typeOAuth2Session.Table()))
|
|
|
|
for _, s := range sessions {
|
|
if s.Session, err = provider.decrypt(s.Session); err != nil {
|
|
return fmt.Errorf("error decrypting oauth2 %s session data with id '%d': %w", typeOAuth2Session.String(), s.ID, err)
|
|
}
|
|
|
|
if s.Session, err = utils.Encrypt(s.Session, &key); err != nil {
|
|
return fmt.Errorf("error encrypting oauth2 %s session data with id '%d': %w", typeOAuth2Session.String(), s.ID, err)
|
|
}
|
|
|
|
if _, err = tx.ExecContext(ctx, query, s.Session, s.ID); err != nil {
|
|
return fmt.Errorf("error updating oauth2 %s session data with id '%d': %w", typeOAuth2Session.String(), s.ID, err)
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
}
|
|
|
|
func schemaEncryptionChangeKeyEncryption(ctx context.Context, provider *SQLProvider, tx *sqlx.Tx, key [32]byte) (err error) {
|
|
var count int
|
|
|
|
if err = tx.GetContext(ctx, &count, fmt.Sprintf(queryFmtSelectRowCount, tableEncryption)); err != nil {
|
|
return err
|
|
}
|
|
|
|
if count == 0 {
|
|
return nil
|
|
}
|
|
|
|
configs := make([]encEncryption, 0, count)
|
|
|
|
if err = tx.SelectContext(ctx, &configs, fmt.Sprintf(queryFmtSelectEncryptionEncryptedData, tableEncryption)); err != nil {
|
|
if errors.Is(err, sql.ErrNoRows) {
|
|
return nil
|
|
}
|
|
|
|
return fmt.Errorf("error selecting encyption value: %w", err)
|
|
}
|
|
|
|
query := provider.db.Rebind(fmt.Sprintf(queryFmtUpdateEncryptionEncryptedData, tableEncryption))
|
|
|
|
for _, c := range configs {
|
|
if c.Value, err = provider.decrypt(c.Value); err != nil {
|
|
return fmt.Errorf("error decrypting encyption value with id '%d': %w", c.ID, err)
|
|
}
|
|
|
|
if c.Value, err = utils.Encrypt(c.Value, &key); err != nil {
|
|
return fmt.Errorf("error encrypting encyption value with id '%d': %w", c.ID, err)
|
|
}
|
|
|
|
if _, err = tx.ExecContext(ctx, query, c.Value, c.ID); err != nil {
|
|
return fmt.Errorf("error updating encyption value with id '%d': %w", c.ID, err)
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func schemaEncryptionCheckKeyOneTimeCode(ctx context.Context, provider *SQLProvider) (table string, result EncryptionValidationTableResult) {
|
|
var (
|
|
rows *sqlx.Rows
|
|
err error
|
|
)
|
|
|
|
if rows, err = provider.db.QueryxContext(ctx, fmt.Sprintf(queryFmtSelectOTCEncryptedData, tableOneTimeCode)); err != nil {
|
|
return tableOneTimeCode, EncryptionValidationTableResult{Error: fmt.Errorf("error selecting one time-codes: %w", err)}
|
|
}
|
|
|
|
var config encOneTimeCode
|
|
|
|
for rows.Next() {
|
|
result.Total++
|
|
|
|
if err = rows.StructScan(&config); err != nil {
|
|
_ = rows.Close()
|
|
|
|
return tableOneTimeCode, EncryptionValidationTableResult{Error: fmt.Errorf("error scanning one time-code to struct: %w", err)}
|
|
}
|
|
|
|
if _, err = provider.decrypt(config.Code); err != nil {
|
|
result.Invalid++
|
|
}
|
|
}
|
|
|
|
_ = rows.Close()
|
|
|
|
return tableOneTimeCode, result
|
|
}
|
|
|
|
func schemaEncryptionCheckKeyTOTP(ctx context.Context, provider *SQLProvider) (table string, result EncryptionValidationTableResult) {
|
|
var (
|
|
rows *sqlx.Rows
|
|
err error
|
|
)
|
|
|
|
if rows, err = provider.db.QueryxContext(ctx, fmt.Sprintf(queryFmtSelectTOTPConfigurationsEncryptedData, tableTOTPConfigurations)); err != nil {
|
|
return tableTOTPConfigurations, EncryptionValidationTableResult{Error: fmt.Errorf("error selecting TOTP configurations: %w", err)}
|
|
}
|
|
|
|
var config encTOTPConfiguration
|
|
|
|
for rows.Next() {
|
|
result.Total++
|
|
|
|
if err = rows.StructScan(&config); err != nil {
|
|
_ = rows.Close()
|
|
|
|
return tableTOTPConfigurations, EncryptionValidationTableResult{Error: fmt.Errorf("error scanning TOTP configuration to struct: %w", err)}
|
|
}
|
|
|
|
if _, err = provider.decrypt(config.Secret); err != nil {
|
|
result.Invalid++
|
|
}
|
|
}
|
|
|
|
_ = rows.Close()
|
|
|
|
return tableTOTPConfigurations, result
|
|
}
|
|
|
|
func schemaEncryptionCheckKeyWebAuthn(ctx context.Context, provider *SQLProvider) (table string, result EncryptionValidationTableResult) {
|
|
var (
|
|
rows *sqlx.Rows
|
|
err error
|
|
)
|
|
|
|
if rows, err = provider.db.QueryxContext(ctx, fmt.Sprintf(queryFmtSelectWebAuthnCredentialsEncryptedData, tableWebAuthnCredentials)); err != nil {
|
|
return tableWebAuthnCredentials, EncryptionValidationTableResult{Error: fmt.Errorf("error selecting WebAuthn credentials: %w", err)}
|
|
}
|
|
|
|
var credential encWebAuthnCredential
|
|
|
|
for rows.Next() {
|
|
result.Total++
|
|
|
|
if err = rows.StructScan(&credential); err != nil {
|
|
_ = rows.Close()
|
|
|
|
return tableWebAuthnCredentials, EncryptionValidationTableResult{Error: fmt.Errorf("error scanning WebAuthn credential to struct: %w", err)}
|
|
}
|
|
|
|
if _, err = provider.decrypt(credential.PublicKey); err != nil {
|
|
result.Invalid++
|
|
}
|
|
}
|
|
|
|
_ = rows.Close()
|
|
|
|
return tableWebAuthnCredentials, result
|
|
}
|
|
|
|
func schemaEncryptionCheckKeyOpenIDConnect(typeOAuth2Session OAuth2SessionType) EncryptionCheckKeyFunc {
|
|
return func(ctx context.Context, provider *SQLProvider) (table string, result EncryptionValidationTableResult) {
|
|
var (
|
|
rows *sqlx.Rows
|
|
err error
|
|
)
|
|
|
|
if rows, err = provider.db.QueryxContext(ctx, fmt.Sprintf(queryFmtSelectOAuth2SessionEncryptedData, typeOAuth2Session.Table())); err != nil {
|
|
return typeOAuth2Session.Table(), EncryptionValidationTableResult{Error: fmt.Errorf("error selecting oauth2 %s sessions: %w", typeOAuth2Session.String(), err)}
|
|
}
|
|
|
|
var session encOAuth2Session
|
|
|
|
for rows.Next() {
|
|
result.Total++
|
|
|
|
if err = rows.StructScan(&session); err != nil {
|
|
_ = rows.Close()
|
|
|
|
return typeOAuth2Session.Table(), EncryptionValidationTableResult{Error: fmt.Errorf("error scanning oauth2 %s session to struct: %w", typeOAuth2Session.String(), err)}
|
|
}
|
|
|
|
if _, err = provider.decrypt(session.Session); err != nil {
|
|
result.Invalid++
|
|
}
|
|
}
|
|
|
|
_ = rows.Close()
|
|
|
|
return typeOAuth2Session.Table(), result
|
|
}
|
|
}
|
|
|
|
func schemaEncryptionCheckKeyEncryption(ctx context.Context, provider *SQLProvider) (table string, result EncryptionValidationTableResult) {
|
|
var (
|
|
rows *sqlx.Rows
|
|
err error
|
|
)
|
|
|
|
if rows, err = provider.db.QueryxContext(ctx, fmt.Sprintf(queryFmtSelectEncryptionEncryptedData, tableEncryption)); err != nil {
|
|
return tableEncryption, EncryptionValidationTableResult{Error: fmt.Errorf("error selecting encryption values: %w", err)}
|
|
}
|
|
|
|
var config encEncryption
|
|
|
|
for rows.Next() {
|
|
result.Total++
|
|
|
|
if err = rows.StructScan(&config); err != nil {
|
|
_ = rows.Close()
|
|
|
|
return tableEncryption, EncryptionValidationTableResult{Error: fmt.Errorf("error scanning encryption value to struct: %w", err)}
|
|
}
|
|
|
|
if _, err = provider.decrypt(config.Value); err != nil {
|
|
result.Invalid++
|
|
}
|
|
}
|
|
|
|
_ = rows.Close()
|
|
|
|
return tableEncryption, result
|
|
}
|
|
|
|
func (p *SQLProvider) encrypt(clearText []byte) (cipherText []byte, err error) {
|
|
return utils.Encrypt(clearText, &p.keys.encryption)
|
|
}
|
|
|
|
func (p *SQLProvider) decrypt(cipherText []byte) (clearText []byte, err error) {
|
|
return utils.Decrypt(cipherText, &p.keys.encryption)
|
|
}
|
|
|
|
func (p *SQLProvider) otcHMACSignature(values ...[]byte) string {
|
|
h := hmac.New(sha512.New, p.keys.otcHMAC)
|
|
|
|
for i := 0; i < len(values); i++ {
|
|
h.Write(values[i])
|
|
}
|
|
|
|
return fmt.Sprintf("%x", h.Sum(nil))
|
|
}
|
|
|
|
func (p *SQLProvider) otpHMACSignature(values ...[]byte) string {
|
|
h := hmac.New(sha256.New, p.keys.otpHMAC)
|
|
|
|
for i := 0; i < len(values); i++ {
|
|
h.Write(values[i])
|
|
}
|
|
|
|
return fmt.Sprintf("%x", h.Sum(nil))
|
|
}
|
|
|
|
func (p *SQLProvider) getHMACOneTimeCode(ctx context.Context) (key []byte, err error) {
|
|
return p.getHMACKey(ctx, "hmac_key_otc", sha512.BlockSize)
|
|
}
|
|
|
|
func (p *SQLProvider) getHMACOneTimePassword(ctx context.Context) (key []byte, err error) {
|
|
return p.getHMACKey(ctx, "hmac_key_otp", sha256.BlockSize)
|
|
}
|
|
|
|
func (p *SQLProvider) getHMACKey(ctx context.Context, name string, size int) (key []byte, err error) {
|
|
if key, err = p.getEncryptionValue(ctx, name); err != nil {
|
|
if errors.Is(err, sql.ErrNoRows) {
|
|
key = make([]byte, size)
|
|
|
|
_, err = rand.Read(key)
|
|
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to generate hmac key: %w", err)
|
|
}
|
|
|
|
if err = p.setEncryptionValue(ctx, name, key); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return key, nil
|
|
}
|
|
|
|
return nil, err
|
|
}
|
|
|
|
return key, nil
|
|
}
|
|
|
|
func (p *SQLProvider) getEncryptionValue(ctx context.Context, name string) (value []byte, err error) {
|
|
var encryptedValue []byte
|
|
|
|
err = p.db.GetContext(ctx, &encryptedValue, p.sqlSelectEncryptionValue, name)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return p.decrypt(encryptedValue)
|
|
}
|
|
|
|
func (p *SQLProvider) setEncryptionValue(ctx context.Context, name string, value []byte) (err error) {
|
|
if value, err = p.encrypt(value); err != nil {
|
|
return err
|
|
}
|
|
|
|
if _, err = p.db.ExecContext(ctx, p.sqlUpsertEncryptionValue, name, value); err != nil {
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (p *SQLProvider) setNewEncryptionCheckValue(ctx context.Context, conn SQLXConnection, key *[32]byte) (err error) {
|
|
valueClearText, err := uuid.NewRandom()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
value, err := utils.Encrypt([]byte(valueClearText.String()), key)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
_, err = conn.ExecContext(ctx, p.sqlUpsertEncryptionValue, encryptionNameCheck, value)
|
|
|
|
return err
|
|
}
|