From 39430c34bb50aa2fd91a0d0251e920b6800d35f4 Mon Sep 17 00:00:00 2001 From: Marc Vertes Date: Wed, 1 Jul 2020 14:39:47 +0200 Subject: [PATCH] fix: untyped constant converson to default type (#729) * fix: untyped constant cconverson to default type In definition assign expression, the source type is propagated to the assigned value. If the source is an untyped constant, the destination type must be set to the default type of the constant definition. A fixType function is provided to perform this. In addition, the type conversion and check of constants is refactored for simplifications. Fixes #727. * test: fix _test/const14.go --- _test/const14.go | 13 +++ interp/cfg.go | 5 +- interp/run.go | 221 ++++++++++++----------------------------------- interp/scope.go | 18 ++++ 4 files changed, 89 insertions(+), 168 deletions(-) create mode 100644 _test/const14.go diff --git a/_test/const14.go b/_test/const14.go new file mode 100644 index 00000000..835858f7 --- /dev/null +++ b/_test/const14.go @@ -0,0 +1,13 @@ +package main + +import "compress/flate" + +func f1(i int) { println("i:", i) } + +func main() { + i := flate.BestSpeed + f1(i) +} + +// Output: +// i: 1 diff --git a/interp/cfg.go b/interp/cfg.go index e39be2e9..f2b9741f 100644 --- a/interp/cfg.go +++ b/interp/cfg.go @@ -470,7 +470,10 @@ func (interp *Interpreter) cfg(root *node, pkgID string) ([]*node, error) { if src.typ.isBinMethod { dest.typ = &itype{cat: valueT, rtype: src.typ.methodCallType()} } else { - dest.typ = src.typ + // In a new definition, propagate the source type to the destination + // type. If the source is an untyped constant, make sure that the + // type matches a default type. + dest.typ = sc.fixType(src.typ) } } if dest.typ.sizedef { diff --git a/interp/run.go b/interp/run.go index abfd7d67..28a264a0 100644 --- a/interp/run.go +++ b/interp/run.go @@ -2703,11 +2703,7 @@ func convertLiteralValue(n *node, t reflect.Type) { // Skip non-constant values, undefined target type or interface target type. case n.rval.IsValid(): // Convert constant value to target type. - if n.typ != nil && n.typ.cat != valueT { - convertConstantValue(n) - } else { - convertConstantValueTo(n, t) - } + convertConstantValue(n) n.rval = n.rval.Convert(t) default: // Create a zero value of target type. @@ -2715,6 +2711,20 @@ func convertLiteralValue(n *node, t reflect.Type) { } } +var bitlen = [...]int{ + reflect.Int: 64, + reflect.Int8: 8, + reflect.Int16: 16, + reflect.Int32: 32, + reflect.Int64: 64, + reflect.Uint: 64, + reflect.Uint8: 8, + reflect.Uint16: 16, + reflect.Uint32: 32, + reflect.Uint64: 64, + reflect.Uintptr: 64, +} + func convertConstantValue(n *node) { if !n.rval.IsValid() { return @@ -2723,187 +2733,64 @@ func convertConstantValue(n *node) { if !ok { return } - t := n.typ - for t != nil && t.cat == aliasT { - // If it is an alias, get the actual type - t = t.val - } v := n.rval - switch t.cat { - case intT, int8T, int16T, int32T, int64T: - i, _ := constant.Int64Val(c) - l := constant.BitLen(c) - switch t.cat { - case intT: - if l > 64 { - panic(fmt.Sprintf("constant %s overflows int", c.ExactString())) - } - v = reflect.ValueOf(int(i)) - case int8T: - if l > 8 { - panic(fmt.Sprintf("constant %s overflows int8", c.ExactString())) - } - v = reflect.ValueOf(int8(i)) - case int16T: - if l > 16 { - panic(fmt.Sprintf("constant %s overflows int16", c.ExactString())) - } - v = reflect.ValueOf(int16(i)) - case int32T: - if l > 32 { - panic(fmt.Sprintf("constant %s overflows int32", c.ExactString())) - } - v = reflect.ValueOf(int32(i)) - case int64T: - if l > 64 { - panic(fmt.Sprintf("constant %s overflows int64", c.ExactString())) - } - v = reflect.ValueOf(i) - } - case uintT, uint8T, uint16T, uint32T, uint64T: - i, _ := constant.Uint64Val(c) - l := constant.BitLen(c) - switch t.cat { - case uintT: - if l > 64 { - panic(fmt.Sprintf("constant %s overflows uint", c.ExactString())) - } - v = reflect.ValueOf(uint(i)) - case uint8T: - if l > 8 { - panic(fmt.Sprintf("constant %s overflows uint8", c.ExactString())) - } - v = reflect.ValueOf(uint8(i)) - case uint16T: - if l > 16 { - panic(fmt.Sprintf("constant %s overflows uint16", c.ExactString())) - } - v = reflect.ValueOf(uint16(i)) - case uint32T: - if l > 32 { - panic(fmt.Sprintf("constant %s overflows uint32", c.ExactString())) - } - v = reflect.ValueOf(uint32(i)) - case uint64T: - if l > 64 { - panic(fmt.Sprintf("constant %s overflows uint64", c.ExactString())) - } - v = reflect.ValueOf(i) - case uintptrT: - if l > 64 { - panic(fmt.Sprintf("constant %s overflows uintptr", c.ExactString())) - } - v = reflect.ValueOf(i) - } - case float32T: - f, _ := constant.Float32Val(c) - v = reflect.ValueOf(f) - case float64T: - f, _ := constant.Float64Val(c) - v = reflect.ValueOf(f) - case complex64T: - r, _ := constant.Float32Val(constant.Real(c)) - i, _ := constant.Float32Val(constant.Imag(c)) - v = reflect.ValueOf(complex(r, i)) - case complex128T: - r, _ := constant.Float64Val(constant.Real(c)) - i, _ := constant.Float64Val(constant.Imag(c)) - v = reflect.ValueOf(complex(r, i)) - } - n.rval = v -} - -func convertConstantValueTo(n *node, typ reflect.Type) { - if !n.rval.IsValid() { - return - } - c, ok := n.rval.Interface().(constant.Value) - if !ok { - return - } - - v := n.rval - switch typ.Kind() { + typ := n.typ.TypeOf() + kind := typ.Kind() + switch kind { + case reflect.Bool: + v = reflect.ValueOf(constant.BoolVal(c)).Convert(typ) + case reflect.String: + v = reflect.ValueOf(constant.StringVal(c)).Convert(typ) case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: i, _ := constant.Int64Val(c) l := constant.BitLen(c) - switch typ.Kind() { - case reflect.Int: - if l > 64 { - panic(fmt.Sprintf("constant %s overflows int", c.ExactString())) - } - v = reflect.ValueOf(int(i)) - case reflect.Int8: - if l > 8 { - panic(fmt.Sprintf("constant %s overflows int8", c.ExactString())) - } - v = reflect.ValueOf(int8(i)) - case reflect.Int16: - if l > 16 { - panic(fmt.Sprintf("constant %s overflows int16", c.ExactString())) - } - v = reflect.ValueOf(int16(i)) - case reflect.Int32: - if l > 32 { - panic(fmt.Sprintf("constant %s overflows int32", c.ExactString())) - } - v = reflect.ValueOf(int32(i)) - case reflect.Int64: - if l > 64 { - panic(fmt.Sprintf("constant %s overflows int64", c.ExactString())) - } - v = reflect.ValueOf(i) + if l > bitlen[kind] { + panic(fmt.Sprintf("constant %s overflows int%d", c.ExactString(), bitlen[kind])) } - case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + v = reflect.ValueOf(i).Convert(typ) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: i, _ := constant.Uint64Val(c) l := constant.BitLen(c) - switch typ.Kind() { - case reflect.Uint: - if l > 64 { - panic(fmt.Sprintf("constant %s overflows uint", c.ExactString())) - } - v = reflect.ValueOf(uint(i)) - case reflect.Uint8: - if l > 8 { - panic(fmt.Sprintf("constant %s overflows uint8", c.ExactString())) - } - v = reflect.ValueOf(uint8(i)) - case reflect.Uint16: - if l > 16 { - panic(fmt.Sprintf("constant %s overflows uint16", c.ExactString())) - } - v = reflect.ValueOf(uint16(i)) - case reflect.Uint32: - if l > 32 { - panic(fmt.Sprintf("constant %s overflows uint32", c.ExactString())) - } - v = reflect.ValueOf(uint32(i)) - case reflect.Uint64: - if l > 64 { - panic(fmt.Sprintf("constant %s overflows uint64", c.ExactString())) - } - v = reflect.ValueOf(i) - case reflect.Uintptr: - if l > 64 { - panic(fmt.Sprintf("constant %s overflows uintptr", c.ExactString())) - } - v = reflect.ValueOf(i) + if l > bitlen[kind] { + panic(fmt.Sprintf("constant %s overflows uint%d", c.ExactString(), bitlen[kind])) } + v = reflect.ValueOf(i).Convert(typ) case reflect.Float32: f, _ := constant.Float32Val(c) - v = reflect.ValueOf(f) + v = reflect.ValueOf(f).Convert(typ) case reflect.Float64: f, _ := constant.Float64Val(c) - v = reflect.ValueOf(f) + v = reflect.ValueOf(f).Convert(typ) case reflect.Complex64: r, _ := constant.Float32Val(constant.Real(c)) i, _ := constant.Float32Val(constant.Imag(c)) - v = reflect.ValueOf(complex(r, i)) + v = reflect.ValueOf(complex(r, i)).Convert(typ) case reflect.Complex128: r, _ := constant.Float64Val(constant.Real(c)) i, _ := constant.Float64Val(constant.Imag(c)) - v = reflect.ValueOf(complex(r, i)) + v = reflect.ValueOf(complex(r, i)).Convert(typ) + default: + // Type kind is from internal constant representation. Only use default types here. + switch c.Kind() { + case constant.Bool: + v = reflect.ValueOf(constant.BoolVal(c)) + case constant.String: + v = reflect.ValueOf(constant.StringVal(c)) + case constant.Int: + i, x := constant.Int64Val(c) + if !x { + panic(fmt.Sprintf("constant %s overflows int64", c.ExactString())) + } + v = reflect.ValueOf(int(i)) + case constant.Float: + f, _ := constant.Float64Val(c) + v = reflect.ValueOf(f) + case constant.Complex: + r, _ := constant.Float64Val(constant.Real(c)) + i, _ := constant.Float64Val(constant.Imag(c)) + v = reflect.ValueOf(complex(r, i)) + } } n.rval = v } diff --git a/interp/scope.go b/interp/scope.go index 66daaae7..a8b59892 100644 --- a/interp/scope.go +++ b/interp/scope.go @@ -153,6 +153,24 @@ func (s *scope) rangeChanType(n *node) *itype { return nil } +// fixType returns the input type, or a valid default type for untyped constant. +func (s *scope) fixType(t *itype) *itype { + if !t.untyped || t.cat != valueT { + return t + } + switch typ := t.TypeOf(); typ.Kind() { + case reflect.Int64: + return s.getType("int") + case reflect.Uint64: + return s.getType("uint") + case reflect.Float64: + return s.getType("float64") + case reflect.Complex128: + return s.getType("complex128") + } + return t +} + func (s *scope) getType(ident string) *itype { var t *itype if sym, _, found := s.lookup(ident); found {