diff --git a/_test/const14.go b/_test/const14.go new file mode 100644 index 00000000..835858f7 --- /dev/null +++ b/_test/const14.go @@ -0,0 +1,13 @@ +package main + +import "compress/flate" + +func f1(i int) { println("i:", i) } + +func main() { + i := flate.BestSpeed + f1(i) +} + +// Output: +// i: 1 diff --git a/interp/cfg.go b/interp/cfg.go index e39be2e9..f2b9741f 100644 --- a/interp/cfg.go +++ b/interp/cfg.go @@ -470,7 +470,10 @@ func (interp *Interpreter) cfg(root *node, pkgID string) ([]*node, error) { if src.typ.isBinMethod { dest.typ = &itype{cat: valueT, rtype: src.typ.methodCallType()} } else { - dest.typ = src.typ + // In a new definition, propagate the source type to the destination + // type. If the source is an untyped constant, make sure that the + // type matches a default type. + dest.typ = sc.fixType(src.typ) } } if dest.typ.sizedef { diff --git a/interp/run.go b/interp/run.go index abfd7d67..28a264a0 100644 --- a/interp/run.go +++ b/interp/run.go @@ -2703,11 +2703,7 @@ func convertLiteralValue(n *node, t reflect.Type) { // Skip non-constant values, undefined target type or interface target type. case n.rval.IsValid(): // Convert constant value to target type. - if n.typ != nil && n.typ.cat != valueT { - convertConstantValue(n) - } else { - convertConstantValueTo(n, t) - } + convertConstantValue(n) n.rval = n.rval.Convert(t) default: // Create a zero value of target type. @@ -2715,6 +2711,20 @@ func convertLiteralValue(n *node, t reflect.Type) { } } +var bitlen = [...]int{ + reflect.Int: 64, + reflect.Int8: 8, + reflect.Int16: 16, + reflect.Int32: 32, + reflect.Int64: 64, + reflect.Uint: 64, + reflect.Uint8: 8, + reflect.Uint16: 16, + reflect.Uint32: 32, + reflect.Uint64: 64, + reflect.Uintptr: 64, +} + func convertConstantValue(n *node) { if !n.rval.IsValid() { return @@ -2723,187 +2733,64 @@ func convertConstantValue(n *node) { if !ok { return } - t := n.typ - for t != nil && t.cat == aliasT { - // If it is an alias, get the actual type - t = t.val - } v := n.rval - switch t.cat { - case intT, int8T, int16T, int32T, int64T: - i, _ := constant.Int64Val(c) - l := constant.BitLen(c) - switch t.cat { - case intT: - if l > 64 { - panic(fmt.Sprintf("constant %s overflows int", c.ExactString())) - } - v = reflect.ValueOf(int(i)) - case int8T: - if l > 8 { - panic(fmt.Sprintf("constant %s overflows int8", c.ExactString())) - } - v = reflect.ValueOf(int8(i)) - case int16T: - if l > 16 { - panic(fmt.Sprintf("constant %s overflows int16", c.ExactString())) - } - v = reflect.ValueOf(int16(i)) - case int32T: - if l > 32 { - panic(fmt.Sprintf("constant %s overflows int32", c.ExactString())) - } - v = reflect.ValueOf(int32(i)) - case int64T: - if l > 64 { - panic(fmt.Sprintf("constant %s overflows int64", c.ExactString())) - } - v = reflect.ValueOf(i) - } - case uintT, uint8T, uint16T, uint32T, uint64T: - i, _ := constant.Uint64Val(c) - l := constant.BitLen(c) - switch t.cat { - case uintT: - if l > 64 { - panic(fmt.Sprintf("constant %s overflows uint", c.ExactString())) - } - v = reflect.ValueOf(uint(i)) - case uint8T: - if l > 8 { - panic(fmt.Sprintf("constant %s overflows uint8", c.ExactString())) - } - v = reflect.ValueOf(uint8(i)) - case uint16T: - if l > 16 { - panic(fmt.Sprintf("constant %s overflows uint16", c.ExactString())) - } - v = reflect.ValueOf(uint16(i)) - case uint32T: - if l > 32 { - panic(fmt.Sprintf("constant %s overflows uint32", c.ExactString())) - } - v = reflect.ValueOf(uint32(i)) - case uint64T: - if l > 64 { - panic(fmt.Sprintf("constant %s overflows uint64", c.ExactString())) - } - v = reflect.ValueOf(i) - case uintptrT: - if l > 64 { - panic(fmt.Sprintf("constant %s overflows uintptr", c.ExactString())) - } - v = reflect.ValueOf(i) - } - case float32T: - f, _ := constant.Float32Val(c) - v = reflect.ValueOf(f) - case float64T: - f, _ := constant.Float64Val(c) - v = reflect.ValueOf(f) - case complex64T: - r, _ := constant.Float32Val(constant.Real(c)) - i, _ := constant.Float32Val(constant.Imag(c)) - v = reflect.ValueOf(complex(r, i)) - case complex128T: - r, _ := constant.Float64Val(constant.Real(c)) - i, _ := constant.Float64Val(constant.Imag(c)) - v = reflect.ValueOf(complex(r, i)) - } - n.rval = v -} - -func convertConstantValueTo(n *node, typ reflect.Type) { - if !n.rval.IsValid() { - return - } - c, ok := n.rval.Interface().(constant.Value) - if !ok { - return - } - - v := n.rval - switch typ.Kind() { + typ := n.typ.TypeOf() + kind := typ.Kind() + switch kind { + case reflect.Bool: + v = reflect.ValueOf(constant.BoolVal(c)).Convert(typ) + case reflect.String: + v = reflect.ValueOf(constant.StringVal(c)).Convert(typ) case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: i, _ := constant.Int64Val(c) l := constant.BitLen(c) - switch typ.Kind() { - case reflect.Int: - if l > 64 { - panic(fmt.Sprintf("constant %s overflows int", c.ExactString())) - } - v = reflect.ValueOf(int(i)) - case reflect.Int8: - if l > 8 { - panic(fmt.Sprintf("constant %s overflows int8", c.ExactString())) - } - v = reflect.ValueOf(int8(i)) - case reflect.Int16: - if l > 16 { - panic(fmt.Sprintf("constant %s overflows int16", c.ExactString())) - } - v = reflect.ValueOf(int16(i)) - case reflect.Int32: - if l > 32 { - panic(fmt.Sprintf("constant %s overflows int32", c.ExactString())) - } - v = reflect.ValueOf(int32(i)) - case reflect.Int64: - if l > 64 { - panic(fmt.Sprintf("constant %s overflows int64", c.ExactString())) - } - v = reflect.ValueOf(i) + if l > bitlen[kind] { + panic(fmt.Sprintf("constant %s overflows int%d", c.ExactString(), bitlen[kind])) } - case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + v = reflect.ValueOf(i).Convert(typ) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: i, _ := constant.Uint64Val(c) l := constant.BitLen(c) - switch typ.Kind() { - case reflect.Uint: - if l > 64 { - panic(fmt.Sprintf("constant %s overflows uint", c.ExactString())) - } - v = reflect.ValueOf(uint(i)) - case reflect.Uint8: - if l > 8 { - panic(fmt.Sprintf("constant %s overflows uint8", c.ExactString())) - } - v = reflect.ValueOf(uint8(i)) - case reflect.Uint16: - if l > 16 { - panic(fmt.Sprintf("constant %s overflows uint16", c.ExactString())) - } - v = reflect.ValueOf(uint16(i)) - case reflect.Uint32: - if l > 32 { - panic(fmt.Sprintf("constant %s overflows uint32", c.ExactString())) - } - v = reflect.ValueOf(uint32(i)) - case reflect.Uint64: - if l > 64 { - panic(fmt.Sprintf("constant %s overflows uint64", c.ExactString())) - } - v = reflect.ValueOf(i) - case reflect.Uintptr: - if l > 64 { - panic(fmt.Sprintf("constant %s overflows uintptr", c.ExactString())) - } - v = reflect.ValueOf(i) + if l > bitlen[kind] { + panic(fmt.Sprintf("constant %s overflows uint%d", c.ExactString(), bitlen[kind])) } + v = reflect.ValueOf(i).Convert(typ) case reflect.Float32: f, _ := constant.Float32Val(c) - v = reflect.ValueOf(f) + v = reflect.ValueOf(f).Convert(typ) case reflect.Float64: f, _ := constant.Float64Val(c) - v = reflect.ValueOf(f) + v = reflect.ValueOf(f).Convert(typ) case reflect.Complex64: r, _ := constant.Float32Val(constant.Real(c)) i, _ := constant.Float32Val(constant.Imag(c)) - v = reflect.ValueOf(complex(r, i)) + v = reflect.ValueOf(complex(r, i)).Convert(typ) case reflect.Complex128: r, _ := constant.Float64Val(constant.Real(c)) i, _ := constant.Float64Val(constant.Imag(c)) - v = reflect.ValueOf(complex(r, i)) + v = reflect.ValueOf(complex(r, i)).Convert(typ) + default: + // Type kind is from internal constant representation. Only use default types here. + switch c.Kind() { + case constant.Bool: + v = reflect.ValueOf(constant.BoolVal(c)) + case constant.String: + v = reflect.ValueOf(constant.StringVal(c)) + case constant.Int: + i, x := constant.Int64Val(c) + if !x { + panic(fmt.Sprintf("constant %s overflows int64", c.ExactString())) + } + v = reflect.ValueOf(int(i)) + case constant.Float: + f, _ := constant.Float64Val(c) + v = reflect.ValueOf(f) + case constant.Complex: + r, _ := constant.Float64Val(constant.Real(c)) + i, _ := constant.Float64Val(constant.Imag(c)) + v = reflect.ValueOf(complex(r, i)) + } } n.rval = v } diff --git a/interp/scope.go b/interp/scope.go index 66daaae7..a8b59892 100644 --- a/interp/scope.go +++ b/interp/scope.go @@ -153,6 +153,24 @@ func (s *scope) rangeChanType(n *node) *itype { return nil } +// fixType returns the input type, or a valid default type for untyped constant. +func (s *scope) fixType(t *itype) *itype { + if !t.untyped || t.cat != valueT { + return t + } + switch typ := t.TypeOf(); typ.Kind() { + case reflect.Int64: + return s.getType("int") + case reflect.Uint64: + return s.getType("uint") + case reflect.Float64: + return s.getType("float64") + case reflect.Complex128: + return s.getType("complex128") + } + return t +} + func (s *scope) getType(ident string) *itype { var t *itype if sym, _, found := s.lookup(ident); found {