From 3640f2f82024ac9e8df384d80f5666352c43e21a Mon Sep 17 00:00:00 2001 From: Nicholas Wiersma Date: Thu, 20 Aug 2020 17:06:05 +0200 Subject: [PATCH] feat: add type assertion expression type checking This adds type checking to TypeAssertExpr. In order to allow for this, method types now have a receiver type in both reflect and native cases. --- _test/method35.go | 14 +++++++++++ interp/cfg.go | 58 ++++++++++++++++++++++++++++----------------- interp/run.go | 6 ++--- interp/type.go | 55 ++++++++++++++++++++++++++++++++++++++---- interp/typecheck.go | 58 +++++++++++++++++++++++++++++++++++++++++---- 5 files changed, 157 insertions(+), 34 deletions(-) create mode 100644 _test/method35.go diff --git a/_test/method35.go b/_test/method35.go new file mode 100644 index 00000000..cee544a2 --- /dev/null +++ b/_test/method35.go @@ -0,0 +1,14 @@ +package main + +import "strconv" + +func main() { + var err error + _, err = strconv.Atoi("erwer") + if _, ok := err.(*strconv.NumError); ok { + println("here") + } +} + +// Output: +// here diff --git a/interp/cfg.go b/interp/cfg.go index 1d70b480..f8db96ab 100644 --- a/interp/cfg.go +++ b/interp/cfg.go @@ -334,6 +334,8 @@ func (interp *Interpreter) cfg(root *node, importPath string) ([]*node, error) { return false } recvTypeNode.typ = typ + n.child[2].typ.recv = typ + n.typ.recv = typ index := sc.add(typ) if len(fr.child) > 1 { sc.sym[fr.child[0].ident] = &symbol{index: index, kind: varSym, typ: typ} @@ -1334,11 +1336,15 @@ func (interp *Interpreter) cfg(root *node, importPath string) ([]*node, error) { // Search for field must then be performed on type T only (not *T) switch method, ok := n.typ.rtype.MethodByName(n.child[1].ident); { case ok: + hasRecvType := n.typ.rtype.Kind() != reflect.Interface n.val = method.Index n.gen = getIndexBinMethod n.action = aGetMethod n.recv = &receiver{node: n.child[0]} n.typ = &itype{cat: valueT, rtype: method.Type, isBinMethod: true} + if hasRecvType { + n.typ.recv = n.typ + } case n.typ.rtype.Kind() == reflect.Ptr: if field, ok := n.typ.rtype.Elem().FieldByName(n.child[1].ident); ok { n.typ = &itype{cat: valueT, rtype: field.Type} @@ -1358,7 +1364,7 @@ func (interp *Interpreter) cfg(root *node, importPath string) ([]*node, error) { if m2, ok2 := pt.MethodByName(n.child[1].ident); ok2 { n.val = m2.Index n.gen = getIndexBinPtrMethod - n.typ = &itype{cat: valueT, rtype: m2.Type} + n.typ = &itype{cat: valueT, rtype: m2.Type, recv: &itype{cat: valueT, rtype: pt}} n.recv = &receiver{node: n.child[0]} n.action = aGetMethod } else { @@ -1372,14 +1378,14 @@ func (interp *Interpreter) cfg(root *node, importPath string) ([]*node, error) { // Handle pointer on object defined in runtime if method, ok := n.typ.val.rtype.MethodByName(n.child[1].ident); ok { n.val = method.Index - n.typ = &itype{cat: valueT, rtype: method.Type} + n.typ = &itype{cat: valueT, rtype: method.Type, recv: n.typ} n.recv = &receiver{node: n.child[0]} n.gen = getIndexBinMethod n.action = aGetMethod } else if method, ok := reflect.PtrTo(n.typ.val.rtype).MethodByName(n.child[1].ident); ok { n.val = method.Index n.gen = getIndexBinMethod - n.typ = &itype{cat: valueT, rtype: method.Type} + n.typ = &itype{cat: valueT, rtype: method.Type, recv: &itype{cat: valueT, rtype: reflect.PtrTo(n.typ.val.rtype)}} n.recv = &receiver{node: n.child[0]} n.action = aGetMethod } else if field, ok := n.typ.val.rtype.FieldByName(n.child[1].ident); ok { @@ -1445,7 +1451,7 @@ func (interp *Interpreter) cfg(root *node, importPath string) ([]*node, error) { } n.recv = &receiver{node: n.child[0], index: lind} n.val = append([]int{m.Index}, lind...) - n.typ = &itype{cat: valueT, rtype: m.Type} + n.typ = &itype{cat: valueT, rtype: m.Type, recv: n.child[0].typ} } else if ti := n.typ.lookupField(n.child[1].ident); len(ti) > 0 { // Handle struct field n.val = ti @@ -1658,25 +1664,33 @@ func (interp *Interpreter) cfg(root *node, importPath string) ([]*node, error) { n.child[0].tnext = sbn.start case typeAssertExpr: - if len(n.child) > 1 { - wireChild(n) - c1 := n.child[1] - if c1.typ == nil { - if c1.typ, err = nodeType(interp, sc, c1); err != nil { - return - } - } - if n.anc.action != aAssignX { - if n.child[0].typ.cat == valueT && isFunc(c1.typ) { - // Avoid special wrapping of interfaces and func types. - n.typ = &itype{cat: valueT, rtype: c1.typ.TypeOf()} - } else { - n.typ = c1.typ - } - n.findex = sc.add(n.typ) - } - } else { + if len(n.child) == 1 { + // The "o.(type)" is handled by typeSwitch. n.gen = nop + break + } + + wireChild(n) + c0, c1 := n.child[0], n.child[1] + if c1.typ == nil { + if c1.typ, err = nodeType(interp, sc, c1); err != nil { + return + } + } + + err = check.typeAssertionExpr(c0, c1.typ) + if err != nil { + break + } + + if n.anc.action != aAssignX { + if c0.typ.cat == valueT && isFunc(c1.typ) { + // Avoid special wrapping of interfaces and func types. + n.typ = &itype{cat: valueT, rtype: c1.typ.TypeOf()} + } else { + n.typ = c1.typ + } + n.findex = sc.add(n.typ) } case sliceExpr: diff --git a/interp/run.go b/interp/run.go index f85bb929..32d6ef5b 100644 --- a/interp/run.go +++ b/interp/run.go @@ -152,7 +152,7 @@ func typeAssertStatus(n *node) { value1(f).SetBool(ok) return next } - case c0.typ.cat == valueT: + case c0.typ.cat == valueT || c0.typ.cat == errorT: n.exec = func(f *frame) bltn { v := value(f) ok := v.IsValid() && canAssertTypes(v.Elem().Type(), rtype) @@ -205,7 +205,7 @@ func typeAssert(n *node) { value0(f).Set(v) return next } - case c0.typ.cat == valueT: + case c0.typ.cat == valueT || c0.typ.cat == errorT: n.exec = func(f *frame) bltn { v := value(f).Elem() typ := value0(f).Type() @@ -272,7 +272,7 @@ func typeAssert2(n *node) { } return next } - case n.child[0].typ.cat == valueT: + case n.child[0].typ.cat == valueT || n.child[0].typ.cat == errorT: n.exec = func(f *frame) bltn { v := value(f).Elem() ok := v.IsValid() && canAssertTypes(v.Type(), rtype) diff --git a/interp/type.go b/interp/type.go index 8f2454b9..dba9cef5 100644 --- a/interp/type.go +++ b/interp/type.go @@ -108,6 +108,7 @@ type itype struct { field []structField // Array of struct fields if structT or interfaceT key *itype // Type of key element if MapT or nil val *itype // Type of value element if chanT,chanSendT, chanRecvT, mapT, ptrT, aliasT, arrayT or variadicT + recv *itype // Receiver type for funcT or nil arg []*itype // Argument types if funcT or nil ret []*itype // Return types if funcT or nil method []*node // Associated methods or nil @@ -559,7 +560,7 @@ func nodeType(interp *Interpreter, sc *scope, n *node) (*itype, error) { if m, _ := lt.lookupMethod(name); m != nil { t, err = nodeType(interp, sc, m.child[2]) } else if bm, _, _, ok := lt.lookupBinMethod(name); ok { - t = &itype{cat: valueT, rtype: bm.Type, isBinMethod: true, scope: sc} + t = &itype{cat: valueT, rtype: bm.Type, recv: lt, isBinMethod: true, scope: sc} } else if ti := lt.lookupField(name); len(ti) > 0 { t = lt.fieldSeq(ti) } else if bs, _, ok := lt.lookupBinField(name); ok { @@ -761,11 +762,16 @@ func (t *itype) numIn() int { case funcT: return len(t.arg) case valueT: - if t.rtype.Kind() == reflect.Func { - return t.rtype.NumIn() + if t.rtype.Kind() != reflect.Func { + return 0 } + in := t.rtype.NumIn() + if t.recv != nil { + in-- + } + return in } - return 1 + return 0 } func (t *itype) in(i int) *itype { @@ -774,6 +780,9 @@ func (t *itype) in(i int) *itype { return t.arg[i] case valueT: if t.rtype.Kind() == reflect.Func { + if t.recv != nil { + i++ + } if t.rtype.IsVariadic() && i == t.rtype.NumIn()-1 { return &itype{cat: variadicT, val: &itype{cat: valueT, rtype: t.rtype.In(i).Elem()}} } @@ -995,6 +1004,14 @@ func (t *itype) methods() methodSet { res[m.Name] = m.Type.String() } case ptrT: + if typ.val.cat == valueT { + // Ptr receiver methods need to be found with the ptr type. + typ.TypeOf() // Ensure the rtype exists. + for i := typ.rtype.NumMethod() - 1; i >= 0; i-- { + m := typ.rtype.Method(i) + res[m.Name] = m.Type.String() + } + } for k, v := range getMethods(typ.val) { res[k] = v } @@ -1244,6 +1261,36 @@ func (t *itype) lookupBinMethod(name string) (m reflect.Method, index []int, isP return m, index, isPtr, ok } +func lookupFieldOrMethod(t *itype, name string) *itype { + switch { + case t.cat == valueT || t.cat == ptrT && t.val.cat == valueT: + m, _, isPtr, ok := t.lookupBinMethod(name) + if !ok { + return nil + } + var recv *itype + if t.rtype.Kind() != reflect.Interface { + recv = t + if isPtr && t.cat != ptrT && t.rtype.Kind() != reflect.Ptr { + recv = &itype{cat: ptrT, val: t} + } + } + return &itype{cat: valueT, rtype: m.Type, recv: recv} + case t.cat == interfaceT: + seq := t.lookupField(name) + if seq == nil { + return nil + } + return t.fieldSeq(seq) + default: + n, _ := t.lookupMethod(name) + if n == nil { + return nil + } + return n.typ + } +} + func exportName(s string) string { if canExport(s) { return s diff --git a/interp/typecheck.go b/interp/typecheck.go index c8a9aad9..f7f1c11c 100644 --- a/interp/typecheck.go +++ b/interp/typecheck.go @@ -527,6 +527,59 @@ func (check typecheck) sliceExpr(n *node) error { return nil } +// typeAssertionExpr type checks a type assert expression. +func (check typecheck) typeAssertionExpr(n *node, typ *itype) error { + // TODO(nick): This type check is not complete and should be revisited once + // https://github.com/golang/go/issues/39717 lands. It is currently impractical to + // type check Named types as they cannot be asserted. + + if n.typ.TypeOf().Kind() != reflect.Interface { + return n.cfgErrorf("invalid type assertion: non-interface type %s on left", n.typ.id()) + } + ims := n.typ.methods() + if len(ims) == 0 { + // Empty interface must be a dynamic check. + return nil + } + + if isInterface(typ) { + // Asserting to an interface is a dynamic check as we must look to the + // underlying struct. + return nil + } + + for name := range ims { + im := lookupFieldOrMethod(n.typ, name) + tm := lookupFieldOrMethod(typ, name) + if im == nil { + // This should not be possible. + continue + } + if tm == nil { + return n.cfgErrorf("impossible type assertion: %s does not implement %s (missing %v method)", typ.id(), n.typ.id(), name) + } + if tm.recv != nil && tm.recv.TypeOf().Kind() == reflect.Ptr && typ.TypeOf().Kind() != reflect.Ptr { + return n.cfgErrorf("impossible type assertion: %s does not implement %s as %q method has a pointer receiver", typ.id(), n.typ.id(), name) + } + + err := n.cfgErrorf("impossible type assertion: %s does not implement %s", typ.id(), n.typ.id()) + if im.numIn() != tm.numIn() || im.numOut() != tm.numOut() { + return err + } + for i := 0; i < im.numIn(); i++ { + if !im.in(i).equals(tm.in(i)) { + return err + } + } + for i := 0; i < im.numOut(); i++ { + if !im.out(i).equals(tm.out(i)) { + return err + } + } + } + return nil +} + // conversion type checks the conversion of n to typ. func (check typecheck) conversion(n *node, typ *itype) error { var c constant.Value @@ -601,11 +654,6 @@ func (check typecheck) arguments(n *node, child []*node, fun *node, ellipsis boo } var cnt int - if fun.kind == selectorExpr && fun.typ.cat == valueT && fun.recv != nil && !isInterface(fun.recv.node.typ) { - // If this is a bin call, and we have a receiver and the receiver is - // not an interface, then the first input is the receiver, so skip it. - cnt++ - } for i, arg := range child { ellip := i == l-1 && ellipsis if err := check.argument(arg, fun.typ, cnt, ellip); err != nil {