Compare commits

...

9 Commits

Author SHA1 Message Date
Marc Vertes
3cd37645eb fix: correct isValueUntyped() to handle typed constants 2019-12-19 15:38:05 +01:00
Marc Vertes
e1ac83f7d8 fix: correct type extraction for returned value 2019-12-17 10:18:06 +01:00
Marc Vertes
4f93be7f19 fix: emulate struct by interface{} only for recursive struct types 2019-12-16 19:00:07 +01:00
Marc Vertes
7a0c09f5eb fix: detect untyped values when importing from binary packages 2019-12-13 11:18:04 +01:00
Marc Vertes
275391c1e8 fix: struct type detection, collision between field and type name 2019-12-12 14:40:05 +01:00
Marc Vertes
273df8af9f fix: improve interface type checks using method sets 2019-12-11 14:46:06 +01:00
Marc Vertes
0d2c39d155 fix: implicit import package name was not correctly generated 2019-12-11 11:54:05 +01:00
Marc Vertes
1ff1a50753 fix: add method checks for interface types 2019-12-09 18:24:04 +01:00
Marc Vertes
488e491bf8 fix: improve type switch clause with assign 2019-11-27 23:00:04 +01:00
17 changed files with 350 additions and 30 deletions

3
_test/foo-bar/foo-bar.go Normal file
View File

@@ -0,0 +1,3 @@
package bar
var Name = "foo-bar"

10
_test/import7.go Normal file
View File

@@ -0,0 +1,10 @@
package main
import "github.com/containous/yaegi/_test/foo-bar"
func main() {
println(bar.Name)
}
// Output:
// foo-bar

17
_test/interface14.go Normal file
View File

@@ -0,0 +1,17 @@
package main
type T struct{}
func (t *T) Error() string { return "T: error" }
var invalidT = &T{}
func main() {
var err error
if err != invalidT {
println("ok")
}
}
// Output:
// ok

28
_test/interface15.go Normal file
View File

@@ -0,0 +1,28 @@
package main
type Fooer interface {
Foo() string
}
type Barer interface {
//fmt.Stringer
Fooer
Bar()
}
type T struct{}
func (t *T) Foo() string { return "T: foo" }
func (*T) Bar() { println("in bar") }
var t = &T{}
func main() {
var f Barer
if f != t {
println("ok")
}
}
// Output:
// ok

25
_test/interface16.go Normal file
View File

@@ -0,0 +1,25 @@
package main
import "fmt"
type Barer interface {
fmt.Stringer
Bar()
}
type T struct{}
func (*T) String() string { return "T: nothing" }
func (*T) Bar() { println("in bar") }
var t = &T{}
func main() {
var f Barer
if f != t {
println("ok")
}
}
// Output:
// ok

17
_test/interface17.go Normal file
View File

@@ -0,0 +1,17 @@
package main
type T struct{}
func (t T) Error() string { return "T: error" }
var invalidT = T{}
func main() {
var err error
if err != invalidT {
println("ok")
}
}
// Output:
// ok

18
_test/interface18.go Normal file
View File

@@ -0,0 +1,18 @@
package main
type T struct{}
func (t *T) Error() string { return "T: error" }
func (*T) Foo() { println("foo") }
var invalidT = &T{}
func main() {
var err error
if err != invalidT {
println("ok")
}
}
// Output:
// ok

15
_test/nil0.go Normal file
View File

@@ -0,0 +1,15 @@
package main
import "fmt"
func f() (host, port string, err error) {
return "", "", nil
}
func main() {
h, p, err := f()
fmt.Println(h, p, err)
}
// Output:
// <nil>

11
_test/str4.go Normal file
View File

@@ -0,0 +1,11 @@
package main
import "unicode/utf8"
func main() {
r, _ := utf8.DecodeRuneInString("Hello")
println(r < utf8.RuneSelf)
}
// Output:
// true

19
_test/struct29.go Normal file
View File

@@ -0,0 +1,19 @@
package main
type T1 struct {
A []T2
B []T2
}
type T2 struct {
name string
}
var t = T1{}
func main() {
println("ok")
}
// Output:
// ok

19
_test/struct30.go Normal file
View File

@@ -0,0 +1,19 @@
package main
type T1 struct {
A []T2
M map[uint64]T2
}
type T2 struct {
name string
}
var t = T1{}
func main() {
println("ok")
}
// Output:
// ok

21
_test/switch22.go Normal file
View File

@@ -0,0 +1,21 @@
package main
type T struct {
Name string
}
func f(t interface{}) {
switch ext := t.(type) {
case *T:
println("*T", ext.Name)
default:
println("unknown")
}
}
func main() {
f(&T{"truc"})
}
// Output:
// *T truc

15
_test/time11.go Normal file
View File

@@ -0,0 +1,15 @@
package main
import (
"fmt"
"time"
)
const df = time.Minute * 30
func main() {
fmt.Printf("df: %v %T\n", df, df)
}
// Output:
// df: 30m0s time.Duration

View File

@@ -4,8 +4,8 @@ import (
"fmt"
"log"
"math"
"path"
"reflect"
"regexp"
"unicode"
)
@@ -32,6 +32,8 @@ var constBltn = map[string]func(*node){
"real": realConst,
}
var identifier = regexp.MustCompile(`([\pL_][\pL_\d]*)$`)
// cfg generates a control flow graph (CFG) from AST (wiring successors in AST)
// and pre-compute frame sizes and indexes for all un-named (temporary) and named
// variables. A list of nodes of init functions is returned.
@@ -170,23 +172,21 @@ func (interp *Interpreter) cfg(root *node) ([]*node, error) {
var typ *itype
if len(n.child) == 2 {
// 1 type in clause: define the var with this type in the case clause scope
switch sym, _, ok := sc.lookup(n.child[0].ident); {
case ok && sym.kind == typeSym:
typ = sym.typ
case n.child[0].kind == selectorExpr:
if typ, err = nodeType(interp, sc, n.child[0]); err != nil {
return false
}
switch {
case n.child[0].ident == "nil":
typ = sc.getType("interface{}")
default:
case !n.child[0].isType(sc):
err = n.cfgErrorf("%s is not a type", n.child[0].ident)
return false
default:
typ, err = nodeType(interp, sc, n.child[0])
}
} else {
// define the var with the type in the switch guard expression
typ = sn.child[1].child[1].child[0].typ
}
if err != nil {
return false
}
nod := n.lastChild().child[0]
index := sc.add(typ)
sc.sym[nod.ident] = &symbol{index: index, kind: varSym, typ: typ}
@@ -318,7 +318,7 @@ func (interp *Interpreter) cfg(root *node) ([]*node, error) {
name = n.child[0].ident
} else {
ipath = n.child[0].rval.String()
name = path.Base(ipath)
name = identifier.FindString(ipath)
}
if interp.binPkg[ipath] != nil && name != "." {
sc.sym[name] = &symbol{kind: pkgSym, typ: &itype{cat: binPkgT, path: ipath}}
@@ -549,7 +549,7 @@ func (interp *Interpreter) cfg(root *node) ([]*node, error) {
t0, t1 := c0.typ.TypeOf(), c1.typ.TypeOf()
// Shift operator type is inherited from first parameter only
// All other binary operators require both parameter types to be the same
if !isShiftNode(n) && !c0.typ.untyped && !c1.typ.untyped && c0.typ.id() != c1.typ.id() {
if !isShiftNode(n) && !c0.typ.untyped && !c1.typ.untyped && !c0.typ.equals(c1.typ) {
err = n.cfgErrorf("mismatched types %s and %s", c0.typ.id(), c1.typ.id())
break
}
@@ -602,8 +602,6 @@ func (interp *Interpreter) cfg(root *node) ([]*node, error) {
constOp[n.action](n)
}
switch {
//case n.typ != nil && n.typ.cat == BoolT && isAncBranch(n):
// n.findex = -1
case n.rval.IsValid():
n.gen = nop
n.findex = -1
@@ -1068,7 +1066,7 @@ func (interp *Interpreter) cfg(root *node) ([]*node, error) {
// nil: Set node value to zero of return type
f := sc.def
var typ *itype
if typ, err = nodeType(interp, sc, f.child[2].child[1].child[i].lastChild()); err != nil {
if typ, err = nodeType(interp, sc, f.child[2].child[1].fieldType(i)); err != nil {
return
}
if typ.cat == funcT {
@@ -1155,7 +1153,7 @@ func (interp *Interpreter) cfg(root *node) ([]*node, error) {
n.typ = &itype{cat: valueT, rtype: s.Type().Elem()}
} else {
n.kind = rvalueExpr
n.typ = &itype{cat: valueT, rtype: s.Type()}
n.typ = &itype{cat: valueT, rtype: s.Type(), untyped: isValueUntyped(s)}
n.rval = s
}
n.gen = nop
@@ -1528,7 +1526,7 @@ func isBinType(v reflect.Value) bool { return v.IsValid() && v.Kind() == reflect
// isType returns true if node refers to a type definition, false otherwise
func (n *node) isType(sc *scope) bool {
switch n.kind {
case arrayType, chanType, funcType, mapType, structType, rtypeExpr:
case arrayType, chanType, funcType, interfaceType, mapType, structType, rtypeExpr:
return true
case parenExpr, starExpr:
if len(n.child) == 1 {
@@ -1536,7 +1534,7 @@ func (n *node) isType(sc *scope) bool {
}
case selectorExpr:
pkg, name := n.child[0].ident, n.child[1].ident
if sym, _, ok := sc.lookup(pkg); ok {
if sym, _, ok := sc.lookup(pkg); ok && sym.kind == pkgSym {
path := sym.typ.path
if p, ok := n.interp.binPkg[path]; ok && isBinType(p[name]) {
return true // Imported binary type
@@ -1643,6 +1641,29 @@ func (n *node) isNatural() bool {
return false
}
// fieldType returns the nth parameter field node (type) of a fieldList node
func (n *node) fieldType(m int) *node {
k := 0
l := len(n.child)
for i := 0; i < l; i++ {
cl := len(n.child[i].child)
if cl < 2 {
if k == m {
return n.child[i].lastChild()
}
k++
continue
}
for j := 0; j < cl-1; j++ {
if k == m {
return n.child[i].lastChild()
}
k++
}
}
return nil
}
// lastChild returns the last child of a node
func (n *node) lastChild() *node { return n.child[len(n.child)-1] }
@@ -1825,3 +1846,13 @@ func arrayTypeLen(n *node) int {
}
return max + 1
}
// isValueUntyped returns true if value is untyped
func isValueUntyped(v reflect.Value) bool {
// Consider only constant values.
if v.CanSet() {
return false
}
t := v.Type()
return t.String() == t.Kind().String()
}

View File

@@ -1,7 +1,6 @@
package interp
import (
"path"
"reflect"
)
@@ -31,6 +30,7 @@ func (interp *Interpreter) gta(root *node, rpath string) ([]*node, error) {
case defineStmt:
var atyp *itype
if n.nleft+n.nright < len(n.child) {
// Type is declared explicitly in the assign expression.
if atyp, err = nodeType(interp, sc, n.child[n.nleft]); err != nil {
return false
}
@@ -126,7 +126,7 @@ func (interp *Interpreter) gta(root *node, rpath string) ([]*node, error) {
name = n.child[0].ident
} else {
ipath = n.child[0].rval.String()
name = path.Base(ipath)
name = identifier.FindString(ipath)
}
// Try to import a binary package first, or a source package
if interp.binPkg[ipath] != nil {

View File

@@ -239,7 +239,7 @@ func isRecursiveStruct(t *itype, rtype reflect.Type) bool {
if t.cat == structT && rtype.Kind() == reflect.Interface {
return true
}
if t.cat == ptrT {
if t.cat == ptrT && t.rtype != nil {
return isRecursiveStruct(t.val, t.rtype.Elem())
}
return false
@@ -260,7 +260,7 @@ func assign(n *node) {
switch {
case dest.typ.cat == interfaceT:
svalue[i] = genValueInterface(src)
case dest.typ.cat == valueT && dest.typ.rtype.Kind() == reflect.Interface:
case (dest.typ.cat == valueT || dest.typ.cat == errorT) && dest.typ.rtype.Kind() == reflect.Interface:
svalue[i] = genInterfaceWrapper(src, dest.typ.rtype)
case dest.typ.cat == valueT && src.typ.cat == funcT:
svalue[i] = genFunctionWrapper(src)

View File

@@ -401,6 +401,9 @@ func nodeType(interp *Interpreter, sc *scope, n *node) (*itype, error) {
t.method = m
sym.typ = t
}
if t.node == nil {
t.node = n
}
} else {
t.incomplete = true
sc.sym[n.ident] = &symbol{kind: typeSym, typ: t}
@@ -636,6 +639,74 @@ func (t *itype) finalize() (*itype, error) {
return t, err
}
// Equals returns true if the given type is identical to the receiver one.
func (t *itype) equals(o *itype) bool {
switch ti, oi := isInterface(t), isInterface(o); {
case ti && oi:
return t.methods().equals(o.methods())
case ti && !oi:
return o.methods().contains(t.methods())
case oi && !ti:
return t.methods().contains(o.methods())
default:
return t.id() == o.id()
}
}
// MethodSet defines the set of methods signatures as strings, indexed per method name.
type methodSet map[string]string
// Contains returns true if the method set m contains the method set n.
func (m methodSet) contains(n methodSet) bool {
for k, v := range n {
if m[k] != v {
return false
}
}
return true
}
// Equal returns true if the method set m is equal to the method set n.
func (m methodSet) equals(n methodSet) bool {
return m.contains(n) && n.contains(m)
}
// Methods returns a map of method type strings, indexed by method names.
func (t *itype) methods() methodSet {
res := make(methodSet)
switch t.cat {
case interfaceT:
// Get methods from recursive analysis of interface fields
for _, f := range t.field {
if f.typ.cat == funcT {
res[f.name] = f.typ.TypeOf().String()
} else {
for k, v := range f.typ.methods() {
res[k] = v
}
}
}
case valueT, errorT:
// Get method from corresponding reflect.Type
for i := t.rtype.NumMethod() - 1; i >= 0; i-- {
m := t.rtype.Method(i)
res[m.Name] = m.Type.String()
}
case ptrT:
// Consider only methods where receiver is a pointer to type t
for _, m := range t.val.method {
if m.child[0].child[0].lastChild().typ.cat == ptrT {
res[m.ident] = m.typ.TypeOf().String()
}
}
default:
for _, m := range t.method {
res[m.ident] = m.typ.TypeOf().String()
}
}
return res
}
// id returns a unique type identificator string
func (t *itype) id() string {
// TODO: if res is nil, build identity from String()
@@ -814,7 +885,9 @@ func (t *itype) refType(defined map[string]bool) reflect.Type {
panic(err)
}
}
if t.val != nil && defined[t.val.name] {
if t.val != nil && defined[t.val.name] && !t.val.incomplete && t.val.rtype == nil {
// Replace reference to self (direct or indirect) by an interface{} to handle
// recursive types with reflect.
t.val.rtype = interf
}
switch t.cat {
@@ -834,15 +907,9 @@ func (t *itype) refType(defined map[string]bool) reflect.Type {
in := make([]reflect.Type, len(t.arg))
out := make([]reflect.Type, len(t.ret))
for i, v := range t.arg {
if defined[v.name] {
v.rtype = interf
}
in[i] = v.refType(defined)
}
for i, v := range t.ret {
if defined[v.name] {
v.rtype = interf
}
out[i] = v.refType(defined)
}
t.rtype = reflect.FuncOf(in, out, false)
@@ -933,7 +1000,11 @@ func isInterface(t *itype) bool {
return isInterfaceSrc(t) || t.TypeOf().Kind() == reflect.Interface
}
func isStruct(t *itype) bool { return t.TypeOf().Kind() == reflect.Struct }
func isStruct(t *itype) bool {
// Test first for a struct category, because a recursive interpreter struct may be
// represented by an interface{} at reflect level.
return t.cat == structT || t.TypeOf().Kind() == reflect.Struct
}
func isBool(t *itype) bool { return t.TypeOf().Kind() == reflect.Bool }