pulumi/pkg/codegen/go/gen_program_optionals.go

165 lines
4.4 KiB
Go
Raw Permalink Normal View History

package gen
import (
"github.com/hashicorp/hcl/v2"
"github.com/hashicorp/hcl/v2/hclsyntax"
"github.com/pulumi/pulumi/pkg/v3/codegen"
"github.com/pulumi/pulumi/pkg/v3/codegen/hcl2/model"
"github.com/pulumi/pulumi/pkg/v3/codegen/hcl2/syntax"
"github.com/pulumi/pulumi/pkg/v3/codegen/pcl"
"github.com/pulumi/pulumi/pkg/v3/codegen/schema"
)
type optionalTemp struct {
Name string
Value model.Expression
}
func (ot *optionalTemp) Type() model.Type {
return ot.Value.Type()
}
func (ot *optionalTemp) Traverse(traverser hcl.Traverser) (model.Traversable, hcl.Diagnostics) {
return ot.Type().Traverse(traverser)
}
func (ot *optionalTemp) SyntaxNode() hclsyntax.Node {
return syntax.None
}
type optionalSpiller struct {
invocation *model.FunctionCallExpression
intrinsicConvertTo *model.Type
}
func (os *optionalSpiller) preVisitor(x model.Expression) (model.Expression, hcl.Diagnostics) {
switch x := x.(type) {
case *model.FunctionCallExpression:
if x.Name == "invoke" {
// recurse into invoke args
isOutputInvoke, _, _ := pcl.RecognizeOutputVersionedInvoke(x)
// ignore output-versioned invokes as they do not need converting
if !isOutputInvoke {
os.invocation = x
}
os.intrinsicConvertTo = nil
return x, nil
}
if x.Name == pcl.IntrinsicConvert {
if os.invocation != nil {
os.intrinsicConvertTo = &x.Signature.ReturnType
}
return x, nil
}
case *model.ObjectConsExpression:
if os.invocation == nil {
return x, nil
}
destType := x.Type()
if os.intrinsicConvertTo != nil {
destType = *os.intrinsicConvertTo
}
if schemaType, ok := pcl.GetSchemaForType(destType); ok {
if schemaType, ok := schemaType.(*schema.ObjectType); ok {
// map of item name to optional type wrapper fn
optionalPrimitives := make(map[string]schema.Type)
for _, v := range schemaType.Properties {
if !v.IsRequired() {
ty := codegen.UnwrapType(v.Type)
switch ty {
case schema.NumberType, schema.BoolType, schema.IntType, schema.StringType:
optionalPrimitives[v.Name] = ty
}
}
}
for i, item := range x.Items {
// keys for schematized objects should be simple strings
if key, ok := item.Key.(*model.LiteralValueExpression); ok {
if model.StringType.AssignableFrom(key.Type()) {
strKey := key.Value.AsString()
if schemaType, isOptional := optionalPrimitives[strKey]; isOptional {
functionName := os.getOptionalConversion(schemaType)
expectedModelType := os.getExpectedModelType(schemaType)
x.Items[i].Value = &model.FunctionCallExpression{
Name: functionName,
Signature: model.StaticFunctionSignature{
Parameters: []model.Parameter{{
Name: "val",
Type: expectedModelType,
}},
ReturnType: model.NewOptionalType(expectedModelType),
},
Args: []model.Expression{item.Value},
}
}
}
}
}
}
}
// Clear before visiting children, require another __convert call to set again
os.intrinsicConvertTo = nil
return x, nil
default:
// Ditto
os.intrinsicConvertTo = nil
return x, nil
}
return x, nil
}
func (os *optionalSpiller) postVisitor(x model.Expression) (model.Expression, hcl.Diagnostics) {
switch x := x.(type) {
case *model.FunctionCallExpression:
if x.Name == "invoke" {
if x == os.invocation {
// Clear invocation flag once we're done traversing children.
os.invocation = nil
}
}
}
return x, nil
}
func (*optionalSpiller) getOptionalConversion(ty schema.Type) string {
switch ty {
case schema.NumberType:
return "goOptionalFloat64"
case schema.BoolType:
return "goOptionalBool"
case schema.IntType:
return "goOptionalInt"
case schema.StringType:
return "goOptionalString"
default:
return ""
}
}
func (*optionalSpiller) getExpectedModelType(ty schema.Type) model.Type {
switch ty {
case schema.NumberType:
return model.NumberType
case schema.BoolType:
return model.BoolType
case schema.IntType:
return model.IntType
case schema.StringType:
return model.StringType
default:
return nil
}
}
func (g *generator) rewriteOptionals(
x model.Expression,
spiller *optionalSpiller,
) (model.Expression, []*optionalTemp, hcl.Diagnostics) {
// We want to recurse but we only want to use the previsitor, if post visitor is nil we don't
// recurse.
x, diags := model.VisitExpression(x, spiller.preVisitor, spiller.postVisitor)
return x, nil, diags
}