mirror of https://github.com/pulumi/pulumi.git
297 lines
7.5 KiB
Go
297 lines
7.5 KiB
Go
package python
|
|
|
|
import (
|
|
"fmt"
|
|
"strings"
|
|
|
|
"github.com/hashicorp/hcl/v2"
|
|
"github.com/hashicorp/hcl/v2/hclsyntax"
|
|
"github.com/pulumi/pulumi/pkg/v3/codegen"
|
|
"github.com/pulumi/pulumi/pkg/v3/codegen/hcl2/model"
|
|
"github.com/pulumi/pulumi/pkg/v3/codegen/hcl2/syntax"
|
|
"github.com/pulumi/pulumi/pkg/v3/codegen/pcl"
|
|
"github.com/pulumi/pulumi/sdk/v3/go/common/util/contract"
|
|
"github.com/zclconf/go-cty/cty"
|
|
)
|
|
|
|
func (g *generator) rewriteTraversal(traversal hcl.Traversal, source model.Expression,
|
|
parts []model.Traversable,
|
|
) model.Expression {
|
|
// TODO(pdg): transfer trivia
|
|
|
|
var rootName string
|
|
var currentTraversal hcl.Traversal
|
|
currentParts := []model.Traversable{parts[0]}
|
|
currentExpression := source
|
|
|
|
if len(traversal) > 0 {
|
|
if root, isRoot := traversal[0].(hcl.TraverseRoot); isRoot {
|
|
traversal = traversal[1:]
|
|
rootName, currentTraversal = root.Name, hcl.Traversal{root}
|
|
}
|
|
}
|
|
|
|
for i, traverser := range traversal {
|
|
var key cty.Value
|
|
switch traverser := traverser.(type) {
|
|
case hcl.TraverseAttr:
|
|
key = cty.StringVal(traverser.Name)
|
|
case hcl.TraverseIndex:
|
|
key = traverser.Key
|
|
default:
|
|
contract.Failf("unexpected traverser of type %T (%v)", traverser, traverser.SourceRange())
|
|
}
|
|
|
|
if key.Type() != cty.String {
|
|
currentTraversal = append(currentTraversal, traverser)
|
|
currentParts = append(currentParts, parts[i+1])
|
|
continue
|
|
}
|
|
|
|
keyVal, objectKey := key.AsString(), false
|
|
|
|
receiver := parts[i]
|
|
_, hasSchema := pcl.GetSchemaForType(model.GetTraversableType(receiver))
|
|
_, isComponent := receiver.(*pcl.Component)
|
|
|
|
if hasSchema || isComponent {
|
|
objectKey = true
|
|
switch t := traverser.(type) {
|
|
case hcl.TraverseAttr:
|
|
t.Name = keyVal
|
|
traverser, traversal[i] = t, t
|
|
case hcl.TraverseIndex:
|
|
t.Key = cty.StringVal(keyVal)
|
|
traverser, traversal[i] = t, t
|
|
}
|
|
}
|
|
|
|
if objectKey && isLegalIdentifier(keyVal) {
|
|
currentTraversal = append(currentTraversal, traverser)
|
|
if i < len(traversal)-1 {
|
|
currentParts = append(currentParts, parts[i+1])
|
|
}
|
|
continue
|
|
}
|
|
|
|
if currentExpression == nil {
|
|
currentExpression = &model.ScopeTraversalExpression{
|
|
RootName: rootName,
|
|
Traversal: currentTraversal,
|
|
Parts: currentParts,
|
|
}
|
|
|
|
currentTraversal, currentParts = nil, nil
|
|
} else if len(currentTraversal) > 0 {
|
|
currentExpression = &model.RelativeTraversalExpression{
|
|
Source: currentExpression,
|
|
Traversal: currentTraversal,
|
|
Parts: currentParts,
|
|
}
|
|
currentTraversal, currentParts = nil, []model.Traversable{currentExpression.Type()}
|
|
}
|
|
|
|
currentExpression = &model.IndexExpression{
|
|
Collection: currentExpression,
|
|
Key: &model.LiteralValueExpression{
|
|
Value: cty.StringVal(keyVal),
|
|
},
|
|
}
|
|
|
|
// typechecking the index expression will compute the type of the expression
|
|
// so that when we call Type() later on, we get the correct type.
|
|
typecheckDiags := currentExpression.Typecheck(true)
|
|
g.diagnostics = g.diagnostics.Extend(typecheckDiags)
|
|
}
|
|
|
|
if currentExpression == source {
|
|
return nil
|
|
}
|
|
|
|
return currentExpression
|
|
}
|
|
|
|
type quoteTemp struct {
|
|
Name string
|
|
VariableType model.Type
|
|
Value model.Expression
|
|
}
|
|
|
|
func (qt *quoteTemp) Type() model.Type {
|
|
return qt.VariableType
|
|
}
|
|
|
|
func (qt *quoteTemp) Traverse(traverser hcl.Traverser) (model.Traversable, hcl.Diagnostics) {
|
|
return qt.VariableType.Traverse(traverser)
|
|
}
|
|
|
|
func (qt *quoteTemp) SyntaxNode() hclsyntax.Node {
|
|
return syntax.None
|
|
}
|
|
|
|
type quoteAllocations struct {
|
|
quotes map[model.Expression]string
|
|
temps []*quoteTemp
|
|
}
|
|
|
|
type quoteAllocator struct {
|
|
allocations *quoteAllocations
|
|
allocated codegen.StringSet
|
|
stack []model.Expression
|
|
}
|
|
|
|
func (qa *quoteAllocator) allocate(longString bool) (string, bool) {
|
|
if longString {
|
|
if !qa.allocated.Has(`"`) && !qa.allocated.Has(`"""`) {
|
|
qa.allocated.Add(`"""`)
|
|
return `"""`, true
|
|
}
|
|
|
|
if !qa.allocated.Has(`'`) && !qa.allocated.Has(`'''`) {
|
|
qa.allocated.Add(`'''`)
|
|
return `'''`, true
|
|
}
|
|
|
|
return "", false
|
|
}
|
|
|
|
if !qa.allocated.Has(`"`) {
|
|
qa.allocated.Add(`"`)
|
|
return `"`, true
|
|
}
|
|
|
|
if !qa.allocated.Has(`'`) {
|
|
qa.allocated.Add(`'`)
|
|
return `'`, true
|
|
}
|
|
|
|
return "", false
|
|
}
|
|
|
|
func (qa *quoteAllocator) free(quotes string) {
|
|
qa.allocated.Delete(quotes)
|
|
}
|
|
|
|
func (qa *quoteAllocator) inTemplate() bool {
|
|
if len(qa.stack) < 2 {
|
|
return false
|
|
}
|
|
_, isTemplate := qa.stack[len(qa.stack)-2].(*model.TemplateExpression)
|
|
return isTemplate
|
|
}
|
|
|
|
func (qa *quoteAllocator) allocateExpression(x model.Expression) (model.Expression, hcl.Diagnostics) {
|
|
qa.stack = append(qa.stack, x)
|
|
|
|
var longString bool
|
|
switch x := x.(type) {
|
|
case *model.LiteralValueExpression:
|
|
if !model.StringType.AssignableFrom(x.Type()) || qa.inTemplate() {
|
|
return x, nil
|
|
}
|
|
v := x.Value.AsString()
|
|
switch strings.Count(v, "\n") {
|
|
case 0:
|
|
// OK
|
|
case 1:
|
|
longString = v[0] != '\n' && v[len(v)-1] != '\n'
|
|
default:
|
|
longString = true
|
|
}
|
|
case *model.TemplateExpression:
|
|
for i, part := range x.Parts {
|
|
if lit, ok := part.(*model.LiteralValueExpression); ok && model.StringType.AssignableFrom(lit.Type()) {
|
|
v := lit.Value.AsString()
|
|
switch strings.Count(v, "\n") {
|
|
case 0:
|
|
continue
|
|
case 1:
|
|
if i == 0 && v[0] == '\n' || i == len(x.Parts)-1 && v[len(v)-1] == '\n' {
|
|
continue
|
|
}
|
|
}
|
|
longString = true
|
|
break
|
|
}
|
|
}
|
|
default:
|
|
return x, nil
|
|
}
|
|
|
|
if quote, ok := qa.allocate(longString); ok {
|
|
qa.allocations.quotes[x] = quote
|
|
return x, nil
|
|
}
|
|
|
|
allocator := "eAllocator{allocated: codegen.StringSet{}, allocations: qa.allocations}
|
|
value, valueDiags := model.VisitExpression(x, allocator.allocateExpression, allocator.freeExpression)
|
|
|
|
temp := "eTemp{
|
|
Name: fmt.Sprintf("str%d", len(qa.allocations.temps)),
|
|
VariableType: x.Type(),
|
|
Value: value,
|
|
}
|
|
qa.allocations.temps = append(qa.allocations.temps, temp)
|
|
|
|
return &model.ScopeTraversalExpression{
|
|
RootName: temp.Name,
|
|
Traversal: hcl.Traversal{hcl.TraverseRoot{Name: ""}},
|
|
Parts: []model.Traversable{temp},
|
|
}, valueDiags
|
|
}
|
|
|
|
func (qa *quoteAllocator) freeExpression(x model.Expression) (model.Expression, hcl.Diagnostics) {
|
|
defer func() {
|
|
qa.stack = qa.stack[:len(qa.stack)-1]
|
|
}()
|
|
|
|
switch x := x.(type) {
|
|
case *model.LiteralValueExpression:
|
|
if !model.StringType.AssignableFrom(x.Type()) || qa.inTemplate() {
|
|
return x, nil
|
|
}
|
|
// OK
|
|
case *model.TemplateExpression:
|
|
// OK
|
|
default:
|
|
return x, nil
|
|
}
|
|
|
|
quotes, ok := qa.allocations.quotes[x]
|
|
contract.Assertf(ok, "cannot free unknown expression")
|
|
qa.free(quotes)
|
|
return x, nil
|
|
}
|
|
|
|
func (g *generator) rewriteQuotes(x model.Expression) (model.Expression, []*quoteTemp, hcl.Diagnostics) {
|
|
var diagnostics hcl.Diagnostics
|
|
|
|
// First, rewrite traversals that require string indices into index expressions.
|
|
x, rewriteDiags := model.VisitExpression(x, nil, func(x model.Expression) (model.Expression, hcl.Diagnostics) {
|
|
switch x := x.(type) {
|
|
case *model.RelativeTraversalExpression:
|
|
idx := g.rewriteTraversal(x.Traversal, x.Source, x.Parts)
|
|
if idx != nil {
|
|
return idx, nil
|
|
}
|
|
case *model.ScopeTraversalExpression:
|
|
idx := g.rewriteTraversal(x.Traversal, nil, x.Parts)
|
|
if idx != nil {
|
|
return idx, nil
|
|
}
|
|
}
|
|
return x, nil
|
|
})
|
|
diagnostics = append(diagnostics, rewriteDiags...)
|
|
|
|
// Then lift any expressions that cannot be allocated quotes into temps.
|
|
allocations := "eAllocations{
|
|
quotes: g.quotes,
|
|
}
|
|
allocator := "eAllocator{allocated: codegen.StringSet{}, allocations: allocations}
|
|
x, rewriteDiags = model.VisitExpression(x, allocator.allocateExpression, allocator.freeExpression)
|
|
diagnostics = append(diagnostics, rewriteDiags...)
|
|
|
|
return x, allocations.temps, diagnostics
|
|
}
|