Fix: handle recursive type definition (#239)

This commit is contained in:
Marc Vertes
2019-07-09 00:08:12 +02:00
committed by Ludovic Fernandez
parent dea1f56f38
commit 122506cc79
7 changed files with 142 additions and 16 deletions

15
_test/struct22.go Normal file
View File

@@ -0,0 +1,15 @@
package main
type S struct {
Child *S
Name string
}
func main() {
s := &S{Name: "root"}
s.Child = &S{Name: "child"}
println(s.Child.Name)
}
// Output:
// child

15
_test/struct23.go Normal file
View File

@@ -0,0 +1,15 @@
package main
type S struct {
Child []*S
Name string
}
func main() {
s := &S{Name: "root"}
s.Child = append(s.Child, &S{Name: "child"})
println(s.Child[0].Name)
}
// Output:
// child

View File

@@ -0,0 +1,38 @@
package interp_test
import (
"fmt"
"log"
"github.com/containous/yaegi/interp"
)
// Generic example
func Example_eval() {
// Create a new interpreter context
i := interp.New(interp.Options{})
// Run some code: define a new function
_, err := i.Eval("func f(i int) int { return 2 * i }")
if err != nil {
log.Fatal(err)
}
// Access the interpreted f function with Eval
v, err := i.Eval("f")
if err != nil {
log.Fatal(err)
}
// Returned v is a reflect.Value, so we can use its interface
f, ok := v.Interface().(func(int) int)
if !ok {
log.Fatal("type assertion failed")
}
// Use interpreted f as it was pre-compiled
fmt.Println(f(2))
// Output:
// 4
}

View File

@@ -198,6 +198,16 @@ func convert(n *node) {
} }
} }
func isRecursiveStruct(t *itype) bool {
if t.cat == structT && t.rtype.Kind() == reflect.Interface {
return true
}
if t.cat == ptrT {
return isRecursiveStruct(t.val)
}
return false
}
func assign(n *node) { func assign(n *node) {
next := getExec(n.tnext) next := getExec(n.tnext)
dvalue := make([]func(*frame) reflect.Value, n.nleft) dvalue := make([]func(*frame) reflect.Value, n.nleft)
@@ -218,6 +228,8 @@ func assign(n *node) {
case src.kind == basicLit && src.val == nil: case src.kind == basicLit && src.val == nil:
t := dest.typ.TypeOf() t := dest.typ.TypeOf()
svalue[i] = func(*frame) reflect.Value { return reflect.New(t).Elem() } svalue[i] = func(*frame) reflect.Value { return reflect.New(t).Elem() }
case isRecursiveStruct(dest.typ):
svalue[i] = genValueInterfacePtr(src)
default: default:
svalue[i] = genValue(src) svalue[i] = genValue(src)
} }
@@ -1071,8 +1083,14 @@ func getIndexSeq(n *node) {
func getPtrIndexSeq(n *node) { func getPtrIndexSeq(n *node) {
index := n.val.([]int) index := n.val.([]int)
value := genValue(n.child[0])
tnext := getExec(n.tnext) tnext := getExec(n.tnext)
var value func(*frame) reflect.Value
if isRecursiveStruct(n.child[0].typ) {
v := genValue(n.child[0])
value = func(f *frame) reflect.Value { return v(f).Elem().Elem() }
} else {
value = genValue(n.child[0])
}
if n.fnext != nil { if n.fnext != nil {
fnext := getExec(n.fnext) fnext := getExec(n.fnext)
@@ -1341,6 +1359,7 @@ func mapLit(n *node) {
return next return next
} }
} }
func compositeBinMap(n *node) { func compositeBinMap(n *node) {
value := valueGenerator(n, n.findex) value := valueGenerator(n, n.findex)
next := getExec(n.tnext) next := getExec(n.tnext)
@@ -1697,9 +1716,12 @@ func _append(n *node) {
l := len(args) l := len(args)
values := make([]func(*frame) reflect.Value, l) values := make([]func(*frame) reflect.Value, l)
for i, arg := range args { for i, arg := range args {
if arg.typ.untyped { switch {
case isRecursiveStruct(n.typ.val):
values[i] = genValueInterfacePtr(arg)
case arg.typ.untyped:
values[i] = genValueAs(arg, n.child[1].typ.TypeOf().Elem()) values[i] = genValueAs(arg, n.child[1].typ.TypeOf().Elem())
} else { default:
values[i] = genValue(arg) values[i] = genValue(arg)
} }
} }
@@ -1713,9 +1735,14 @@ func _append(n *node) {
return next return next
} }
} else { } else {
value0 := genValue(n.child[2]) var value0 func(*frame) reflect.Value
if n.child[2].typ.untyped { switch {
case isRecursiveStruct(n.typ.val):
value0 = genValueInterfacePtr(n.child[2])
case n.child[2].typ.untyped:
value0 = genValueAs(n.child[2], n.child[1].typ.TypeOf().Elem()) value0 = genValueAs(n.child[2], n.child[1].typ.TypeOf().Elem())
default:
value0 = genValue(n.child[2])
} }
n.exec = func(f *frame) bltn { n.exec = func(f *frame) bltn {

View File

@@ -44,16 +44,17 @@ func (k sKind) String() string {
// A symbol represents an interpreter object such as type, constant, var, func, // A symbol represents an interpreter object such as type, constant, var, func,
// label, builtin or binary object. Symbols are defined within a scope. // label, builtin or binary object. Symbols are defined within a scope.
type symbol struct { type symbol struct {
kind sKind kind sKind
typ *itype // Type of value typ *itype // Type of value
node *node // Node value if index is negative node *node // Node value if index is negative
from []*node // list of nodes jumping to node if kind is label, or nil from []*node // list of nodes jumping to node if kind is label, or nil
recv *receiver // receiver node value, if sym refers to a method recv *receiver // receiver node value, if sym refers to a method
index int // index of value in frame or -1 index int // index of value in frame or -1
rval reflect.Value // default value (used for constants) rval reflect.Value // default value (used for constants)
path string // package path if typ.cat is SrcPkgT or BinPkgT path string // package path if typ.cat is SrcPkgT or BinPkgT
builtin bltnGenerator // Builtin function or nil builtin bltnGenerator // Builtin function or nil
global bool // true if symbol is defined in global space global bool // true if symbol is defined in global space
recursive bool // true if symbol is a recursive type definition
// TODO: implement constant checking // TODO: implement constant checking
//constant bool // true if symbol value is constant //constant bool // true if symbol value is constant
} }

View File

@@ -308,6 +308,11 @@ func nodeType(interp *Interpreter, sc *scope, n *node) (*itype, error) {
case identExpr: case identExpr:
if sym, _, found := sc.lookup(n.ident); found { if sym, _, found := sc.lookup(n.ident); found {
t = sym.typ t = sym.typ
if sym.recursive && t.incomplete {
t.incomplete = false
t.rtype = reflect.TypeOf((*interface{})(nil)).Elem()
sym.typ = t
}
if t.incomplete && t.node != n { if t.incomplete && t.node != n {
m := t.method m := t.method
if t, err = nodeType(interp, sc, t.node); err != nil { if t, err = nodeType(interp, sc, t.node); err != nil {
@@ -388,7 +393,13 @@ func nodeType(interp *Interpreter, sc *scope, n *node) (*itype, error) {
case structType: case structType:
t.cat = structT t.cat = structT
var incomplete bool var incomplete, found bool
var sym *symbol
if sname := structName(n); sname != "" {
if sym, _, found = sc.lookup(sname); found && sym.kind == typeSym {
sym.recursive = true
}
}
for _, c := range n.child[0].child { for _, c := range n.child[0].child {
switch { switch {
case len(c.child) == 1: case len(c.child) == 1:
@@ -432,6 +443,14 @@ func nodeType(interp *Interpreter, sc *scope, n *node) (*itype, error) {
return t, err return t, err
} }
// struct name returns the name of a struct type
func structName(n *node) string {
if n.anc.kind == typeSpec {
return n.anc.child[0].ident
}
return ""
}
// fieldName returns an implicit struct field name according to node kind // fieldName returns an implicit struct field name according to node kind
func fieldName(n *node) string { func fieldName(n *node) string {
switch n.kind { switch n.kind {

View File

@@ -107,6 +107,17 @@ func genValue(n *node) func(*frame) reflect.Value {
} }
} }
func genValueInterfacePtr(n *node) func(*frame) reflect.Value {
value := genValue(n)
it := reflect.TypeOf((*interface{})(nil)).Elem()
return func(f *frame) reflect.Value {
v := reflect.New(it).Elem()
v.Set(value(f))
return v.Addr()
}
}
func genValueInterface(n *node) func(*frame) reflect.Value { func genValueInterface(n *node) func(*frame) reflect.Value {
value := genValue(n) value := genValue(n)