diff --git a/interp/cfg.go b/interp/cfg.go index b93a1b9c..1c8e9d66 100644 --- a/interp/cfg.go +++ b/interp/cfg.go @@ -66,6 +66,41 @@ func (interp *Interpreter) cfg(root *node, importPath string) ([]*node, error) { return false } switch n.kind { + case binaryExpr, unaryExpr, parenExpr: + if isBoolAction(n) { + break + } + // Gather assigned type if set, to give context for type propagation at post-order. + switch n.anc.kind { + case assignStmt, defineStmt: + a := n.anc + i := childPos(n) - a.nright + if len(a.child) > a.nright+a.nleft { + i-- + } + dest := a.child[i] + if dest.typ != nil && !isInterface(dest.typ) { + // Interface type are not propagated, and will be resolved at post-order. + n.typ = dest.typ + } + case binaryExpr, unaryExpr, parenExpr: + n.typ = n.anc.typ + } + + case defineStmt: + // Determine type of variables initialized at declaration, so it can be propagated. + if n.nleft+n.nright == len(n.child) { + // No type was specified on the left hand side, it will resolved at post-order. + break + } + n.typ, err = nodeType(interp, sc, n.child[n.nleft]) + if err != nil { + break + } + for i := 0; i < n.nleft; i++ { + n.child[i].typ = n.typ + } + case blockStmt: if n.anc != nil && n.anc.kind == rangeStmt { // For range block: ensure that array or map type is propagated to iterators @@ -447,7 +482,7 @@ func (interp *Interpreter) cfg(root *node, importPath string) ([]*node, error) { var atyp *itype if n.nleft+n.nright < len(n.child) { if atyp, err = nodeType(interp, sc, n.child[n.nleft]); err != nil { - return + break } } @@ -644,7 +679,12 @@ func (interp *Interpreter) cfg(root *node, importPath string) ([]*node, error) { } switch n.action { - case aRem, aShl, aShr: + case aRem: + n.typ = c0.typ + case aShl, aShr: + if c0.typ.untyped { + break + } n.typ = c0.typ case aEqual, aNotEqual: n.typ = sc.getType("bool") @@ -860,7 +900,12 @@ func (interp *Interpreter) cfg(root *node, importPath string) ([]*node, error) { n.gen = nop n.findex = -1 n.typ = c0.typ - n.rval = c1.rval.Convert(c0.typ.rtype) + if c, ok := c1.rval.Interface().(constant.Value); ok { + i, _ := constant.Int64Val(constant.ToInt(c)) + n.rval = reflect.ValueOf(i).Convert(c0.typ.rtype) + } else { + n.rval = c1.rval.Convert(c0.typ.rtype) + } default: n.gen = convert n.typ = c0.typ @@ -2474,3 +2519,12 @@ func isArithmeticAction(n *node) bool { return false } } + +func isBoolAction(n *node) bool { + switch n.action { + case aEqual, aGreater, aGreaterEqual, aLand, aLor, aLower, aLowerEqual, aNot, aNotEqual: + return true + default: + return false + } +} diff --git a/interp/interp_eval_test.go b/interp/interp_eval_test.go index 5156a00b..6f3a23ff 100644 --- a/interp/interp_eval_test.go +++ b/interp/interp_eval_test.go @@ -71,6 +71,16 @@ func TestEvalArithmetic(t *testing.T) { }) } +func TestEvalShift(t *testing.T) { + i := interp.New(interp.Options{}) + runTests(t, i, []testCase{ + {src: "a, b, m := uint32(1), uint32(2), uint32(0); m = a + (1 << b)", res: "5"}, + {src: "c := uint(1); d := uint(+(-(1 << c)))", res: "18446744073709551614"}, + {src: "e, f := uint32(0), uint32(0); f = 1 << -(e * 2)", res: "1"}, + {pre: func() { eval(t, i, "const k uint = 1 << 17") }, src: "int(k)", res: "131072"}, + }) +} + func TestEvalStar(t *testing.T) { i := interp.New(interp.Options{}) runTests(t, i, []testCase{ diff --git a/interp/run.go b/interp/run.go index 5eafc3d4..5b1e39ac 100644 --- a/interp/run.go +++ b/interp/run.go @@ -1766,6 +1766,11 @@ func neg(n *node) { dest(f).SetInt(-value(f).Int()) return next } + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: + n.exec = func(f *frame) bltn { + dest(f).SetUint(-value(f).Uint()) + return next + } case reflect.Float32, reflect.Float64: n.exec = func(f *frame) bltn { dest(f).SetFloat(-value(f).Float()) diff --git a/interp/typecheck.go b/interp/typecheck.go index 87c917f4..307df9c8 100644 --- a/interp/typecheck.go +++ b/interp/typecheck.go @@ -217,6 +217,7 @@ var binaryOpPredicates = opPredicates{ // binaryExpr type checks a binary expression. func (check typecheck) binaryExpr(n *node) error { c0, c1 := n.child[0], n.child[1] + a := n.action if isAssignAction(a) { a-- @@ -226,6 +227,21 @@ func (check typecheck) binaryExpr(n *node) error { return check.shift(n) } + switch n.action { + case aRem: + if zeroConst(c1) { + return n.cfgErrorf("invalid operation: division by zero") + } + case aQuo: + if zeroConst(c1) { + return n.cfgErrorf("invalid operation: division by zero") + } + if c0.rval.IsValid() && c1.rval.IsValid() { + // Avoid constant conversions below to ensure correct constant integer quotient. + return nil + } + } + _ = check.convertUntyped(c0, c1.typ) _ = check.convertUntyped(c1, c0.typ) @@ -241,16 +257,13 @@ func (check typecheck) binaryExpr(n *node) error { if err := check.op(binaryOpPredicates, a, n, c0, t0); err != nil { return err } - - switch n.action { - case aQuo, aRem: - if (c0.typ.untyped || isInt(t0)) && c1.typ.untyped && constant.Sign(c1.rval.Interface().(constant.Value)) == 0 { - return n.cfgErrorf("invalid operation: division by zero") - } - } return nil } +func zeroConst(n *node) bool { + return n.typ.untyped && constant.Sign(n.rval.Interface().(constant.Value)) == 0 +} + func (check typecheck) index(n *node, max int) error { if err := check.convertUntyped(n, &itype{cat: intT, name: "int"}); err != nil { return err