diff --git a/_test/assert0.go b/_test/assert0.go new file mode 100644 index 00000000..a75934dc --- /dev/null +++ b/_test/assert0.go @@ -0,0 +1,58 @@ +package main + +import ( + "fmt" + "time" +) + +type MyWriter interface { + Write(p []byte) (i int, err error) +} + +type TestStruct struct{} + +func (t TestStruct) Write(p []byte) (n int, err error) { + return len(p), nil +} + +func usesWriter(w MyWriter) { + w.Write(nil) +} + +type MyStringer interface { + String() string +} + +func usesStringer(s MyStringer) { + fmt.Println(s.String()) +} + +func main() { + var t interface{} + t = TestStruct{} + var tw MyWriter + var ok bool + tw, ok = t.(MyWriter) + if !ok { + fmt.Println("TestStruct does not implement MyWriter") + } else { + fmt.Println("TestStruct implements MyWriter") + usesWriter(tw) + } + + var tt interface{} + tt = time.Nanosecond + var myD MyStringer + myD, ok = tt.(MyStringer) + if !ok { + fmt.Println("time.Nanosecond does not implement MyStringer") + } else { + fmt.Println("time.Nanosecond implements MyStringer") + usesStringer(myD) + } +} + +// Output: +// TestStruct implements MyWriter +// time.Nanosecond implements MyStringer +// 1ns diff --git a/interp/run.go b/interp/run.go index d61c15c0..4b013803 100644 --- a/interp/run.go +++ b/interp/run.go @@ -3,10 +3,12 @@ package interp //go:generate go run ../internal/cmd/genop/genop.go import ( + "errors" "fmt" "go/constant" "log" "reflect" + "regexp" "sync" "unsafe" ) @@ -72,6 +74,17 @@ var builtin = [...]bltnGenerator{ aXorAssign: xorAssign, } +var receiverStripperRxp *regexp.Regexp + +func init() { + re := `func\(((.*?(, |\)))(.*))` + var err error + receiverStripperRxp, err = regexp.Compile(re) + if err != nil { + panic(err) + } +} + type valueInterface struct { node *node value reflect.Value @@ -283,6 +296,17 @@ func typeAssert(n *node) { } } +func stripReceiverFromArgs(signature string) (string, error) { + fields := receiverStripperRxp.FindStringSubmatch(signature) + if len(fields) < 5 { + return "", errors.New("error while matching method signature") + } + if fields[3] == ")" { + return fmt.Sprintf("func()%s", fields[4]), nil + } + return fmt.Sprintf("func(%s", fields[4]), nil +} + func typeAssert2(n *node) { c0, c1 := n.child[0], n.child[1] value := genValue(c0) // input value @@ -298,14 +322,60 @@ func typeAssert2(n *node) { case isInterfaceSrc(typ): n.exec = func(f *frame) bltn { v, ok := value(f).Interface().(valueInterface) - if ok && v.node.typ.id() == typID { + defer func() { + assertOk := ok + if setStatus { + value1(f).SetBool(assertOk) + } + }() + if !ok { + return next + } + if v.node.typ.id() == typID { value0(f).Set(value(f)) - } else { + return next + } + m0 := v.node.typ.methods() + m1 := typ.methods() + if len(m0) < len(m1) { ok = false + return next } - if setStatus { - value1(f).SetBool(ok) + + for k, meth1 := range m1 { + var meth0 string + meth0, ok = m0[k] + if !ok { + return next + } + // As far as we know this equality check can fail because they are two ways to + // represent the signature of a method: one where the receiver appears before the + // func keyword, and one where it is just a func signature, and the receiver is + // seen as the first argument. That's why if that equality fails, we try harder to + // compare them afterwards. Hopefully that is the only reason this equality can fail. + if meth0 == meth1 { + continue + } + tm := lookupFieldOrMethod(v.node.typ, k) + if tm == nil { + ok = false + return next + } + + var err error + meth0, err = stripReceiverFromArgs(meth0) + if err != nil { + ok = false + return next + } + + if meth0 != meth1 { + ok = false + return next + } } + + value0(f).Set(value(f)) return next } case isInterface(typ): @@ -879,7 +949,10 @@ func call(n *node) { var method bool value := genValue(n.child[0]) var values []func(*frame) reflect.Value - if n.child[0].recv != nil { + + recvIndexLater := false + switch { + case n.child[0].recv != nil: // Compute method receiver value. if isRecursiveType(n.child[0].recv.node.typ, n.child[0].recv.node.typ.rtype) { values = append(values, genValueRecvInterfacePtr(n.child[0])) @@ -887,11 +960,17 @@ func call(n *node) { values = append(values, genValueRecv(n.child[0])) } method = true - } else if n.child[0].action == aMethod { + case len(n.child[0].child) > 0 && n.child[0].child[0].typ != nil && n.child[0].child[0].typ.cat == interfaceT: + recvIndexLater = true + values = append(values, genValueBinRecv(n.child[0], &receiver{node: n.child[0].child[0]})) + value = genValueBinMethodOnInterface(n, value) + method = true + case n.child[0].action == aMethod: // Add a place holder for interface method receiver. values = append(values, nil) method = true } + numRet := len(n.child[0].typ.ret) variadic := variadicPos(n) child := n.child[1:] @@ -1001,6 +1080,7 @@ func call(n *node) { n.exec = func(f *frame) bltn { var def *node var ok bool + bf := value(f) if def, ok = bf.Interface().(*node); ok { bf = def.rval @@ -1070,16 +1150,16 @@ func call(n *node) { var src reflect.Value if v == nil { src = def.recv.val - if len(def.recv.index) > 0 { - if src.Kind() == reflect.Ptr { - src = src.Elem().FieldByIndex(def.recv.index) - } else { - src = src.FieldByIndex(def.recv.index) - } - } } else { src = v(f) } + if recvIndexLater && def.recv != nil && len(def.recv.index) > 0 { + if src.Kind() == reflect.Ptr { + src = src.Elem().FieldByIndex(def.recv.index) + } else { + src = src.FieldByIndex(def.recv.index) + } + } // Accommodate to receiver type d := dest[0] if ks, kd := src.Kind(), d.Kind(); ks != kd { @@ -1619,6 +1699,15 @@ func getMethodByName(n *node) { n.exec = func(f *frame) bltn { val := value0(f).Interface().(valueInterface) + typ := val.node.typ + if typ.node == nil && typ.cat == valueT { + // happens with a var of empty interface type, that has value of concrete type + // from runtime, being asserted to "user-defined" interface. + if _, ok := typ.rtype.MethodByName(name); !ok { + panic(fmt.Sprintf("method %s not found", name)) + } + return next + } m, li := val.node.typ.lookupMethod(name) fr := f.clone() nod := *m diff --git a/interp/value.go b/interp/value.go index adf5b60f..022c375e 100644 --- a/interp/value.go +++ b/interp/value.go @@ -33,6 +33,30 @@ func valueOf(data []reflect.Value, i int) reflect.Value { return reflect.Value{} } +func genValueBinMethodOnInterface(n *node, defaultGen func(*frame) reflect.Value) func(*frame) reflect.Value { + if n == nil || n.child == nil || n.child[0] == nil || + n.child[0].child == nil || n.child[0].child[0] == nil { + return defaultGen + } + if n.child[0].child[1] == nil || n.child[0].child[1].ident == "" { + return defaultGen + } + value0 := genValue(n.child[0].child[0]) + + return func(f *frame) reflect.Value { + val, ok := value0(f).Interface().(valueInterface) + if !ok { + return defaultGen(f) + } + typ := val.node.typ + if typ.node != nil || typ.cat != valueT { + return defaultGen(f) + } + meth, _ := typ.rtype.MethodByName(n.child[0].child[1].ident) + return meth.Func + } +} + func genValueRecvIndirect(n *node) func(*frame) reflect.Value { v := genValueRecv(n) return func(f *frame) reflect.Value { return v(f).Elem() } @@ -55,6 +79,35 @@ func genValueRecv(n *node) func(*frame) reflect.Value { } } +func genValueBinRecv(n *node, recv *receiver) func(*frame) reflect.Value { + value := genValue(n) + binValue := genValue(recv.node) + + v := func(f *frame) reflect.Value { + if def, ok := value(f).Interface().(*node); ok { + if def != nil && def.recv != nil && def.recv.val.IsValid() { + return def.recv.val + } + } + + ival, _ := binValue(f).Interface().(valueInterface) + return ival.value + } + + fi := recv.index + if len(fi) == 0 { + return v + } + + return func(f *frame) reflect.Value { + r := v(f) + if r.Kind() == reflect.Ptr { + r = r.Elem() + } + return r.FieldByIndex(fi) + } +} + func genValueRecvInterfacePtr(n *node) func(*frame) reflect.Value { v := genValue(n.recv.node) fi := n.recv.index