interp: improve processing of recursive types

Make sure to keep always a single copy of incomplete type structures.
Remove remnants of recursive types processing.

Now `import "go.uber.org/zap"` works again (see #1172), fixing regressions
introduced since #1236.
This commit is contained in:
Marc Vertes
2021-09-13 18:24:10 +02:00
committed by GitHub
parent 3eb2c79fd8
commit bd9a6a4f8a
4 changed files with 98 additions and 159 deletions

View File

@@ -123,13 +123,12 @@ type itype struct {
path string // for a defined type, the package import path
length int // length of array if ArrayT
rtype reflect.Type // Reflection type if ValueT, or nil
incomplete bool // true if type must be parsed again (out of order declarations)
recursive bool // true if the type has an element which refer to itself
untyped bool // true for a literal value (string or number)
isBinMethod bool // true if the type refers to a bin method function
node *node // root AST node of type definition
scope *scope // type declaration scope (in case of re-parse incomplete type)
str string // String representation of the type
incomplete bool // true if type must be parsed again (out of order declarations)
untyped bool // true for a literal value (string or number)
isBinMethod bool // true if the type refers to a bin method function
}
func untypedBool() *itype {
@@ -213,6 +212,14 @@ func wrapperValueTOf(rtype reflect.Type, val *itype, opts ...itypeOption) *itype
return t
}
func variadicOf(val *itype, opts ...itypeOption) *itype {
t := &itype{cat: variadicT, val: val, str: "..." + val.str}
for _, opt := range opts {
opt(t)
}
return t
}
// ptrOf returns a pointer to t.
func ptrOf(val *itype, opts ...itypeOption) *itype {
if val.ptr != nil {
@@ -319,12 +326,17 @@ func mapOf(key, val *itype, opts ...itypeOption) *itype {
}
// interfaceOf returns an interface type with the given fields.
func interfaceOf(fields []structField, opts ...itypeOption) *itype {
func interfaceOf(t *itype, fields []structField, opts ...itypeOption) *itype {
str := "interface{}"
if len(fields) > 0 {
str = "interface { " + methodsTypeString(fields) + "}"
}
t := &itype{cat: interfaceT, field: fields, str: str}
if t == nil {
t = &itype{}
}
t.cat = interfaceT
t.field = fields
t.str = str
for _, opt := range opts {
opt(t)
}
@@ -332,12 +344,17 @@ func interfaceOf(fields []structField, opts ...itypeOption) *itype {
}
// structOf returns a struct type with the given fields.
func structOf(fields []structField, opts ...itypeOption) *itype {
func structOf(t *itype, fields []structField, opts ...itypeOption) *itype {
str := "struct {}"
if len(fields) > 0 {
str = "struct { " + fieldsTypeString(fields) + "}"
}
t := &itype{cat: structT, field: fields, str: str}
if t == nil {
t = &itype{}
}
t.cat = structT
t.field = fields
t.str = str
for _, opt := range opts {
opt(t)
}
@@ -346,21 +363,31 @@ func structOf(fields []structField, opts ...itypeOption) *itype {
// nodeType returns a type definition for the corresponding AST subtree.
func nodeType(interp *Interpreter, sc *scope, n *node) (*itype, error) {
return nodeType2(interp, sc, n, map[*node]bool{})
}
func nodeType2(interp *Interpreter, sc *scope, n *node, seen map[*node]bool) (t *itype, err error) {
if n.typ != nil && !n.typ.incomplete {
return n.typ, nil
}
if sname := typeName(n); sname != "" {
if sym, _, found := sc.lookup(sname); found && sym.kind == typeSym && sym.typ != nil && sym.typ.isComplete() {
return sym.typ, nil
sym, _, found := sc.lookup(sname)
if found && sym.kind == typeSym && sym.typ != nil {
if sym.typ.isComplete() {
return sym.typ, nil
}
if seen[n] {
// TODO (marc): find a better way to distinguish recursive vs incomplete types.
sym.typ.incomplete = false
return sym.typ, nil
}
}
}
seen[n] = true
t := &itype{node: n, scope: sc}
var err error
switch n.kind {
case addressExpr, starExpr:
val, err := nodeType(interp, sc, n.child[0])
val, err := nodeType2(interp, sc, n.child[0], seen)
if err != nil {
return nil, err
}
@@ -370,7 +397,7 @@ func nodeType(interp *Interpreter, sc *scope, n *node) (*itype, error) {
case arrayType:
c0 := n.child[0]
if len(n.child) == 1 {
val, err := nodeType(interp, sc, c0)
val, err := nodeType2(interp, sc, c0, seen)
if err != nil {
return nil, err
}
@@ -422,7 +449,7 @@ func nodeType(interp *Interpreter, sc *scope, n *node) (*itype, error) {
}
length = constToInt(v)
}
val, err := nodeType(interp, sc, n.child[1])
val, err := nodeType2(interp, sc, n.child[1], seen)
if err != nil {
return nil, err
}
@@ -459,17 +486,17 @@ func nodeType(interp *Interpreter, sc *scope, n *node) (*itype, error) {
}
case unaryExpr:
t, err = nodeType(interp, sc, n.child[0])
t, err = nodeType2(interp, sc, n.child[0], seen)
case binaryExpr:
// Get type of first operand.
if t, err = nodeType(interp, sc, n.child[0]); err != nil {
if t, err = nodeType2(interp, sc, n.child[0], seen); err != nil {
return nil, err
}
// For operators other than shift, get the type from the 2nd operand if the first is untyped.
if t.untyped && !isShiftNode(n) {
var t1 *itype
t1, err = nodeType(interp, sc, n.child[1])
t1, err = nodeType2(interp, sc, n.child[1], seen)
if !(t1.untyped && isInt(t1.TypeOf()) && isFloat(t.TypeOf())) {
t = t1
}
@@ -487,7 +514,7 @@ func nodeType(interp *Interpreter, sc *scope, n *node) (*itype, error) {
a.child[0].typ = &itype{cat: interfaceT, val: dt, str: "interface{}"}
case a.kind == defineStmt && len(a.child) > a.nleft+a.nright:
if dt, err = nodeType(interp, sc, a.child[a.nleft]); err != nil {
if dt, err = nodeType2(interp, sc, a.child[a.nleft], seen); err != nil {
return nil, err
}
@@ -503,14 +530,13 @@ func nodeType(interp *Interpreter, sc *scope, n *node) (*itype, error) {
case callExpr:
if isBuiltinCall(n, sc) {
// Builtin types are special and may depend from their input arguments.
t.cat = builtinT
switch n.child[0].ident {
case bltnComplex:
var nt0, nt1 *itype
if nt0, err = nodeType(interp, sc, n.child[1]); err != nil {
if nt0, err = nodeType2(interp, sc, n.child[1], seen); err != nil {
return nil, err
}
if nt1, err = nodeType(interp, sc, n.child[2]); err != nil {
if nt1, err = nodeType2(interp, sc, n.child[2], seen); err != nil {
return nil, err
}
if nt0.incomplete || nt1.incomplete {
@@ -535,7 +561,7 @@ func nodeType(interp *Interpreter, sc *scope, n *node) (*itype, error) {
}
}
case bltnReal, bltnImag:
if t, err = nodeType(interp, sc, n.child[1]); err != nil {
if t, err = nodeType2(interp, sc, n.child[1], seen); err != nil {
return nil, err
}
if !t.incomplete {
@@ -553,20 +579,22 @@ func nodeType(interp *Interpreter, sc *scope, n *node) (*itype, error) {
case bltnCap, bltnCopy, bltnLen:
t = sc.getType("int")
case bltnAppend, bltnMake:
t, err = nodeType(interp, sc, n.child[1])
t, err = nodeType2(interp, sc, n.child[1], seen)
case bltnNew:
t, err = nodeType(interp, sc, n.child[1])
t, err = nodeType2(interp, sc, n.child[1], seen)
incomplete := t.incomplete
t = ptrOf(t, withScope(sc))
t.incomplete = incomplete
case bltnRecover:
t = sc.getType("interface{}")
default:
t = &itype{cat: builtinT}
}
if err != nil {
return nil, err
}
} else {
if t, err = nodeType(interp, sc, n.child[0]); err != nil {
if t, err = nodeType2(interp, sc, n.child[0], seen); err != nil || t == nil {
return nil, err
}
switch t.cat {
@@ -582,7 +610,7 @@ func nodeType(interp *Interpreter, sc *scope, n *node) (*itype, error) {
}
case compositeLitExpr:
t, err = nodeType(interp, sc, n.child[0])
t, err = nodeType2(interp, sc, n.child[0], seen)
case chanType, chanTypeRecv, chanTypeSend:
dir := chanSendRecv
@@ -592,7 +620,7 @@ func nodeType(interp *Interpreter, sc *scope, n *node) (*itype, error) {
case chanTypeSend:
dir = chanSend
}
val, err := nodeType(interp, sc, n.child[0])
val, err := nodeType2(interp, sc, n.child[0], seen)
if err != nil {
return nil, err
}
@@ -600,15 +628,15 @@ func nodeType(interp *Interpreter, sc *scope, n *node) (*itype, error) {
t.incomplete = val.incomplete
case ellipsisExpr:
t.cat = variadicT
if t.val, err = nodeType(interp, sc, n.child[0]); err != nil {
val, err := nodeType2(interp, sc, n.child[0], seen)
if err != nil {
return nil, err
}
t.str = "..." + t.val.str
t = variadicOf(val, withNode(n), withScope(sc))
t.incomplete = t.val.incomplete
case funcLit:
t, err = nodeType(interp, sc, n.child[2])
t, err = nodeType2(interp, sc, n.child[2], seen)
case funcType:
var incomplete bool
@@ -616,7 +644,7 @@ func nodeType(interp *Interpreter, sc *scope, n *node) (*itype, error) {
args := make([]*itype, 0, len(n.child[0].child))
for _, arg := range n.child[0].child {
cl := len(arg.child) - 1
typ, err := nodeType(interp, sc, arg.child[cl])
typ, err := nodeType2(interp, sc, arg.child[cl], seen)
if err != nil {
return nil, err
}
@@ -633,7 +661,7 @@ func nodeType(interp *Interpreter, sc *scope, n *node) (*itype, error) {
// Handle returned values
for _, ret := range n.child[1].child {
cl := len(ret.child) - 1
typ, err := nodeType(interp, sc, ret.child[cl])
typ, err := nodeType2(interp, sc, ret.child[cl], seen)
if err != nil {
return nil, err
}
@@ -656,9 +684,7 @@ func nodeType(interp *Interpreter, sc *scope, n *node) (*itype, error) {
ident := filepath.Join(n.ident, baseName)
sym, _, found = sc.lookup(ident)
if !found {
t.name = n.ident
t.path = sc.pkgName
t.incomplete = true
t = &itype{name: n.ident, path: sc.pkgName, node: n, incomplete: true, scope: sc}
sc.sym[n.ident] = &symbol{kind: typeSym, typ: t}
break
}
@@ -669,7 +695,7 @@ func nodeType(interp *Interpreter, sc *scope, n *node) (*itype, error) {
}
if t.incomplete && t.node != n {
m := t.method
if t, err = nodeType(interp, sc, t.node); err != nil {
if t, err = nodeType2(interp, sc, t.node, seen); err != nil {
return nil, err
}
t.method = m
@@ -681,7 +707,7 @@ func nodeType(interp *Interpreter, sc *scope, n *node) (*itype, error) {
case indexExpr:
var lt *itype
if lt, err = nodeType(interp, sc, n.child[0]); err != nil {
if lt, err = nodeType2(interp, sc, n.child[0], seen); err != nil {
return nil, err
}
if lt.incomplete {
@@ -694,13 +720,12 @@ func nodeType(interp *Interpreter, sc *scope, n *node) (*itype, error) {
}
case interfaceType:
t.cat = interfaceT
var incomplete bool
if sname := typeName(n); sname != "" {
if sym, _, found := sc.lookup(sname); found && sym.kind == typeSym {
sym.typ = t
t = interfaceOf(sym.typ, sym.typ.field, withNode(n), withScope(sc))
}
}
var incomplete bool
fields := make([]structField, 0, len(n.child[0].child))
for _, field := range n.child[0].child {
f0 := field.child[0]
@@ -712,7 +737,7 @@ func nodeType(interp *Interpreter, sc *scope, n *node) (*itype, error) {
fields = append(fields, structField{name: "Error", typ: typ})
continue
}
typ, err := nodeType(interp, sc, f0)
typ, err := nodeType2(interp, sc, f0, seen)
if err != nil {
return nil, err
}
@@ -720,25 +745,25 @@ func nodeType(interp *Interpreter, sc *scope, n *node) (*itype, error) {
incomplete = incomplete || typ.incomplete
continue
}
typ, err := nodeType(interp, sc, field.child[1])
typ, err := nodeType2(interp, sc, field.child[1], seen)
if err != nil {
return nil, err
}
fields = append(fields, structField{name: f0.ident, typ: typ})
incomplete = incomplete || typ.incomplete
}
*t = *interfaceOf(fields, withNode(n), withScope(sc))
t = interfaceOf(t, fields, withNode(n), withScope(sc))
t.incomplete = incomplete
case landExpr, lorExpr:
t = sc.getType("bool")
case mapType:
key, err := nodeType(interp, sc, n.child[0])
key, err := nodeType2(interp, sc, n.child[0], seen)
if err != nil {
return nil, err
}
val, err := nodeType(interp, sc, n.child[1])
val, err := nodeType2(interp, sc, n.child[1], seen)
if err != nil {
return nil, err
}
@@ -746,7 +771,7 @@ func nodeType(interp *Interpreter, sc *scope, n *node) (*itype, error) {
t.incomplete = key.incomplete || val.incomplete
case parenExpr:
t, err = nodeType(interp, sc, n.child[0])
t, err = nodeType2(interp, sc, n.child[0], seen)
case selectorExpr:
// Resolve the left part of selector, then lookup the right part on it
@@ -772,12 +797,11 @@ func nodeType(interp *Interpreter, sc *scope, n *node) (*itype, error) {
}
}
if lt, err = nodeType(interp, localScope, n.child[0]); err != nil {
if lt, err = nodeType2(interp, localScope, n.child[0], seen); err != nil {
return nil, err
}
if lt.incomplete {
t.incomplete = true
break
}
name := n.child[1].ident
@@ -803,7 +827,7 @@ func nodeType(interp *Interpreter, sc *scope, n *node) (*itype, error) {
}
default:
if m, _ := lt.lookupMethod(name); m != nil {
t, err = nodeType(interp, sc, m.child[2])
t, err = nodeType2(interp, sc, m.child[2], seen)
} else if bm, _, _, ok := lt.lookupBinMethod(name); ok {
t = valueTOf(bm.Type, isBinMethod(), withRecv(lt), withScope(sc))
} else if ti := lt.lookupField(name); len(ti) > 0 {
@@ -816,7 +840,7 @@ func nodeType(interp *Interpreter, sc *scope, n *node) (*itype, error) {
}
case sliceExpr:
t, err = nodeType(interp, sc, n.child[0])
t, err = nodeType2(interp, sc, n.child[0], seen)
if err != nil {
return nil, err
}
@@ -830,22 +854,17 @@ func nodeType(interp *Interpreter, sc *scope, n *node) (*itype, error) {
}
case structType:
t.cat = structT
var (
methods []*node
incomplete bool
)
if sname := typeName(n); sname != "" {
if sym, _, found := sc.lookup(sname); found && sym.kind == typeSym {
methods = sym.typ.method
sym.typ = t
t = structOf(sym.typ, sym.typ.field, withNode(n), withScope(sc))
}
}
var incomplete bool
fields := make([]structField, 0, len(n.child[0].child))
for _, c := range n.child[0].child {
switch {
case len(c.child) == 1:
typ, err := nodeType(interp, sc, c.child[0])
typ, err := nodeType2(interp, sc, c.child[0], seen)
if err != nil {
return nil, err
}
@@ -853,7 +872,7 @@ func nodeType(interp *Interpreter, sc *scope, n *node) (*itype, error) {
incomplete = incomplete || typ.incomplete
case len(c.child) == 2 && c.child[1].kind == basicLit:
tag := vString(c.child[1].rval)
typ, err := nodeType(interp, sc, c.child[0])
typ, err := nodeType2(interp, sc, c.child[0], seen)
if err != nil {
return nil, err
}
@@ -866,7 +885,7 @@ func nodeType(interp *Interpreter, sc *scope, n *node) (*itype, error) {
tag = vString(c.lastChild().rval)
l--
}
typ, err := nodeType(interp, sc, c.child[l-1])
typ, err := nodeType2(interp, sc, c.child[l-1], seen)
if err != nil {
return nil, err
}
@@ -876,15 +895,14 @@ func nodeType(interp *Interpreter, sc *scope, n *node) (*itype, error) {
}
}
}
*t = *structOf(fields, withNode(n), withScope(sc))
t.method = methods // Recover the symbol methods.
t = structOf(t, fields, withNode(n), withScope(sc))
t.incomplete = incomplete
default:
err = n.cfgErrorf("type definition not implemented: %s", n.kind)
}
if err == nil && t.cat == nilT && !t.incomplete {
if err == nil && t != nil && t.cat == nilT && !t.incomplete {
err = n.cfgErrorf("use of untyped nil %s", t.name)
}
@@ -997,40 +1015,13 @@ func (t *itype) finalize() (*itype, error) {
return t, err
}
// ReferTo returns true if the type contains a reference to a
// full type name. It allows to assess a type recursive status.
func (t *itype) referTo(name string, seen map[*itype]bool) bool {
if t.path+"/"+t.name == name {
return true
}
if seen[t] {
return false
}
seen[t] = true
switch t.cat {
case aliasT, arrayT, chanT, chanRecvT, chanSendT, ptrT, sliceT, variadicT:
return t.val.referTo(name, seen)
case funcT:
for _, a := range t.arg {
if a.referTo(name, seen) {
return true
}
}
for _, a := range t.ret {
if a.referTo(name, seen) {
return true
}
}
case mapT:
return t.key.referTo(name, seen) || t.val.referTo(name, seen)
case structT, interfaceT:
for _, f := range t.field {
if f.typ.referTo(name, seen) {
return true
}
func (t *itype) addMethod(n *node) {
for _, m := range t.method {
if m == n {
return
}
}
return false
t.method = append(t.method, n)
}
func (t *itype) numIn() int {
@@ -1100,27 +1091,6 @@ func (t *itype) concrete() *itype {
return t
}
// IsRecursive returns true if type is recursive.
// Only a named struct or interface can be recursive.
func (t *itype) isRecursive() bool {
if t.name == "" {
return false
}
switch t.cat {
case structT, interfaceT:
for _, f := range t.field {
if f.typ.referTo(t.path+"/"+t.name, map[*itype]bool{}) {
return true
}
}
}
return false
}
func (t *itype) isIndirectRecursive() bool {
return t.isRecursive() || t.val != nil && t.val.isIndirectRecursive()
}
// isVariadic returns true if the function type is variadic.
// If the type is not a function or is not variadic, it will
// return false.
@@ -1664,23 +1634,8 @@ func (t *itype) refType(ctx *refTypeContext) reflect.Type {
panic(err)
}
}
recursive := false
name := t.path + "/" + t.name
// Predefined types from universe or runtime may have a nil scope.
if t.scope != nil {
if st := t.scope.sym[t.name]; st != nil {
// Update the type recursive status. Several copies of type
// may exist per symbol, as a new type is created at each GTA
// pass (several needed due to out of order declarations), and
// a node can still point to a previous copy.
st.typ.recursive = st.typ.recursive || st.typ.isRecursive()
recursive = st.typ.isRecursive()
// It is possible that t.recursive is not inline with st.typ.recursive
// which will break recursion detection. Set it here to make sure it
// is correct.
t.recursive = recursive
}
}
if t.rtype != nil && !ctx.rebuilding {
return t.rtype
}
@@ -1738,7 +1693,7 @@ func (t *itype) refType(ctx *refTypeContext) reflect.Type {
fctx := ctx.Clone()
field := reflect.StructField{
Name: exportName(f.name), Type: f.typ.refType(fctx),
Tag: reflect.StructTag(f.tag), Anonymous: (f.embed && !recursive),
Tag: reflect.StructTag(f.tag), Anonymous: f.embed,
}
fields = append(fields, field)
// Find any nil type refs that indicates a rebuild is needed on this field.