feat: add index and composite literal type checking

This adds type checking to both `IndexExpr` and `CompositeLitExpr` as well as handling any required constant type conversion.

This includes a change to the type propagation to the children of a composite literal. Previously in most cases the composite literal type was propagated to its children. This does not work with type checking as the actual child type is needed.
This commit is contained in:
Nicholas Wiersma
2020-08-11 15:58:04 +02:00
committed by GitHub
parent 88569f5df7
commit cdc352cee2
6 changed files with 361 additions and 45 deletions

32
_test/math3.go Normal file
View File

@@ -0,0 +1,32 @@
package main
import (
"crypto/md5"
"fmt"
)
func md5Crypt(password, salt, magic []byte) []byte {
d := md5.New()
d.Write(password)
d.Write(magic)
d.Write(salt)
d2 := md5.New()
d2.Write(password)
d2.Write(salt)
for i, mixin := 0, d2.Sum(nil); i < len(password); i++ {
d.Write([]byte{mixin[i%16]})
}
return d.Sum(nil)
}
func main() {
b := md5Crypt([]byte("1"), []byte("2"), []byte("3"))
fmt.Println(b)
}
// Output:
// [187 141 73 89 101 229 33 106 226 63 117 234 117 149 230 21]

View File

@@ -245,6 +245,7 @@ func (interp *Interpreter) cfg(root *node, pkgID string) ([]*node, error) {
if n.typ, err = nodeType(interp, sc, n.child[0]); err != nil {
return false
}
// Indicate that the first child is the type
n.nleft = 1
} else {
// Get type from ancestor (implicit type)
@@ -258,18 +259,28 @@ func (interp *Interpreter) cfg(root *node, pkgID string) ([]*node, error) {
return false
}
}
child := n.child
if n.nleft > 0 {
n.child[0].typ = n.typ
child = n.child[1:]
}
// Propagate type to children, to handle implicit types
for _, c := range n.child {
for _, c := range child {
switch c.kind {
case binaryExpr, unaryExpr:
case binaryExpr, unaryExpr, compositeLitExpr:
// Do not attempt to propagate composite type to operator expressions,
// it breaks constant folding.
case callExpr:
case keyValueExpr, typeAssertExpr, indexExpr:
c.typ = n.typ
default:
if c.ident == nilIdent {
c.typ = sc.getType(nilIdent)
continue
}
if c.typ, err = nodeType(interp, sc, c); err != nil {
return false
}
default:
c.typ = n.typ
}
}
@@ -701,13 +712,22 @@ func (interp *Interpreter) cfg(root *node, pkgID string) ([]*node, error) {
}
n.findex = sc.add(n.typ)
typ := t.TypeOf()
switch k := typ.Kind(); k {
case reflect.Map:
if typ.Kind() == reflect.Map {
err = check.assignment(n.child[1], t.key, "map index")
n.gen = getIndexMap
case reflect.Array, reflect.Slice, reflect.String:
break
}
l := -1
switch k := typ.Kind(); k {
case reflect.Array:
l = typ.Len()
fallthrough
case reflect.Slice, reflect.String:
n.gen = getIndexArray
case reflect.Ptr:
if typ2 := typ.Elem(); typ2.Kind() == reflect.Array {
l = typ2.Len()
n.gen = getIndexArray
} else {
err = n.cfgErrorf("type %v does not support indexing", typ)
@@ -716,6 +736,8 @@ func (interp *Interpreter) cfg(root *node, pkgID string) ([]*node, error) {
err = n.cfgErrorf("type is not an array, slice, string or map: %v", t.id())
}
err = check.index(n.child[1], l)
case blockStmt:
wireChild(n)
if len(n.child) > 0 {
@@ -923,6 +945,46 @@ func (interp *Interpreter) cfg(root *node, pkgID string) ([]*node, error) {
case compositeLitExpr:
wireChild(n)
underlying := func(t *itype) *itype {
for {
switch t.cat {
case ptrT, aliasT:
t = t.val
continue
default:
return t
}
}
}
child := n.child
if n.nleft > 0 {
child = child[1:]
}
switch n.typ.cat {
case arrayT:
err = check.arrayLitExpr(child, underlying(n.typ.val), n.typ.size)
case mapT:
err = check.mapLitExpr(child, n.typ.key, underlying(n.typ.val))
case structT:
err = check.structLitExpr(child, n.typ)
case valueT:
rtype := n.typ.rtype
switch rtype.Kind() {
case reflect.Struct:
err = check.structBinLitExpr(child, rtype)
case reflect.Map:
ktyp := &itype{cat: valueT, rtype: rtype.Key()}
vtyp := &itype{cat: valueT, rtype: rtype.Elem()}
err = check.mapLitExpr(child, ktyp, vtyp)
}
}
if err != nil {
break
}
n.findex = sc.add(n.typ)
// TODO: Check that composite literal expr matches corresponding type
n.gen = compositeGenerator(n)

View File

@@ -311,6 +311,40 @@ func TestEvalCompositeArray(t *testing.T) {
i := interp.New(interp.Options{})
runTests(t, i, []testCase{
{src: "a := []int{1, 2, 7: 20, 30}", res: "[1 2 0 0 0 0 0 20 30]"},
{src: `a := []int{1, 1.2}`, err: "1:42: 6/5 truncated to int"},
{src: `a := []int{0:1, 0:1}`, err: "1:46: duplicate index 0 in array or slice literal"},
{src: `a := []int{1.1:1, 1.2:"test"}`, err: "1:39: index float64 must be integer constant"},
{src: `a := [2]int{1, 1.2}`, err: "1:43: 6/5 truncated to int"},
{src: `a := [1]int{1, 2}`, err: "1:43: index 1 is out of bounds (>= 1)"},
})
}
func TestEvalCompositeMap(t *testing.T) {
i := interp.New(interp.Options{})
runTests(t, i, []testCase{
{src: `a := map[string]int{"one":1, "two":2}`, res: "map[one:1 two:2]"},
{src: `a := map[string]int{1:1, 2:2}`, err: "1:48: cannot convert 1 to string"},
{src: `a := map[string]int{"one":1, "two":2.2}`, err: "1:63: 11/5 truncated to int"},
{src: `a := map[string]int{1, "two":2}`, err: "1:48: missing key in map literal"},
{src: `a := map[string]int{"one":1, "one":2}`, err: "1:57: duplicate key one in map literal"},
})
}
func TestEvalCompositeStruct(t *testing.T) {
i := interp.New(interp.Options{})
runTests(t, i, []testCase{
{src: `a := struct{A,B,C int}{}`, res: "{0 0 0}"},
{src: `a := struct{A,B,C int}{1,2,3}`, res: "{1 2 3}"},
{src: `a := struct{A,B,C int}{1,2.2,3}`, err: "1:53: 11/5 truncated to int"},
{src: `a := struct{A,B,C int}{1,2}`, err: "1:53: too few values in struct literal"},
{src: `a := struct{A,B,C int}{1,2,3,4}`, err: "1:57: too many values in struct literal"},
{src: `a := struct{A,B,C int}{1,B:2,3}`, err: "1:53: mixture of field:value and value elements in struct literal"},
{src: `a := struct{A,B,C int}{A:1,B:2,C:3}`, res: "{1 2 3}"},
{src: `a := struct{A,B,C int}{B:2}`, res: "{0 2 0}"},
{src: `a := struct{A,B,C int}{A:1,D:2,C:3}`, err: "1:55: unknown field D in struct literal"},
{src: `a := struct{A,B,C int}{A:1,A:2,C:3}`, err: "1:55: duplicate field name A in struct literal"},
{src: `a := struct{A,B,C int}{A:1,B:2.2,C:3}`, err: "1:57: 11/5 truncated to int"},
{src: `a := struct{A,B,C int}{A:1,2,C:3}`, err: "1:55: mixture of field:value and value elements in struct literal"},
})
}

View File

@@ -1315,7 +1315,6 @@ func getIndexMap(n *node) {
z := reflect.New(n.child[0].typ.frameType().Elem()).Elem()
if n.child[1].rval.IsValid() { // constant map index
convertConstantValue(n.child[1])
mi := n.child[1].rval
switch {
@@ -1409,7 +1408,6 @@ func getIndexMap2(n *node) {
return
}
if n.child[1].rval.IsValid() { // constant map index
convertConstantValue(n.child[1])
mi := n.child[1].rval
switch {
case !doValue:

View File

@@ -28,7 +28,33 @@ func (check typecheck) op(p opPredicates, a action, n, c *node, t reflect.Type)
return nil
}
// addressExpr type checks an assign expression.
// assignment checks if n can be assigned to typ.
//
// Use typ == nil to indicate assignment to an untyped blank identifier.
func (check typecheck) assignment(n *node, typ *itype, context string) error {
if n.typ.untyped {
if typ == nil || isInterface(typ) {
if typ == nil && n.typ.cat == nilT {
return n.cfgErrorf("use of untyped nil in %s", context)
}
typ = n.typ.defaultType()
}
if err := check.convertUntyped(n, typ); err != nil {
return err
}
}
if typ == nil {
return nil
}
if !n.typ.assignableTo(typ) {
return n.cfgErrorf("cannot use type %s as type %s in %s", n.typ.id(), typ.id(), context)
}
return nil
}
// assignExpr type checks an assign expression.
//
// This is done per pair of assignments.
func (check typecheck) assignExpr(n, dest, src *node) error {
@@ -39,20 +65,7 @@ func (check typecheck) assignExpr(n, dest, src *node) error {
dest.typ = dest.typ.defaultType()
}
if src.typ.untyped {
typ := dest.typ
if typ.isNil() || isInterface(typ) {
typ = src.typ.defaultType()
}
if err := check.convertUntyped(src, typ); err != nil {
return err
}
}
if !src.typ.assignableTo(dest.typ) {
return src.cfgErrorf("cannot use type %s as type %s in assignment", src.typ.id(), dest.typ.id())
}
return nil
return check.assignment(src, dest.typ, "assignment")
}
// assignment operations.
@@ -224,6 +237,203 @@ func (check typecheck) binaryExpr(n *node) error {
return nil
}
func (check typecheck) index(n *node, max int) error {
if err := check.convertUntyped(n, &itype{cat: intT, name: "int"}); err != nil {
return err
}
if !isInt(n.typ.TypeOf()) {
return n.cfgErrorf("index %s must be integer", n.typ.id())
}
if !n.rval.IsValid() || max < 1 {
return nil
}
if int(vInt(n.rval)) >= max {
return n.cfgErrorf("index %s is out of bounds", n.typ.id())
}
return nil
}
// arrayLitExpr type checks an array composite literal expression.
func (check typecheck) arrayLitExpr(child []*node, typ *itype, length int) error {
visited := make(map[int]bool, len(child))
index := 0
for _, c := range child {
n := c
switch {
case c.kind == keyValueExpr:
if err := check.index(c.child[0], length); err != nil {
return c.cfgErrorf("index %s must be integer constant", c.child[0].typ.id())
}
n = c.child[1]
index = int(vInt(c.child[0].rval))
case length > 0 && index >= length:
return c.cfgErrorf("index %d is out of bounds (>= %d)", index, length)
}
if visited[index] {
return n.cfgErrorf("duplicate index %d in array or slice literal", index)
}
visited[index] = true
index++
if err := check.assignment(n, typ, "array or slice literal"); err != nil {
return err
}
}
return nil
}
// mapLitExpr type checks an map composite literal expression.
func (check typecheck) mapLitExpr(child []*node, ktyp, vtyp *itype) error {
visited := make(map[interface{}]bool, len(child))
for _, c := range child {
if c.kind != keyValueExpr {
return c.cfgErrorf("missing key in map literal")
}
key, val := c.child[0], c.child[1]
if err := check.assignment(key, ktyp, "map literal"); err != nil {
return err
}
if key.rval.IsValid() {
kval := key.rval.Interface()
if visited[kval] {
return c.cfgErrorf("duplicate key %s in map literal", kval)
}
visited[kval] = true
}
if err := check.assignment(val, vtyp, "map literal"); err != nil {
return err
}
}
return nil
}
// structLitExpr type checks an struct composite literal expression.
func (check typecheck) structLitExpr(child []*node, typ *itype) error {
if len(child) == 0 {
return nil
}
if child[0].kind == keyValueExpr {
// All children must be keyValueExpr
visited := make([]bool, len(typ.field))
for _, c := range child {
if c.kind != keyValueExpr {
return c.cfgErrorf("mixture of field:value and value elements in struct literal")
}
key, val := c.child[0], c.child[1]
name := key.ident
if name == "" {
return c.cfgErrorf("invalid field name %s in struct literal", key.typ.id())
}
i := typ.fieldIndex(name)
if i < 0 {
return c.cfgErrorf("unknown field %s in struct literal", name)
}
field := typ.field[i]
if err := check.assignment(val, field.typ, "struct literal"); err != nil {
return err
}
if visited[i] {
return c.cfgErrorf("duplicate field name %s in struct literal", name)
}
visited[i] = true
}
return nil
}
// No children can be keyValueExpr
for i, c := range child {
if c.kind == keyValueExpr {
return c.cfgErrorf("mixture of field:value and value elements in struct literal")
}
if i >= len(typ.field) {
return c.cfgErrorf("too many values in struct literal")
}
field := typ.field[i]
// TODO(nick): check if this field is not exported and in a different package.
if err := check.assignment(c, field.typ, "struct literal"); err != nil {
return err
}
}
if len(child) < len(typ.field) {
return child[len(child)-1].cfgErrorf("too few values in struct literal")
}
return nil
}
// structBinLitExpr type checks an struct composite literal expression on a binary type.
func (check typecheck) structBinLitExpr(child []*node, typ reflect.Type) error {
if len(child) == 0 {
return nil
}
if child[0].kind == keyValueExpr {
// All children must be keyValueExpr
visited := make(map[string]bool, typ.NumField())
for _, c := range child {
if c.kind != keyValueExpr {
return c.cfgErrorf("mixture of field:value and value elements in struct literal")
}
key, val := c.child[0], c.child[1]
name := key.ident
if name == "" {
return c.cfgErrorf("invalid field name %s in struct literal", key.typ.id())
}
field, ok := typ.FieldByName(name)
if !ok {
return c.cfgErrorf("unknown field %s in struct literal", name)
}
if err := check.assignment(val, &itype{cat: valueT, rtype: field.Type}, "struct literal"); err != nil {
return err
}
if visited[field.Name] {
return c.cfgErrorf("duplicate field name %s in struct literal", name)
}
visited[field.Name] = true
}
return nil
}
// No children can be keyValueExpr
for i, c := range child {
if c.kind == keyValueExpr {
return c.cfgErrorf("mixture of field:value and value elements in struct literal")
}
if i >= typ.NumField() {
return c.cfgErrorf("too many values in struct literal")
}
field := typ.Field(i)
if !canExport(field.Name) {
return c.cfgErrorf("implicit assignment to unexported field %s in %s literal", field.Name, typ)
}
if err := check.assignment(c, &itype{cat: valueT, rtype: field.Type}, "struct literal"); err != nil {
return err
}
}
if len(child) < typ.NumField() {
return child[len(child)-1].cfgErrorf("too few values in struct literal")
}
return nil
}
var errCantConvert = errors.New("cannot convert")
func (check typecheck) convertUntyped(n *node, typ *itype) error {

View File

@@ -334,10 +334,6 @@ func vInt(v reflect.Value) (i int64) {
case reflect.Complex64, reflect.Complex128:
i = int64(real(v.Complex()))
}
if v.Type().Implements(constVal) {
c := v.Interface().(constant.Value)
i, _ = constant.Int64Val(constant.ToInt(c))
}
return
}
@@ -352,11 +348,6 @@ func vUint(v reflect.Value) (i uint64) {
case reflect.Complex64, reflect.Complex128:
i = uint64(real(v.Complex()))
}
if v.Type().Implements(constVal) {
c := v.Interface().(constant.Value)
iv, _ := constant.Int64Val(constant.ToInt(c))
i = uint64(iv)
}
return
}
@@ -371,13 +362,6 @@ func vComplex(v reflect.Value) (c complex128) {
case reflect.Complex64, reflect.Complex128:
c = v.Complex()
}
if v.Type().Implements(constVal) {
con := v.Interface().(constant.Value)
con = constant.ToComplex(con)
rel, _ := constant.Float64Val(constant.Real(con))
img, _ := constant.Float64Val(constant.Imag(con))
c = complex(rel, img)
}
return
}
@@ -392,10 +376,6 @@ func vFloat(v reflect.Value) (i float64) {
case reflect.Complex64, reflect.Complex128:
i = real(v.Complex())
}
if v.Type().Implements(constVal) {
c := v.Interface().(constant.Value)
i, _ = constant.Float64Val(constant.ToFloat(c))
}
return
}