mirror of https://github.com/mautrix/go.git
174 lines
4.3 KiB
Go
174 lines
4.3 KiB
Go
// Copyright (c) 2024 Tulir Asokan
|
|
//
|
|
// This Source Code Form is subject to the terms of the Mozilla Public
|
|
// License, v. 2.0. If a copy of the MPL was not distributed with this
|
|
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
|
|
|
//go:build ignore
|
|
|
|
package main
|
|
|
|
import (
|
|
"fmt"
|
|
"go/ast"
|
|
"go/parser"
|
|
"go/token"
|
|
"os"
|
|
"strings"
|
|
|
|
"go.mau.fi/util/exerrors"
|
|
)
|
|
|
|
const header = `// GENERATED BY portalinternal_generate.go; DO NOT EDIT
|
|
|
|
//go:generate go run portalinternal_generate.go
|
|
//go:generate goimports -local maunium.net/go/mautrix -w portalinternal.go
|
|
|
|
package bridgev2
|
|
|
|
`
|
|
const postImportHeader = `
|
|
type PortalInternals Portal
|
|
|
|
// Deprecated: portal internals should be used carefully and only when necessary.
|
|
func (portal *Portal) Internal() *PortalInternals {
|
|
return (*PortalInternals)(portal)
|
|
}
|
|
`
|
|
|
|
func getTypeName(expr ast.Expr) string {
|
|
switch e := expr.(type) {
|
|
case *ast.Ident:
|
|
return e.Name
|
|
case *ast.StarExpr:
|
|
return "*" + getTypeName(e.X)
|
|
case *ast.ArrayType:
|
|
return "[]" + getTypeName(e.Elt)
|
|
case *ast.MapType:
|
|
return fmt.Sprintf("map[%s]%s", getTypeName(e.Key), getTypeName(e.Value))
|
|
case *ast.ChanType:
|
|
return fmt.Sprintf("chan %s", getTypeName(e.Value))
|
|
case *ast.FuncType:
|
|
var params []string
|
|
for _, param := range e.Params.List {
|
|
params = append(params, getTypeName(param.Type))
|
|
}
|
|
var results []string
|
|
if e.Results != nil {
|
|
for _, result := range e.Results.List {
|
|
results = append(results, getTypeName(result.Type))
|
|
}
|
|
}
|
|
return fmt.Sprintf("func(%s) %s", strings.Join(params, ", "), strings.Join(results, ", "))
|
|
case *ast.SelectorExpr:
|
|
return fmt.Sprintf("%s.%s", getTypeName(e.X), e.Sel.Name)
|
|
default:
|
|
panic(fmt.Errorf("unknown type %T", e))
|
|
}
|
|
}
|
|
|
|
var write func(str string)
|
|
var writef func(format string, args ...any)
|
|
|
|
func main() {
|
|
fset := token.NewFileSet()
|
|
fileNames := []string{"portal.go", "portalbackfill.go", "portalreid.go", "space.go", "matrixinvite.go"}
|
|
files := make([]*ast.File, len(fileNames))
|
|
for i, name := range fileNames {
|
|
files[i] = exerrors.Must(parser.ParseFile(fset, name, nil, parser.SkipObjectResolution))
|
|
}
|
|
file := exerrors.Must(os.OpenFile("portalinternal.go", os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0644))
|
|
write = func(str string) {
|
|
exerrors.Must(file.WriteString(str))
|
|
}
|
|
writef = func(format string, args ...any) {
|
|
exerrors.Must(fmt.Fprintf(file, format, args...))
|
|
}
|
|
write(header)
|
|
write("import (\n")
|
|
for _, i := range files[0].Imports {
|
|
write("\t")
|
|
if i.Name != nil {
|
|
writef("%s ", i.Name.Name)
|
|
}
|
|
writef("%s\n", i.Path.Value)
|
|
}
|
|
write(")\n")
|
|
write(postImportHeader)
|
|
for _, f := range files {
|
|
processFile(f)
|
|
}
|
|
exerrors.PanicIfNotNil(file.Close())
|
|
}
|
|
|
|
func processFile(f *ast.File) {
|
|
ast.Inspect(f, func(node ast.Node) (retVal bool) {
|
|
retVal = true
|
|
funcDecl, ok := node.(*ast.FuncDecl)
|
|
if !ok || funcDecl.Name.IsExported() {
|
|
return
|
|
}
|
|
if funcDecl.Recv == nil || len(funcDecl.Recv.List) == 0 || len(funcDecl.Recv.List[0].Names) == 0 ||
|
|
funcDecl.Recv.List[0].Names[0].Name != "portal" {
|
|
return
|
|
}
|
|
writef("\nfunc (portal *PortalInternals) %s%s(", strings.ToUpper(funcDecl.Name.Name[0:1]), funcDecl.Name.Name[1:])
|
|
for i, param := range funcDecl.Type.Params.List {
|
|
if i != 0 {
|
|
write(", ")
|
|
}
|
|
for j, name := range param.Names {
|
|
if j != 0 {
|
|
write(", ")
|
|
}
|
|
write(name.Name)
|
|
}
|
|
if len(param.Names) > 0 {
|
|
write(" ")
|
|
}
|
|
write(getTypeName(param.Type))
|
|
}
|
|
write(") ")
|
|
if funcDecl.Type.Results != nil && len(funcDecl.Type.Results.List) > 0 {
|
|
needsParentheses := len(funcDecl.Type.Results.List) > 1 || len(funcDecl.Type.Results.List[0].Names) > 0
|
|
if needsParentheses {
|
|
write("(")
|
|
}
|
|
for i, result := range funcDecl.Type.Results.List {
|
|
if i != 0 {
|
|
write(", ")
|
|
}
|
|
for j, name := range result.Names {
|
|
if j != 0 {
|
|
write(", ")
|
|
}
|
|
write(name.Name)
|
|
}
|
|
if len(result.Names) > 0 {
|
|
write(" ")
|
|
}
|
|
write(getTypeName(result.Type))
|
|
}
|
|
if needsParentheses {
|
|
write(")")
|
|
}
|
|
write(" ")
|
|
}
|
|
write("{\n\t")
|
|
if funcDecl.Type.Results != nil {
|
|
write("return ")
|
|
}
|
|
writef("(*Portal)(portal).%s(", funcDecl.Name.Name)
|
|
for i, param := range funcDecl.Type.Params.List {
|
|
for j, name := range param.Names {
|
|
if i != 0 || j != 0 {
|
|
write(", ")
|
|
}
|
|
write(name.Name)
|
|
}
|
|
}
|
|
write(")\n}\n")
|
|
return
|
|
})
|
|
}
|