interp: improve handling of generic types

When generating a new type, the parameter type was not correctly duplicated in the new AST. This is fixed by making copyNode recursive if needed. The out of order processing of generic types has also been fixed.

Fixes #1488
This commit is contained in:
Marc Vertes
2023-02-08 11:48:05 +01:00
committed by GitHub
parent 0e3ea5732a
commit f3dbce93a4
12 changed files with 424 additions and 114 deletions

33
_test/gen11.go Normal file
View File

@@ -0,0 +1,33 @@
package main
import (
"encoding/json"
"fmt"
"net/netip"
)
type Slice[T any] struct {
x []T
}
type IPPrefixSlice struct {
x Slice[netip.Prefix]
}
func (v Slice[T]) MarshalJSON() ([]byte, error) { return json.Marshal(v.x) }
// MarshalJSON implements json.Marshaler.
func (v IPPrefixSlice) MarshalJSON() ([]byte, error) {
return v.x.MarshalJSON()
}
func main() {
t := IPPrefixSlice{}
fmt.Println(t)
b, e := t.MarshalJSON()
fmt.Println(string(b), e)
}
// Output:
// {{[]}}
// null <nil>

31
_test/gen12.go Normal file
View File

@@ -0,0 +1,31 @@
package main
import (
"fmt"
)
func MapOf[K comparable, V any](m map[K]V) Map[K, V] {
return Map[K, V]{m}
}
type Map[K comparable, V any] struct {
ж map[K]V
}
func (v MapView) Int() Map[string, int] { return MapOf(v.ж.Int) }
type VMap struct {
Int map[string]int
}
type MapView struct {
ж *VMap
}
func main() {
mv := MapView{&VMap{}}
fmt.Println(mv.ж)
}
// Output:
// &{map[]}

18
_test/gen13.go Normal file
View File

@@ -0,0 +1,18 @@
package main
type Map[K comparable, V any] struct {
ж map[K]V
}
func (m Map[K, V]) Has(k K) bool {
_, ok := m.ж[k]
return ok
}
func main() {
m := Map[string, float64]{}
println(m.Has("test"))
}
// Output:
// false

View File

@@ -4,6 +4,7 @@ import (
"bytes"
"encoding/json"
"errors"
"net/netip"
"reflect"
)
@@ -17,6 +18,10 @@ func unmarshalJSON[T any](b []byte, x *[]T) error {
return json.Unmarshal(b, x)
}
func SliceOfViews[T ViewCloner[T, V], V StructView[T]](x []T) SliceView[T, V] {
return SliceView[T, V]{x}
}
type StructView[T any] interface {
Valid() bool
AsStruct() T
@@ -31,10 +36,6 @@ type ViewCloner[T any, V StructView[T]] interface {
Clone() T
}
func SliceOfViews[T ViewCloner[T, V], V StructView[T]](x []T) SliceView[T, V] {
return SliceView[T, V]{x}
}
func (v SliceView[T, V]) MarshalJSON() ([]byte, error) { return json.Marshal(v.ж) }
func (v *SliceView[T, V]) UnmarshalJSON(b []byte) error { return unmarshalJSON(b, &v.ж) }
@@ -51,6 +52,10 @@ func SliceOf[T any](x []T) Slice[T] {
return Slice[T]{x}
}
type IPPrefixSlice struct {
ж Slice[netip.Prefix]
}
type viewStruct struct {
Int int
Strings Slice[string]

23
_test/issue-1488.go Normal file
View File

@@ -0,0 +1,23 @@
package main
import "fmt"
type vector interface {
[]int | [3]int
}
func sum[V vector](v V) (out int) {
for i := 0; i < len(v); i++ {
out += v[i]
}
return
}
func main() {
va := [3]int{1, 2, 3}
vs := []int{1, 2, 3}
fmt.Println(sum[[3]int](va), sum[[]int](vs))
}
// Output:
// 6 6

14
_test/p6.go Normal file
View File

@@ -0,0 +1,14 @@
package main
import (
"fmt"
"github.com/traefik/yaegi/_test/p6"
)
func main() {
t := p6.IPPrefixSlice{}
fmt.Println(t)
b, e := t.MarshalJSON()
fmt.Println(string(b), e)
}

21
_test/p6/p6.go Normal file
View File

@@ -0,0 +1,21 @@
package p6
import (
"encoding/json"
"net/netip"
)
type Slice[T any] struct {
x []T
}
type IPPrefixSlice struct {
x Slice[netip.Prefix]
}
func (v Slice[T]) MarshalJSON() ([]byte, error) { return json.Marshal(v.x) }
// MarshalJSON implements json.Marshaler.
func (v IPPrefixSlice) MarshalJSON() ([]byte, error) {
return v.x.MarshalJSON()
}

View File

@@ -322,8 +322,60 @@ func (interp *Interpreter) cfg(root *node, sc *scope, importPath, pkgName string
}
}
if n.typ == nil {
err = n.cfgErrorf("undefined type")
return false
// A nil type indicates either an error or a generic type.
// A child indexExpr or indexListExpr is used for type parameters,
// it indicates an instanciated generic.
if n.child[0].kind != indexExpr && n.child[0].kind != indexListExpr {
err = n.cfgErrorf("undefined type")
return false
}
t0, err1 := nodeType(interp, sc, n.child[0].child[0])
if err1 != nil {
return false
}
if t0.cat != genericT {
err = n.cfgErrorf("undefined type")
return false
}
// We have a composite literal of generic type, instantiate it.
lt := []*itype{}
for _, n1 := range n.child[0].child[1:] {
t1, err1 := nodeType(interp, sc, n1)
if err1 != nil {
return false
}
lt = append(lt, t1)
}
var g *node
g, _, err = genAST(sc, t0.node.anc, lt)
if err != nil {
return false
}
n.child[0] = g.lastChild()
n.typ, err = nodeType(interp, sc, n.child[0])
if err != nil {
return false
}
// Generate methods if any.
for _, nod := range t0.method {
gm, _, err2 := genAST(nod.scope, nod, lt)
if err2 != nil {
err = err2
return false
}
gm.typ, err = nodeType(interp, nod.scope, gm.child[2])
if err != nil {
return false
}
if _, err = interp.cfg(gm, sc, sc.pkgID, sc.pkgName); err != nil {
return false
}
if err = genRun(gm); err != nil {
return false
}
n.typ.addMethod(gm)
}
n.nleft = 1 // Indictate the type of composite literal.
}
}
@@ -439,6 +491,19 @@ func (interp *Interpreter) cfg(root *node, sc *scope, importPath, pkgName string
if typ, err = nodeType(interp, sc, recvTypeNode); err != nil {
return false
}
if typ.cat == nilT {
// This may happen when instantiating generic methods.
s2, _, ok := sc.lookup(typ.id())
if !ok {
err = n.cfgErrorf("type not found: %s", typ.id())
break
}
typ = s2.typ
if typ.cat == nilT {
err = n.cfgErrorf("nil type: %s", typ.id())
break
}
}
recvTypeNode.typ = typ
n.child[2].typ.recv = typ
n.typ.recv = typ
@@ -871,16 +936,18 @@ func (interp *Interpreter) cfg(root *node, sc *scope, importPath, pkgName string
n.typ = t
return
}
g, err := genAST(sc, t.node.anc, []*node{c1})
g, found, err := genAST(sc, t.node.anc, []*itype{c1.typ})
if err != nil {
return
}
if _, err = interp.cfg(g, nil, importPath, pkgName); err != nil {
return
}
// Generate closures for function body.
if err = genRun(g.child[3]); err != nil {
return
if !found {
if _, err = interp.cfg(g, t.node.anc.scope, importPath, pkgName); err != nil {
return
}
// Generate closures for function body.
if err = genRun(g.child[3]); err != nil {
return
}
}
// Replace generic func node by instantiated one.
n.anc.child[childPos(n)] = g
@@ -1030,17 +1097,23 @@ func (interp *Interpreter) cfg(root *node, sc *scope, importPath, pkgName string
case c0.kind == indexListExpr:
// Instantiate a generic function then call it.
fun := c0.child[0].sym.node
g, err := genAST(sc, fun, c0.child[1:])
lt := []*itype{}
for _, c := range c0.child[1:] {
lt = append(lt, c.typ)
}
g, found, err := genAST(sc, fun, lt)
if err != nil {
return
}
_, err = interp.cfg(g, nil, importPath, pkgName)
if err != nil {
return
}
err = genRun(g.child[3]) // Generate closures for function body.
if err != nil {
return
if !found {
_, err = interp.cfg(g, fun.scope, importPath, pkgName)
if err != nil {
return
}
err = genRun(g.child[3]) // Generate closures for function body.
if err != nil {
return
}
}
n.child[0] = g
c0 = n.child[0]
@@ -1212,23 +1285,26 @@ func (interp *Interpreter) cfg(root *node, sc *scope, importPath, pkgName string
if isGeneric(c0.typ) {
fun := c0.typ.node.anc
var g *node
var types []*node
var types []*itype
var found bool
// Infer type parameter from function call arguments.
if types, err = inferTypesFromCall(sc, fun, n.child[1:]); err != nil {
break
}
// Generate an instantiated AST from the generic function one.
if g, err = genAST(sc, fun, types); err != nil {
if g, found, err = genAST(sc, fun, types); err != nil {
break
}
// Compile the generated function AST, so it becomes part of the scope.
if _, err = interp.cfg(g, nil, importPath, pkgName); err != nil {
break
}
// AST compilation part 2: Generate closures for function body.
if err = genRun(g.child[3]); err != nil {
break
if !found {
// Compile the generated function AST, so it becomes part of the scope.
if _, err = interp.cfg(g, fun.scope, importPath, pkgName); err != nil {
break
}
// AST compilation part 2: Generate closures for function body.
if err = genRun(g.child[3]); err != nil {
break
}
}
n.child[0] = g
c0 = n.child[0]
@@ -1487,6 +1563,10 @@ func (interp *Interpreter) cfg(root *node, sc *scope, importPath, pkgName string
sym, level, found := sc.lookup(n.ident)
if !found {
if n.typ != nil {
// Node is a generic instance with an already populated type.
break
}
// retry with the filename, in case ident is a package name.
sym, level, found = sc.lookup(filepath.Join(n.ident, baseName))
if !found {
@@ -1916,7 +1996,7 @@ func (interp *Interpreter) cfg(root *node, sc *scope, importPath, pkgName string
err = n.cfgErrorf("undefined selector: %s", n.child[1].ident)
}
}
if err == nil && n.findex != -1 {
if err == nil && n.findex != -1 && n.typ.cat != genericT {
n.findex = sc.add(n.typ)
}
@@ -2375,11 +2455,13 @@ func (n *node) cfgErrorf(format string, a ...interface{}) *cfgError {
func genRun(nod *node) error {
var err error
seen := map[*node]bool{}
nod.Walk(func(n *node) bool {
if err != nil {
if err != nil || seen[n] {
return false
}
seen[n] = true
switch n.kind {
case funcType:
if len(n.anc.child) == 4 {

View File

@@ -5,8 +5,11 @@ import (
"sync/atomic"
)
// adot produces an AST dot(1) directed acyclic graph for the given node. For debugging only.
// func (n *node) adot() { n.astDot(dotWriter(n.interp.dotCmd), n.ident) }
// genAST returns a new AST where generic types are replaced by instantiated types.
func genAST(sc *scope, root *node, types []*node) (*node, error) {
func genAST(sc *scope, root *node, types []*itype) (*node, bool, error) {
typeParam := map[string]*node{}
pindex := 0
tname := ""
@@ -14,9 +17,20 @@ func genAST(sc *scope, root *node, types []*node) (*node, error) {
recvrPtr := false
fixNodes := []*node{}
var gtree func(*node, *node) (*node, error)
sname := root.child[0].ident + "["
if root.kind == funcDecl {
sname = root.child[1].ident + "["
}
// Input type parameters must be resolved prior AST generation, as compilation
// of generated AST may occur in a different scope.
for _, t := range types {
sname += t.id() + ","
}
sname = strings.TrimSuffix(sname, ",") + "]"
gtree = func(n, anc *node) (*node, error) {
nod := copyNode(n, anc)
nod := copyNode(n, anc, false)
switch n.kind {
case funcDecl, funcType:
nod.val = nod
@@ -27,7 +41,8 @@ func genAST(sc *scope, root *node, types []*node) (*node, error) {
if !ok {
break
}
nod = copyNode(nt, anc)
nod = copyNode(nt, anc, true)
nod.typ = nt.typ
case indexExpr:
// Catch a possible recursive generic type definition
@@ -37,7 +52,7 @@ func genAST(sc *scope, root *node, types []*node) (*node, error) {
if root.child[0].ident != n.child[0].ident {
break
}
nod := copyNode(n.child[0], anc)
nod := copyNode(n.child[0], anc, false)
fixNodes = append(fixNodes, nod)
return nod, nil
@@ -51,10 +66,16 @@ func genAST(sc *scope, root *node, types []*node) (*node, error) {
if pindex >= len(types) {
return nil, cc.cfgErrorf("undefined type for %s", cc.ident)
}
if err := checkConstraint(sc, types[pindex], c.child[l]); err != nil {
t, err := nodeType(c.interp, sc, c.child[l])
if err != nil {
return nil, err
}
typeParam[cc.ident] = types[pindex]
if err := checkConstraint(types[pindex], t); err != nil {
return nil, err
}
typeParam[cc.ident] = copyNode(cc, cc.anc, false)
typeParam[cc.ident].ident = types[pindex].id()
typeParam[cc.ident].typ = types[pindex]
pindex++
}
}
@@ -65,9 +86,9 @@ func genAST(sc *scope, root *node, types []*node) (*node, error) {
// Node is the receiver of a generic method.
if root.kind == funcDecl && n.anc == root && childPos(n) == 0 && len(n.child) > 0 {
rtn := n.child[0].child[1]
if rtn.kind == indexExpr || (rtn.kind == starExpr && rtn.child[0].kind == indexExpr) {
// Method receiver is a generic type.
if rtn.kind == starExpr && rtn.child[0].kind == indexExpr {
// Method receiver is a generic type if it takes some type parameters.
if rtn.kind == indexExpr || rtn.kind == indexListExpr || (rtn.kind == starExpr && (rtn.child[0].kind == indexExpr || rtn.child[0].kind == indexListExpr)) {
if rtn.kind == starExpr {
// Method receiver is a pointer on a generic type.
rtn = rtn.child[0]
recvrPtr = true
@@ -77,11 +98,10 @@ func genAST(sc *scope, root *node, types []*node) (*node, error) {
if pindex >= len(types) {
return nil, cc.cfgErrorf("undefined type for %s", cc.ident)
}
it, err := nodeType(n.interp, sc, types[pindex])
if err != nil {
return nil, err
}
typeParam[cc.ident] = types[pindex]
it := types[pindex]
typeParam[cc.ident] = copyNode(cc, cc.anc, false)
typeParam[cc.ident].ident = it.id()
typeParam[cc.ident].typ = it
rtname += it.id() + ","
pindex++
}
@@ -99,14 +119,17 @@ func genAST(sc *scope, root *node, types []*node) (*node, error) {
if pindex >= len(types) {
return nil, cc.cfgErrorf("undefined type for %s", cc.ident)
}
it, err := nodeType(n.interp, sc, types[pindex])
it := types[pindex]
t, err := nodeType(c.interp, sc, c.child[l])
if err != nil {
return nil, err
}
if err := checkConstraint(sc, types[pindex], c.child[l]); err != nil {
if err := checkConstraint(types[pindex], t); err != nil {
return nil, err
}
typeParam[cc.ident] = types[pindex]
typeParam[cc.ident] = copyNode(cc, cc.anc, false)
typeParam[cc.ident].ident = it.id()
typeParam[cc.ident].typ = it
tname += it.id() + ","
pindex++
}
@@ -115,6 +138,7 @@ func genAST(sc *scope, root *node, types []*node) (*node, error) {
return nod, nil
}
}
for _, c := range n.child {
gn, err := gtree(c, nod)
if err != nil {
@@ -125,10 +149,16 @@ func genAST(sc *scope, root *node, types []*node) (*node, error) {
return nod, nil
}
if nod, found := root.interp.generic[sname]; found {
return nod, true, nil
}
r, err := gtree(root, root.anc)
if err != nil {
return nil, err
return nil, false, err
}
root.interp.generic[sname] = r
r.param = append(r.param, types...)
if tname != "" {
for _, nod := range fixNodes {
nod.ident = tname
@@ -145,11 +175,11 @@ func genAST(sc *scope, root *node, types []*node) (*node, error) {
nod.ident = rtname
nod.child = nil
}
// r.astDot(dotWriter(root.interp.dotCmd), root.child[1].ident) // Used for debugging only.
return r, nil
// r.adot() // Used for debugging only.
return r, false, nil
}
func copyNode(n, anc *node) *node {
func copyNode(n, anc *node, recursive bool) *node {
var i interface{}
nindex := atomic.AddInt64(&n.interp.nindex, 1)
nod := &node{
@@ -170,25 +200,30 @@ func copyNode(n, anc *node) *node {
meta: n.meta,
}
nod.start = nod
if recursive {
for _, c := range n.child {
nod.child = append(nod.child, copyNode(c, nod, true))
}
}
return nod
}
func inferTypesFromCall(sc *scope, fun *node, args []*node) ([]*node, error) {
func inferTypesFromCall(sc *scope, fun *node, args []*node) ([]*itype, error) {
ftn := fun.typ.node
// Fill the map of parameter types, indexed by type param ident.
types := map[string]*itype{}
paramTypes := map[string]*itype{}
for _, c := range ftn.child[0].child {
typ, err := nodeType(fun.interp, sc, c.lastChild())
if err != nil {
return nil, err
}
for _, cc := range c.child[:len(c.child)-1] {
types[cc.ident] = typ
paramTypes[cc.ident] = typ
}
}
var inferTypes func(*itype, *itype) ([]*node, error)
inferTypes = func(param, input *itype) ([]*node, error) {
var inferTypes func(*itype, *itype) ([]*itype, error)
inferTypes = func(param, input *itype) ([]*itype, error) {
switch param.cat {
case chanT, ptrT, sliceT:
return inferTypes(param.val, input.val)
@@ -205,65 +240,68 @@ func inferTypesFromCall(sc *scope, fun *node, args []*node) ([]*node, error) {
return append(k, v...), nil
case structT:
nods := []*node{}
lt := []*itype{}
for i, f := range param.field {
nl, err := inferTypes(f.typ, input.field[i].typ)
if err != nil {
return nil, err
}
nods = append(nods, nl...)
lt = append(lt, nl...)
}
return nods, nil
return lt, nil
case funcT:
nods := []*node{}
lt := []*itype{}
for i, t := range param.arg {
if i >= len(input.arg) {
break
}
nl, err := inferTypes(t, input.arg[i])
if err != nil {
return nil, err
}
nods = append(nods, nl...)
lt = append(lt, nl...)
}
for i, t := range param.ret {
if i >= len(input.ret) {
break
}
nl, err := inferTypes(t, input.ret[i])
if err != nil {
return nil, err
}
nods = append(nods, nl...)
lt = append(lt, nl...)
}
return lt, nil
case nilT:
if paramTypes[param.name] != nil {
return []*itype{input}, nil
}
return nods, nil
case genericT:
return []*node{input.node}, nil
return []*itype{input}, nil
}
return nil, nil
}
nodes := []*node{}
types := []*itype{}
for i, c := range ftn.child[1].child {
typ, err := nodeType(fun.interp, sc, c.lastChild())
if err != nil {
return nil, err
}
nods, err := inferTypes(typ, args[i].typ)
lt, err := inferTypes(typ, args[i].typ)
if err != nil {
return nil, err
}
nodes = append(nodes, nods...)
types = append(types, lt...)
}
return nodes, nil
return types, nil
}
func checkConstraint(sc *scope, input, constraint *node) error {
ct, err := nodeType(constraint.interp, sc, constraint)
if err != nil {
return err
}
it, err := nodeType(input.interp, sc, input)
if err != nil {
return err
}
func checkConstraint(it, ct *itype) error {
if len(ct.constraint) == 0 && len(ct.ulconstraint) == 0 {
return nil
}
@@ -277,5 +315,5 @@ func checkConstraint(sc *scope, input, constraint *node) error {
return nil
}
}
return input.cfgErrorf("%s does not implement %s", input.typ.id(), ct.id())
return it.node.cfgErrorf("%s does not implement %s", it.id(), ct.id())
}

View File

@@ -21,6 +21,9 @@ func (interp *Interpreter) gta(root *node, rpath, importPath, pkgName string) ([
if err != nil {
return false
}
if n.scope == nil {
n.scope = sc
}
switch n.kind {
case constDecl:
// Early parse of constDecl subtree, to compute all constant
@@ -166,7 +169,7 @@ func (interp *Interpreter) gta(root *node, rpath, importPath, pkgName string) ([
typName = c.child[0].ident
genericMethod = true
}
case indexExpr:
case indexExpr, indexListExpr:
genericMethod = true
}
}
@@ -189,6 +192,14 @@ func (interp *Interpreter) gta(root *node, rpath, importPath, pkgName string) ([
}
rcvrtype.addMethod(n)
rtn.typ = rcvrtype
if rcvrtype.cat == genericT {
// generate methods for already instantiated receivers
for _, it := range rcvrtype.instance {
if err = genMethod(interp, sc, it, n, it.node.anc.param); err != nil {
return false
}
}
}
case ident == "init":
// init functions do not get declared as per the Go spec.
default:

View File

@@ -28,6 +28,7 @@ type node struct {
debug *nodeDebugData // debug info
child []*node // child subtrees (AST)
anc *node // ancestor (AST)
param []*itype // generic parameter nodes (AST)
start *node // entry point in subtree (CFG)
tnext *node // true branch successor (CFG)
fnext *node // false branch successor (CFG)
@@ -215,6 +216,7 @@ type Interpreter struct {
pkgNames map[string]string // package names, indexed by import path
done chan struct{} // for cancellation of channel operations
roots []*node
generic map[string]*node
hooks *hooks // symbol hooks
@@ -335,6 +337,7 @@ func New(options Options) *Interpreter {
pkgNames: map[string]string{},
rdir: map[string]bool{},
hooks: &hooks{},
generic: map[string]*node{},
}
if i.opt.stdin = options.Stdin; i.opt.stdin == nil {

View File

@@ -126,6 +126,7 @@ type itype struct {
method []*node // Associated methods or nil
constraint []*itype // For interfaceT: list of types part of interface set
ulconstraint []*itype // For interfaceT: list of underlying types part of interface set
instance []*itype // For genericT: list of instantiated types
name string // name of type within its package for a defined type
path string // for a defined type, the package import path
length int // length of array if ArrayT
@@ -786,6 +787,11 @@ func nodeType2(interp *Interpreter, sc *scope, n *node, seen []*node) (t *itype,
} else {
t = sym.typ
}
if t == nil {
if t, err = nodeType2(interp, sc, sym.node, seen); err != nil {
return nil, err
}
}
if t.incomplete && t.cat == linkedT && t.val != nil && t.val.cat != nilT {
t.incomplete = false
}
@@ -807,7 +813,11 @@ func nodeType2(interp *Interpreter, sc *scope, n *node, seen []*node) (t *itype,
return nil, err
}
if lt.incomplete {
t.incomplete = true
if t == nil {
t = lt
} else {
t.incomplete = true
}
break
}
switch lt.cat {
@@ -828,7 +838,7 @@ func nodeType2(interp *Interpreter, sc *scope, n *node, seen []*node) (t *itype,
break
}
// A generic type is being instantiated. Generate it.
t, err = genType(interp, sc, name, lt, []*node{t1.node}, seen)
t, err = genType(interp, sc, name, lt, []*itype{t1}, seen)
if err != nil {
return nil, err
}
@@ -840,6 +850,15 @@ func nodeType2(interp *Interpreter, sc *scope, n *node, seen []*node) (t *itype,
if lt, err = nodeType2(interp, sc, n.child[0], seen); err != nil {
return nil, err
}
if lt.incomplete {
if t == nil {
t = lt
} else {
t.incomplete = true
}
break
}
// Index list expressions can be used only in context of generic types.
if lt.cat != genericT {
err = n.cfgErrorf("not a generic type: %s", lt.id())
@@ -847,7 +866,7 @@ func nodeType2(interp *Interpreter, sc *scope, n *node, seen []*node) (t *itype,
}
name := lt.id() + "["
out := false
tnodes := []*node{}
types := []*itype{}
for _, c := range n.child[1:] {
t1, err := nodeType2(interp, sc, c, seen)
if err != nil {
@@ -858,19 +877,19 @@ func nodeType2(interp *Interpreter, sc *scope, n *node, seen []*node) (t *itype,
out = true
break
}
tnodes = append(tnodes, t1.node)
types = append(types, t1)
name += t1.id() + ","
}
if out {
break
}
name += "]"
name = strings.TrimSuffix(name, ",") + "]"
if sym, _, found := sc.lookup(name); found {
t = sym.typ
break
}
// A generic type is being instantiated. Generate it.
t, err = genType(interp, sc, name, lt, tnodes, seen)
t, err = genType(interp, sc, name, lt, types, seen)
case interfaceType:
if sname := typeName(n); sname != "" {
@@ -1016,7 +1035,7 @@ func nodeType2(interp *Interpreter, sc *scope, n *node, seen []*node) (t *itype,
sname := structName(n)
if sname != "" {
sym, _, found = sc.lookup(sname)
if found && sym.kind == typeSym {
if found && sym.kind == typeSym && sym.typ != nil {
t = structOf(sym.typ, sym.typ.field, withNode(n), withScope(sc))
} else {
t = structOf(nil, nil, withNode(n), withScope(sc))
@@ -1062,6 +1081,9 @@ func nodeType2(interp *Interpreter, sc *scope, n *node, seen []*node) (t *itype,
t = structOf(t, fields, withNode(n), withScope(sc))
t.incomplete = incomplete
if sname != "" {
if sc.sym[sname] == nil {
sc.sym[sname] = &symbol{index: -1, kind: typeSym, node: n}
}
sc.sym[sname].typ = t
}
@@ -1094,9 +1116,9 @@ func nodeType2(interp *Interpreter, sc *scope, n *node, seen []*node) (t *itype,
return t, err
}
func genType(interp *Interpreter, sc *scope, name string, lt *itype, tnodes, seen []*node) (t *itype, err error) {
func genType(interp *Interpreter, sc *scope, name string, lt *itype, types []*itype, seen []*node) (t *itype, err error) {
// A generic type is being instantiated. Generate it.
g, err := genAST(sc, lt.node.anc, tnodes)
g, _, err := genAST(sc, lt.node.anc, types)
if err != nil {
return nil, err
}
@@ -1104,39 +1126,48 @@ func genType(interp *Interpreter, sc *scope, name string, lt *itype, tnodes, see
if err != nil {
return nil, err
}
lt.instance = append(lt.instance, t)
// Add generated symbol in the scope of generic source and user.
sc.sym[name] = &symbol{index: -1, kind: typeSym, typ: t, node: g}
// Instantiate type methods (if any).
var pt *itype
if len(lt.method) > 0 {
pt = ptrOf(t, withNode(g), withScope(sc))
if lt.scope.sym[name] == nil {
lt.scope.sym[name] = sc.sym[name]
}
for _, nod := range lt.method {
gm, err := genAST(sc, nod, tnodes)
if err != nil {
return nil, err
}
if gm.typ, err = nodeType(interp, sc, gm.child[2]); err != nil {
return nil, err
}
t.addMethod(gm)
if rtn := gm.child[0].child[0].lastChild(); rtn.kind == starExpr {
// The receiver is a pointer on a generic type.
pt.addMethod(gm)
rtn.typ = pt
}
// Compile method CFG.
if _, err = interp.cfg(gm, sc, sc.pkgID, sc.pkgName); err != nil {
return nil, err
}
// Generate closures for function body.
if err = genRun(gm); err != nil {
if err := genMethod(interp, sc, t, nod, types); err != nil {
return nil, err
}
}
return t, err
}
func genMethod(interp *Interpreter, sc *scope, t *itype, nod *node, types []*itype) error {
gm, _, err := genAST(sc, nod, types)
if err != nil {
return err
}
if gm.typ, err = nodeType(interp, sc, gm.child[2]); err != nil {
return err
}
t.addMethod(gm)
// If the receiver is a pointer to a generic type, generate also the pointer type.
if rtn := gm.child[0].child[0].lastChild(); rtn != nil && rtn.kind == starExpr {
pt := ptrOf(t, withNode(t.node), withScope(sc))
pt.addMethod(gm)
rtn.typ = pt
}
// Compile the method AST in the scope of the generic type.
scop := nod.typ.scope
if _, err = interp.cfg(gm, scop, scop.pkgID, scop.pkgName); err != nil {
return err
}
// Generate closures for function body.
return genRun(gm)
}
// findPackageType searches the top level scope for a package type.
func findPackageType(interp *Interpreter, sc *scope, n *node) *itype {
// Find the root scope, the package symbols will exist there.