fix: correct handling of constant expressions in type declarations (#204)

This commit is contained in:
Marc Vertes
2019-06-04 16:50:32 +02:00
committed by Ludovic Fernandez
parent 99eac2d333
commit 455a37e678
6 changed files with 277 additions and 51 deletions

16
_test/type11.go Normal file
View File

@@ -0,0 +1,16 @@
package main
import (
"compress/gzip"
"fmt"
"sync"
)
var gzipWriterPools [gzip.BestCompression - gzip.BestSpeed + 2]*sync.Pool
func main() {
fmt.Printf("%T\n", gzipWriterPools)
}
// Output:
// [10]*sync.Pool

View File

@@ -173,6 +173,32 @@ func {{$name}}(n *Node) {
{{- end}}
}
}
func {{$name}}Const(n *Node) {
v0, v1 := n.child[0].rval, n.child[1].rval
t := n.typ.rtype
n.rval = reflect.New(t).Elem()
switch {
{{- if $op.Str}}
case isString(t):
n.rval.SetString(v0.String() {{$op.Name}} v1.String())
{{- end}}
{{- if $op.Float}}
case isComplex(t):
n.rval.SetComplex(vComplex(v0) {{$op.Name}} vComplex(v1))
case isFloat(t):
n.rval.SetFloat(vFloat(v0) {{$op.Name}} vFloat(v1))
{{- end}}
case isUint(t):
n.rval.SetUint(vUint(v0) {{$op.Name}} vUint(v1))
case isInt(t):
{{- if $op.Shift}}
n.rval.SetInt(vInt(v0) {{$op.Name}} vUint(v1))
{{- else}}
n.rval.SetInt(vInt(v0) {{$op.Name}} vInt(v1))
{{- end}}
}
}
{{end}}
// Assign operators
{{range $name, $op := .Arithmetic}}
@@ -332,7 +358,7 @@ func {{$name}}(n *Node) {
dest := genValue(n)
c0, c1 := n.child[0], n.child[1]
switch t0, t1 := c0.typ, c1.typ; {
switch t0, t1 := c0.typ.TypeOf(), c1.typ.TypeOf(); {
case isString(t0) || isString(t1):
switch {
case c0.rval.IsValid():

View File

@@ -11,6 +11,19 @@ import (
// A CfgError represents an error during CFG build stage
type CfgError error
var constOp = map[Action]func(*Node){
Add: addConst,
Sub: subConst,
Mul: mulConst,
Quo: quoConst,
Rem: remConst,
And: andConst,
Or: orConst,
Shl: shlConst,
Shr: shrConst,
AndNot: andnotConst,
}
// Cfg generates a control flow graph (CFG) from AST (wiring successors in AST)
// and pre-compute frame sizes and indexes for all un-named (temporary) and named
// variables. A list of nodes of init functions is returned.
@@ -336,7 +349,7 @@ func (interp *Interpreter) Cfg(root *Node) ([]*Node, error) {
} else {
sym, level, _ = scope.lookup(dest.ident)
}
switch t0, t1 := dest.typ, src.typ; n.action {
switch t0, t1 := dest.typ.TypeOf(), src.typ.TypeOf(); n.action {
case AddAssign:
if !(isNumber(t0) && isNumber(t1) || isString(t0) && isString(t1)) || isInt(t0) && isFloat(t1) {
err = n.cfgError("illegal operand types for '%v' operator", n.action)
@@ -483,9 +496,10 @@ func (interp *Interpreter) Cfg(root *Node) ([]*Node, error) {
case BinaryExpr:
wireChild(n)
nilSym := interp.universe.sym["nil"]
t0, t1 := n.child[0].typ, n.child[1].typ
if !t0.untyped && !t1.untyped && t0.id() != t1.id() {
err = n.cfgError("mismatched types %s and %s", t0.id(), t1.id())
c0, c1 := n.child[0], n.child[1]
t0, t1 := c0.typ.TypeOf(), c1.typ.TypeOf()
if !c0.typ.untyped && !c1.typ.untyped && c0.typ.id() != c1.typ.id() {
err = n.cfgError("mismatched types %s and %s", c0.typ.id(), c1.typ.id())
break
}
switch n.action {
@@ -505,7 +519,7 @@ func (interp *Interpreter) Cfg(root *Node) ([]*Node, error) {
if !(isInt(t0) && isUint(t1)) {
err = n.cfgError("illegal operand types for '%v' operator", n.action)
}
n.typ = t0
n.typ = c0.typ
case Equal, NotEqual:
if isNumber(t0) && !isNumber(t1) || isString(t0) && !isString(t1) {
err = n.cfgError("illegal operand types for '%v' operator", n.action)
@@ -527,9 +541,19 @@ func (interp *Interpreter) Cfg(root *Node) ([]*Node, error) {
if err != nil {
break
}
if c0.rval.IsValid() && c1.rval.IsValid() && constOp[n.action] != nil {
if n.typ == nil {
n.typ, err = nodeType(interp, scope, n)
}
n.typ.TypeOf() // init reflect type
constOp[n.action](n)
}
switch {
//case n.typ != nil && n.typ.cat == BoolT && isAncBranch(n):
// n.findex = -1
case n.rval.IsValid():
n.gen = nop
n.findex = -1
case n.anc.kind == AssignStmt && n.anc.action == Assign:
dest := n.anc.child[childPos(n)-n.anc.nright]
n.typ = dest.typ
@@ -559,10 +583,12 @@ func (interp *Interpreter) Cfg(root *Node) ([]*Node, error) {
case BlockStmt:
wireChild(n)
if len(n.child) > 0 {
n.findex = n.lastChild().findex
n.val = n.lastChild().val
n.sym = n.lastChild().sym
n.typ = n.lastChild().typ
l := n.lastChild()
n.findex = l.findex
n.val = l.val
n.sym = l.sym
n.typ = l.typ
n.rval = l.rval
}
scope = scope.pop()
@@ -575,10 +601,12 @@ func (interp *Interpreter) Cfg(root *Node) ([]*Node, error) {
case DeclStmt, ExprStmt, SendStmt:
wireChild(n)
n.findex = n.lastChild().findex
n.val = n.lastChild().val
n.sym = n.lastChild().sym
n.typ = n.lastChild().typ
l := n.lastChild()
n.findex = l.findex
n.val = l.val
n.sym = l.sym
n.typ = l.typ
n.rval = l.rval
case Break:
if len(n.child) > 0 {
@@ -616,26 +644,27 @@ func (interp *Interpreter) Cfg(root *Node) ([]*Node, error) {
}
if len(n.child) == 3 {
if c2.typ.cat == ArrayT && c2.typ.val.id() == n.typ.val.id() ||
isByteArray(c1.typ) && isString(c2.typ) {
isByteArray(c1.typ.TypeOf()) && isString(c2.typ.TypeOf()) {
n.gen = appendSlice
}
}
case "cap", "copy", "len":
n.typ = scope.getType("int")
case "complex":
switch t0, t1 := n.child[1].typ, n.child[2].typ; {
c0, c1 := n.child[1], n.child[2]
switch t0, t1 := c0.typ.TypeOf(), c1.typ.TypeOf(); {
case isFloat32(t0) && isFloat32(t1):
n.typ = scope.getType("complex64")
case isFloat64(t0) && isFloat64(t1):
n.typ = scope.getType("complex128")
case isUntypedNumber(t0) && isUntypedNumber(t1):
case c0.typ.untyped && isNumber(t0) && c1.typ.untyped && isNumber(t1):
n.typ = &Type{cat: ValueT, rtype: complexType}
case isUntypedNumber(t0) && isFloat32(t1) || isUntypedNumber(t1) && isFloat32(t0):
case c0.typ.untyped && isFloat32(t1) || c1.typ.untyped && isFloat32(t0):
n.typ = scope.getType("complex64")
case isUntypedNumber(t0) && isFloat64(t1) || isUntypedNumber(t1) && isFloat64(t0):
case c0.typ.untyped && isFloat64(t1) || c1.typ.untyped && isFloat64(t0):
n.typ = scope.getType("complex128")
default:
err = n.cfgError("invalid types %s and %s", t0.TypeOf().Kind(), t1.TypeOf().Kind())
err = n.cfgError("invalid types %s and %s", t0.Kind(), t1.Kind())
}
case "real", "imag":
switch k := n.child[1].typ.TypeOf().Kind(); {
@@ -643,7 +672,7 @@ func (interp *Interpreter) Cfg(root *Node) ([]*Node, error) {
n.typ = scope.getType("float32")
case k == reflect.Complex128:
n.typ = scope.getType("float64")
case isUntypedNumber(n.child[1].typ):
case n.child[1].typ.untyped && isNumber(n.child[1].typ.TypeOf()):
n.typ = &Type{cat: ValueT, rtype: floatType}
default:
err = n.cfgError("invalid complex type %s", k)
@@ -668,7 +697,7 @@ func (interp *Interpreter) Cfg(root *Node) ([]*Node, error) {
}
case n.child[0].isType(scope):
// Type conversion expression
if isInt(n.child[0].typ) && n.child[1].kind == BasicLit && isFloat(n.child[1].typ) {
if isInt(n.child[0].typ.TypeOf()) && n.child[1].kind == BasicLit && isFloat(n.child[1].typ.TypeOf()) {
err = n.cfgError("truncated to integer")
}
n.gen = convert
@@ -916,8 +945,10 @@ func (interp *Interpreter) Cfg(root *Node) ([]*Node, error) {
case ParenExpr:
wireChild(n)
n.findex = n.lastChild().findex
n.typ = n.lastChild().typ
c := n.lastChild()
n.findex = c.findex
n.typ = c.typ
n.rval = c.rval
case RangeStmt:
if scope.rangeChanType(n) != nil {

View File

@@ -148,6 +148,24 @@ func add(n *Node) {
}
}
func addConst(n *Node) {
v0, v1 := n.child[0].rval, n.child[1].rval
t := n.typ.rtype
n.rval = reflect.New(t).Elem()
switch {
case isString(t):
n.rval.SetString(v0.String() + v1.String())
case isComplex(t):
n.rval.SetComplex(vComplex(v0) + vComplex(v1))
case isFloat(t):
n.rval.SetFloat(vFloat(v0) + vFloat(v1))
case isUint(t):
n.rval.SetUint(vUint(v0) + vUint(v1))
case isInt(t):
n.rval.SetInt(vInt(v0) + vInt(v1))
}
}
func and(n *Node) {
dest := genValue(n)
next := getExec(n.tnext)
@@ -214,6 +232,18 @@ func and(n *Node) {
}
}
func andConst(n *Node) {
v0, v1 := n.child[0].rval, n.child[1].rval
t := n.typ.rtype
n.rval = reflect.New(t).Elem()
switch {
case isUint(t):
n.rval.SetUint(vUint(v0) & vUint(v1))
case isInt(t):
n.rval.SetInt(vInt(v0) & vInt(v1))
}
}
func andnot(n *Node) {
dest := genValue(n)
next := getExec(n.tnext)
@@ -280,6 +310,18 @@ func andnot(n *Node) {
}
}
func andnotConst(n *Node) {
v0, v1 := n.child[0].rval, n.child[1].rval
t := n.typ.rtype
n.rval = reflect.New(t).Elem()
switch {
case isUint(t):
n.rval.SetUint(vUint(v0) &^ vUint(v1))
case isInt(t):
n.rval.SetInt(vInt(v0) &^ vInt(v1))
}
}
func mul(n *Node) {
dest := genValue(n)
next := getExec(n.tnext)
@@ -398,6 +440,22 @@ func mul(n *Node) {
}
}
func mulConst(n *Node) {
v0, v1 := n.child[0].rval, n.child[1].rval
t := n.typ.rtype
n.rval = reflect.New(t).Elem()
switch {
case isComplex(t):
n.rval.SetComplex(vComplex(v0) * vComplex(v1))
case isFloat(t):
n.rval.SetFloat(vFloat(v0) * vFloat(v1))
case isUint(t):
n.rval.SetUint(vUint(v0) * vUint(v1))
case isInt(t):
n.rval.SetInt(vInt(v0) * vInt(v1))
}
}
func or(n *Node) {
dest := genValue(n)
next := getExec(n.tnext)
@@ -464,6 +522,18 @@ func or(n *Node) {
}
}
func orConst(n *Node) {
v0, v1 := n.child[0].rval, n.child[1].rval
t := n.typ.rtype
n.rval = reflect.New(t).Elem()
switch {
case isUint(t):
n.rval.SetUint(vUint(v0) | vUint(v1))
case isInt(t):
n.rval.SetInt(vInt(v0) | vInt(v1))
}
}
func quo(n *Node) {
dest := genValue(n)
next := getExec(n.tnext)
@@ -582,6 +652,22 @@ func quo(n *Node) {
}
}
func quoConst(n *Node) {
v0, v1 := n.child[0].rval, n.child[1].rval
t := n.typ.rtype
n.rval = reflect.New(t).Elem()
switch {
case isComplex(t):
n.rval.SetComplex(vComplex(v0) / vComplex(v1))
case isFloat(t):
n.rval.SetFloat(vFloat(v0) / vFloat(v1))
case isUint(t):
n.rval.SetUint(vUint(v0) / vUint(v1))
case isInt(t):
n.rval.SetInt(vInt(v0) / vInt(v1))
}
}
func rem(n *Node) {
dest := genValue(n)
next := getExec(n.tnext)
@@ -648,6 +734,18 @@ func rem(n *Node) {
}
}
func remConst(n *Node) {
v0, v1 := n.child[0].rval, n.child[1].rval
t := n.typ.rtype
n.rval = reflect.New(t).Elem()
switch {
case isUint(t):
n.rval.SetUint(vUint(v0) % vUint(v1))
case isInt(t):
n.rval.SetInt(vInt(v0) % vInt(v1))
}
}
func shl(n *Node) {
dest := genValue(n)
next := getExec(n.tnext)
@@ -714,6 +812,18 @@ func shl(n *Node) {
}
}
func shlConst(n *Node) {
v0, v1 := n.child[0].rval, n.child[1].rval
t := n.typ.rtype
n.rval = reflect.New(t).Elem()
switch {
case isUint(t):
n.rval.SetUint(vUint(v0) << vUint(v1))
case isInt(t):
n.rval.SetInt(vInt(v0) << vUint(v1))
}
}
func shr(n *Node) {
dest := genValue(n)
next := getExec(n.tnext)
@@ -780,6 +890,18 @@ func shr(n *Node) {
}
}
func shrConst(n *Node) {
v0, v1 := n.child[0].rval, n.child[1].rval
t := n.typ.rtype
n.rval = reflect.New(t).Elem()
switch {
case isUint(t):
n.rval.SetUint(vUint(v0) >> vUint(v1))
case isInt(t):
n.rval.SetInt(vInt(v0) >> vUint(v1))
}
}
func sub(n *Node) {
dest := genValue(n)
next := getExec(n.tnext)
@@ -898,6 +1020,22 @@ func sub(n *Node) {
}
}
func subConst(n *Node) {
v0, v1 := n.child[0].rval, n.child[1].rval
t := n.typ.rtype
n.rval = reflect.New(t).Elem()
switch {
case isComplex(t):
n.rval.SetComplex(vComplex(v0) - vComplex(v1))
case isFloat(t):
n.rval.SetFloat(vFloat(v0) - vFloat(v1))
case isUint(t):
n.rval.SetUint(vUint(v0) - vUint(v1))
case isInt(t):
n.rval.SetInt(vInt(v0) - vInt(v1))
}
}
func xor(n *Node) {
dest := genValue(n)
next := getExec(n.tnext)
@@ -964,6 +1102,18 @@ func xor(n *Node) {
}
}
func xorConst(n *Node) {
v0, v1 := n.child[0].rval, n.child[1].rval
t := n.typ.rtype
n.rval = reflect.New(t).Elem()
switch {
case isUint(t):
n.rval.SetUint(vUint(v0) ^ vUint(v1))
case isInt(t):
n.rval.SetInt(vInt(v0) ^ vInt(v1))
}
}
// Assign operators
func addAssign(n *Node) {
@@ -1719,7 +1869,7 @@ func equal(n *Node) {
dest := genValue(n)
c0, c1 := n.child[0], n.child[1]
switch t0, t1 := c0.typ, c1.typ; {
switch t0, t1 := c0.typ.TypeOf(), c1.typ.TypeOf(); {
case isString(t0) || isString(t1):
switch {
case c0.rval.IsValid():
@@ -2077,7 +2227,7 @@ func greater(n *Node) {
dest := genValue(n)
c0, c1 := n.child[0], n.child[1]
switch t0, t1 := c0.typ, c1.typ; {
switch t0, t1 := c0.typ.TypeOf(), c1.typ.TypeOf(); {
case isString(t0) || isString(t1):
switch {
case c0.rval.IsValid():
@@ -2366,7 +2516,7 @@ func greaterEqual(n *Node) {
dest := genValue(n)
c0, c1 := n.child[0], n.child[1]
switch t0, t1 := c0.typ, c1.typ; {
switch t0, t1 := c0.typ.TypeOf(), c1.typ.TypeOf(); {
case isString(t0) || isString(t1):
switch {
case c0.rval.IsValid():
@@ -2655,7 +2805,7 @@ func lower(n *Node) {
dest := genValue(n)
c0, c1 := n.child[0], n.child[1]
switch t0, t1 := c0.typ, c1.typ; {
switch t0, t1 := c0.typ.TypeOf(), c1.typ.TypeOf(); {
case isString(t0) || isString(t1):
switch {
case c0.rval.IsValid():
@@ -2944,7 +3094,7 @@ func lowerEqual(n *Node) {
dest := genValue(n)
c0, c1 := n.child[0], n.child[1]
switch t0, t1 := c0.typ, c1.typ; {
switch t0, t1 := c0.typ.TypeOf(), c1.typ.TypeOf(); {
case isString(t0) || isString(t1):
switch {
case c0.rval.IsValid():
@@ -3233,7 +3383,7 @@ func notEqual(n *Node) {
dest := genValue(n)
c0, c1 := n.child[0], n.child[1]
switch t0, t1 := c0.typ, c1.typ; {
switch t0, t1 := c0.typ.TypeOf(), c1.typ.TypeOf(); {
case isString(t0) || isString(t1):
switch {
case c0.rval.IsValid():

View File

@@ -1621,7 +1621,7 @@ func appendSlice(n *Node) {
value := genValue(n.child[1])
value0 := genValue(n.child[2])
if isString(n.child[2].typ) {
if isString(n.child[2].typ.TypeOf()) {
typ := reflect.TypeOf([]byte{})
n.exec = func(f *Frame) Builtin {
dest(f).Set(reflect.AppendSlice(value(f), value0(f).Convert(typ)))

View File

@@ -152,6 +152,11 @@ func nodeType(interp *Interpreter, scope *Scope, n *Node) (*Type, error) {
t.incomplete = true
}
} else {
// Evaluate constant array size expression
_, err = interp.Cfg(n.child[0])
if err != nil {
return nil, err
}
t.incomplete = true
}
}
@@ -222,7 +227,7 @@ func nodeType(interp *Interpreter, scope *Scope, n *Node) (*Type, error) {
if t.untyped {
var t1 *Type
t1, err = nodeType(interp, scope, n.child[1])
if !(t1.untyped && isInt(t1) && isFloat(t)) {
if !(t1.untyped && isInt(t1.TypeOf()) && isFloat(t.TypeOf())) {
t = t1
}
}
@@ -690,46 +695,44 @@ func isShiftOperand(n *Node) bool {
func isStruct(t *Type) bool { return t.TypeOf().Kind() == reflect.Struct }
func isInt(t *Type) bool {
switch t.TypeOf().Kind() {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
func isInt(t reflect.Type) bool {
switch t.Kind() {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
return true
}
return false
}
func isUint(t *Type) bool {
switch t.TypeOf().Kind() {
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
func isUint(t reflect.Type) bool {
switch t.Kind() {
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
return true
}
return false
}
func isComplex(t *Type) bool {
switch t.TypeOf().Kind() {
func isComplex(t reflect.Type) bool {
switch t.Kind() {
case reflect.Complex64, reflect.Complex128:
return true
}
return false
}
func isFloat(t *Type) bool {
switch t.TypeOf().Kind() {
func isFloat(t reflect.Type) bool {
switch t.Kind() {
case reflect.Float32, reflect.Float64:
return true
}
return false
}
func isByteArray(t *Type) bool {
r := t.TypeOf()
k := r.Kind()
return (k == reflect.Array || k == reflect.Slice) && r.Elem().Kind() == reflect.Uint8
func isByteArray(t reflect.Type) bool {
k := t.Kind()
return (k == reflect.Array || k == reflect.Slice) && t.Elem().Kind() == reflect.Uint8
}
func isFloat32(t *Type) bool { return t.TypeOf().Kind() == reflect.Float32 }
func isFloat64(t *Type) bool { return t.TypeOf().Kind() == reflect.Float64 }
func isUntypedNumber(t *Type) bool { return t.untyped && (isInt(t) || isFloat(t) || isComplex(t)) }
func isNumber(t *Type) bool { return isInt(t) || isFloat(t) || isComplex(t) }
func isString(t *Type) bool { return t.TypeOf().Kind() == reflect.String }
func isFloat32(t reflect.Type) bool { return t.Kind() == reflect.Float32 }
func isFloat64(t reflect.Type) bool { return t.Kind() == reflect.Float64 }
func isNumber(t reflect.Type) bool { return isInt(t) || isFloat(t) || isComplex(t) }
func isString(t reflect.Type) bool { return t.Kind() == reflect.String }