diff --git a/_test/assert0.go b/_test/assert0.go index c0a7af14..7c403d33 100644 --- a/_test/assert0.go +++ b/_test/assert0.go @@ -48,6 +48,14 @@ func main() { bType := reflect.TypeOf(TestStruct{}) fmt.Println(bType.Implements(aType)) + // not redundant with the above, because it goes through a slightly different code path. + if _, ok := t.(MyWriter); !ok { + fmt.Println("TestStruct does not implement MyWriter") + return + } else { + fmt.Println("TestStruct implements MyWriter") + } + t = 42 foo, ok := t.(MyWriter) if !ok { @@ -57,6 +65,12 @@ func main() { } _ = foo + if _, ok := t.(MyWriter); !ok { + fmt.Println("42 does not implement MyWriter") + } else { + fmt.Println("42 implements MyWriter") + } + var tt interface{} tt = time.Nanosecond var myD MyStringer @@ -72,6 +86,12 @@ func main() { dType := reflect.TypeOf(time.Nanosecond) fmt.Println(dType.Implements(cType)) + if _, ok := tt.(MyStringer); !ok { + fmt.Println("time.Nanosecond does not implement MyStringer") + } else { + fmt.Println("time.Nanosecond implements MyStringer") + } + tt = 42 bar, ok := tt.(MyStringer) if !ok { @@ -81,6 +101,11 @@ func main() { } _ = bar + if _, ok := tt.(MyStringer); !ok { + fmt.Println("42 does not implement MyStringer") + } else { + fmt.Println("42 implements MyStringer") + } } // Output: @@ -88,9 +113,13 @@ func main() { // 11 // 11 // true +// TestStruct implements MyWriter +// 42 does not implement MyWriter // 42 does not implement MyWriter // time.Nanosecond implements MyStringer // 1ns // 1ns // true +// time.Nanosecond implements MyStringer +// 42 does not implement MyStringer // 42 does not implement MyStringer diff --git a/_test/assert1.go b/_test/assert1.go index a44e163d..2ee35c23 100644 --- a/_test/assert1.go +++ b/_test/assert1.go @@ -27,6 +27,13 @@ func main() { bType := reflect.TypeOf(time.Nanosecond) fmt.Println(bType.Implements(aType)) + // not redundant with the above, because it goes through a slightly different code path. + if _, ok := t.(fmt.Stringer); !ok { + fmt.Println("time.Nanosecond does not implement fmt.Stringer") + return + } else { + fmt.Println("time.Nanosecond implements fmt.Stringer") + } t = 42 foo, ok := t.(fmt.Stringer) @@ -34,9 +41,17 @@ func main() { fmt.Println("42 does not implement fmt.Stringer") } else { fmt.Println("42 implements fmt.Stringer") + return } _ = foo + if _, ok := t.(fmt.Stringer); !ok { + fmt.Println("42 does not implement fmt.Stringer") + } else { + fmt.Println("42 implements fmt.Stringer") + return + } + var tt interface{} tt = TestStruct{} ss, ok := tt.(fmt.Stringer) @@ -49,12 +64,22 @@ func main() { // TODO(mpl): uncomment when fixed // cType := reflect.TypeOf(TestStruct{}) // fmt.Println(cType.Implements(aType)) + + if _, ok := tt.(fmt.Stringer); !ok { + fmt.Println("TestStuct does not implement fmt.Stringer") + return + } else { + fmt.Println("TestStuct implements fmt.Stringer") + } } // Output: // 1ns // 1ns // true +// time.Nanosecond implements fmt.Stringer +// 42 does not implement fmt.Stringer // 42 does not implement fmt.Stringer // hello world // hello world +// TestStuct implements fmt.Stringer diff --git a/interp/cfg.go b/interp/cfg.go index 5190253d..57aa23a9 100644 --- a/interp/cfg.go +++ b/interp/cfg.go @@ -648,7 +648,7 @@ func (interp *Interpreter) cfg(root *node, importPath string) ([]*node, error) { if n.child[0].ident == "_" { lc.gen = typeAssertStatus } else { - lc.gen = typeAssert2 + lc.gen = typeAssertLong } n.gen = nop case unaryExpr: @@ -1903,7 +1903,7 @@ func compDefineX(sc *scope, n *node) error { if n.child[0].ident == "_" { n.child[l].gen = typeAssertStatus } else { - n.child[l].gen = typeAssert2 + n.child[l].gen = typeAssertLong } types = append(types, n.child[l].child[1].typ, sc.getType("bool")) n.gen = nop diff --git a/interp/run.go b/interp/run.go index b99504ae..010a6ef7 100644 --- a/interp/run.go +++ b/interp/run.go @@ -69,7 +69,7 @@ var builtin = [...]bltnGenerator{ aStar: deref, aSub: sub, aSubAssign: subAssign, - aTypeAssert: typeAssert1, + aTypeAssert: typeAssertShort, aXor: xor, aXorAssign: xorAssign, } @@ -191,45 +191,6 @@ func runCfg(n *node, f *frame) { } } -func typeAssertStatus(n *node) { - c0, c1 := n.child[0], n.child[1] // cO contains the input value, c1 the type to assert - value := genValue(c0) // input value - value1 := genValue(n.anc.child[1]) // returned status - rtype := c1.typ.rtype // type to assert - next := getExec(n.tnext) - - switch { - case isInterfaceSrc(c1.typ): - typ := c1.typ - n.exec = func(f *frame) bltn { - v, ok := value(f).Interface().(valueInterface) - value1(f).SetBool(ok && v.node.typ.implements(typ)) - return next - } - case isInterface(c1.typ): - n.exec = func(f *frame) bltn { - v := value(f) - ok := v.IsValid() && canAssertTypes(v.Elem().Type(), rtype) - value1(f).SetBool(ok) - return next - } - case c0.typ.cat == valueT || c0.typ.cat == errorT: - n.exec = func(f *frame) bltn { - v := value(f) - ok := v.IsValid() && canAssertTypes(v.Elem().Type(), rtype) - value1(f).SetBool(ok) - return next - } - default: - n.exec = func(f *frame) bltn { - v, ok := value(f).Interface().(valueInterface) - ok = ok && v.value.IsValid() && canAssertTypes(v.value.Type(), rtype) - value1(f).SetBool(ok) - return next - } - } -} - func stripReceiverFromArgs(signature string) (string, error) { fields := receiverStripperRxp.FindStringSubmatch(signature) if len(fields) < 5 { @@ -241,25 +202,33 @@ func stripReceiverFromArgs(signature string) (string, error) { return fmt.Sprintf("func(%s", fields[4]), nil } -func typeAssert1(n *node) { - typeAssert(n, false) +func typeAssertShort(n *node) { + typeAssert(n, true, false) } -func typeAssert2(n *node) { - typeAssert(n, true) +func typeAssertLong(n *node) { + typeAssert(n, true, true) } -func typeAssert(n *node, withOk bool) { +func typeAssertStatus(n *node) { + typeAssert(n, false, true) +} + +func typeAssert(n *node, withResult, withOk bool) { c0, c1 := n.child[0], n.child[1] value := genValue(c0) // input value var value0, value1 func(*frame) reflect.Value setStatus := false - if withOk { + switch { + case withResult && withOk: value0 = genValue(n.anc.child[0]) // returned result value1 = genValue(n.anc.child[1]) // returned status setStatus = n.anc.child[1].ident != "_" // do not assign status to "_" - } else { + case withResult && !withOk: value0 = genValue(n) // returned result + case !withResult && withOk: + value1 = genValue(n.anc.child[1]) // returned status + setStatus = n.anc.child[1].ident != "_" // do not assign status to "_" } typ := c1.typ // type to assert or convert to @@ -270,7 +239,8 @@ func typeAssert(n *node, withOk bool) { switch { case isInterfaceSrc(typ): n.exec = func(f *frame) bltn { - v, ok := value(f).Interface().(valueInterface) + valf := value(f) + v, ok := valf.Interface().(valueInterface) if setStatus { defer func() { value1(f).SetBool(ok) @@ -283,7 +253,9 @@ func typeAssert(n *node, withOk bool) { return next } if v.node.typ.id() == typID { - value0(f).Set(value(f)) + if withResult { + value0(f).Set(valf) + } return next } m0 := v.node.typ.methods() @@ -329,7 +301,9 @@ func typeAssert(n *node, withOk bool) { } } - value0(f).Set(value(f)) + if withResult { + value0(f).Set(valf) + } return next } case isInterface(typ): @@ -362,9 +336,9 @@ func typeAssert(n *node, withOk bool) { } } - // TODO(mpl): make this case compliant with reflect's Implements. - v = genInterfaceWrapper(val.node, rtype)(f) - value0(f).Set(v) + if withResult { + value0(f).Set(genInterfaceWrapper(val.node, rtype)(f)) + } ok = true return next } @@ -392,7 +366,9 @@ func typeAssert(n *node, withOk bool) { } return next } - value0(f).Set(v) + if withResult { + value0(f).Set(v) + } return next } case n.child[0].typ.cat == valueT || n.child[0].typ.cat == errorT: @@ -418,7 +394,9 @@ func typeAssert(n *node, withOk bool) { } return next } - value0(f).Set(v) + if withResult { + value0(f).Set(v) + } return next } default: @@ -443,7 +421,9 @@ func typeAssert(n *node, withOk bool) { } return next } - value0(f).Set(v.value) + if withResult { + value0(f).Set(v.value) + } return next } }