interp: support more type assertion cases

Fixes #955
This commit is contained in:
mpl
2020-11-30 18:00:04 +01:00
committed by GitHub
parent 662d2a6afe
commit 1e0f6ece6e
3 changed files with 213 additions and 13 deletions

58
_test/assert0.go Normal file
View File

@@ -0,0 +1,58 @@
package main
import (
"fmt"
"time"
)
type MyWriter interface {
Write(p []byte) (i int, err error)
}
type TestStruct struct{}
func (t TestStruct) Write(p []byte) (n int, err error) {
return len(p), nil
}
func usesWriter(w MyWriter) {
w.Write(nil)
}
type MyStringer interface {
String() string
}
func usesStringer(s MyStringer) {
fmt.Println(s.String())
}
func main() {
var t interface{}
t = TestStruct{}
var tw MyWriter
var ok bool
tw, ok = t.(MyWriter)
if !ok {
fmt.Println("TestStruct does not implement MyWriter")
} else {
fmt.Println("TestStruct implements MyWriter")
usesWriter(tw)
}
var tt interface{}
tt = time.Nanosecond
var myD MyStringer
myD, ok = tt.(MyStringer)
if !ok {
fmt.Println("time.Nanosecond does not implement MyStringer")
} else {
fmt.Println("time.Nanosecond implements MyStringer")
usesStringer(myD)
}
}
// Output:
// TestStruct implements MyWriter
// time.Nanosecond implements MyStringer
// 1ns

View File

@@ -3,10 +3,12 @@ package interp
//go:generate go run ../internal/cmd/genop/genop.go //go:generate go run ../internal/cmd/genop/genop.go
import ( import (
"errors"
"fmt" "fmt"
"go/constant" "go/constant"
"log" "log"
"reflect" "reflect"
"regexp"
"sync" "sync"
"unsafe" "unsafe"
) )
@@ -72,6 +74,17 @@ var builtin = [...]bltnGenerator{
aXorAssign: xorAssign, aXorAssign: xorAssign,
} }
var receiverStripperRxp *regexp.Regexp
func init() {
re := `func\(((.*?(, |\)))(.*))`
var err error
receiverStripperRxp, err = regexp.Compile(re)
if err != nil {
panic(err)
}
}
type valueInterface struct { type valueInterface struct {
node *node node *node
value reflect.Value value reflect.Value
@@ -283,6 +296,17 @@ func typeAssert(n *node) {
} }
} }
func stripReceiverFromArgs(signature string) (string, error) {
fields := receiverStripperRxp.FindStringSubmatch(signature)
if len(fields) < 5 {
return "", errors.New("error while matching method signature")
}
if fields[3] == ")" {
return fmt.Sprintf("func()%s", fields[4]), nil
}
return fmt.Sprintf("func(%s", fields[4]), nil
}
func typeAssert2(n *node) { func typeAssert2(n *node) {
c0, c1 := n.child[0], n.child[1] c0, c1 := n.child[0], n.child[1]
value := genValue(c0) // input value value := genValue(c0) // input value
@@ -298,14 +322,60 @@ func typeAssert2(n *node) {
case isInterfaceSrc(typ): case isInterfaceSrc(typ):
n.exec = func(f *frame) bltn { n.exec = func(f *frame) bltn {
v, ok := value(f).Interface().(valueInterface) v, ok := value(f).Interface().(valueInterface)
if ok && v.node.typ.id() == typID { defer func() {
assertOk := ok
if setStatus {
value1(f).SetBool(assertOk)
}
}()
if !ok {
return next
}
if v.node.typ.id() == typID {
value0(f).Set(value(f)) value0(f).Set(value(f))
} else { return next
}
m0 := v.node.typ.methods()
m1 := typ.methods()
if len(m0) < len(m1) {
ok = false ok = false
return next
} }
if setStatus {
value1(f).SetBool(ok) for k, meth1 := range m1 {
var meth0 string
meth0, ok = m0[k]
if !ok {
return next
}
// As far as we know this equality check can fail because they are two ways to
// represent the signature of a method: one where the receiver appears before the
// func keyword, and one where it is just a func signature, and the receiver is
// seen as the first argument. That's why if that equality fails, we try harder to
// compare them afterwards. Hopefully that is the only reason this equality can fail.
if meth0 == meth1 {
continue
}
tm := lookupFieldOrMethod(v.node.typ, k)
if tm == nil {
ok = false
return next
}
var err error
meth0, err = stripReceiverFromArgs(meth0)
if err != nil {
ok = false
return next
}
if meth0 != meth1 {
ok = false
return next
}
} }
value0(f).Set(value(f))
return next return next
} }
case isInterface(typ): case isInterface(typ):
@@ -879,7 +949,10 @@ func call(n *node) {
var method bool var method bool
value := genValue(n.child[0]) value := genValue(n.child[0])
var values []func(*frame) reflect.Value var values []func(*frame) reflect.Value
if n.child[0].recv != nil {
recvIndexLater := false
switch {
case n.child[0].recv != nil:
// Compute method receiver value. // Compute method receiver value.
if isRecursiveType(n.child[0].recv.node.typ, n.child[0].recv.node.typ.rtype) { if isRecursiveType(n.child[0].recv.node.typ, n.child[0].recv.node.typ.rtype) {
values = append(values, genValueRecvInterfacePtr(n.child[0])) values = append(values, genValueRecvInterfacePtr(n.child[0]))
@@ -887,11 +960,17 @@ func call(n *node) {
values = append(values, genValueRecv(n.child[0])) values = append(values, genValueRecv(n.child[0]))
} }
method = true method = true
} else if n.child[0].action == aMethod { case len(n.child[0].child) > 0 && n.child[0].child[0].typ != nil && n.child[0].child[0].typ.cat == interfaceT:
recvIndexLater = true
values = append(values, genValueBinRecv(n.child[0], &receiver{node: n.child[0].child[0]}))
value = genValueBinMethodOnInterface(n, value)
method = true
case n.child[0].action == aMethod:
// Add a place holder for interface method receiver. // Add a place holder for interface method receiver.
values = append(values, nil) values = append(values, nil)
method = true method = true
} }
numRet := len(n.child[0].typ.ret) numRet := len(n.child[0].typ.ret)
variadic := variadicPos(n) variadic := variadicPos(n)
child := n.child[1:] child := n.child[1:]
@@ -1001,6 +1080,7 @@ func call(n *node) {
n.exec = func(f *frame) bltn { n.exec = func(f *frame) bltn {
var def *node var def *node
var ok bool var ok bool
bf := value(f) bf := value(f)
if def, ok = bf.Interface().(*node); ok { if def, ok = bf.Interface().(*node); ok {
bf = def.rval bf = def.rval
@@ -1070,16 +1150,16 @@ func call(n *node) {
var src reflect.Value var src reflect.Value
if v == nil { if v == nil {
src = def.recv.val src = def.recv.val
if len(def.recv.index) > 0 {
if src.Kind() == reflect.Ptr {
src = src.Elem().FieldByIndex(def.recv.index)
} else {
src = src.FieldByIndex(def.recv.index)
}
}
} else { } else {
src = v(f) src = v(f)
} }
if recvIndexLater && def.recv != nil && len(def.recv.index) > 0 {
if src.Kind() == reflect.Ptr {
src = src.Elem().FieldByIndex(def.recv.index)
} else {
src = src.FieldByIndex(def.recv.index)
}
}
// Accommodate to receiver type // Accommodate to receiver type
d := dest[0] d := dest[0]
if ks, kd := src.Kind(), d.Kind(); ks != kd { if ks, kd := src.Kind(), d.Kind(); ks != kd {
@@ -1619,6 +1699,15 @@ func getMethodByName(n *node) {
n.exec = func(f *frame) bltn { n.exec = func(f *frame) bltn {
val := value0(f).Interface().(valueInterface) val := value0(f).Interface().(valueInterface)
typ := val.node.typ
if typ.node == nil && typ.cat == valueT {
// happens with a var of empty interface type, that has value of concrete type
// from runtime, being asserted to "user-defined" interface.
if _, ok := typ.rtype.MethodByName(name); !ok {
panic(fmt.Sprintf("method %s not found", name))
}
return next
}
m, li := val.node.typ.lookupMethod(name) m, li := val.node.typ.lookupMethod(name)
fr := f.clone() fr := f.clone()
nod := *m nod := *m

View File

@@ -33,6 +33,30 @@ func valueOf(data []reflect.Value, i int) reflect.Value {
return reflect.Value{} return reflect.Value{}
} }
func genValueBinMethodOnInterface(n *node, defaultGen func(*frame) reflect.Value) func(*frame) reflect.Value {
if n == nil || n.child == nil || n.child[0] == nil ||
n.child[0].child == nil || n.child[0].child[0] == nil {
return defaultGen
}
if n.child[0].child[1] == nil || n.child[0].child[1].ident == "" {
return defaultGen
}
value0 := genValue(n.child[0].child[0])
return func(f *frame) reflect.Value {
val, ok := value0(f).Interface().(valueInterface)
if !ok {
return defaultGen(f)
}
typ := val.node.typ
if typ.node != nil || typ.cat != valueT {
return defaultGen(f)
}
meth, _ := typ.rtype.MethodByName(n.child[0].child[1].ident)
return meth.Func
}
}
func genValueRecvIndirect(n *node) func(*frame) reflect.Value { func genValueRecvIndirect(n *node) func(*frame) reflect.Value {
v := genValueRecv(n) v := genValueRecv(n)
return func(f *frame) reflect.Value { return v(f).Elem() } return func(f *frame) reflect.Value { return v(f).Elem() }
@@ -55,6 +79,35 @@ func genValueRecv(n *node) func(*frame) reflect.Value {
} }
} }
func genValueBinRecv(n *node, recv *receiver) func(*frame) reflect.Value {
value := genValue(n)
binValue := genValue(recv.node)
v := func(f *frame) reflect.Value {
if def, ok := value(f).Interface().(*node); ok {
if def != nil && def.recv != nil && def.recv.val.IsValid() {
return def.recv.val
}
}
ival, _ := binValue(f).Interface().(valueInterface)
return ival.value
}
fi := recv.index
if len(fi) == 0 {
return v
}
return func(f *frame) reflect.Value {
r := v(f)
if r.Kind() == reflect.Ptr {
r = r.Elem()
}
return r.FieldByIndex(fi)
}
}
func genValueRecvInterfacePtr(n *node) func(*frame) reflect.Value { func genValueRecvInterfacePtr(n *node) func(*frame) reflect.Value {
v := genValue(n.recv.node) v := genValue(n.recv.node)
fi := n.recv.index fi := n.recv.index