From 4af992bccbfa1f1a6fb6123bc1986866e0df561b Mon Sep 17 00:00:00 2001 From: Nicholas Wiersma Date: Mon, 30 Aug 2021 18:38:12 +0200 Subject: [PATCH] interp: create real recursive types with unsafe type swapping As the unsafe and pointer methods in `reflect` are to be depreciated, and seeing no replacement functions, it is now forced that some unsafe is needed to replace this as when and interface is dereferenced it is made unsettable by reflect. With this in mind, this adds real recursive types by hot swapping the struct field type on the fly. This removes a lot of compensation code, simplifying all previous cases. **Note:** While the struct field type is swapped for the real type, the type string is not changed. Due to this, unsafe will recreate the same type. --- internal/unsafe2/unsafe.go | 52 ++++++++++ internal/unsafe2/unsafe_test.go | 33 +++++++ interp/run.go | 50 +--------- interp/type.go | 165 ++++++++++++++++---------------- interp/typecheck.go | 4 - interp/value.go | 80 ---------------- 6 files changed, 167 insertions(+), 217 deletions(-) create mode 100644 internal/unsafe2/unsafe.go create mode 100644 internal/unsafe2/unsafe_test.go diff --git a/internal/unsafe2/unsafe.go b/internal/unsafe2/unsafe.go new file mode 100644 index 00000000..5f885d0e --- /dev/null +++ b/internal/unsafe2/unsafe.go @@ -0,0 +1,52 @@ +package unsafe2 + +import ( + "reflect" + "unsafe" +) + +type dummy struct{} + +// DummyType represents a stand-in for a recursive type. +var DummyType = reflect.TypeOf(dummy{}) + +type rtype struct { + _ [48]byte +} + +type emptyInterface struct { + typ *rtype + _ unsafe.Pointer +} + +type structField struct { + _ int64 + typ *rtype + _ uintptr +} + +type structType struct { + rtype + _ int64 + fields []structField +} + +// SwapFieldType swaps the type of the struct field with the given type. +// +// The struct type must have been created at runtime. This is very unsafe. +func SwapFieldType(s reflect.Type, idx int, t reflect.Type) { + if s.Kind() != reflect.Struct || idx >= s.NumField() { + return + } + + rtyp := unpackType(s) + styp := (*structType)(unsafe.Pointer(rtyp)) + f := styp.fields[idx] + f.typ = unpackType(t) + styp.fields[idx] = f +} + +func unpackType(t reflect.Type) *rtype { + v := reflect.New(t).Elem().Interface() + return (*emptyInterface)(unsafe.Pointer(&v)).typ +} diff --git a/internal/unsafe2/unsafe_test.go b/internal/unsafe2/unsafe_test.go new file mode 100644 index 00000000..64e303b3 --- /dev/null +++ b/internal/unsafe2/unsafe_test.go @@ -0,0 +1,33 @@ +package unsafe2_test + +import ( + "reflect" + "testing" + + "github.com/traefik/yaegi/internal/unsafe2" +) + +func TestSwapFieldType(t *testing.T) { + f := []reflect.StructField{ + { + Name: "A", + Type: reflect.TypeOf(int(0)), + }, + { + Name: "B", + Type: reflect.PtrTo(unsafe2.DummyType), + }, + { + Name: "C", + Type: reflect.TypeOf(int64(0)), + }, + } + typ := reflect.StructOf(f) + ntyp := reflect.PtrTo(typ) + + unsafe2.SwapFieldType(typ, 1, ntyp) + + if typ.Field(1).Type != ntyp { + t.Fatalf("unexpected field type: want %s; got %s", ntyp, typ.Field(1).Type) + } +} diff --git a/interp/run.go b/interp/run.go index ea575271..80adbe4f 100644 --- a/interp/run.go +++ b/interp/run.go @@ -10,7 +10,6 @@ import ( "regexp" "strings" "sync" - "unsafe" ) // bltn type defines functions which run at CFG execution. @@ -568,18 +567,6 @@ func convert(n *node) { } } -func isRecursiveType(t *itype, rtype reflect.Type) bool { - if t.cat == structT && rtype.Kind() == reflect.Interface { - return true - } - switch t.cat { - case aliasT, arrayT, mapT, ptrT, sliceT: - return isRecursiveType(t.val, t.val.rtype) - default: - return false - } -} - func assign(n *node) { next := getExec(n.tnext) dvalue := make([]func(*frame) reflect.Value, n.nleft) @@ -1038,11 +1025,7 @@ func call(n *node) { switch { case n.child[0].recv != nil: // Compute method receiver value. - if isRecursiveType(n.child[0].recv.node.typ, n.child[0].recv.node.typ.rtype) { - values = append(values, genValueRecvInterfacePtr(n.child[0])) - } else { - values = append(values, genValueRecv(n.child[0])) - } + values = append(values, genValueRecv(n.child[0])) method = true case len(n.child[0].child) > 0 && n.child[0].child[0].typ != nil && isInterfaceSrc(n.child[0].child[0].typ): recvIndexLater = true @@ -1096,8 +1079,6 @@ func call(n *node) { values = append(values, genValueInterface(c)) case isInterfaceBin(arg): values = append(values, genInterfaceWrapper(c, arg.rtype)) - case isRecursiveType(c.typ, c.typ.rtype): - values = append(values, genValueRecursiveInterfacePtrValue(c)) default: values = append(values, genValue(c)) } @@ -1852,9 +1833,6 @@ func getIndexSeq(n *node) { fnext := getExec(n.fnext) n.exec = func(f *frame) bltn { v := value(f) - if v.Type().Kind() == reflect.Interface && n.child[0].typ.recursive { - v = writableDeref(v) - } r := v.FieldByIndex(index) getFrame(f, l).data[i] = r if r.Bool() { @@ -1865,34 +1843,16 @@ func getIndexSeq(n *node) { } else { n.exec = func(f *frame) bltn { v := value(f) - if v.Type().Kind() == reflect.Interface && n.child[0].typ.recursive { - v = writableDeref(v) - } getFrame(f, l).data[i] = v.FieldByIndex(index) return tnext } } } -//go:nocheckptr -func writableDeref(v reflect.Value) reflect.Value { - // Here we have an interface to a struct. Any attempt to dereference it will - // make a copy of the struct. We need to get a Value to the actual struct. - // TODO: using unsafe is a temporary measure. Rethink this. - // TODO: InterfaceData has been depreciated, this is even less of a good idea now. - return reflect.NewAt(v.Elem().Type(), unsafe.Pointer(v.InterfaceData()[1])).Elem() //nolint:govet,staticcheck -} - func getPtrIndexSeq(n *node) { index := n.val.([]int) tnext := getExec(n.tnext) - var value func(*frame) reflect.Value - if isRecursiveType(n.child[0].typ, n.child[0].typ.rtype) { - v := genValue(n.child[0]) - value = func(f *frame) reflect.Value { return v(f).Elem().Elem() } - } else { - value = genValue(n.child[0]) - } + value := genValue(n.child[0]) i := n.findex l := n.level @@ -2546,8 +2506,6 @@ func doComposite(n *node, hasType bool, keyed bool) { values[fieldIndex] = genValueAsFunctionWrapper(val) case isArray(val.typ) && val.typ.val != nil && isInterfaceSrc(val.typ.val): values[fieldIndex] = genValueInterfaceArray(val) - case isRecursiveType(ft, rft): - values[fieldIndex] = genValueRecursiveInterface(val, rft) case isInterfaceSrc(ft) && !isEmptyInterface(ft): values[fieldIndex] = genValueInterface(val) case isInterface(ft): @@ -2946,8 +2904,6 @@ func _append(n *node) { values[i] = genValueInterface(arg) case isInterfaceBin(n.typ.val): values[i] = genInterfaceWrapper(arg, n.typ.val.rtype) - case isRecursiveType(n.typ.val, n.typ.val.rtype): - values[i] = genValueRecursiveInterface(arg, n.typ.val.rtype) case arg.typ.untyped: values[i] = genValueAs(arg, n.child[1].typ.TypeOf().Elem()) default: @@ -2972,8 +2928,6 @@ func _append(n *node) { value0 = genValueInterface(n.child[2]) case isInterfaceBin(elem): value0 = genInterfaceWrapper(n.child[2], elem.rtype) - case isRecursiveType(elem, elem.rtype): - value0 = genValueRecursiveInterface(n.child[2], elem.rtype) case n.child[2].typ.untyped: value0 = genValueAs(n.child[2], n.child[1].typ.TypeOf().Elem()) default: diff --git a/interp/type.go b/interp/type.go index 8f02f757..a748d6a1 100644 --- a/interp/type.go +++ b/interp/type.go @@ -8,6 +8,8 @@ import ( "strconv" "strings" "sync" + + "github.com/traefik/yaegi/internal/unsafe2" ) // tcat defines interpreter type categories. @@ -274,6 +276,11 @@ func nodeType(interp *Interpreter, sc *scope, n *node) (*itype, error) { if n.typ != nil && !n.typ.incomplete { return n.typ, nil } + if sname := typeName(n); sname != "" { + if sym, _, found := sc.lookup(sname); found && sym.kind == typeSym && sym.typ != nil && sym.typ.isComplete() { + return sym.typ, nil + } + } repr := strings.Builder{} t := &itype{node: n, scope: sc} @@ -610,6 +617,8 @@ func nodeType(interp *Interpreter, sc *scope, n *node) (*itype, error) { ident := filepath.Join(n.ident, baseName) sym, _, found = sc.lookup(ident) if !found { + t.name = n.ident + t.path = sc.pkgID t.incomplete = true sc.sym[n.ident] = &symbol{kind: typeSym, typ: t} break @@ -853,12 +862,12 @@ func nodeType(interp *Interpreter, sc *scope, n *node) (*itype, error) { switch { case t == nil: - case t.cat == nilT: - t.str = "nil" case t.name != "" && t.path != "": t.str = t.path + "." + t.name case repr.Len() > 0: t.str = repr.String() + case t.cat == nilT: + t.str = "nil" } return t, err @@ -1582,13 +1591,37 @@ var ( constVal = reflect.TypeOf((*constant.Value)(nil)).Elem() ) +type fieldRebuild struct { + typ *itype + idx int +} + +type refTypeContext struct { + defined map[string]*itype + refs map[string][]fieldRebuild + rebuilding bool +} + +// Clone creates a copy if the ref type context without the `needsRebuild` set. +func (c *refTypeContext) Clone() *refTypeContext { + return &refTypeContext{defined: c.defined, refs: c.refs, rebuilding: c.rebuilding} +} + // RefType returns a reflect.Type representation from an interpreter type. // In simple cases, reflect types are directly mapped from the interpreter // counterpart. // For recursive named struct or interfaces, as reflect does not permit to -// create a recursive named struct, an interface{} is returned in place to -// avoid infinitely nested structs. -func (t *itype) refType(defined map[string]*itype, wrapRecursive bool) reflect.Type { +// create a recursive named struct, a nil type is set temporarily for each recursive +// field. When done, the nil type fields are updated with the original reflect type +// pointer using unsafe. We thus obtain a usable recursive type definition, except +// for string representation, as created reflect types are still unnamed. +func (t *itype) refType(ctx *refTypeContext) reflect.Type { + if ctx == nil { + ctx = &refTypeContext{ + defined: map[string]*itype{}, + refs: map[string][]fieldRebuild{}, + } + } if t.incomplete || t.cat == nilT { var err error if t, err = t.finalize(); err != nil { @@ -1612,82 +1645,82 @@ func (t *itype) refType(defined map[string]*itype, wrapRecursive bool) reflect.T t.recursive = recursive } } - if wrapRecursive && t.recursive { - return interf - } - if t.rtype != nil { + if t.rtype != nil && !ctx.rebuilding { return t.rtype } - if defined[name] != nil && defined[name].rtype != nil { - return defined[name].rtype - } - if t.val != nil && t.val.cat == structT && t.val.rtype == nil && hasRecursiveStruct(t.val, copyDefined(defined)) { - // Replace reference to self (direct or indirect) by an interface{} to handle - // recursive types with reflect. - typ := *t.val - t.val = &typ - t.val.rtype = interf - recursive = true + if dt := ctx.defined[name]; dt != nil { + if dt.rtype != nil { + t.rtype = dt.rtype + return dt.rtype + } + + // To indicate that a rebuild is needed on the nearest struct + // field, create an entry with a nil type. + flds := ctx.refs[name] + ctx.refs[name] = append(flds, fieldRebuild{}) + return unsafe2.DummyType } switch t.cat { case aliasT: - t.rtype = t.val.refType(defined, wrapRecursive) + t.rtype = t.val.refType(ctx) case arrayT: - t.rtype = reflect.ArrayOf(t.length, t.val.refType(defined, wrapRecursive)) + t.rtype = reflect.ArrayOf(t.length, t.val.refType(ctx)) case sliceT, variadicT: - t.rtype = reflect.SliceOf(t.val.refType(defined, wrapRecursive)) + t.rtype = reflect.SliceOf(t.val.refType(ctx)) case chanT: - t.rtype = reflect.ChanOf(reflect.BothDir, t.val.refType(defined, wrapRecursive)) + t.rtype = reflect.ChanOf(reflect.BothDir, t.val.refType(ctx)) case chanRecvT: - t.rtype = reflect.ChanOf(reflect.RecvDir, t.val.refType(defined, wrapRecursive)) + t.rtype = reflect.ChanOf(reflect.RecvDir, t.val.refType(ctx)) case chanSendT: - t.rtype = reflect.ChanOf(reflect.SendDir, t.val.refType(defined, wrapRecursive)) + t.rtype = reflect.ChanOf(reflect.SendDir, t.val.refType(ctx)) case errorT: t.rtype = reflect.TypeOf(new(error)).Elem() case funcT: - if t.name != "" { - defined[name] = t // TODO(marc): make sure that key is name and not t.name. - } variadic := false in := make([]reflect.Type, len(t.arg)) out := make([]reflect.Type, len(t.ret)) for i, v := range t.arg { - in[i] = v.refType(defined, true) + in[i] = v.refType(ctx) variadic = v.cat == variadicT } for i, v := range t.ret { - out[i] = v.refType(defined, true) + out[i] = v.refType(ctx) } t.rtype = reflect.FuncOf(in, out, variadic) case interfaceT: t.rtype = interf case mapT: - t.rtype = reflect.MapOf(t.key.refType(defined, wrapRecursive), t.val.refType(defined, wrapRecursive)) + t.rtype = reflect.MapOf(t.key.refType(ctx), t.val.refType(ctx)) case ptrT: - t.rtype = reflect.PtrTo(t.val.refType(defined, wrapRecursive)) + t.rtype = reflect.PtrTo(t.val.refType(ctx)) case structT: if t.name != "" { - // Check against local t.name and not name to catch recursive type definitions. - if defined[t.name] != nil { - recursive = true - } - defined[t.name] = t + ctx.defined[name] = t } var fields []reflect.StructField - // TODO(mpl): make Anonymous work for recursive types too. Maybe not worth the - // effort, and we're better off just waiting for - // https://github.com/golang/go/issues/39717 to land. - for _, f := range t.field { + for i, f := range t.field { + fctx := ctx.Clone() field := reflect.StructField{ - Name: exportName(f.name), Type: f.typ.refType(defined, wrapRecursive), + Name: exportName(f.name), Type: f.typ.refType(fctx), Tag: reflect.StructTag(f.tag), Anonymous: (f.embed && !recursive), } fields = append(fields, field) + // Find any nil type refs that indicates a rebuild is needed on this field. + for _, flds := range ctx.refs { + for j, fld := range flds { + if fld.typ == nil { + flds[j] = fieldRebuild{typ: t, idx: i} + } + } + } } - if recursive && wrapRecursive { - t.rtype = interf - } else { - t.rtype = reflect.StructOf(fields) + t.rtype = reflect.StructOf(fields) + + // The rtype has now been built, we can go back and rebuild + // all the recursive types that relied on this type. + for _, f := range ctx.refs[name] { + ftyp := f.typ.field[f.idx].typ.refType(&refTypeContext{defined: ctx.defined, rebuilding: true}) + unsafe2.SwapFieldType(f.typ.rtype, f.idx, ftyp) } default: if z, _ := t.zero(); z.IsValid() { @@ -1699,7 +1732,7 @@ func (t *itype) refType(defined map[string]*itype, wrapRecursive bool) reflect.T // TypeOf returns the reflection type of dynamic interpreter type t. func (t *itype) TypeOf() reflect.Type { - return t.refType(map[string]*itype{}, false) + return t.refType(nil) } func (t *itype) frameType() (r reflect.Type) { @@ -1802,44 +1835,6 @@ func (t *itype) elem() *itype { return t.val } -func copyDefined(m map[string]*itype) map[string]*itype { - n := make(map[string]*itype, len(m)) - for k, v := range m { - n[k] = v - } - return n -} - -// hasRecursiveStruct determines if a struct is a recursion or a recursion -// intermediate. A recursion intermediate is a struct that contains a recursive -// struct. -func hasRecursiveStruct(t *itype, defined map[string]*itype) bool { - if len(defined) == 0 { - return false - } - - typ := t - for typ != nil { - if typ.cat != structT { - typ = typ.val - continue - } - - if defined[typ.path+"/"+typ.name] != nil { - return true - } - defined[typ.path+"/"+typ.name] = typ - - for _, f := range typ.field { - if hasRecursiveStruct(f.typ, copyDefined(defined)) { - return true - } - } - return false - } - return false -} - func constToInt(c constant.Value) int { if constant.BitLen(c) > 64 { panic(fmt.Sprintf("constant %s overflows int64", c.ExactString())) diff --git a/interp/typecheck.go b/interp/typecheck.go index b683ce1b..38b629ac 100644 --- a/interp/typecheck.go +++ b/interp/typecheck.go @@ -752,10 +752,6 @@ func (check typecheck) builtin(name string, n *node, child []*node, ellipsis boo } return nil } - // We cannot check a recursive type. - if isRecursiveType(typ, typ.TypeOf()) { - return nil - } fun := &node{ typ: &itype{ diff --git a/interp/value.go b/interp/value.go index 8c33ffb6..8f9ea0e2 100644 --- a/interp/value.go +++ b/interp/value.go @@ -129,25 +129,6 @@ func genValueBinRecv(n *node, recv *receiver) func(*frame) reflect.Value { } } -func genValueRecvInterfacePtr(n *node) func(*frame) reflect.Value { - v := genValue(n.recv.node) - fi := n.recv.index - - return func(f *frame) reflect.Value { - r := v(f) - r = r.Elem().Elem() - - if len(fi) == 0 { - return r - } - - if r.Kind() == reflect.Ptr { - r = r.Elem() - } - return r.FieldByIndex(fi) - } -} - func genValueAsFunctionWrapper(n *node) func(*frame) reflect.Value { value := genValue(n) typ := n.typ.TypeOf() @@ -240,10 +221,6 @@ func genDestValue(typ *itype, n *node) func(*frame) reflect.Value { return genInterfaceWrapper(n, typ.rtype) case n.kind == basicLit && n.val == nil: return func(*frame) reflect.Value { return reflect.New(typ.rtype).Elem() } - case isRecursiveType(typ, typ.rtype): - return genValueRecursiveInterface(n, typ.rtype) - case isRecursiveType(n.typ, n.typ.rtype): - return genValueRecursiveInterfacePtrValue(n) case n.typ.untyped && isComplex(typ.TypeOf()): return genValueComplex(n) case n.typ.untyped && !typ.untyped: @@ -440,63 +417,6 @@ func genValueNode(n *node) func(*frame) reflect.Value { } } -func genValueRecursiveInterface(n *node, t reflect.Type) func(*frame) reflect.Value { - value := genValue(n) - - return func(f *frame) reflect.Value { - vv := value(f) - v := reflect.New(t).Elem() - toRecursive(v, vv) - return v - } -} - -func toRecursive(dest, src reflect.Value) { - if !src.IsValid() { - return - } - - switch dest.Kind() { - case reflect.Map: - v := reflect.MakeMapWithSize(dest.Type(), src.Len()) - for _, kv := range src.MapKeys() { - vv := reflect.New(dest.Type().Elem()).Elem() - toRecursive(vv, src.MapIndex(kv)) - vv.SetMapIndex(kv, vv) - } - dest.Set(v) - case reflect.Slice: - l := src.Len() - v := reflect.MakeSlice(dest.Type(), l, l) - for i := 0; i < l; i++ { - toRecursive(v.Index(i), src.Index(i)) - } - dest.Set(v) - case reflect.Ptr: - v := reflect.New(dest.Type().Elem()).Elem() - s := src - if s.Elem().Kind() != reflect.Struct { // In the case of *interface{}, we want *struct{} - s = s.Elem() - } - toRecursive(v, s) - dest.Set(v.Addr()) - default: - dest.Set(src) - } -} - -func genValueRecursiveInterfacePtrValue(n *node) func(*frame) reflect.Value { - value := genValue(n) - - return func(f *frame) reflect.Value { - v := value(f) - if v.IsZero() { - return v - } - return v.Elem().Elem() - } -} - func vInt(v reflect.Value) (i int64) { if c := vConstantValue(v); c != nil { i, _ = constant.Int64Val(constant.ToInt(c))