From cdc352cee24cdf4afc64ea63146b0a1bae367c3e Mon Sep 17 00:00:00 2001 From: Nicholas Wiersma Date: Tue, 11 Aug 2020 15:58:04 +0200 Subject: [PATCH] feat: add index and composite literal type checking This adds type checking to both `IndexExpr` and `CompositeLitExpr` as well as handling any required constant type conversion. This includes a change to the type propagation to the children of a composite literal. Previously in most cases the composite literal type was propagated to its children. This does not work with type checking as the actual child type is needed. --- _test/math3.go | 32 +++++ interp/cfg.go | 78 ++++++++++-- interp/interp_eval_test.go | 34 ++++++ interp/run.go | 2 - interp/typecheck.go | 240 ++++++++++++++++++++++++++++++++++--- interp/value.go | 20 ---- 6 files changed, 361 insertions(+), 45 deletions(-) create mode 100644 _test/math3.go diff --git a/_test/math3.go b/_test/math3.go new file mode 100644 index 00000000..025207c7 --- /dev/null +++ b/_test/math3.go @@ -0,0 +1,32 @@ +package main + +import ( + "crypto/md5" + "fmt" +) + +func md5Crypt(password, salt, magic []byte) []byte { + d := md5.New() + d.Write(password) + d.Write(magic) + d.Write(salt) + + d2 := md5.New() + d2.Write(password) + d2.Write(salt) + + for i, mixin := 0, d2.Sum(nil); i < len(password); i++ { + d.Write([]byte{mixin[i%16]}) + } + + return d.Sum(nil) +} + +func main() { + b := md5Crypt([]byte("1"), []byte("2"), []byte("3")) + + fmt.Println(b) +} + +// Output: +// [187 141 73 89 101 229 33 106 226 63 117 234 117 149 230 21] diff --git a/interp/cfg.go b/interp/cfg.go index 21e39802..3f95a133 100644 --- a/interp/cfg.go +++ b/interp/cfg.go @@ -245,6 +245,7 @@ func (interp *Interpreter) cfg(root *node, pkgID string) ([]*node, error) { if n.typ, err = nodeType(interp, sc, n.child[0]); err != nil { return false } + // Indicate that the first child is the type n.nleft = 1 } else { // Get type from ancestor (implicit type) @@ -258,18 +259,28 @@ func (interp *Interpreter) cfg(root *node, pkgID string) ([]*node, error) { return false } } + + child := n.child + if n.nleft > 0 { + n.child[0].typ = n.typ + child = n.child[1:] + } // Propagate type to children, to handle implicit types - for _, c := range n.child { + for _, c := range child { switch c.kind { - case binaryExpr, unaryExpr: + case binaryExpr, unaryExpr, compositeLitExpr: // Do not attempt to propagate composite type to operator expressions, // it breaks constant folding. - case callExpr: + case keyValueExpr, typeAssertExpr, indexExpr: + c.typ = n.typ + default: + if c.ident == nilIdent { + c.typ = sc.getType(nilIdent) + continue + } if c.typ, err = nodeType(interp, sc, c); err != nil { return false } - default: - c.typ = n.typ } } @@ -701,13 +712,22 @@ func (interp *Interpreter) cfg(root *node, pkgID string) ([]*node, error) { } n.findex = sc.add(n.typ) typ := t.TypeOf() - switch k := typ.Kind(); k { - case reflect.Map: + if typ.Kind() == reflect.Map { + err = check.assignment(n.child[1], t.key, "map index") n.gen = getIndexMap - case reflect.Array, reflect.Slice, reflect.String: + break + } + + l := -1 + switch k := typ.Kind(); k { + case reflect.Array: + l = typ.Len() + fallthrough + case reflect.Slice, reflect.String: n.gen = getIndexArray case reflect.Ptr: if typ2 := typ.Elem(); typ2.Kind() == reflect.Array { + l = typ2.Len() n.gen = getIndexArray } else { err = n.cfgErrorf("type %v does not support indexing", typ) @@ -716,6 +736,8 @@ func (interp *Interpreter) cfg(root *node, pkgID string) ([]*node, error) { err = n.cfgErrorf("type is not an array, slice, string or map: %v", t.id()) } + err = check.index(n.child[1], l) + case blockStmt: wireChild(n) if len(n.child) > 0 { @@ -923,6 +945,46 @@ func (interp *Interpreter) cfg(root *node, pkgID string) ([]*node, error) { case compositeLitExpr: wireChild(n) + + underlying := func(t *itype) *itype { + for { + switch t.cat { + case ptrT, aliasT: + t = t.val + continue + default: + return t + } + } + } + + child := n.child + if n.nleft > 0 { + child = child[1:] + } + + switch n.typ.cat { + case arrayT: + err = check.arrayLitExpr(child, underlying(n.typ.val), n.typ.size) + case mapT: + err = check.mapLitExpr(child, n.typ.key, underlying(n.typ.val)) + case structT: + err = check.structLitExpr(child, n.typ) + case valueT: + rtype := n.typ.rtype + switch rtype.Kind() { + case reflect.Struct: + err = check.structBinLitExpr(child, rtype) + case reflect.Map: + ktyp := &itype{cat: valueT, rtype: rtype.Key()} + vtyp := &itype{cat: valueT, rtype: rtype.Elem()} + err = check.mapLitExpr(child, ktyp, vtyp) + } + } + if err != nil { + break + } + n.findex = sc.add(n.typ) // TODO: Check that composite literal expr matches corresponding type n.gen = compositeGenerator(n) diff --git a/interp/interp_eval_test.go b/interp/interp_eval_test.go index a9532a1f..5fb45771 100644 --- a/interp/interp_eval_test.go +++ b/interp/interp_eval_test.go @@ -311,6 +311,40 @@ func TestEvalCompositeArray(t *testing.T) { i := interp.New(interp.Options{}) runTests(t, i, []testCase{ {src: "a := []int{1, 2, 7: 20, 30}", res: "[1 2 0 0 0 0 0 20 30]"}, + {src: `a := []int{1, 1.2}`, err: "1:42: 6/5 truncated to int"}, + {src: `a := []int{0:1, 0:1}`, err: "1:46: duplicate index 0 in array or slice literal"}, + {src: `a := []int{1.1:1, 1.2:"test"}`, err: "1:39: index float64 must be integer constant"}, + {src: `a := [2]int{1, 1.2}`, err: "1:43: 6/5 truncated to int"}, + {src: `a := [1]int{1, 2}`, err: "1:43: index 1 is out of bounds (>= 1)"}, + }) +} + +func TestEvalCompositeMap(t *testing.T) { + i := interp.New(interp.Options{}) + runTests(t, i, []testCase{ + {src: `a := map[string]int{"one":1, "two":2}`, res: "map[one:1 two:2]"}, + {src: `a := map[string]int{1:1, 2:2}`, err: "1:48: cannot convert 1 to string"}, + {src: `a := map[string]int{"one":1, "two":2.2}`, err: "1:63: 11/5 truncated to int"}, + {src: `a := map[string]int{1, "two":2}`, err: "1:48: missing key in map literal"}, + {src: `a := map[string]int{"one":1, "one":2}`, err: "1:57: duplicate key one in map literal"}, + }) +} + +func TestEvalCompositeStruct(t *testing.T) { + i := interp.New(interp.Options{}) + runTests(t, i, []testCase{ + {src: `a := struct{A,B,C int}{}`, res: "{0 0 0}"}, + {src: `a := struct{A,B,C int}{1,2,3}`, res: "{1 2 3}"}, + {src: `a := struct{A,B,C int}{1,2.2,3}`, err: "1:53: 11/5 truncated to int"}, + {src: `a := struct{A,B,C int}{1,2}`, err: "1:53: too few values in struct literal"}, + {src: `a := struct{A,B,C int}{1,2,3,4}`, err: "1:57: too many values in struct literal"}, + {src: `a := struct{A,B,C int}{1,B:2,3}`, err: "1:53: mixture of field:value and value elements in struct literal"}, + {src: `a := struct{A,B,C int}{A:1,B:2,C:3}`, res: "{1 2 3}"}, + {src: `a := struct{A,B,C int}{B:2}`, res: "{0 2 0}"}, + {src: `a := struct{A,B,C int}{A:1,D:2,C:3}`, err: "1:55: unknown field D in struct literal"}, + {src: `a := struct{A,B,C int}{A:1,A:2,C:3}`, err: "1:55: duplicate field name A in struct literal"}, + {src: `a := struct{A,B,C int}{A:1,B:2.2,C:3}`, err: "1:57: 11/5 truncated to int"}, + {src: `a := struct{A,B,C int}{A:1,2,C:3}`, err: "1:55: mixture of field:value and value elements in struct literal"}, }) } diff --git a/interp/run.go b/interp/run.go index ded67b14..c61cd11b 100644 --- a/interp/run.go +++ b/interp/run.go @@ -1315,7 +1315,6 @@ func getIndexMap(n *node) { z := reflect.New(n.child[0].typ.frameType().Elem()).Elem() if n.child[1].rval.IsValid() { // constant map index - convertConstantValue(n.child[1]) mi := n.child[1].rval switch { @@ -1409,7 +1408,6 @@ func getIndexMap2(n *node) { return } if n.child[1].rval.IsValid() { // constant map index - convertConstantValue(n.child[1]) mi := n.child[1].rval switch { case !doValue: diff --git a/interp/typecheck.go b/interp/typecheck.go index a949c95d..0e7e0e9c 100644 --- a/interp/typecheck.go +++ b/interp/typecheck.go @@ -28,7 +28,33 @@ func (check typecheck) op(p opPredicates, a action, n, c *node, t reflect.Type) return nil } -// addressExpr type checks an assign expression. +// assignment checks if n can be assigned to typ. +// +// Use typ == nil to indicate assignment to an untyped blank identifier. +func (check typecheck) assignment(n *node, typ *itype, context string) error { + if n.typ.untyped { + if typ == nil || isInterface(typ) { + if typ == nil && n.typ.cat == nilT { + return n.cfgErrorf("use of untyped nil in %s", context) + } + typ = n.typ.defaultType() + } + if err := check.convertUntyped(n, typ); err != nil { + return err + } + } + + if typ == nil { + return nil + } + + if !n.typ.assignableTo(typ) { + return n.cfgErrorf("cannot use type %s as type %s in %s", n.typ.id(), typ.id(), context) + } + return nil +} + +// assignExpr type checks an assign expression. // // This is done per pair of assignments. func (check typecheck) assignExpr(n, dest, src *node) error { @@ -39,20 +65,7 @@ func (check typecheck) assignExpr(n, dest, src *node) error { dest.typ = dest.typ.defaultType() } - if src.typ.untyped { - typ := dest.typ - if typ.isNil() || isInterface(typ) { - typ = src.typ.defaultType() - } - if err := check.convertUntyped(src, typ); err != nil { - return err - } - } - - if !src.typ.assignableTo(dest.typ) { - return src.cfgErrorf("cannot use type %s as type %s in assignment", src.typ.id(), dest.typ.id()) - } - return nil + return check.assignment(src, dest.typ, "assignment") } // assignment operations. @@ -224,6 +237,203 @@ func (check typecheck) binaryExpr(n *node) error { return nil } +func (check typecheck) index(n *node, max int) error { + if err := check.convertUntyped(n, &itype{cat: intT, name: "int"}); err != nil { + return err + } + + if !isInt(n.typ.TypeOf()) { + return n.cfgErrorf("index %s must be integer", n.typ.id()) + } + + if !n.rval.IsValid() || max < 1 { + return nil + } + + if int(vInt(n.rval)) >= max { + return n.cfgErrorf("index %s is out of bounds", n.typ.id()) + } + + return nil +} + +// arrayLitExpr type checks an array composite literal expression. +func (check typecheck) arrayLitExpr(child []*node, typ *itype, length int) error { + visited := make(map[int]bool, len(child)) + index := 0 + for _, c := range child { + n := c + switch { + case c.kind == keyValueExpr: + if err := check.index(c.child[0], length); err != nil { + return c.cfgErrorf("index %s must be integer constant", c.child[0].typ.id()) + } + n = c.child[1] + index = int(vInt(c.child[0].rval)) + case length > 0 && index >= length: + return c.cfgErrorf("index %d is out of bounds (>= %d)", index, length) + } + + if visited[index] { + return n.cfgErrorf("duplicate index %d in array or slice literal", index) + } + visited[index] = true + index++ + + if err := check.assignment(n, typ, "array or slice literal"); err != nil { + return err + } + } + return nil +} + +// mapLitExpr type checks an map composite literal expression. +func (check typecheck) mapLitExpr(child []*node, ktyp, vtyp *itype) error { + visited := make(map[interface{}]bool, len(child)) + for _, c := range child { + if c.kind != keyValueExpr { + return c.cfgErrorf("missing key in map literal") + } + + key, val := c.child[0], c.child[1] + if err := check.assignment(key, ktyp, "map literal"); err != nil { + return err + } + + if key.rval.IsValid() { + kval := key.rval.Interface() + if visited[kval] { + return c.cfgErrorf("duplicate key %s in map literal", kval) + } + visited[kval] = true + } + + if err := check.assignment(val, vtyp, "map literal"); err != nil { + return err + } + } + return nil +} + +// structLitExpr type checks an struct composite literal expression. +func (check typecheck) structLitExpr(child []*node, typ *itype) error { + if len(child) == 0 { + return nil + } + + if child[0].kind == keyValueExpr { + // All children must be keyValueExpr + visited := make([]bool, len(typ.field)) + for _, c := range child { + if c.kind != keyValueExpr { + return c.cfgErrorf("mixture of field:value and value elements in struct literal") + } + + key, val := c.child[0], c.child[1] + name := key.ident + if name == "" { + return c.cfgErrorf("invalid field name %s in struct literal", key.typ.id()) + } + i := typ.fieldIndex(name) + if i < 0 { + return c.cfgErrorf("unknown field %s in struct literal", name) + } + field := typ.field[i] + + if err := check.assignment(val, field.typ, "struct literal"); err != nil { + return err + } + + if visited[i] { + return c.cfgErrorf("duplicate field name %s in struct literal", name) + } + visited[i] = true + } + return nil + } + + // No children can be keyValueExpr + for i, c := range child { + if c.kind == keyValueExpr { + return c.cfgErrorf("mixture of field:value and value elements in struct literal") + } + + if i >= len(typ.field) { + return c.cfgErrorf("too many values in struct literal") + } + field := typ.field[i] + // TODO(nick): check if this field is not exported and in a different package. + + if err := check.assignment(c, field.typ, "struct literal"); err != nil { + return err + } + } + if len(child) < len(typ.field) { + return child[len(child)-1].cfgErrorf("too few values in struct literal") + } + return nil +} + +// structBinLitExpr type checks an struct composite literal expression on a binary type. +func (check typecheck) structBinLitExpr(child []*node, typ reflect.Type) error { + if len(child) == 0 { + return nil + } + + if child[0].kind == keyValueExpr { + // All children must be keyValueExpr + visited := make(map[string]bool, typ.NumField()) + for _, c := range child { + if c.kind != keyValueExpr { + return c.cfgErrorf("mixture of field:value and value elements in struct literal") + } + + key, val := c.child[0], c.child[1] + name := key.ident + if name == "" { + return c.cfgErrorf("invalid field name %s in struct literal", key.typ.id()) + } + field, ok := typ.FieldByName(name) + if !ok { + return c.cfgErrorf("unknown field %s in struct literal", name) + } + + if err := check.assignment(val, &itype{cat: valueT, rtype: field.Type}, "struct literal"); err != nil { + return err + } + + if visited[field.Name] { + return c.cfgErrorf("duplicate field name %s in struct literal", name) + } + visited[field.Name] = true + } + return nil + } + + // No children can be keyValueExpr + for i, c := range child { + if c.kind == keyValueExpr { + return c.cfgErrorf("mixture of field:value and value elements in struct literal") + } + + if i >= typ.NumField() { + return c.cfgErrorf("too many values in struct literal") + } + field := typ.Field(i) + if !canExport(field.Name) { + return c.cfgErrorf("implicit assignment to unexported field %s in %s literal", field.Name, typ) + } + + if err := check.assignment(c, &itype{cat: valueT, rtype: field.Type}, "struct literal"); err != nil { + return err + } + } + if len(child) < typ.NumField() { + return child[len(child)-1].cfgErrorf("too few values in struct literal") + } + return nil +} + var errCantConvert = errors.New("cannot convert") func (check typecheck) convertUntyped(n *node, typ *itype) error { diff --git a/interp/value.go b/interp/value.go index 18e5c0c8..edc3bc7e 100644 --- a/interp/value.go +++ b/interp/value.go @@ -334,10 +334,6 @@ func vInt(v reflect.Value) (i int64) { case reflect.Complex64, reflect.Complex128: i = int64(real(v.Complex())) } - if v.Type().Implements(constVal) { - c := v.Interface().(constant.Value) - i, _ = constant.Int64Val(constant.ToInt(c)) - } return } @@ -352,11 +348,6 @@ func vUint(v reflect.Value) (i uint64) { case reflect.Complex64, reflect.Complex128: i = uint64(real(v.Complex())) } - if v.Type().Implements(constVal) { - c := v.Interface().(constant.Value) - iv, _ := constant.Int64Val(constant.ToInt(c)) - i = uint64(iv) - } return } @@ -371,13 +362,6 @@ func vComplex(v reflect.Value) (c complex128) { case reflect.Complex64, reflect.Complex128: c = v.Complex() } - if v.Type().Implements(constVal) { - con := v.Interface().(constant.Value) - con = constant.ToComplex(con) - rel, _ := constant.Float64Val(constant.Real(con)) - img, _ := constant.Float64Val(constant.Imag(con)) - c = complex(rel, img) - } return } @@ -392,10 +376,6 @@ func vFloat(v reflect.Value) (i float64) { case reflect.Complex64, reflect.Complex128: i = real(v.Complex()) } - if v.Type().Implements(constVal) { - c := v.Interface().(constant.Value) - i, _ = constant.Float64Val(constant.ToFloat(c)) - } return }