package stack

import (
	"context"
	"encoding/json"
	"errors"
	"fmt"
	"os"
	"strings"
	"testing"

	"github.com/pulumi/pulumi/pkg/v3/secrets"
	"github.com/pulumi/pulumi/pkg/v3/secrets/b64"
	"github.com/pulumi/pulumi/sdk/v3/go/common/encoding"
	"github.com/pulumi/pulumi/sdk/v3/go/common/resource"
	"github.com/pulumi/pulumi/sdk/v3/go/common/resource/config"
	"github.com/stretchr/testify/assert"
	"github.com/stretchr/testify/require"
)

type testSecretsManager struct {
	encryptCalls int
	decryptCalls int
}

func (t *testSecretsManager) Type() string { return "test" }

func (t *testSecretsManager) State() json.RawMessage { return nil }

func (t *testSecretsManager) Encrypter() (config.Encrypter, error) {
	return t, nil
}

func (t *testSecretsManager) Decrypter() (config.Decrypter, error) {
	return t, nil
}

func (t *testSecretsManager) EncryptValue(
	ctx context.Context, plaintext string,
) (string, error) {
	t.encryptCalls++
	return fmt.Sprintf("%v:%v", t.encryptCalls, plaintext), nil
}

func (t *testSecretsManager) DecryptValue(
	ctx context.Context, ciphertext string,
) (string, error) {
	t.decryptCalls++
	i := strings.Index(ciphertext, ":")
	if i == -1 {
		return "", errors.New("invalid ciphertext format")
	}
	return ciphertext[i+1:], nil
}

func (t *testSecretsManager) BulkDecrypt(
	ctx context.Context, ciphertexts []string,
) (map[string]string, error) {
	return config.DefaultBulkDecrypt(ctx, t, ciphertexts)
}

func deserializeProperty(v interface{}, dec config.Decrypter) (resource.PropertyValue, error) {
	b, err := json.Marshal(v)
	if err != nil {
		return resource.PropertyValue{}, err
	}
	if err := json.Unmarshal(b, &v); err != nil {
		return resource.PropertyValue{}, err
	}
	return DeserializePropertyValue(v, dec, config.NewPanicCrypter())
}

func TestCachingCrypter(t *testing.T) {
	t.Parallel()

	ctx := context.Background()
	sm := &testSecretsManager{}
	csm := NewCachingSecretsManager(sm)

	foo1 := resource.MakeSecret(resource.NewStringProperty("foo"))
	foo2 := resource.MakeSecret(resource.NewStringProperty("foo"))
	bar := resource.MakeSecret(resource.NewStringProperty("bar"))

	enc, err := csm.Encrypter()
	assert.NoError(t, err)

	// Serialize the first copy of "foo". Encrypt should be called once, as this value has not yet been encrypted.
	foo1Ser, err := SerializePropertyValue(ctx, foo1, enc, false /* showSecrets */)
	assert.NoError(t, err)
	assert.Equal(t, 1, sm.encryptCalls)

	// Serialize the second copy of "foo". Because this is a different secret instance, Encrypt should be called
	// a second time even though the plaintext is the same as the last value we encrypted.
	foo2Ser, err := SerializePropertyValue(ctx, foo2, enc, false /* showSecrets */)
	assert.NoError(t, err)
	assert.Equal(t, 2, sm.encryptCalls)
	assert.NotEqual(t, foo1Ser, foo2Ser)

	// Serialize "bar". Encrypt should be called once, as this value has not yet been encrypted.
	barSer, err := SerializePropertyValue(ctx, bar, enc, false /* showSecrets */)
	assert.NoError(t, err)
	assert.Equal(t, 3, sm.encryptCalls)

	// Serialize the first copy of "foo" again. Encrypt should not be called, as this value has already been
	// encrypted.
	foo1Ser2, err := SerializePropertyValue(ctx, foo1, enc, false /* showSecrets */)
	assert.NoError(t, err)
	assert.Equal(t, 3, sm.encryptCalls)
	assert.Equal(t, foo1Ser, foo1Ser2)

	// Serialize the second copy of "foo" again. Encrypt should not be called, as this value has already been
	// encrypted.
	foo2Ser2, err := SerializePropertyValue(ctx, foo2, enc, false /* showSecrets */)
	assert.NoError(t, err)
	assert.Equal(t, 3, sm.encryptCalls)
	assert.Equal(t, foo2Ser, foo2Ser2)

	// Serialize "bar" again. Encrypt should not be called, as this value has already been encrypted.
	barSer2, err := SerializePropertyValue(ctx, bar, enc, false /* showSecrets */)
	assert.NoError(t, err)
	assert.Equal(t, 3, sm.encryptCalls)
	assert.Equal(t, barSer, barSer2)

	dec, err := csm.Decrypter()
	assert.NoError(t, err)

	// Decrypt foo1Ser. Decrypt should be called.
	foo1Dec, err := deserializeProperty(foo1Ser, dec)
	assert.NoError(t, err)
	assert.True(t, foo1.DeepEquals(foo1Dec))
	assert.Equal(t, 1, sm.decryptCalls)

	// Decrypt foo2Ser. Decrypt should be called.
	foo2Dec, err := deserializeProperty(foo2Ser, dec)
	assert.NoError(t, err)
	assert.True(t, foo2.DeepEquals(foo2Dec))
	assert.Equal(t, 2, sm.decryptCalls)

	// Decrypt barSer. Decrypt should be called.
	barDec, err := deserializeProperty(barSer, dec)
	assert.NoError(t, err)
	assert.True(t, bar.DeepEquals(barDec))
	assert.Equal(t, 3, sm.decryptCalls)

	// Create a new CachingSecretsManager and re-run the decrypts. Each decrypt should insert the plain- and
	// ciphertext into the cache with the associated secret.
	csm = NewCachingSecretsManager(sm)

	dec, err = csm.Decrypter()
	assert.NoError(t, err)

	// Decrypt foo1Ser. Decrypt should be called.
	foo1Dec, err = deserializeProperty(foo1Ser, dec)
	assert.NoError(t, err)
	assert.True(t, foo1.DeepEquals(foo1Dec))
	assert.Equal(t, 4, sm.decryptCalls)

	// Decrypt foo2Ser. Decrypt should be called.
	foo2Dec, err = deserializeProperty(foo2Ser, dec)
	assert.NoError(t, err)
	assert.True(t, foo2.DeepEquals(foo2Dec))
	assert.Equal(t, 5, sm.decryptCalls)

	// Decrypt barSer. Decrypt should be called.
	barDec, err = deserializeProperty(barSer, dec)
	assert.NoError(t, err)
	assert.True(t, bar.DeepEquals(barDec))
	assert.Equal(t, 6, sm.decryptCalls)

	enc, err = csm.Encrypter()
	assert.NoError(t, err)

	// Serialize the first copy of "foo" again. Encrypt should not be called, as this value has already been
	// cached by the earlier calls to Decrypt.
	foo1Ser2, err = SerializePropertyValue(ctx, foo1Dec, enc, false /* showSecrets */)
	assert.NoError(t, err)
	assert.Equal(t, 3, sm.encryptCalls)
	assert.Equal(t, foo1Ser, foo1Ser2)

	// Serialize the second copy of "foo" again. Encrypt should not be called, as this value has already been
	// cached by the earlier calls to Decrypt.
	foo2Ser2, err = SerializePropertyValue(ctx, foo2Dec, enc, false /* showSecrets */)
	assert.NoError(t, err)
	assert.Equal(t, 3, sm.encryptCalls)
	assert.Equal(t, foo2Ser, foo2Ser2)

	// Serialize "bar" again. Encrypt should not be called, as this value has already been cached by the
	// earlier calls to Decrypt.
	barSer2, err = SerializePropertyValue(ctx, barDec, enc, false /* showSecrets */)
	assert.NoError(t, err)
	assert.Equal(t, 3, sm.encryptCalls)
	assert.Equal(t, barSer, barSer2)
}

type mapTestSecretsProvider struct {
	m *mapTestSecretsManager
}

func (p *mapTestSecretsProvider) OfType(ty string, state json.RawMessage) (secrets.Manager, error) {
	m, err := b64.Base64SecretsProvider.OfType(ty, state)
	if err != nil {
		return nil, err
	}
	p.m = &mapTestSecretsManager{sm: m}
	return p.m, nil
}

type mapTestSecretsManager struct {
	sm secrets.Manager

	d *mapTestDecrypter
}

func (t *mapTestSecretsManager) Type() string { return t.sm.Type() }

func (t *mapTestSecretsManager) State() json.RawMessage { return t.sm.State() }

func (t *mapTestSecretsManager) Encrypter() (config.Encrypter, error) {
	return t.sm.Encrypter()
}

func (t *mapTestSecretsManager) Decrypter() (config.Decrypter, error) {
	d, err := t.sm.Decrypter()
	if err != nil {
		return nil, err
	}
	t.d = &mapTestDecrypter{d: d}
	return t.d, nil
}

type mapTestDecrypter struct {
	d config.Decrypter

	decryptCalls     int
	bulkDecryptCalls int
}

func (t *mapTestDecrypter) DecryptValue(
	ctx context.Context, ciphertext string,
) (string, error) {
	t.decryptCalls++
	return t.d.DecryptValue(ctx, ciphertext)
}

func (t *mapTestDecrypter) BulkDecrypt(
	ctx context.Context, ciphertexts []string,
) (map[string]string, error) {
	t.bulkDecryptCalls++
	return config.DefaultBulkDecrypt(ctx, t.d, ciphertexts)
}

func TestMapCrypter(t *testing.T) {
	t.Parallel()

	ctx := context.Background()

	bytes, err := os.ReadFile("testdata/checkpoint-secrets.json")
	require.NoError(t, err)

	chk, err := UnmarshalVersionedCheckpointToLatestCheckpoint(encoding.JSON, bytes)
	require.NoError(t, err)

	var prov mapTestSecretsProvider

	_, err = DeserializeDeploymentV3(ctx, *chk.Latest, &prov)
	require.NoError(t, err)

	d := prov.m.d
	assert.Equal(t, 1, d.bulkDecryptCalls)
	assert.Equal(t, 0, d.decryptCalls)
}