feat: type checking for binary operators (#116)

Catch illegal combinations for all binary operators.
Memoize type conversion to reflect.type.
Add some unit tests for arithmetic and assign operations.
This commit is contained in:
Marc Vertes
2019-03-15 09:34:01 +01:00
committed by Ludovic Fernandez
parent 5a10046944
commit 842d22a8c2
5 changed files with 105 additions and 37 deletions

View File

@@ -282,10 +282,29 @@ func (interp *Interpreter) Cfg(root *Node) ([]*Node, error) {
sym, level, _ = scope.lookup(dest.ident)
}
wireChild(n)
// Detect invalid float truncate
if isInt(dest.typ) && isFloat(src.typ) {
err = src.cfgError("invalid float truncate")
break
switch t0, t1 := dest.typ, src.typ; n.action {
case AddAssign:
if !(isNumber(t0) && isNumber(t1) || isString(t0) && isString(t1)) || isInt(t0) && isFloat(t1) {
err = n.cfgError("illegal operand types for '%v' operator", n.action)
}
case SubAssign, MulAssign, QuoAssign:
if !(isNumber(t0) && isNumber(t1)) || isInt(t0) && isFloat(t1) {
err = n.cfgError("illegal operand types for '%v' operator", n.action)
}
case RemAssign, AndAssign, OrAssign, XorAssign, AndNotAssign:
if !(isInt(t0) && isInt(t1)) {
err = n.cfgError("illegal operand types for '%v' operator", n.action)
}
case ShlAssign, ShrAssign:
if !(isInt(t0) && isUint(t1)) {
err = n.cfgError("illegal operand types for '%v' operator", n.action)
}
default:
// Detect invalid float truncate
if isInt(dest.typ) && isFloat(src.typ) {
err = src.cfgError("invalid float truncate")
return
}
}
n.findex = dest.findex
n.val = dest.val
@@ -398,28 +417,53 @@ func (interp *Interpreter) Cfg(root *Node) ([]*Node, error) {
case BinaryExpr:
wireChild(n)
nilSym := interp.universe.sym["nil"]
if t0, t1 := n.child[0].typ, n.child[1].typ; !t0.untyped && !t1.untyped && t0.id() != t1.id() {
t0, t1 := n.child[0].typ, n.child[1].typ
if !t0.untyped && !t1.untyped && t0.id() != t1.id() {
err = n.cfgError("mismatched types %s and %s", t0.id(), t1.id())
break
}
switch n.action {
case NotEqual:
n.typ = scope.getType("bool")
if n.child[0].sym == nilSym || n.child[1].sym == nilSym {
n.gen = isNotNil
case Add:
if !(isNumber(t0) && isNumber(t1) || isString(t0) && isString(t1)) {
err = n.cfgError("illegal operand types for '%v' operator", n.action)
}
case Sub, Mul, Quo:
if !(isNumber(t0) && isNumber(t1)) {
err = n.cfgError("illegal operand types for '%v' operator", n.action)
}
case Rem, And, Or, Xor, AndNot:
if !(isInt(t0) && isInt(t1)) {
err = n.cfgError("illegal operand types for '%v' operator", n.action)
}
case Shl, Shr:
if !(isInt(t0) && isUint(t1)) {
err = n.cfgError("illegal operand types for '%v' operator", n.action)
}
case Equal, NotEqual:
if isNumber(t0) && !isNumber(t1) || isString(t0) && !isString(t1) {
err = n.cfgError("illegal operand types for '%v' operator", n.action)
}
case Equal:
n.typ = scope.getType("bool")
if n.child[0].sym == nilSym || n.child[1].sym == nilSym {
n.gen = isNil
if n.action == Equal {
n.gen = isNil
} else {
n.gen = isNotNil
}
}
case Greater, GreaterEqual, Lower, LowerEqual:
if isNumber(t0) && !isNumber(t1) || isString(t0) && !isString(t1) {
err = n.cfgError("illegal operand types for '%v' operator", n.action)
}
n.typ = scope.getType("bool")
default:
n.typ, err = nodeType(interp, scope, n)
}
// TODO: Possible optimisation: if type is bool and not in assignment or call, then skip result store
n.findex = scope.add(n.typ)
if err == nil {
if n.typ == nil {
n.typ, err = nodeType(interp, scope, n)
}
n.findex = scope.add(n.typ)
}
case IndexExpr:
wireChild(n)
@@ -443,10 +487,13 @@ func (interp *Interpreter) Cfg(root *Node) ([]*Node, error) {
scope = scope.pop()
case ConstDecl:
wireChild(n)
iotaValue = 0
wireChild(n)
case DeclStmt, ExprStmt, VarDecl, SendStmt:
case VarDecl:
wireChild(n)
case DeclStmt, ExprStmt, SendStmt:
wireChild(n)
n.findex = n.lastChild().findex
n.val = n.lastChild().val
@@ -483,6 +530,9 @@ func (interp *Interpreter) Cfg(root *Node) ([]*Node, error) {
}
case n.child[0].isType(scope):
// Type conversion expression
if isInt(n.child[0].typ) && n.child[1].kind == BasicLit && isFloat(n.child[1].typ) {
err = n.cfgError("truncated to integer")
}
n.gen = convert
n.typ = n.child[0].typ
n.findex = scope.add(n.typ)
@@ -1056,6 +1106,7 @@ func genRun(node *Node) error {
// function body entry point
setExec(n.anc.child[3].start)
}
// continue in function body as there may be inner function definitions
case ConstDecl, VarDecl:
setExec(n.start)
return false

View File

@@ -68,7 +68,7 @@ func (interp *Interpreter) Gta(root *Node, rpath string) error {
if n.typ, err = nodeType(interp, scope, n.child[2]); err != nil {
return false
}
scope.sym[n.child[1].ident] = &Symbol{kind: Func, typ: n.typ, node: n}
scope.sym[n.child[1].ident] = &Symbol{kind: Func, typ: n.typ, node: n, index: -1}
if len(n.child[0].child) > 0 {
// function is a method, add it to the related type
var receiverType *Type

View File

@@ -130,7 +130,7 @@ func TestInterpErrorConsistency(t *testing.T) {
}{
{
fileName: "op1.go",
expectedInterp: "5:7: invalid float truncate",
expectedInterp: "5:2: illegal operand types for '+=' operator",
expectedExec: "5:4: constant 1.3 truncated to integer",
},
{

View File

@@ -31,9 +31,27 @@ func TestEvalArithmetic(t *testing.T) {
{desc: "add_II", src: "2 + 3", res: "5"},
{desc: "add_FI", src: "2.3 + 3", res: "5.3"},
{desc: "add_IF", src: "2 + 3.3", res: "5.3"},
{desc: "add_SS", src: `"foo" + "bar"`, res: "foobar"},
{desc: "add_SI", src: `"foo" + 1`, err: "1:22: illegal operand types for '+' operator"},
{desc: "sub_SS", src: `"foo" - "bar"`, err: "1:22: illegal operand types for '-' operator"},
{desc: "sub_II", src: "7 - 3", res: "4"},
{desc: "sub_FI", src: "7.2 - 3", res: "4.2"},
{desc: "sub_IF", src: "7 - 3.2", res: "3.8"},
{desc: "mul_II", src: "2 * 3", res: "6"},
{desc: "mul_FI", src: "2.2 * 3", res: "6.6000000000000005"},
{desc: "mul_IF", src: "3 * 2.2", res: "6.6000000000000005"},
{desc: "rem_FI", src: "8.0 % 4", err: "1:22: illegal operand types for '%' operator"},
})
}
func TestEvalAssign(t *testing.T) {
i := interp.New(interp.Opt{})
runTests(t, i, []testCase{
{src: `a := "Hello"; a += " world"`, res: "Hello world"},
{src: `b := "Hello"; b += 1`, err: "1:36: illegal operand types for '+=' operator"},
{src: `c := "Hello"; c -= " world"`, err: "1:36: illegal operand types for '-=' operator"},
{src: "e := 64.0; e %= 64", err: "1:33: illegal operand types for '%=' operator"},
{src: "f := int64(3.2)", err: "1:27: truncated to integer"},
})
}

View File

@@ -502,21 +502,23 @@ func exportName(s string) string {
// TypeOf returns the reflection type of dynamic interpreter type t.
func (t *Type) TypeOf() reflect.Type {
if t.rtype != nil {
return t.rtype
}
switch t.cat {
case ArrayT:
if t.size > 0 {
return reflect.ArrayOf(t.size, t.val.TypeOf())
t.rtype = reflect.ArrayOf(t.size, t.val.TypeOf())
} else {
t.rtype = reflect.SliceOf(t.val.TypeOf())
}
return reflect.SliceOf(t.val.TypeOf())
case BinPkgT, BuiltinT, InterfaceT, SrcPkgT:
return nil
case ChanT:
return reflect.ChanOf(reflect.BothDir, t.val.TypeOf())
t.rtype = reflect.ChanOf(reflect.BothDir, t.val.TypeOf())
case ErrorT:
return reflect.TypeOf(new(error)).Elem()
t.rtype = reflect.TypeOf(new(error)).Elem()
case FuncT:
in := make([]reflect.Type, len(t.arg))
@@ -527,13 +529,13 @@ func (t *Type) TypeOf() reflect.Type {
for i, v := range t.ret {
out[i] = v.TypeOf()
}
return reflect.FuncOf(in, out, false)
t.rtype = reflect.FuncOf(in, out, false)
case MapT:
return reflect.MapOf(t.key.TypeOf(), t.val.TypeOf())
t.rtype = reflect.MapOf(t.key.TypeOf(), t.val.TypeOf())
case PtrT:
return reflect.PtrTo(t.val.TypeOf())
t.rtype = reflect.PtrTo(t.val.TypeOf())
case StructT:
var fields []reflect.StructField
@@ -541,21 +543,18 @@ func (t *Type) TypeOf() reflect.Type {
field := reflect.StructField{Name: exportName(f.name), Type: f.typ.TypeOf()}
fields = append(fields, field)
}
return reflect.StructOf(fields)
case ValueT:
return t.rtype
t.rtype = reflect.StructOf(fields)
default:
z, _ := t.zero()
if z.IsValid() {
return z.Type()
if z, _ := t.zero(); z.IsValid() {
t.rtype = z.Type()
}
var empty reflect.Type
return empty
}
return t.rtype
}
func isNumber(t *Type) bool { return isInt(t) || isFloat(t) }
func isInt(t *Type) bool {
typ := t.TypeOf()
if typ == nil {