// Copyright 2016-2020, Pulumi Corporation.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package pcl

import (
	"fmt"
	"io"
	"path"
	"sort"

	"github.com/hashicorp/hcl/v2"
	"github.com/hashicorp/hcl/v2/hclsyntax"
	"github.com/pulumi/pulumi/pkg/v3/codegen/hcl2/model"
	"github.com/pulumi/pulumi/pkg/v3/codegen/hcl2/syntax"
	"github.com/pulumi/pulumi/pkg/v3/codegen/schema"
	"github.com/pulumi/pulumi/sdk/v3/go/common/slice"
	"github.com/spf13/afero"
)

// Node represents a single definition in a program or component.
// Nodes may be config, locals, resources, components, or outputs.
type Node interface {
	model.Definition

	// Name returns the lexical name of the node.
	Name() string

	// Type returns the type of the node.
	Type() model.Type

	// VisitExpressions visits the expressions that make up the node's body.
	VisitExpressions(pre, post model.ExpressionVisitor) hcl.Diagnostics

	markBinding()
	markBound()
	isBinding() bool
	isBound() bool

	getDependencies() []Node
	setDependencies(nodes []Node)

	isNode()
}

type node struct {
	binding bool
	bound   bool
	deps    []Node
}

func (r *node) markBinding() {
	r.binding = true
}

func (r *node) markBound() {
	r.bound = true
}

func (r *node) isBinding() bool {
	return r.binding && !r.bound
}

func (r *node) isBound() bool {
	return r.bound
}

func (r *node) getDependencies() []Node {
	return r.deps
}

func (r *node) setDependencies(nodes []Node) {
	r.deps = nodes
}

func (*node) isNode() {}

// Program represents a semantically-analyzed Pulumi HCL2 program.
type Program struct {
	Nodes []Node

	files []*syntax.File

	binder *binder
}

// NewDiagnosticWriter creates a new hcl.DiagnosticWriter for use with diagnostics generated by the program.
func (p *Program) NewDiagnosticWriter(w io.Writer, width uint, color bool) hcl.DiagnosticWriter {
	return syntax.NewDiagnosticWriter(w, p.files, width, color)
}

// BindExpression binds an HCL2 expression in the top-level context of the program.
func (p *Program) BindExpression(node hclsyntax.Node) (model.Expression, hcl.Diagnostics) {
	return p.binder.bindExpression(node)
}

// Packages returns the list of package referenced used by this program.
func (p *Program) Packages() []*schema.Package {
	refs := p.PackageReferences()
	defs := make([]*schema.Package, len(refs))
	for i, ref := range refs {
		def, err := ref.Definition()
		if err != nil {
			panic(fmt.Errorf("loading package definition: %w", err))
		}
		defs[i] = def
	}
	return defs
}

// PackageReferences returns the list of package referenced used by this program.
func (p *Program) PackageReferences() []schema.PackageReference {
	keys := slice.Prealloc[string](len(p.binder.referencedPackages))
	for k := range p.binder.referencedPackages {
		keys = append(keys, k)
	}
	sort.Strings(keys)

	values := slice.Prealloc[schema.PackageReference](len(p.binder.referencedPackages))
	for _, k := range keys {
		values = append(values, p.binder.referencedPackages[k])
	}
	return values
}

// PackageSnapshots returns the list of packages schemas used by this program. If a referenced package is partial,
// its returned value is a snapshot that contains only the package members referenced by the program. Otherwise, its
// returned value is the full package definition.
func (p *Program) PackageSnapshots() ([]*schema.Package, error) {
	keys := slice.Prealloc[string](len(p.binder.referencedPackages))
	for k := range p.binder.referencedPackages {
		keys = append(keys, k)
	}
	sort.Strings(keys)

	values := slice.Prealloc[*schema.Package](len(p.binder.referencedPackages))
	for _, k := range keys {
		ref := p.binder.referencedPackages[k]

		var pkg *schema.Package
		var err error
		if partial, ok := ref.(*schema.PartialPackage); ok {
			pkg, err = partial.Snapshot()
		} else {
			pkg, err = ref.Definition()
		}
		if err != nil {
			return nil, fmt.Errorf("defining package '%v': %w", ref.Name(), err)
		}

		values = append(values, pkg)
	}
	return values, nil
}

func (p *Program) Source() map[string]string {
	source := make(map[string]string)
	for _, file := range p.files {
		source[file.Name] = string(file.Bytes)
	}
	return source
}

// writeSourceFiles writes the source files of the program, including those files of used components into the
// provided file system starting from a root directory.
func (p *Program) writeSourceFiles(directory string, fs afero.Fs, seenPaths map[string]bool) error {
	// create the directory if it doesn't already exist
	err := fs.MkdirAll(directory, 0o755)
	if err != nil {
		return fmt.Errorf("could not create output directory: %w", err)
	}

	for _, file := range p.files {
		outputFile := path.Join(directory, file.Name)
		err := afero.WriteFile(fs, outputFile, file.Bytes, 0o600)
		if err != nil {
			return fmt.Errorf("could not write output program: %w", err)
		}
	}

	for _, node := range p.Nodes {
		switch node := node.(type) {
		case *Component:
			componentDirectory := path.Join(directory, node.source)
			if _, seen := seenPaths[componentDirectory]; !seen {
				seenPaths[componentDirectory] = true
				err = node.Program.writeSourceFiles(componentDirectory, fs, seenPaths)
				if err != nil {
					return err
				}
			}
		}
	}

	return nil
}

// WriteSource writes the source files of the program, including those files of used components into the
// provided file system.
func (p *Program) WriteSource(fs afero.Fs) error {
	seenPaths := map[string]bool{}
	return p.writeSourceFiles("/", fs, seenPaths)
}

// collectComponentsRecursive is a helper function to find all used components in a program
// and recursively searches of nested components from sub programs.
func (p *Program) collectComponentsRecursive(components map[string]*Component) {
	for _, node := range p.Nodes {
		switch node := node.(type) {
		case *Component:
			if _, seen := components[node.DirPath()]; !seen {
				components[node.DirPath()] = node
				node.Program.collectComponentsRecursive(components)
			}
		}
	}
}

// CollectComponents finds all used components in a program and recursively searches of nested components
// from sub programs.
func (p *Program) CollectComponents() map[string]*Component {
	components := map[string]*Component{}
	p.collectComponentsRecursive(components)
	return components
}

func (p *Program) collectPackageSnapshots(seenPackages map[string]*schema.Package) error {
	packages, err := p.PackageSnapshots()
	if err != nil {
		return err
	}

	for _, pkg := range packages {
		if _, seen := seenPackages[pkg.Name]; !seen {
			seenPackages[pkg.Name] = pkg
		}
	}

	for _, component := range p.CollectComponents() {
		err = component.Program.collectPackageSnapshots(seenPackages)
		if err != nil {
			return err
		}
	}

	return nil
}

func (p *Program) CollectNestedPackageSnapshots() (map[string]*schema.Package, error) {
	seenPackages := map[string]*schema.Package{}
	err := p.collectPackageSnapshots(seenPackages)
	return seenPackages, err
}

// ConfigVariables returns the config variable nodes of the program
func (p *Program) ConfigVariables() []*ConfigVariable {
	var configVars []*ConfigVariable
	for _, node := range p.Nodes {
		switch node := node.(type) {
		case *ConfigVariable:
			configVars = append(configVars, node)
		}
	}

	return configVars
}

// OutputVariables returns the output variable nodes of the program
func (p *Program) OutputVariables() []*OutputVariable {
	var outputs []*OutputVariable
	for _, node := range p.Nodes {
		switch node := node.(type) {
		case *OutputVariable:
			outputs = append(outputs, node)
		}
	}

	return outputs
}