diff --git a/_test/op6.go b/_test/op6.go new file mode 100644 index 00000000..fab2c2a8 --- /dev/null +++ b/_test/op6.go @@ -0,0 +1,17 @@ +package main + +type T int + +func (t T) Error() string { return "T: error" } + +var invalidT T + +func main() { + var err error + if err != invalidT { + println("ok") + } +} + +// Output: +// ok diff --git a/_test/op7.go b/_test/op7.go new file mode 100644 index 00000000..dda60bdb --- /dev/null +++ b/_test/op7.go @@ -0,0 +1,17 @@ +package main + +type T int + +func (t T) Error() string { return "T: error" } + +var invalidT T + +func main() { + var err error + if err > invalidT { + println("ok") + } +} + +// Error: +// _test/op7.go:11:5: invalid operation: operator > not defined on error diff --git a/_test/op8.go b/_test/op8.go new file mode 100644 index 00000000..1aa4e28b --- /dev/null +++ b/_test/op8.go @@ -0,0 +1,21 @@ +package main + +type I interface { + Get() interface{} +} + +type T struct{} + +func (T) Get() interface{} { + return nil +} + +func main() { + var i I = T{} + var ei interface{} + + println(i != ei) +} + +// Output: +// true diff --git a/_test/op9.go b/_test/op9.go new file mode 100644 index 00000000..cbe37b5a --- /dev/null +++ b/_test/op9.go @@ -0,0 +1,11 @@ +package main + +func main() { + var i complex128 = 1i + var f complex128 = 0.4i + + print(i > f) +} + +// Error: +// _test/op9.go:7:8: invalid operation: operator > not defined on complex128 diff --git a/internal/genop/genop.go b/internal/genop/genop.go index 111d963f..9434b82d 100644 --- a/internal/genop/genop.go +++ b/internal/genop/genop.go @@ -412,6 +412,80 @@ func {{$name}}(n *node) { dest := genValueOutput(n, reflect.TypeOf(true)) c0, c1 := n.child[0], n.child[1] + if c0.typ.cat == aliasT || c1.typ.cat == aliasT { + switch { + case c0.rval.IsValid(): + i0 := c0.rval.Interface() + v1 := genValue(c1) + if n.fnext != nil { + fnext := getExec(n.fnext) + n.exec = func(f *frame) bltn { + i1 := v1(f).Interface() + if i0 != i1 { + dest(f).SetBool(true) + return tnext + } + dest(f).SetBool(false) + return fnext + } + } else { + dest := genValue(n) + n.exec = func(f *frame) bltn { + i1 := v1(f).Interface() + dest(f).SetBool(i0 != i1) + return tnext + } + } + case c1.rval.IsValid(): + i1 := c1.rval.Interface() + v0 := genValue(c0) + if n.fnext != nil { + fnext := getExec(n.fnext) + n.exec = func(f *frame) bltn { + i0 := v0(f).Interface() + if i0 != i1 { + dest(f).SetBool(true) + return tnext + } + dest(f).SetBool(false) + return fnext + } + } else { + dest := genValue(n) + n.exec = func(f *frame) bltn { + i0 := v0(f).Interface() + dest(f).SetBool(i0 != i1) + return tnext + } + } + default: + v0 := genValue(c0) + v1 := genValue(c1) + if n.fnext != nil { + fnext := getExec(n.fnext) + n.exec = func(f *frame) bltn { + i0 := v0(f).Interface() + i1 := v1(f).Interface() + if i0 != i1 { + dest(f).SetBool(true) + return tnext + } + dest(f).SetBool(false) + return fnext + } + } else { + dest := genValue(n) + n.exec = func(f *frame) bltn { + i0 := v0(f).Interface() + i1 := v1(f).Interface() + dest(f).SetBool(i0 != i1) + return tnext + } + } + } + return + } + switch t0, t1 := c0.typ.TypeOf(), c1.typ.TypeOf(); { case isString(t0) || isString(t1): switch { diff --git a/interp/ast.go b/interp/ast.go index 091ef1e4..be1d91f4 100644 --- a/interp/ast.go +++ b/interp/ast.go @@ -250,6 +250,7 @@ const ( aTypeAssert aXor aXorAssign + aMax ) var actions = [...]string{ diff --git a/interp/cfg.go b/interp/cfg.go index c00557e5..a3346321 100644 --- a/interp/cfg.go +++ b/interp/cfg.go @@ -8,6 +8,7 @@ import ( "path/filepath" "reflect" "regexp" + "strings" "unicode" ) @@ -639,39 +640,77 @@ func (interp *Interpreter) cfg(root *node, pkgID string) ([]*node, error) { nilSym := interp.universe.sym["nil"] c0, c1 := n.child[0], n.child[1] t0, t1 := c0.typ.TypeOf(), c1.typ.TypeOf() - // Shift operator type is inherited from first parameter only. - // All other binary operators require both parameter types to be the same. - if !isShiftNode(n) && !c0.typ.untyped && !c1.typ.untyped && !c0.typ.equals(c1.typ) { - err = n.cfgErrorf("mismatched types %s and %s", c0.typ.id(), c1.typ.id()) + + isConstVal := func(n *node) bool { + return n.rval.IsValid() && isConstantValue(n.rval.Type()) + } + + // Type check the binary expression. Mimics Gos logic as closely and possible. + c := c0 + if isConstVal(c) { + c = c1 + } + + if isShiftNode(n) { + if !c1.isNatural() { + err = n.cfgErrorf("invalid operation: shift count type %v, must be integer", strings.TrimLeft(c1.typ.id(), ".")) + break + } + + if !c0.isInteger() { + err = n.cfgErrorf("invalid operation: shift of type %v", strings.TrimLeft(c0.typ.id(), ".")) + break + } + } + if !isShiftNode(n) && isComparisonNode(n) && !isConstVal(c) && !c0.typ.equals(c1.typ) { + if isInterface(c1.typ) && !isInterface(c0.typ) && !c0.typ.comparable() { + err = n.cfgErrorf("invalid operation: operator %v not defined on %s", n.action, strings.TrimLeft(c0.typ.id(), ".")) + break + } + + if isInterface(c0.typ) && !isInterface(c1.typ) && !c1.typ.comparable() { + err = n.cfgErrorf("invalid operation: operator %v not defined on %s", n.action, strings.TrimLeft(c1.typ.id(), ".")) + break + } + } + if !isShiftNode(n) && !isConstVal(c) && !c0.typ.equals(c1.typ) && t0 != nil && t1 != nil { + switch { + case isConstVal(c0) && isNumber(t1) || isConstVal(c1) && isNumber(t0): // const <-> numberic case + case t0.Kind() == reflect.Uint8 && t1.Kind() == reflect.Int32 || t1.Kind() == reflect.Uint8 && t0.Kind() == reflect.Int32: // byte <-> rune case + case isInterface(c0.typ) && isInterface(c1.typ): // interface <-> interface case + default: + err = n.cfgErrorf("invalid operation: mismatched types %s and %s", strings.TrimLeft(c0.typ.id(), "."), strings.TrimLeft(c1.typ.id(), ".")) + } + if err != nil { + break + } + } + + cat := c.typ.cat + switch { + case isConstVal(c): + cat = catOfConst(c.rval) + case c.typ.cat == valueT: + cat = catOf(c.typ.rtype) + case c.typ.cat == aliasT: + cat = c.typ.val.cat + } + if !isShiftNode(n) && !okFor[n.action][cat] { + err = n.cfgErrorf("invalid operation: operator %v not defined on %s", n.action, strings.TrimLeft(c0.typ.id(), ".")) break } + if !isShiftNode(n) && isConstVal(c0) && isConstVal(c1) { + // If both are constants, check the left type as well. + if !okFor[n.action][catOfConst(c0.rval)] { + err = n.cfgErrorf("invalid operation: operator %v not defined on %s", n.action, strings.TrimLeft(c0.typ.id(), ".")) + break + } + } + switch n.action { - case aAdd: - if !(isNumber(t0) && isNumber(t1) || isString(t0) && isString(t1)) { - err = n.cfgErrorf("illegal operand types for '%v' operator", n.action) - } - case aSub, aMul, aQuo: - if !(isNumber(t0) && isNumber(t1)) { - err = n.cfgErrorf("illegal operand types for '%v' operator", n.action) - } - case aAnd, aOr, aXor, aAndNot: - if !(isInt(t0) && isInt(t1)) { - err = n.cfgErrorf("illegal operand types for '%v' operator", n.action) - } - case aRem: - if !(c0.isInteger() && c1.isInteger()) { - err = n.cfgErrorf("illegal operand types for '%v' operator", n.action) - } - n.typ = c0.typ - case aShl, aShr: - if !(c0.isInteger() && c1.isNatural()) { - err = n.cfgErrorf("illegal operand types for '%v' operator", n.action) - } + case aRem, aShl, aShr: n.typ = c0.typ case aEqual, aNotEqual: - if isNumber(t0) && !isNumber(t1) || isString(t0) && !isString(t1) { - err = n.cfgErrorf("illegal operand types for '%v' operator", n.action) - } n.typ = sc.getType("bool") if n.child[0].sym == nilSym || n.child[1].sym == nilSym { if n.action == aEqual { @@ -681,9 +720,6 @@ func (interp *Interpreter) cfg(root *node, pkgID string) ([]*node, error) { } } case aGreater, aGreaterEqual, aLower, aLowerEqual: - if isNumber(t0) && !isNumber(t1) || isString(t0) && !isString(t1) { - err = n.cfgErrorf("illegal operand types for '%v' operator", n.action) - } n.typ = sc.getType("bool") } if err != nil { diff --git a/interp/interp_consistent_test.go b/interp/interp_consistent_test.go index d3a9b88c..590513fb 100644 --- a/interp/interp_consistent_test.go +++ b/interp/interp_consistent_test.go @@ -46,6 +46,8 @@ func TestInterpConsistencyBuild(t *testing.T) { file.Name() == "import6.go" || // expect error file.Name() == "io0.go" || // use random number file.Name() == "op1.go" || // expect error + file.Name() == "op7.go" || // expect error + file.Name() == "op9.go" || // expect error file.Name() == "bltn0.go" || // expect error file.Name() == "method16.go" || // private struct field file.Name() == "switch8.go" || // expect error diff --git a/interp/interp_eval_test.go b/interp/interp_eval_test.go index cee64ae6..45ec5a65 100644 --- a/interp/interp_eval_test.go +++ b/interp/interp_eval_test.go @@ -34,24 +34,24 @@ func TestEvalArithmetic(t *testing.T) { {desc: "add_FI", src: "2.3 + 3", res: "5.3"}, {desc: "add_IF", src: "2 + 3.3", res: "5.3"}, {desc: "add_SS", src: `"foo" + "bar"`, res: "foobar"}, - {desc: "add_SI", src: `"foo" + 1`, err: "1:28: illegal operand types for '+' operator"}, - {desc: "sub_SS", src: `"foo" - "bar"`, err: "1:28: illegal operand types for '-' operator"}, + {desc: "add_SI", src: `"foo" + 1`, err: "1:28: invalid operation: mismatched types string and int"}, + {desc: "sub_SS", src: `"foo" - "bar"`, err: "1:28: invalid operation: operator - not defined on string"}, {desc: "sub_II", src: "7 - 3", res: "4"}, {desc: "sub_FI", src: "7.2 - 3", res: "4.2"}, {desc: "sub_IF", src: "7 - 3.2", res: "3.8"}, {desc: "mul_II", src: "2 * 3", res: "6"}, {desc: "mul_FI", src: "2.2 * 3", res: "6.6"}, {desc: "mul_IF", src: "3 * 2.2", res: "6.6"}, - {desc: "rem_FI", src: "8.2 % 4", err: "1:28: illegal operand types for '%' operator"}, + {desc: "rem_FI", src: "8.2 % 4", err: "1:28: invalid operation: operator % not defined on float64"}, {desc: "shl_II", src: "1 << 8", res: "256"}, - {desc: "shl_IN", src: "1 << -1", err: "1:28: illegal operand types for '<<' operator"}, + {desc: "shl_IN", src: "1 << -1", err: "1:28: invalid operation: shift count type int, must be integer"}, {desc: "shl_IF", src: "1 << 1.0", res: "2"}, - {desc: "shl_IF1", src: "1 << 1.1", err: "1:28: illegal operand types for '<<' operator"}, + {desc: "shl_IF1", src: "1 << 1.1", err: "1:28: invalid operation: shift count type float64, must be integer"}, {desc: "shl_IF2", src: "1.0 << 1", res: "2"}, {desc: "shr_II", src: "1 >> 8", res: "0"}, - {desc: "shr_IN", src: "1 >> -1", err: "1:28: illegal operand types for '>>' operator"}, + {desc: "shr_IN", src: "1 >> -1", err: "1:28: invalid operation: shift count type int, must be integer"}, {desc: "shr_IF", src: "1 >> 1.0", res: "0"}, - {desc: "shr_IF1", src: "1 >> 1.1", err: "1:28: illegal operand types for '>>' operator"}, + {desc: "shr_IF1", src: "1 >> 1.1", err: "1:28: invalid operation: shift count type float64, must be integer"}, {desc: "neg_I", src: "-2", res: "-2"}, {desc: "pos_I", src: "+2", res: "2"}, {desc: "bitnot_I", src: "^2", res: "-3"}, @@ -300,7 +300,7 @@ func TestEvalComparison(t *testing.T) { var b = Bar("test") var c = a == b `, - err: "7:13: mismatched types main.Foo and main.Bar", + err: "7:13: invalid operation: mismatched types main.Foo and main.Bar", }, }) } diff --git a/interp/op.go b/interp/op.go index 1026ec76..b61ebf7d 100644 --- a/interp/op.go +++ b/interp/op.go @@ -2033,6 +2033,80 @@ func equal(n *node) { dest := genValueOutput(n, reflect.TypeOf(true)) c0, c1 := n.child[0], n.child[1] + if c0.typ.cat == aliasT || c1.typ.cat == aliasT { + switch { + case c0.rval.IsValid(): + i0 := c0.rval.Interface() + v1 := genValue(c1) + if n.fnext != nil { + fnext := getExec(n.fnext) + n.exec = func(f *frame) bltn { + i1 := v1(f).Interface() + if i0 != i1 { + dest(f).SetBool(true) + return tnext + } + dest(f).SetBool(false) + return fnext + } + } else { + dest := genValue(n) + n.exec = func(f *frame) bltn { + i1 := v1(f).Interface() + dest(f).SetBool(i0 != i1) + return tnext + } + } + case c1.rval.IsValid(): + i1 := c1.rval.Interface() + v0 := genValue(c0) + if n.fnext != nil { + fnext := getExec(n.fnext) + n.exec = func(f *frame) bltn { + i0 := v0(f).Interface() + if i0 != i1 { + dest(f).SetBool(true) + return tnext + } + dest(f).SetBool(false) + return fnext + } + } else { + dest := genValue(n) + n.exec = func(f *frame) bltn { + i0 := v0(f).Interface() + dest(f).SetBool(i0 != i1) + return tnext + } + } + default: + v0 := genValue(c0) + v1 := genValue(c1) + if n.fnext != nil { + fnext := getExec(n.fnext) + n.exec = func(f *frame) bltn { + i0 := v0(f).Interface() + i1 := v1(f).Interface() + if i0 != i1 { + dest(f).SetBool(true) + return tnext + } + dest(f).SetBool(false) + return fnext + } + } else { + dest := genValue(n) + n.exec = func(f *frame) bltn { + i0 := v0(f).Interface() + i1 := v1(f).Interface() + dest(f).SetBool(i0 != i1) + return tnext + } + } + } + return + } + switch t0, t1 := c0.typ.TypeOf(), c1.typ.TypeOf(); { case isString(t0) || isString(t1): switch { @@ -2462,6 +2536,80 @@ func greater(n *node) { dest := genValueOutput(n, reflect.TypeOf(true)) c0, c1 := n.child[0], n.child[1] + if c0.typ.cat == aliasT || c1.typ.cat == aliasT { + switch { + case c0.rval.IsValid(): + i0 := c0.rval.Interface() + v1 := genValue(c1) + if n.fnext != nil { + fnext := getExec(n.fnext) + n.exec = func(f *frame) bltn { + i1 := v1(f).Interface() + if i0 != i1 { + dest(f).SetBool(true) + return tnext + } + dest(f).SetBool(false) + return fnext + } + } else { + dest := genValue(n) + n.exec = func(f *frame) bltn { + i1 := v1(f).Interface() + dest(f).SetBool(i0 != i1) + return tnext + } + } + case c1.rval.IsValid(): + i1 := c1.rval.Interface() + v0 := genValue(c0) + if n.fnext != nil { + fnext := getExec(n.fnext) + n.exec = func(f *frame) bltn { + i0 := v0(f).Interface() + if i0 != i1 { + dest(f).SetBool(true) + return tnext + } + dest(f).SetBool(false) + return fnext + } + } else { + dest := genValue(n) + n.exec = func(f *frame) bltn { + i0 := v0(f).Interface() + dest(f).SetBool(i0 != i1) + return tnext + } + } + default: + v0 := genValue(c0) + v1 := genValue(c1) + if n.fnext != nil { + fnext := getExec(n.fnext) + n.exec = func(f *frame) bltn { + i0 := v0(f).Interface() + i1 := v1(f).Interface() + if i0 != i1 { + dest(f).SetBool(true) + return tnext + } + dest(f).SetBool(false) + return fnext + } + } else { + dest := genValue(n) + n.exec = func(f *frame) bltn { + i0 := v0(f).Interface() + i1 := v1(f).Interface() + dest(f).SetBool(i0 != i1) + return tnext + } + } + } + return + } + switch t0, t1 := c0.typ.TypeOf(), c1.typ.TypeOf(); { case isString(t0) || isString(t1): switch { @@ -2751,6 +2899,80 @@ func greaterEqual(n *node) { dest := genValueOutput(n, reflect.TypeOf(true)) c0, c1 := n.child[0], n.child[1] + if c0.typ.cat == aliasT || c1.typ.cat == aliasT { + switch { + case c0.rval.IsValid(): + i0 := c0.rval.Interface() + v1 := genValue(c1) + if n.fnext != nil { + fnext := getExec(n.fnext) + n.exec = func(f *frame) bltn { + i1 := v1(f).Interface() + if i0 != i1 { + dest(f).SetBool(true) + return tnext + } + dest(f).SetBool(false) + return fnext + } + } else { + dest := genValue(n) + n.exec = func(f *frame) bltn { + i1 := v1(f).Interface() + dest(f).SetBool(i0 != i1) + return tnext + } + } + case c1.rval.IsValid(): + i1 := c1.rval.Interface() + v0 := genValue(c0) + if n.fnext != nil { + fnext := getExec(n.fnext) + n.exec = func(f *frame) bltn { + i0 := v0(f).Interface() + if i0 != i1 { + dest(f).SetBool(true) + return tnext + } + dest(f).SetBool(false) + return fnext + } + } else { + dest := genValue(n) + n.exec = func(f *frame) bltn { + i0 := v0(f).Interface() + dest(f).SetBool(i0 != i1) + return tnext + } + } + default: + v0 := genValue(c0) + v1 := genValue(c1) + if n.fnext != nil { + fnext := getExec(n.fnext) + n.exec = func(f *frame) bltn { + i0 := v0(f).Interface() + i1 := v1(f).Interface() + if i0 != i1 { + dest(f).SetBool(true) + return tnext + } + dest(f).SetBool(false) + return fnext + } + } else { + dest := genValue(n) + n.exec = func(f *frame) bltn { + i0 := v0(f).Interface() + i1 := v1(f).Interface() + dest(f).SetBool(i0 != i1) + return tnext + } + } + } + return + } + switch t0, t1 := c0.typ.TypeOf(), c1.typ.TypeOf(); { case isString(t0) || isString(t1): switch { @@ -3040,6 +3262,80 @@ func lower(n *node) { dest := genValueOutput(n, reflect.TypeOf(true)) c0, c1 := n.child[0], n.child[1] + if c0.typ.cat == aliasT || c1.typ.cat == aliasT { + switch { + case c0.rval.IsValid(): + i0 := c0.rval.Interface() + v1 := genValue(c1) + if n.fnext != nil { + fnext := getExec(n.fnext) + n.exec = func(f *frame) bltn { + i1 := v1(f).Interface() + if i0 != i1 { + dest(f).SetBool(true) + return tnext + } + dest(f).SetBool(false) + return fnext + } + } else { + dest := genValue(n) + n.exec = func(f *frame) bltn { + i1 := v1(f).Interface() + dest(f).SetBool(i0 != i1) + return tnext + } + } + case c1.rval.IsValid(): + i1 := c1.rval.Interface() + v0 := genValue(c0) + if n.fnext != nil { + fnext := getExec(n.fnext) + n.exec = func(f *frame) bltn { + i0 := v0(f).Interface() + if i0 != i1 { + dest(f).SetBool(true) + return tnext + } + dest(f).SetBool(false) + return fnext + } + } else { + dest := genValue(n) + n.exec = func(f *frame) bltn { + i0 := v0(f).Interface() + dest(f).SetBool(i0 != i1) + return tnext + } + } + default: + v0 := genValue(c0) + v1 := genValue(c1) + if n.fnext != nil { + fnext := getExec(n.fnext) + n.exec = func(f *frame) bltn { + i0 := v0(f).Interface() + i1 := v1(f).Interface() + if i0 != i1 { + dest(f).SetBool(true) + return tnext + } + dest(f).SetBool(false) + return fnext + } + } else { + dest := genValue(n) + n.exec = func(f *frame) bltn { + i0 := v0(f).Interface() + i1 := v1(f).Interface() + dest(f).SetBool(i0 != i1) + return tnext + } + } + } + return + } + switch t0, t1 := c0.typ.TypeOf(), c1.typ.TypeOf(); { case isString(t0) || isString(t1): switch { @@ -3329,6 +3625,80 @@ func lowerEqual(n *node) { dest := genValueOutput(n, reflect.TypeOf(true)) c0, c1 := n.child[0], n.child[1] + if c0.typ.cat == aliasT || c1.typ.cat == aliasT { + switch { + case c0.rval.IsValid(): + i0 := c0.rval.Interface() + v1 := genValue(c1) + if n.fnext != nil { + fnext := getExec(n.fnext) + n.exec = func(f *frame) bltn { + i1 := v1(f).Interface() + if i0 != i1 { + dest(f).SetBool(true) + return tnext + } + dest(f).SetBool(false) + return fnext + } + } else { + dest := genValue(n) + n.exec = func(f *frame) bltn { + i1 := v1(f).Interface() + dest(f).SetBool(i0 != i1) + return tnext + } + } + case c1.rval.IsValid(): + i1 := c1.rval.Interface() + v0 := genValue(c0) + if n.fnext != nil { + fnext := getExec(n.fnext) + n.exec = func(f *frame) bltn { + i0 := v0(f).Interface() + if i0 != i1 { + dest(f).SetBool(true) + return tnext + } + dest(f).SetBool(false) + return fnext + } + } else { + dest := genValue(n) + n.exec = func(f *frame) bltn { + i0 := v0(f).Interface() + dest(f).SetBool(i0 != i1) + return tnext + } + } + default: + v0 := genValue(c0) + v1 := genValue(c1) + if n.fnext != nil { + fnext := getExec(n.fnext) + n.exec = func(f *frame) bltn { + i0 := v0(f).Interface() + i1 := v1(f).Interface() + if i0 != i1 { + dest(f).SetBool(true) + return tnext + } + dest(f).SetBool(false) + return fnext + } + } else { + dest := genValue(n) + n.exec = func(f *frame) bltn { + i0 := v0(f).Interface() + i1 := v1(f).Interface() + dest(f).SetBool(i0 != i1) + return tnext + } + } + } + return + } + switch t0, t1 := c0.typ.TypeOf(), c1.typ.TypeOf(); { case isString(t0) || isString(t1): switch { @@ -3618,6 +3988,80 @@ func notEqual(n *node) { dest := genValueOutput(n, reflect.TypeOf(true)) c0, c1 := n.child[0], n.child[1] + if c0.typ.cat == aliasT || c1.typ.cat == aliasT { + switch { + case c0.rval.IsValid(): + i0 := c0.rval.Interface() + v1 := genValue(c1) + if n.fnext != nil { + fnext := getExec(n.fnext) + n.exec = func(f *frame) bltn { + i1 := v1(f).Interface() + if i0 != i1 { + dest(f).SetBool(true) + return tnext + } + dest(f).SetBool(false) + return fnext + } + } else { + dest := genValue(n) + n.exec = func(f *frame) bltn { + i1 := v1(f).Interface() + dest(f).SetBool(i0 != i1) + return tnext + } + } + case c1.rval.IsValid(): + i1 := c1.rval.Interface() + v0 := genValue(c0) + if n.fnext != nil { + fnext := getExec(n.fnext) + n.exec = func(f *frame) bltn { + i0 := v0(f).Interface() + if i0 != i1 { + dest(f).SetBool(true) + return tnext + } + dest(f).SetBool(false) + return fnext + } + } else { + dest := genValue(n) + n.exec = func(f *frame) bltn { + i0 := v0(f).Interface() + dest(f).SetBool(i0 != i1) + return tnext + } + } + default: + v0 := genValue(c0) + v1 := genValue(c1) + if n.fnext != nil { + fnext := getExec(n.fnext) + n.exec = func(f *frame) bltn { + i0 := v0(f).Interface() + i1 := v1(f).Interface() + if i0 != i1 { + dest(f).SetBool(true) + return tnext + } + dest(f).SetBool(false) + return fnext + } + } else { + dest := genValue(n) + n.exec = func(f *frame) bltn { + i0 := v0(f).Interface() + i1 := v1(f).Interface() + dest(f).SetBool(i0 != i1) + return tnext + } + } + } + return + } + switch t0, t1 := c0.typ.TypeOf(), c1.typ.TypeOf(); { case isString(t0) || isString(t1): switch { diff --git a/interp/type.go b/interp/type.go index 63d49159..891ed47d 100644 --- a/interp/type.go +++ b/interp/type.go @@ -691,6 +691,7 @@ func fieldName(n *node) string { } var zeroValues [maxT]reflect.Value +var okFor [aMax][maxT]bool func init() { zeroValues[boolT] = reflect.ValueOf(false) @@ -711,6 +712,77 @@ func init() { zeroValues[uint32T] = reflect.ValueOf(uint32(0)) zeroValues[uint64T] = reflect.ValueOf(uint64(0)) zeroValues[uintptrT] = reflect.ValueOf(uintptr(0)) + + // Calculate the action -> type allowances + var ( + okForEq [maxT]bool + okForCmp [maxT]bool + okForAdd [maxT]bool + okForAnd [maxT]bool + okForBool [maxT]bool + okForArith [maxT]bool + ) + for cat := tcat(0); cat < maxT; cat++ { + if (cat >= intT && cat <= int64T) || (cat >= uintT && cat <= uintptrT) { + okForEq[cat] = true + okForCmp[cat] = true + okForAdd[cat] = true + okForAnd[cat] = true + okForArith[cat] = true + } + if cat == float32T || cat == float64T { + okForEq[cat] = true + okForCmp[cat] = true + okForAdd[cat] = true + okForArith[cat] = true + } + if cat == complex64T || cat == complex128T { + okForEq[cat] = true + okForAdd[cat] = true + okForArith[cat] = true + } + } + + okForAdd[stringT] = true + + okForBool[boolT] = true + + okForEq[nilT] = true + okForEq[ptrT] = true + okForEq[interfaceT] = true + okForEq[errorT] = true + okForEq[chanT] = true + okForEq[stringT] = true + okForEq[boolT] = true + okForEq[mapT] = true // nil only + okForEq[funcT] = true // nil only + okForEq[arrayT] = true // array: only if element type is comparable slice: nil only + okForEq[structT] = true // only if all struct fields are comparable + + okForCmp[stringT] = true + + okFor[aAdd] = okForAdd + okFor[aAnd] = okForAnd + okFor[aLand] = okForBool + okFor[aAndNot] = okForAnd + okFor[aQuo] = okForArith + okFor[aEqual] = okForEq + okFor[aGreaterEqual] = okForCmp + okFor[aGreater] = okForCmp + okFor[aLowerEqual] = okForCmp + okFor[aLower] = okForCmp + okFor[aRem] = okForAnd + okFor[aMul] = okForArith + okFor[aNotEqual] = okForEq + okFor[aOr] = okForAnd + okFor[aLor] = okForBool + okFor[aSub] = okForArith + okFor[aXor] = okForAnd + okFor[aShl] = okForAnd + okFor[aShr] = okForAnd + okFor[aNeg] = okForArith + okFor[aNot] = okForBool + okFor[aPos] = okForArith } // Finalize returns a type pointer and error. It reparses a type from the @@ -855,6 +927,12 @@ func isComplete(t *itype, visited map[string]bool) bool { return true } +// comparable returns true if the type is comparable. +func (t *itype) comparable() bool { + typ := t.TypeOf() + return t.cat == nilT || typ != nil && typ.Comparable() +} + // Equals returns true if the given type is identical to the receiver one. func (t *itype) equals(o *itype) bool { switch ti, oi := isInterface(t), isInterface(o); { @@ -1279,6 +1357,88 @@ func (t *itype) implements(it *itype) bool { return t.methods().contains(it.methods()) } +var errType = reflect.TypeOf((*error)(nil)).Elem() + +func catOf(t reflect.Type) tcat { + if t == nil { + return nilT + } + if t == errType { + return errorT + } + switch t.Kind() { + case reflect.Bool: + return boolT + case reflect.Int: + return intT + case reflect.Int8: + return int8T + case reflect.Int16: + return int16T + case reflect.Int32: + return int32T + case reflect.Int64: + return int64T + case reflect.Uint: + return uintT + case reflect.Uint8: + return uint8T + case reflect.Uint16: + return uint16T + case reflect.Uint32: + return uint32T + case reflect.Uint64: + return uint64T + case reflect.Uintptr: + return uintptrT + case reflect.Float32: + return float32T + case reflect.Float64: + return float64T + case reflect.Complex64: + return complex64T + case reflect.Complex128: + return complex128T + case reflect.Array, reflect.Slice: + return arrayT + case reflect.Chan: + return chanT + case reflect.Func: + return funcT + case reflect.Interface: + return interfaceT + case reflect.Map: + return mapT + case reflect.Ptr: + return ptrT + case reflect.String: + return stringT + case reflect.Struct: + return structT + case reflect.UnsafePointer: + return uintptrT + } + return nilT +} + +func catOfConst(v reflect.Value) tcat { + c, ok := v.Interface().(constant.Value) + if !ok { + return nilT + } + + switch c.Kind() { + case constant.Int: + return intT + case constant.Float: + return float64T + case constant.Complex: + return complex128T + default: + return nilT + } +} + func constToInt(c constant.Value) int { if constant.BitLen(c) > 64 { panic(fmt.Sprintf("constant %s overflows int64", c.ExactString())) @@ -1305,6 +1465,14 @@ func isShiftNode(n *node) bool { return false } +func isComparisonNode(n *node) bool { + switch n.action { + case aEqual, aNotEqual, aGreater, aGreaterEqual, aLower, aLowerEqual: + return true + } + return false +} + // chanElement returns the channel element type. func chanElement(t *itype) *itype { switch t.cat { @@ -1331,7 +1499,7 @@ func isInterfaceSrc(t *itype) bool { } func isInterface(t *itype) bool { - return isInterfaceSrc(t) || t.TypeOf().Kind() == reflect.Interface + return isInterfaceSrc(t) || t.TypeOf() != nil && t.TypeOf().Kind() == reflect.Interface } func isStruct(t *itype) bool {