From 808f0bde9d63a869cf91d9daf9188dded1cd0abf Mon Sep 17 00:00:00 2001 From: Ethan Reesor Date: Thu, 23 Sep 2021 05:34:12 -0500 Subject: [PATCH] interp: add a function to directly compile Go AST Adds CompileAST, which can be used to compile Go AST directly. This allows users to delegate parsing of source to their own code instead of relying on the interpreter. CLoses #1251 --- interp/ast.go | 50 ++++++++++++------------- interp/compile_test.go | 84 ++++++++++++++++++++++++++++++++++++++++++ interp/interp.go | 2 +- interp/program.go | 22 +++++++++-- interp/src.go | 10 ++++- 5 files changed, 137 insertions(+), 31 deletions(-) create mode 100644 interp/compile_test.go diff --git a/interp/ast.go b/interp/ast.go index ec1b5082..bd135f01 100644 --- a/interp/ast.go +++ b/interp/ast.go @@ -1,7 +1,6 @@ package interp import ( - "errors" "fmt" "go/ast" "go/constant" @@ -362,21 +361,14 @@ func wrapInMain(src string) string { return fmt.Sprintf("package main; func main() {%s\n}", src) } -// Note: no type analysis is performed at this stage, it is done in pre-order -// processing of CFG, in order to accommodate forward type declarations. - -// ast parses src string containing Go code and generates the corresponding AST. -// The package name and the AST root node are returned. -// The given name is used to set the filename of the relevant source file in the -// interpreter's FileSet. -func (interp *Interpreter) ast(src, name string, inc bool) (string, *node, error) { - var inFunc bool +func (interp *Interpreter) parse(src, name string, inc bool) (node ast.Node, err error) { mode := parser.DeclarationErrors // Allow incremental parsing of declarations or statements, by inserting // them in a pseudo file package or function. Those statements or // declarations will be always evaluated in the global scope. var tok token.Token + var inFunc bool if inc { tok = interp.firstToken(src) switch tok { @@ -393,18 +385,18 @@ func (interp *Interpreter) ast(src, name string, inc bool) (string, *node, error } if ok, err := interp.buildOk(&interp.context, name, src); !ok || err != nil { - return "", nil, err // skip source not matching build constraints + return nil, err // skip source not matching build constraints } f, err := parser.ParseFile(interp.fset, name, src, mode) if err != nil { // only retry if we're on an expression/statement about a func if !inc || tok != token.FUNC { - return "", nil, err + return nil, err } // do not bother retrying if we know it's an error we're going to ignore later on. if ignoreError(err, src) { - return "", nil, err + return nil, err } // do not lose initial error, in case retrying fails. initialError := err @@ -412,16 +404,32 @@ func (interp *Interpreter) ast(src, name string, inc bool) (string, *node, error src := wrapInMain(strings.TrimPrefix(src, "package main;")) f, err = parser.ParseFile(interp.fset, name, src, mode) if err != nil { - return "", nil, initialError + return nil, initialError } } - setYaegiTags(&interp.context, f.Comments) + if inFunc { + // return the body of the wrapper main function + return f.Decls[0].(*ast.FuncDecl).Body, nil + } + setYaegiTags(&interp.context, f.Comments) + return f, nil +} + +// Note: no type analysis is performed at this stage, it is done in pre-order +// processing of CFG, in order to accommodate forward type declarations. + +// ast parses src string containing Go code and generates the corresponding AST. +// The package name and the AST root node are returned. +// The given name is used to set the filename of the relevant source file in the +// interpreter's FileSet. +func (interp *Interpreter) ast(f ast.Node) (string, *node, error) { + var err error var root *node var anc astNode var st nodestack - var pkgName string + pkgName := "main" addChild := func(root **node, anc astNode, pos token.Pos, kind nkind, act action) *node { var i interface{} @@ -898,15 +906,7 @@ func (interp *Interpreter) ast(src, name string, inc bool) (string, *node, error } return true }) - if inFunc { - // Incremental parsing: statements were inserted in a pseudo function. - // Set root to function body so its statements are evaluated in global scope. - root = root.child[1].child[3] - root.anc = nil - } - if pkgName == "" { - return "", root, errors.New("no package name found") - } + interp.roots = append(interp.roots, root) return pkgName, root, err } diff --git a/interp/compile_test.go b/interp/compile_test.go new file mode 100644 index 00000000..3e2f54b9 --- /dev/null +++ b/interp/compile_test.go @@ -0,0 +1,84 @@ +package interp + +import ( + "go/ast" + "go/parser" + "go/token" + "testing" + + "github.com/traefik/yaegi/stdlib" +) + +func TestCompileAST(t *testing.T) { + file, err := parser.ParseFile(token.NewFileSet(), "_.go", ` + package main + + import "fmt" + + type Foo struct{} + + var foo Foo + const bar = "asdf" + + func main() { + fmt.Println(1) + } + `, 0) + if err != nil { + panic(err) + } + if len(file.Imports) != 1 || len(file.Decls) != 5 { + panic("wrong number of imports or decls") + } + + dType := file.Decls[1].(*ast.GenDecl) + dVar := file.Decls[2].(*ast.GenDecl) + dConst := file.Decls[3].(*ast.GenDecl) + dFunc := file.Decls[4].(*ast.FuncDecl) + + if dType.Tok != token.TYPE { + panic("decl[1] is not a type") + } + if dVar.Tok != token.VAR { + panic("decl[2] is not a var") + } + if dConst.Tok != token.CONST { + panic("decl[3] is not a const") + } + + cases := []struct { + desc string + node ast.Node + skip string + }{ + {desc: "file", node: file}, + {desc: "import", node: file.Imports[0]}, + {desc: "type", node: dType}, + {desc: "var", node: dVar, skip: "not supported"}, + {desc: "const", node: dConst}, + {desc: "func", node: dFunc}, + {desc: "block", node: dFunc.Body}, + {desc: "expr", node: dFunc.Body.List[0]}, + } + + i := New(Options{}) + _ = i.Use(stdlib.Symbols) + + for _, c := range cases { + t.Run(c.desc, func(t *testing.T) { + if c.skip != "" { + t.Skip(c.skip) + } + + i := i + if _, ok := c.node.(*ast.File); ok { + i = New(Options{}) + _ = i.Use(stdlib.Symbols) + } + _, err := i.CompileAST(c.node) + if err != nil { + t.Fatalf("Failed to compile %s: %v", c.desc, err) + } + }) + } +} diff --git a/interp/interp.go b/interp/interp.go index e82cc9dd..4bcfe3a5 100644 --- a/interp/interp.go +++ b/interp/interp.go @@ -573,7 +573,7 @@ func isFile(filesystem fs.FS, path string) bool { } func (interp *Interpreter) eval(src, name string, inc bool) (res reflect.Value, err error) { - prog, err := interp.compile(src, name, inc) + prog, err := interp.compileSrc(src, name, inc) if err != nil { return res, err } diff --git a/interp/program.go b/interp/program.go index b72debb0..51c5cd90 100644 --- a/interp/program.go +++ b/interp/program.go @@ -2,6 +2,7 @@ package interp import ( "context" + "go/ast" "io/ioutil" "reflect" "runtime" @@ -17,7 +18,7 @@ type Program struct { // Compile parses and compiles a Go code represented as a string. func (interp *Interpreter) Compile(src string) (*Program, error) { - return interp.compile(src, "", true) + return interp.compileSrc(src, "", true) } // CompilePath parses and compiles a Go code located at the given path. @@ -31,10 +32,10 @@ func (interp *Interpreter) CompilePath(path string) (*Program, error) { if err != nil { return nil, err } - return interp.compile(string(b), path, false) + return interp.compileSrc(string(b), path, false) } -func (interp *Interpreter) compile(src, name string, inc bool) (*Program, error) { +func (interp *Interpreter) compileSrc(src, name string, inc bool) (*Program, error) { if name != "" { interp.name = name } @@ -43,7 +44,20 @@ func (interp *Interpreter) compile(src, name string, inc bool) (*Program, error) } // Parse source to AST. - pkgName, root, err := interp.ast(src, interp.name, inc) + n, err := interp.parse(src, interp.name, inc) + if err != nil { + return nil, err + } + + return interp.CompileAST(n) +} + +// CompileAST builds a Program for the given Go code AST. Files and block +// statements can be compiled, as can most expressions. Var declaration nodes +// cannot be compiled. +func (interp *Interpreter) CompileAST(n ast.Node) (*Program, error) { + // Convert AST. + pkgName, root, err := interp.ast(n) if err != nil || root == nil { return nil, err } diff --git a/interp/src.go b/interp/src.go index bc00bc5a..a1f20c21 100644 --- a/interp/src.go +++ b/interp/src.go @@ -73,8 +73,16 @@ func (interp *Interpreter) importSrc(rPath, importPath string, skipTest bool) (s return "", err } + n, err := interp.parse(string(buf), name, false) + if err != nil { + return "", err + } + if n == nil { + continue + } + var pname string - if pname, root, err = interp.ast(string(buf), name, false); err != nil { + if pname, root, err = interp.ast(n); err != nil { return "", err } if root == nil {