diff --git a/_test/assert1.go b/_test/assert1.go new file mode 100644 index 00000000..aeba87fb --- /dev/null +++ b/_test/assert1.go @@ -0,0 +1,36 @@ +package main + +import ( + "fmt" + "time" +) + +type TestStruct struct{} + +func (t TestStruct) String() string { + return "hello world" +} + +func main() { + var t interface{} + t = time.Nanosecond + s, ok := t.(fmt.Stringer) + if !ok { + fmt.Println("time.Nanosecond does not implement fmt.Stringer") + return + } + fmt.Println(s.String()) + + var tt interface{} + tt = TestStruct{} + ss, ok := tt.(fmt.Stringer) + if !ok { + fmt.Println("TestStuct does not implement fmt.Stringer") + return + } + fmt.Println(ss.String()) +} + +// Output: +// 1ns +// hello world diff --git a/interp/run.go b/interp/run.go index 4b013803..256f61c4 100644 --- a/interp/run.go +++ b/interp/run.go @@ -380,14 +380,54 @@ func typeAssert2(n *node) { } case isInterface(typ): n.exec = func(f *frame) bltn { - v := value(f).Elem() - ok := v.IsValid() && canAssertTypes(v.Type(), rtype) - if ok { + var leftType reflect.Type + v := value(f) + val, ok := v.Interface().(valueInterface) + defer func() { + assertOk := ok + if setStatus { + value1(f).SetBool(assertOk) + } + }() + if ok && val.node.typ.cat != valueT { + m0 := val.node.typ.methods() + m1 := typ.methods() + if len(m0) < len(m1) { + ok = false + return next + } + + for k, meth1 := range m1 { + var meth0 string + meth0, ok = m0[k] + if !ok { + return next + } + if meth0 != meth1 { + ok = false + return next + } + } + + v = genInterfaceWrapper(val.node, rtype)(f) value0(f).Set(v) + ok = true + return next } - if setStatus { - value1(f).SetBool(ok) + + if ok { + v = val.value + leftType = val.node.typ.rtype + } else { + v = v.Elem() + leftType = v.Type() + ok = true } + ok = v.IsValid() && canAssertTypes(leftType, rtype) + if !ok { + return next + } + value0(f).Set(v) return next } case n.child[0].typ.cat == valueT || n.child[0].typ.cat == errorT: @@ -844,6 +884,9 @@ func genFunctionWrapper(n *node) func(*frame) reflect.Value { if src.Type().Kind() != dest.Type().Kind() { dest.Set(src.Addr()) } else { + if wrappedSrc, ok := src.Interface().(valueInterface); ok { + src = wrappedSrc.value + } dest.Set(src) } d = d[numRet+1:]