mautrix-go/bridgev2/portalinternal_generate.go

174 lines
4.3 KiB
Go

// Copyright (c) 2024 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
//go:build ignore
package main
import (
"fmt"
"go/ast"
"go/parser"
"go/token"
"os"
"strings"
"go.mau.fi/util/exerrors"
)
const header = `// GENERATED BY portalinternal_generate.go; DO NOT EDIT
//go:generate go run portalinternal_generate.go
//go:generate goimports -local maunium.net/go/mautrix -w portalinternal.go
package bridgev2
`
const postImportHeader = `
type PortalInternals Portal
// Deprecated: portal internals should be used carefully and only when necessary.
func (portal *Portal) Internal() *PortalInternals {
return (*PortalInternals)(portal)
}
`
func getTypeName(expr ast.Expr) string {
switch e := expr.(type) {
case *ast.Ident:
return e.Name
case *ast.StarExpr:
return "*" + getTypeName(e.X)
case *ast.ArrayType:
return "[]" + getTypeName(e.Elt)
case *ast.MapType:
return fmt.Sprintf("map[%s]%s", getTypeName(e.Key), getTypeName(e.Value))
case *ast.ChanType:
return fmt.Sprintf("chan %s", getTypeName(e.Value))
case *ast.FuncType:
var params []string
for _, param := range e.Params.List {
params = append(params, getTypeName(param.Type))
}
var results []string
if e.Results != nil {
for _, result := range e.Results.List {
results = append(results, getTypeName(result.Type))
}
}
return fmt.Sprintf("func(%s) %s", strings.Join(params, ", "), strings.Join(results, ", "))
case *ast.SelectorExpr:
return fmt.Sprintf("%s.%s", getTypeName(e.X), e.Sel.Name)
default:
panic(fmt.Errorf("unknown type %T", e))
}
}
var write func(str string)
var writef func(format string, args ...any)
func main() {
fset := token.NewFileSet()
fileNames := []string{"portal.go", "portalbackfill.go", "portalreid.go", "space.go", "matrixinvite.go"}
files := make([]*ast.File, len(fileNames))
for i, name := range fileNames {
files[i] = exerrors.Must(parser.ParseFile(fset, name, nil, parser.SkipObjectResolution))
}
file := exerrors.Must(os.OpenFile("portalinternal.go", os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0644))
write = func(str string) {
exerrors.Must(file.WriteString(str))
}
writef = func(format string, args ...any) {
exerrors.Must(fmt.Fprintf(file, format, args...))
}
write(header)
write("import (\n")
for _, i := range files[0].Imports {
write("\t")
if i.Name != nil {
writef("%s ", i.Name.Name)
}
writef("%s\n", i.Path.Value)
}
write(")\n")
write(postImportHeader)
for _, f := range files {
processFile(f)
}
exerrors.PanicIfNotNil(file.Close())
}
func processFile(f *ast.File) {
ast.Inspect(f, func(node ast.Node) (retVal bool) {
retVal = true
funcDecl, ok := node.(*ast.FuncDecl)
if !ok || funcDecl.Name.IsExported() {
return
}
if funcDecl.Recv == nil || len(funcDecl.Recv.List) == 0 || len(funcDecl.Recv.List[0].Names) == 0 ||
funcDecl.Recv.List[0].Names[0].Name != "portal" {
return
}
writef("\nfunc (portal *PortalInternals) %s%s(", strings.ToUpper(funcDecl.Name.Name[0:1]), funcDecl.Name.Name[1:])
for i, param := range funcDecl.Type.Params.List {
if i != 0 {
write(", ")
}
for j, name := range param.Names {
if j != 0 {
write(", ")
}
write(name.Name)
}
if len(param.Names) > 0 {
write(" ")
}
write(getTypeName(param.Type))
}
write(") ")
if funcDecl.Type.Results != nil && len(funcDecl.Type.Results.List) > 0 {
needsParentheses := len(funcDecl.Type.Results.List) > 1 || len(funcDecl.Type.Results.List[0].Names) > 0
if needsParentheses {
write("(")
}
for i, result := range funcDecl.Type.Results.List {
if i != 0 {
write(", ")
}
for j, name := range result.Names {
if j != 0 {
write(", ")
}
write(name.Name)
}
if len(result.Names) > 0 {
write(" ")
}
write(getTypeName(result.Type))
}
if needsParentheses {
write(")")
}
write(" ")
}
write("{\n\t")
if funcDecl.Type.Results != nil {
write("return ")
}
writef("(*Portal)(portal).%s(", funcDecl.Name.Name)
for i, param := range funcDecl.Type.Params.List {
for j, name := range param.Names {
if i != 0 || j != 0 {
write(", ")
}
write(name.Name)
}
}
write(")\n}\n")
return
})
}