// Copyright 2016-2023, Pulumi Corporation.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package deploytest

import (
	"context"
	"errors"
	"testing"

	"github.com/blang/semver"
	"github.com/pulumi/pulumi/sdk/v3/go/common/resource"
	"github.com/pulumi/pulumi/sdk/v3/go/common/resource/plugin"
	"github.com/pulumi/pulumi/sdk/v3/go/common/tokens"
	"github.com/stretchr/testify/assert"
)

func TestProvider(t *testing.T) {
	t.Parallel()
	t.Run("SignalCancellation", func(t *testing.T) {
		t.Parallel()
		t.Run("has CancelF", func(t *testing.T) {
			t.Parallel()
			var called bool
			prov := &Provider{
				CancelF: func() error {
					called = true
					return errors.New("expected error")
				},
			}
			assert.Error(t, prov.SignalCancellation(context.Background()))
			assert.True(t, called)
		})
		t.Run("no CancelF", func(t *testing.T) {
			t.Parallel()
			prov := &Provider{}
			assert.NoError(t, prov.SignalCancellation(context.Background()))
		})
	})
	t.Run("Close", func(t *testing.T) {
		t.Parallel()
		prov := &Provider{}
		assert.NoError(t, prov.Close())
		// Ensure idempotent.
		assert.NoError(t, prov.Close())
	})
	t.Run("GetPluginInfo", func(t *testing.T) {
		t.Parallel()
		prov := &Provider{
			Name:    "expected-name",
			Version: semver.MustParse("1.0.0"),
		}
		info, err := prov.GetPluginInfo(context.Background())
		assert.NoError(t, err)
		assert.Equal(t, "expected-name", info.Name)
		// Ensure reference is passed correctly.
		assert.Equal(t, &prov.Version, info.Version)
	})
	t.Run("GetSchema", func(t *testing.T) {
		t.Parallel()
		t.Run("has GetSchemaF", func(t *testing.T) {
			t.Parallel()
			expectedErr := errors.New("expected error")
			expectedVersion := semver.MustParse("1.0.0")
			var called bool
			prov := &Provider{
				GetSchemaF: func(request plugin.GetSchemaRequest) ([]byte, error) {
					assert.Equal(t, 1, request.Version)
					assert.Equal(t, "expected-subpackage", request.SubpackageName)
					assert.Equal(t, &expectedVersion, request.SubpackageVersion)
					called = true
					return nil, expectedErr
				},
			}
			_, err := prov.GetSchema(context.Background(), plugin.GetSchemaRequest{
				Version:           1,
				SubpackageName:    "expected-subpackage",
				SubpackageVersion: &expectedVersion,
			})
			assert.ErrorIs(t, err, expectedErr)
			assert.True(t, called)
		})
		t.Run("no GetSchemaF", func(t *testing.T) {
			t.Parallel()
			prov := &Provider{}
			b, err := prov.GetSchema(context.Background(), plugin.GetSchemaRequest{})
			assert.NoError(t, err)
			assert.Equal(t, []byte("{}"), b.Schema)
		})
	})
	t.Run("CheckConfig", func(t *testing.T) {
		t.Parallel()
		t.Run("has CheckConfigF", func(t *testing.T) {
			t.Parallel()
			expectedErr := errors.New("expected error")
			var called bool
			prov := &Provider{
				CheckConfigF: func(
					urn resource.URN, olds, news resource.PropertyMap, allowUnknowns bool,
				) (resource.PropertyMap, []plugin.CheckFailure, error) {
					assert.Equal(t, resource.URN("expected-urn"), urn)
					assert.Equal(t, resource.NewStringProperty("old-value"), olds["old"])
					assert.Equal(t, resource.NewStringProperty("new-value"), news["new"])
					called = true
					return nil, nil, expectedErr
				},
			}
			_, err := prov.CheckConfig(context.Background(), plugin.CheckConfigRequest{
				URN: resource.URN("expected-urn"),
				Olds: resource.PropertyMap{
					"old": resource.NewStringProperty("old-value"),
				},
				News: resource.PropertyMap{
					"new": resource.NewStringProperty("new-value"),
				},
				AllowUnknowns: true,
			})
			assert.ErrorIs(t, err, expectedErr)
			assert.True(t, called)
		})
		t.Run("no CheckConfigF", func(t *testing.T) {
			t.Parallel()
			prov := &Provider{}
			resp, err := prov.CheckConfig(context.Background(), plugin.CheckConfigRequest{
				News: resource.PropertyMap{
					"expected": resource.NewStringProperty("expected-value"),
				},
				AllowUnknowns: true,
			})
			assert.NoError(t, err)
			assert.Empty(t, resp.Failures)
			// Should return the news.
			assert.Equal(t, resource.NewStringProperty("expected-value"), resp.Properties["expected"])
		})
	})
	t.Run("Construct", func(t *testing.T) {
		t.Parallel()
		t.Run("has ConstructF", func(t *testing.T) {
			t.Run("inject error", func(t *testing.T) {
				t.Parallel()
				expectedErr := errors.New("expected error")
				var dialCalled bool
				var constructCalled bool
				expectedResmon := &ResourceMonitor{}
				prov := &Provider{
					DialMonitorF: func(ctx context.Context, endpoint string) (*ResourceMonitor, error) {
						assert.Equal(t, "expected-endpoint", endpoint)
						dialCalled = true
						// Returns no error to avoid short-circuiting due to the monitor being provided
						// an invalid resource monitor address.
						return expectedResmon, nil
					},
					ConstructF: func(
						monitor *ResourceMonitor,
						typ, name string, parent resource.URN, inputs resource.PropertyMap,
						info plugin.ConstructInfo, options plugin.ConstructOptions,
					) (plugin.ConstructResult, error) {
						assert.Equal(t, expectedResmon, monitor)
						constructCalled = true
						return plugin.ConstructResult{}, expectedErr
					},
				}
				_, err := prov.Construct(context.Background(), plugin.ConstructRequest{
					Type: tokens.Type("some-type"),
					Name: "name",
					Info: plugin.ConstructInfo{
						MonitorAddress: "expected-endpoint",
					},
					Parent: resource.URN("<parent-urn>"),
				})
				assert.ErrorIs(t, err, expectedErr)
				assert.True(t, dialCalled)
				assert.True(t, constructCalled)
			})
			t.Run("dial error", func(t *testing.T) {
				t.Parallel()
				t.Run("invalid address", func(t *testing.T) {
					t.Parallel()
					prov := &Provider{
						ConstructF: func(
							monitor *ResourceMonitor,
							typ, name string, parent resource.URN, inputs resource.PropertyMap,
							info plugin.ConstructInfo, options plugin.ConstructOptions,
						) (plugin.ConstructResult, error) {
							assert.Fail(t, "Construct should not be called")
							return plugin.ConstructResult{}, nil
						},
					}
					_, err := prov.Construct(context.Background(), plugin.ConstructRequest{
						Type:   tokens.Type("some-type"),
						Name:   "name",
						Parent: resource.URN("<parent-urn>"),
					})
					assert.ErrorContains(t, err, "could not connect to resource monitor")
				})
				t.Run("injected error", func(t *testing.T) {
					t.Parallel()
					expectedErr := errors.New("expected error")
					var dialCalled bool
					prov := &Provider{
						DialMonitorF: func(ctx context.Context, endpoint string) (*ResourceMonitor, error) {
							dialCalled = true
							// Returns no error to avoid short-circuiting due to the monitor being provided
							// an invalid resource monitor address.
							return nil, expectedErr
						},
						ConstructF: func(
							monitor *ResourceMonitor,
							typ, name string, parent resource.URN, inputs resource.PropertyMap,
							info plugin.ConstructInfo, options plugin.ConstructOptions,
						) (plugin.ConstructResult, error) {
							assert.Fail(t, "Construct should not be called")
							return plugin.ConstructResult{}, nil
						},
					}
					_, err := prov.Construct(context.Background(), plugin.ConstructRequest{
						Type:   tokens.Type("some-type"),
						Name:   "name",
						Parent: resource.URN("<parent-urn>"),
					})
					assert.ErrorIs(t, err, expectedErr)
					assert.True(t, dialCalled)
				})
			})
		})
		t.Run("no ConstructF", func(t *testing.T) {
			t.Parallel()
			prov := &Provider{
				DialMonitorF: func(ctx context.Context, endpoint string) (*ResourceMonitor, error) {
					assert.Fail(t, "DialMonitor should not be called")
					return nil, nil
				},
			}
			_, err := prov.Construct(context.Background(), plugin.ConstructRequest{
				Type:   tokens.Type("some-type"),
				Name:   "name",
				Parent: resource.URN("<parent-urn>"),
			})
			assert.NoError(t, err)
		})
	})
	t.Run("Invoke", func(t *testing.T) {
		t.Parallel()
		t.Run("has InvokeF", func(t *testing.T) {
			t.Parallel()
			expectedPropertyMap := resource.PropertyMap{
				"key": resource.NewStringProperty("expected-value"),
			}
			var called bool
			prov := &Provider{
				InvokeF: func(tok tokens.ModuleMember, inputs resource.PropertyMap,
				) (resource.PropertyMap, []plugin.CheckFailure, error) {
					assert.Equal(t, tokens.ModuleMember("expected-tok"), tok)
					called = true
					return expectedPropertyMap, nil, nil
				},
			}
			resp, err := prov.Invoke(context.Background(), plugin.InvokeRequest{
				Tok: "expected-tok",
			})
			assert.NoError(t, err)
			assert.True(t, called)
			assert.Equal(t, expectedPropertyMap, resp.Properties)
		})
		t.Run("no InvokeF", func(t *testing.T) {
			t.Parallel()
			prov := &Provider{}
			resp, err := prov.Invoke(context.Background(), plugin.InvokeRequest{})
			assert.NoError(t, err)
			assert.Empty(t, resp.Failures)
			assert.Equal(t, resource.PropertyMap{}, resp.Properties)
		})
	})
	t.Run("StreamInvoke", func(t *testing.T) {
		t.Parallel()
		t.Run("has StreamInvokeF", func(t *testing.T) {
			t.Parallel()
			expectedErr := errors.New("expected error")
			prov := &Provider{
				StreamInvokeF: func(
					tok tokens.ModuleMember, args resource.PropertyMap,
					onNext func(resource.PropertyMap) error,
				) ([]plugin.CheckFailure, error) {
					assert.Equal(t, tokens.ModuleMember("expected-tok"), tok)
					return nil, expectedErr
				},
			}
			_, err := prov.StreamInvoke(context.Background(), plugin.StreamInvokeRequest{Tok: "expected-tok"})
			assert.ErrorIs(t, err, expectedErr)
		})
		t.Run("no StreamInvokeF", func(t *testing.T) {
			t.Parallel()
			prov := &Provider{}
			_, err := prov.StreamInvoke(context.Background(), plugin.StreamInvokeRequest{})
			assert.ErrorContains(t, err, "StreamInvoke unimplemented")
		})
	})
	t.Run("Call", func(t *testing.T) {
		t.Parallel()
		t.Run("has CallF", func(t *testing.T) {
			t.Run("inject error", func(t *testing.T) {
				t.Parallel()
				expectedErr := errors.New("expected error")
				var dialCalled bool
				var callCalled bool
				expectedResmon := &ResourceMonitor{}
				prov := &Provider{
					DialMonitorF: func(ctx context.Context, endpoint string) (*ResourceMonitor, error) {
						dialCalled = true
						// Returns no error to avoid short-circuiting due to the monitor being provided
						// an invalid resource monitor address.
						return expectedResmon, nil
					},
					CallF: func(
						monitor *ResourceMonitor, tok tokens.ModuleMember, args resource.PropertyMap,
						info plugin.CallInfo, options plugin.CallOptions,
					) (plugin.CallResult, error) {
						assert.Equal(t, expectedResmon, monitor)
						assert.Equal(t, tokens.ModuleMember("expected-tok"), tok)
						callCalled = true
						return plugin.CallResult{}, expectedErr
					},
				}
				_, err := prov.Call(context.Background(), plugin.CallRequest{Tok: "expected-tok"})
				assert.ErrorIs(t, err, expectedErr)
				assert.True(t, dialCalled)
				assert.True(t, callCalled)
			})
			t.Run("dial error", func(t *testing.T) {
				t.Parallel()
				t.Run("invalid address", func(t *testing.T) {
					t.Parallel()
					prov := &Provider{
						CallF: func(
							monitor *ResourceMonitor, tok tokens.ModuleMember, args resource.PropertyMap,
							info plugin.CallInfo, options plugin.CallOptions,
						) (plugin.CallResult, error) {
							assert.Fail(t, "Call should not be called")
							return plugin.CallResult{}, nil
						},
					}
					_, err := prov.Call(context.Background(), plugin.CallRequest{})
					assert.ErrorContains(t, err, "could not connect to resource monitor")
				})
				t.Run("injected error", func(t *testing.T) {
					t.Parallel()
					expectedErr := errors.New("expected error")
					var dialCalled bool
					prov := &Provider{
						DialMonitorF: func(ctx context.Context, endpoint string) (*ResourceMonitor, error) {
							dialCalled = true
							// Returns no error to avoid short-circuiting due to the monitor being provided
							// an invalid resource monitor address.
							return nil, expectedErr
						},
						CallF: func(
							monitor *ResourceMonitor, tok tokens.ModuleMember, args resource.PropertyMap,
							info plugin.CallInfo, options plugin.CallOptions,
						) (plugin.CallResult, error) {
							assert.Fail(t, "Call should not be called")
							return plugin.CallResult{}, expectedErr
						},
					}
					_, err := prov.Call(context.Background(), plugin.CallRequest{})
					assert.ErrorIs(t, err, expectedErr)
					assert.True(t, dialCalled)
				})
			})
		})
		t.Run("no CallF", func(t *testing.T) {
			t.Parallel()
			prov := &Provider{
				DialMonitorF: func(ctx context.Context, endpoint string) (*ResourceMonitor, error) {
					assert.Fail(t, "Dial should not be called")
					return nil, nil
				},
			}
			_, err := prov.Call(context.Background(), plugin.CallRequest{})
			assert.NoError(t, err)
		})
	})
	t.Run("GetMapping", func(t *testing.T) {
		t.Parallel()
		t.Run("has GetMappingF", func(t *testing.T) {
			t.Parallel()
			expectedErr := errors.New("expected error")
			prov := &Provider{
				GetMappingF: func(key, provider string) ([]byte, string, error) {
					assert.Equal(t, "expected-key", key)
					assert.Equal(t, "expected-provider", provider)
					return nil, "", expectedErr
				},
			}
			_, err := prov.GetMapping(context.Background(), plugin.GetMappingRequest{
				Key:      "expected-key",
				Provider: "expected-provider",
			})
			assert.ErrorIs(t, err, expectedErr)
		})
		t.Run("no GetMappingF", func(t *testing.T) {
			t.Parallel()
			prov := &Provider{}
			resp, err := prov.GetMapping(context.Background(), plugin.GetMappingRequest{})
			assert.NoError(t, err)
			assert.Equal(t, "", resp.Provider)
			assert.Nil(t, resp.Data)
		})
	})
	t.Run("GetMappings", func(t *testing.T) {
		t.Parallel()
		t.Run("has GetMappingsF", func(t *testing.T) {
			t.Parallel()
			expectedErr := errors.New("expected error")
			prov := &Provider{
				GetMappingsF: func(key string) ([]string, error) {
					assert.Equal(t, "expected-key", key)
					return nil, expectedErr
				},
			}
			_, err := prov.GetMappings(context.Background(), plugin.GetMappingsRequest{Key: "expected-key"})
			assert.ErrorIs(t, err, expectedErr)
		})
		t.Run("no GetMappingsF", func(t *testing.T) {
			t.Parallel()
			prov := &Provider{}
			mappingStrs, err := prov.GetMappings(context.Background(), plugin.GetMappingsRequest{})
			assert.NoError(t, err)
			assert.Empty(t, mappingStrs)
		})
	})
}