From c1f5005b2a5ea11f56147528da4eeaf93241b83b Mon Sep 17 00:00:00 2001 From: Marc Vertes Date: Wed, 10 Jun 2020 11:21:16 +0200 Subject: [PATCH] fix: finish support of type assertions which was incomplete (#657) * fix: finish support of type assertions which was incomplete TypeAssert was optimistically returning ok without verifying that value could be converted to the required interface (in case of type assert of an interface type), or not checking the type in all conditions. There is now a working implements method for itype. Fixes #640. * style: appease lint * fix: remove useless code block * doc: improve comments * avoid test conflict --- _test/method33.go | 58 +++++++++++++++++++++++ _test/method34.go | 23 +++++++++ interp/gta.go | 1 + interp/run.go | 117 +++++++++++++++++++++++++++++++--------------- interp/type.go | 26 ++++++----- 5 files changed, 176 insertions(+), 49 deletions(-) create mode 100644 _test/method33.go create mode 100644 _test/method34.go diff --git a/_test/method33.go b/_test/method33.go new file mode 100644 index 00000000..da5e3d12 --- /dev/null +++ b/_test/method33.go @@ -0,0 +1,58 @@ +package main + +import ( + "fmt" +) + +type T1 struct{} + +func (t1 T1) f() { + fmt.Println("T1.f()") +} + +func (t1 T1) g() { + fmt.Println("T1.g()") +} + +type T2 struct { + T1 +} + +func (t2 T2) f() { + fmt.Println("T2.f()") +} + +type I interface { + f() +} + +func printType(i I) { + if t1, ok := i.(T1); ok { + println("T1 ok") + t1.f() + t1.g() + } + + if t2, ok := i.(T2); ok { + println("T2 ok") + t2.f() + t2.g() + } +} + +func main() { + println("T1") + printType(T1{}) + println("T2") + printType(T2{}) +} + +// Output: +// T1 +// T1 ok +// T1.f() +// T1.g() +// T2 +// T2 ok +// T2.f() +// T1.g() diff --git a/_test/method34.go b/_test/method34.go new file mode 100644 index 00000000..314bf22b --- /dev/null +++ b/_test/method34.go @@ -0,0 +1,23 @@ +package main + +type Root struct { + Name string +} + +type One struct { + Root +} + +type Hi interface { + Hello() string +} + +func (r *Root) Hello() string { return "Hello " + r.Name } + +func main() { + var one interface{} = &One{Root{Name: "test2"}} + println(one.(Hi).Hello()) +} + +// Output: +// Hello test2 diff --git a/interp/gta.go b/interp/gta.go index c83ec32e..24e187a0 100644 --- a/interp/gta.go +++ b/interp/gta.go @@ -137,6 +137,7 @@ func (interp *Interpreter) gta(root *node, rpath, pkgID string) ([]*node, error) } } rcvrtype.method = append(rcvrtype.method, n) + n.child[0].child[0].lastChild().typ = rcvrtype } else { // Add a function symbol in the package name space sc.sym[n.child[1].ident] = &symbol{kind: funcSym, typ: n.typ, node: n, index: -1} diff --git a/interp/run.go b/interp/run.go index 7da12961..de4b8203 100644 --- a/interp/run.go +++ b/interp/run.go @@ -122,34 +122,38 @@ func runCfg(n *node, f *frame) { } func typeAssertStatus(n *node) { - c0, c1 := n.child[0], n.child[1] + c0, c1 := n.child[0], n.child[1] // cO contains the input value, c1 the type to assert value := genValue(c0) // input value value1 := genValue(n.anc.child[1]) // returned status - typ := c1.typ.rtype // type to assert next := getExec(n.tnext) switch { - case c0.typ.cat == valueT: + case isInterfaceSrc(c1.typ): + typ := c1.typ n.exec = func(f *frame) bltn { - v := value(f) - if !v.IsValid() || v.IsNil() { - value1(f).SetBool(false) - } - value1(f).SetBool(v.Type().Implements(typ)) + v, ok := value(f).Interface().(valueInterface) + value1(f).SetBool(ok && v.node.typ.implements(typ)) return next } - case c1.typ.cat == interfaceT: + case isInterface(c1.typ): + rtype := c1.typ.rtype n.exec = func(f *frame) bltn { - _, ok := value(f).Interface().(valueInterface) - // TODO: verify that value(f) implements asserted type. - value1(f).SetBool(ok) + v := value(f) + value1(f).SetBool(v.IsValid() && v.Type().Implements(rtype)) + return next + } + case c0.typ.cat == valueT: + rtype := c1.typ.rtype + n.exec = func(f *frame) bltn { + v := value(f) + value1(f).SetBool(v.IsValid() && v.Type() == rtype) return next } default: + typID := c1.typ.id() n.exec = func(f *frame) bltn { - _, ok := value(f).Interface().(valueInterface) - // TODO: verify that value(f) implements asserted type. - value1(f).SetBool(ok) + v, ok := value(f).Interface().(valueInterface) + value1(f).SetBool(ok && v.node.typ.id() == typID) return next } } @@ -162,24 +166,35 @@ func typeAssert(n *node) { next := getExec(n.tnext) switch { - case c0.typ.cat == valueT: + case isInterfaceSrc(c1.typ): + typ := n.child[1].typ + typID := n.child[1].typ.id() n.exec = func(f *frame) bltn { v := value(f) - dest(f).Set(v.Elem()) + vi, ok := v.Interface().(valueInterface) + if !ok { + panic(n.cfgErrorf("interface conversion: nil is not %v", typID)) + } + if !vi.node.typ.implements(typ) { + panic(n.cfgErrorf("interface conversion: %v is not %v", vi.node.typ.id(), typID)) + } + dest(f).Set(v) return next } - case c1.typ.cat == interfaceT: + case isInterface(c1.typ): + rtype := n.child[1].typ.rtype n.exec = func(f *frame) bltn { - v := value(f).Interface().(valueInterface) - // TODO: verify that value(f) implements asserted type. - dest(f).Set(reflect.ValueOf(valueInterface{v.node, v.value})) + dest(f).Set(value(f).Convert(rtype)) + return next + } + case c0.typ.cat == valueT: + n.exec = func(f *frame) bltn { + dest(f).Set(value(f).Elem()) return next } default: n.exec = func(f *frame) bltn { - v := value(f).Interface().(valueInterface) - // TODO: verify that value(f) implements asserted type. - dest(f).Set(v.value) + dest(f).Set(value(f).Interface().(valueInterface).value) return next } } @@ -189,30 +204,58 @@ func typeAssert2(n *node) { value := genValue(n.child[0]) // input value value0 := genValue(n.anc.child[0]) // returned result value1 := genValue(n.anc.child[1]) // returned status + typ := n.child[1].typ // type to assert or convert to + typID := typ.id() next := getExec(n.tnext) switch { - case n.child[0].typ.cat == valueT: - n.exec = func(f *frame) bltn { - if value(f).IsValid() && !value(f).IsNil() { - value0(f).Set(value(f).Elem()) - } - value1(f).SetBool(true) - return next - } - case n.child[1].typ.cat == interfaceT: + case isInterfaceSrc(typ): n.exec = func(f *frame) bltn { v, ok := value(f).Interface().(valueInterface) - // TODO: verify that value(f) implements asserted type. - value0(f).Set(reflect.ValueOf(valueInterface{v.node, v.value})) + if ok && v.node.typ.id() == typID { + value0(f).Set(value(f)) + } else { + ok = false + } + value1(f).SetBool(ok) + return next + } + case isInterface(typ): + rtype := typ.rtype + n.exec = func(f *frame) bltn { + v := value(f) + ok := v.IsValid() && v.Type().Implements(rtype) + if ok { + value0(f).Set(v.Convert(rtype)) + } + value1(f).SetBool(ok) + return next + } + case n.child[0].typ.cat == valueT: + rtype := n.child[1].typ.rtype + n.exec = func(f *frame) bltn { + v := value(f) + ok := v.IsValid() && !value(f).IsNil() + if ok { + if e := v.Elem(); e.Type() == rtype { + value0(f).Set(e) + } else { + ok = false + } + } value1(f).SetBool(ok) return next } default: n.exec = func(f *frame) bltn { v, ok := value(f).Interface().(valueInterface) - // TODO: verify that value(f) implements asserted type. - value0(f).Set(v.value) + if ok { + if v.node.typ.id() == typID { + value0(f).Set(v.value) + } else { + ok = false + } + } value1(f).SetBool(ok) return next } diff --git a/interp/type.go b/interp/type.go index 73c5bdbd..a69a714e 100644 --- a/interp/type.go +++ b/interp/type.go @@ -818,7 +818,7 @@ func (t *itype) methods() methodSet { res := make(methodSet) switch t.cat { case interfaceT: - // Get methods from recursive analysis of interface fields + // Get methods from recursive analysis of interface fields. for _, f := range t.field { if f.typ.cat == funcT { res[f.name] = f.typ.TypeOf().String() @@ -829,22 +829,25 @@ func (t *itype) methods() methodSet { } } case valueT, errorT: - // Get method from corresponding reflect.Type + // Get method from corresponding reflect.Type. for i := t.rtype.NumMethod() - 1; i >= 0; i-- { m := t.rtype.Method(i) res[m.Name] = m.Type.String() } case ptrT: - // Consider only methods where receiver is a pointer to type t - for _, m := range t.val.method { - if m.child[0].child[0].lastChild().typ.cat == ptrT { - res[m.ident] = m.typ.TypeOf().String() + for k, v := range t.val.methods() { + res[k] = v + } + case structT: + for _, f := range t.field { + for k, v := range f.typ.methods() { + res[k] = v } } - default: - for _, m := range t.method { - res[m.ident] = m.typ.TypeOf().String() - } + } + // Get all methods defined on this type. + for _, m := range t.method { + res[m.ident] = m.typ.TypeOf().String() } return res } @@ -1192,8 +1195,7 @@ func (t *itype) implements(it *itype) bool { if t.cat == valueT { return t.TypeOf().Implements(it.TypeOf()) } - // TODO: implement method check for interpreted types - return true + return t.methods().contains(it.methods()) } func defRecvType(n *node) *itype {