diff --git a/_test/interface18.go b/_test/interface18.go new file mode 100644 index 00000000..072ababb --- /dev/null +++ b/_test/interface18.go @@ -0,0 +1,18 @@ +package main + +type T struct{} + +func (t *T) Error() string { return "T: error" } +func (*T) Foo() { println("foo") } + +var invalidT = &T{} + +func main() { + var err error + if err != invalidT { + println("ok") + } +} + +// Output: +// ok diff --git a/interp/cfg.go b/interp/cfg.go index 42163a0f..f83a9133 100644 --- a/interp/cfg.go +++ b/interp/cfg.go @@ -549,7 +549,7 @@ func (interp *Interpreter) cfg(root *node) ([]*node, error) { t0, t1 := c0.typ.TypeOf(), c1.typ.TypeOf() // Shift operator type is inherited from first parameter only // All other binary operators require both parameter types to be the same - if !isShiftNode(n) && !c0.typ.untyped && !c1.typ.untyped && !c0.typ.equal(c1.typ) { + if !isShiftNode(n) && !c0.typ.untyped && !c1.typ.untyped && !c0.typ.equals(c1.typ) { err = n.cfgErrorf("mismatched types %s and %s", c0.typ.id(), c1.typ.id()) break } diff --git a/interp/type.go b/interp/type.go index ec38a380..ba3295f3 100644 --- a/interp/type.go +++ b/interp/type.go @@ -636,18 +636,41 @@ func (t *itype) finalize() (*itype, error) { return t, err } -// equal returns true if the given type is identical to the receiver one -func (t *itype) equal(o *itype) bool { - if isInterface(t) || isInterface(o) { - // Check for identical methods sets - return reflect.DeepEqual(t.methods(), o.methods()) +// Equals returns true if the given type is identical to the receiver one. +func (t *itype) equals(o *itype) bool { + switch ti, oi := isInterface(t), isInterface(o); { + case ti && oi: + return t.methods().equals(o.methods()) + case ti && !oi: + return o.methods().contains(t.methods()) + case oi && !ti: + return t.methods().contains(o.methods()) + default: + return t.id() == o.id() } - return t.id() == o.id() } -// methods returns a map of method type strings, indexed by method names -func (t *itype) methods() map[string]string { - res := make(map[string]string) +// MethodSet defines the set of methods signatures as strings, indexed per method name. +type methodSet map[string]string + +// Contains returns true if the method set m contains the method set n. +func (m methodSet) contains(n methodSet) bool { + for k, v := range n { + if m[k] != v { + return false + } + } + return true +} + +// Equal returns true if the method set m is equal to the method set n. +func (m methodSet) equals(n methodSet) bool { + return m.contains(n) && n.contains(m) +} + +// Methods returns a map of method type strings, indexed by method names. +func (t *itype) methods() methodSet { + res := make(methodSet) switch t.cat { case interfaceT: // Get methods from recursive analysis of interface fields @@ -879,15 +902,9 @@ func (t *itype) refType(defined map[string]bool) reflect.Type { in := make([]reflect.Type, len(t.arg)) out := make([]reflect.Type, len(t.ret)) for i, v := range t.arg { - if defined[v.name] { - v.rtype = interf - } in[i] = v.refType(defined) } for i, v := range t.ret { - if defined[v.name] { - v.rtype = interf - } out[i] = v.refType(defined) } t.rtype = reflect.FuncOf(in, out, false)