diff --git a/interp/ast.go b/interp/ast.go index ec1b5082..bd135f01 100644 --- a/interp/ast.go +++ b/interp/ast.go @@ -1,7 +1,6 @@ package interp import ( - "errors" "fmt" "go/ast" "go/constant" @@ -362,21 +361,14 @@ func wrapInMain(src string) string { return fmt.Sprintf("package main; func main() {%s\n}", src) } -// Note: no type analysis is performed at this stage, it is done in pre-order -// processing of CFG, in order to accommodate forward type declarations. - -// ast parses src string containing Go code and generates the corresponding AST. -// The package name and the AST root node are returned. -// The given name is used to set the filename of the relevant source file in the -// interpreter's FileSet. -func (interp *Interpreter) ast(src, name string, inc bool) (string, *node, error) { - var inFunc bool +func (interp *Interpreter) parse(src, name string, inc bool) (node ast.Node, err error) { mode := parser.DeclarationErrors // Allow incremental parsing of declarations or statements, by inserting // them in a pseudo file package or function. Those statements or // declarations will be always evaluated in the global scope. var tok token.Token + var inFunc bool if inc { tok = interp.firstToken(src) switch tok { @@ -393,18 +385,18 @@ func (interp *Interpreter) ast(src, name string, inc bool) (string, *node, error } if ok, err := interp.buildOk(&interp.context, name, src); !ok || err != nil { - return "", nil, err // skip source not matching build constraints + return nil, err // skip source not matching build constraints } f, err := parser.ParseFile(interp.fset, name, src, mode) if err != nil { // only retry if we're on an expression/statement about a func if !inc || tok != token.FUNC { - return "", nil, err + return nil, err } // do not bother retrying if we know it's an error we're going to ignore later on. if ignoreError(err, src) { - return "", nil, err + return nil, err } // do not lose initial error, in case retrying fails. initialError := err @@ -412,16 +404,32 @@ func (interp *Interpreter) ast(src, name string, inc bool) (string, *node, error src := wrapInMain(strings.TrimPrefix(src, "package main;")) f, err = parser.ParseFile(interp.fset, name, src, mode) if err != nil { - return "", nil, initialError + return nil, initialError } } - setYaegiTags(&interp.context, f.Comments) + if inFunc { + // return the body of the wrapper main function + return f.Decls[0].(*ast.FuncDecl).Body, nil + } + setYaegiTags(&interp.context, f.Comments) + return f, nil +} + +// Note: no type analysis is performed at this stage, it is done in pre-order +// processing of CFG, in order to accommodate forward type declarations. + +// ast parses src string containing Go code and generates the corresponding AST. +// The package name and the AST root node are returned. +// The given name is used to set the filename of the relevant source file in the +// interpreter's FileSet. +func (interp *Interpreter) ast(f ast.Node) (string, *node, error) { + var err error var root *node var anc astNode var st nodestack - var pkgName string + pkgName := "main" addChild := func(root **node, anc astNode, pos token.Pos, kind nkind, act action) *node { var i interface{} @@ -898,15 +906,7 @@ func (interp *Interpreter) ast(src, name string, inc bool) (string, *node, error } return true }) - if inFunc { - // Incremental parsing: statements were inserted in a pseudo function. - // Set root to function body so its statements are evaluated in global scope. - root = root.child[1].child[3] - root.anc = nil - } - if pkgName == "" { - return "", root, errors.New("no package name found") - } + interp.roots = append(interp.roots, root) return pkgName, root, err } diff --git a/interp/compile_test.go b/interp/compile_test.go new file mode 100644 index 00000000..3e2f54b9 --- /dev/null +++ b/interp/compile_test.go @@ -0,0 +1,84 @@ +package interp + +import ( + "go/ast" + "go/parser" + "go/token" + "testing" + + "github.com/traefik/yaegi/stdlib" +) + +func TestCompileAST(t *testing.T) { + file, err := parser.ParseFile(token.NewFileSet(), "_.go", ` + package main + + import "fmt" + + type Foo struct{} + + var foo Foo + const bar = "asdf" + + func main() { + fmt.Println(1) + } + `, 0) + if err != nil { + panic(err) + } + if len(file.Imports) != 1 || len(file.Decls) != 5 { + panic("wrong number of imports or decls") + } + + dType := file.Decls[1].(*ast.GenDecl) + dVar := file.Decls[2].(*ast.GenDecl) + dConst := file.Decls[3].(*ast.GenDecl) + dFunc := file.Decls[4].(*ast.FuncDecl) + + if dType.Tok != token.TYPE { + panic("decl[1] is not a type") + } + if dVar.Tok != token.VAR { + panic("decl[2] is not a var") + } + if dConst.Tok != token.CONST { + panic("decl[3] is not a const") + } + + cases := []struct { + desc string + node ast.Node + skip string + }{ + {desc: "file", node: file}, + {desc: "import", node: file.Imports[0]}, + {desc: "type", node: dType}, + {desc: "var", node: dVar, skip: "not supported"}, + {desc: "const", node: dConst}, + {desc: "func", node: dFunc}, + {desc: "block", node: dFunc.Body}, + {desc: "expr", node: dFunc.Body.List[0]}, + } + + i := New(Options{}) + _ = i.Use(stdlib.Symbols) + + for _, c := range cases { + t.Run(c.desc, func(t *testing.T) { + if c.skip != "" { + t.Skip(c.skip) + } + + i := i + if _, ok := c.node.(*ast.File); ok { + i = New(Options{}) + _ = i.Use(stdlib.Symbols) + } + _, err := i.CompileAST(c.node) + if err != nil { + t.Fatalf("Failed to compile %s: %v", c.desc, err) + } + }) + } +} diff --git a/interp/interp.go b/interp/interp.go index e82cc9dd..4bcfe3a5 100644 --- a/interp/interp.go +++ b/interp/interp.go @@ -573,7 +573,7 @@ func isFile(filesystem fs.FS, path string) bool { } func (interp *Interpreter) eval(src, name string, inc bool) (res reflect.Value, err error) { - prog, err := interp.compile(src, name, inc) + prog, err := interp.compileSrc(src, name, inc) if err != nil { return res, err } diff --git a/interp/program.go b/interp/program.go index b72debb0..51c5cd90 100644 --- a/interp/program.go +++ b/interp/program.go @@ -2,6 +2,7 @@ package interp import ( "context" + "go/ast" "io/ioutil" "reflect" "runtime" @@ -17,7 +18,7 @@ type Program struct { // Compile parses and compiles a Go code represented as a string. func (interp *Interpreter) Compile(src string) (*Program, error) { - return interp.compile(src, "", true) + return interp.compileSrc(src, "", true) } // CompilePath parses and compiles a Go code located at the given path. @@ -31,10 +32,10 @@ func (interp *Interpreter) CompilePath(path string) (*Program, error) { if err != nil { return nil, err } - return interp.compile(string(b), path, false) + return interp.compileSrc(string(b), path, false) } -func (interp *Interpreter) compile(src, name string, inc bool) (*Program, error) { +func (interp *Interpreter) compileSrc(src, name string, inc bool) (*Program, error) { if name != "" { interp.name = name } @@ -43,7 +44,20 @@ func (interp *Interpreter) compile(src, name string, inc bool) (*Program, error) } // Parse source to AST. - pkgName, root, err := interp.ast(src, interp.name, inc) + n, err := interp.parse(src, interp.name, inc) + if err != nil { + return nil, err + } + + return interp.CompileAST(n) +} + +// CompileAST builds a Program for the given Go code AST. Files and block +// statements can be compiled, as can most expressions. Var declaration nodes +// cannot be compiled. +func (interp *Interpreter) CompileAST(n ast.Node) (*Program, error) { + // Convert AST. + pkgName, root, err := interp.ast(n) if err != nil || root == nil { return nil, err } diff --git a/interp/src.go b/interp/src.go index bc00bc5a..a1f20c21 100644 --- a/interp/src.go +++ b/interp/src.go @@ -73,8 +73,16 @@ func (interp *Interpreter) importSrc(rPath, importPath string, skipTest bool) (s return "", err } + n, err := interp.parse(string(buf), name, false) + if err != nil { + return "", err + } + if n == nil { + continue + } + var pname string - if pname, root, err = interp.ast(string(buf), name, false); err != nil { + if pname, root, err = interp.ast(n); err != nil { return "", err } if root == nil {