diff --git a/mystenbcs/decode.go b/mystenbcs/decode.go new file mode 100644 index 0000000..9d00b78 --- /dev/null +++ b/mystenbcs/decode.go @@ -0,0 +1,358 @@ +package mystenbcs + +// https://github.com/fardream/go-bcs/blob/main/bcs/decode.go#L19 + +import ( + "bytes" + "encoding/binary" + "fmt" + "io" + "reflect" +) + +// Unmarshal unmarshals the bcs serialized data into v. +// +// Refer to notes in [Marshal] for details how data serialized/deserialized. +// +// During the unmarshalling process +// 1. if [Unmarshaler], use "UnmarshalBCS" method. +// 2. if not [Unmarshaler] but [Enum], use the specialization for [Enum]. +// 3. otherwise standard process. +func Unmarshal(data []byte, v any) (int, error) { + return NewDecoder(bytes.NewReader(data)).Decode(v) +} + +// Decoder takes an [io.Reader] and decodes value from it. +type Decoder struct { + reader io.Reader + byteBuffer [1]byte +} + +// NewDecoder creates a new [Decoder] from an [io.Reader] +func NewDecoder(r io.Reader) *Decoder { + return &Decoder{ + reader: r, + } +} + +// DecodeWithSize decodes a value from the decoder, and returns the number of bytes it consumed from the decoder. +// +// - If the value is [Unmarshaler], the corresponding UnmarshalBCS will be called. +// - If the value is [Enum], it will be special handled for [Enum] +func (d *Decoder) Decode(v any) (int, error) { + reflectValue := reflect.ValueOf(v) + if reflectValue.Kind() != reflect.Pointer || reflectValue.IsNil() { + return 0, fmt.Errorf("not a pointer or nil pointer") + } + + return d.decode(reflectValue) +} + +// decode is the main lifter, it first checks if a value can be [reflect.Value.CanInterface], +// then checks if the value implements [Unmarshaler] or [Enum], and then switch on the kind of the value: +// - pointer, create a new one and decode into its element. +// - interface, decode into element. +// - function, channel, unsafe pointers, ignore +// - otherwise call [decodeVanilla]. +func (d *Decoder) decode(v reflect.Value) (int, error) { + // if v cannot interface, ignore + if !v.CanInterface() { + return 0, nil + } + + // Unmarshaler + if i, isUnmarshaler := v.Interface().(Unmarshaler); isUnmarshaler { + return i.UnmarshalBCS(d.reader) + } + + // Enum + if _, isEnum := v.Interface().(Enum); isEnum { + switch v.Kind() { + case reflect.Pointer, reflect.Interface: + if v.IsNil() { + return 0, fmt.Errorf("trying to decode into nil pointer/interface") + } + return d.decodeEnum(v.Elem()) + default: + return d.decodeEnum(v) + } + } + + // switch kind + switch v.Kind() { + case reflect.Pointer: + if v.IsNil() { + v.Set(reflect.New(v.Type().Elem())) + } + return d.decode(v.Elem()) + + case reflect.Interface: + if v.IsNil() { + return 0, fmt.Errorf("cannot decode into nil interface") + } + return d.decode(v.Elem()) + + case reflect.Chan, reflect.Func, reflect.Uintptr, reflect.UnsafePointer: + // silently ignore + return 0, nil + default: + return d.decodeVanilla(v) + } +} + +// decodeVanilla decodes bool, ints, slice, struct, array, and string. +func (d *Decoder) decodeVanilla(v reflect.Value) (int, error) { + kind := v.Kind() + if !v.CanSet() { + return 0, fmt.Errorf("cannot change value of kind %s", kind.String()) + } + + switch kind { + case reflect.Bool: + t, n, err := d.readByte() + if err != nil { + return n, err + } + + if t == 0 { + v.SetBool(false) + } else { + v.SetBool(true) + } + + return n, nil + + case reflect.Int8, reflect.Uint8: + return 1, binary.Read(d.reader, binary.LittleEndian, v.Addr().Interface()) + case reflect.Int16, reflect.Uint16: + return 2, binary.Read(d.reader, binary.LittleEndian, v.Addr().Interface()) + case reflect.Int32, reflect.Uint32: + return 4, binary.Read(d.reader, binary.LittleEndian, v.Addr().Interface()) + case reflect.Int64, reflect.Uint64: + return 8, binary.Read(d.reader, binary.LittleEndian, v.Addr().Interface()) + + case reflect.Struct: + return d.decodeStruct(v) + + case reflect.Slice: + sliceType := v.Type().Elem() + if sliceType.Kind() == reflect.Uint8 { + return d.decodeByteSlice(v) + } + + return d.decodeSlice(v) + + case reflect.Array: + return d.decodeArray(v) + + case reflect.String: + return d.decodeString(v) + + default: + return 0, fmt.Errorf("unsupported vanilla decoding type: %s", kind.String()) + } +} + +// decodeString +func (d *Decoder) decodeString(v reflect.Value) (int, error) { + size, n, err := ULEB128Decode[int](d.reader) + if err != nil { + return n, err + } + + if size == 0 { + v.SetString("") + return n, nil + } + + tmp := make([]byte, size) + + read, err := d.reader.Read(tmp) + n += read + if err != nil { + return n, err + } + + if size != read { + return n, fmt.Errorf("wrong number of bytes read for string, want: %d, got %d", size, read) + } + + v.SetString(string(tmp)) + + return n, nil +} + +// readByte reads one byte from the input, error if no byte is read. +func (d *Decoder) readByte() (byte, int, error) { + b := d.byteBuffer[:] + n, err := d.reader.Read(b) + if err != nil { + return 0, n, err + } + if n == 0 { + return 0, n, io.ErrUnexpectedEOF + } + + return b[0], n, nil +} + +func (d *Decoder) decodeStruct(v reflect.Value) (int, error) { + t := v.Type() + + var n int + +fieldLoop: + for i := 0; i < v.NumField(); i++ { + field := v.Field(i) + if !field.CanInterface() { + continue fieldLoop + } + tag, err := parseTagValue(t.Field(i).Tag.Get(tagName)) + if err != nil { + return n, err + } + + switch { + case tag.isIgnored(): // ignored + continue fieldLoop + case tag.isOptional(): // optional + isOptional, k, err := d.readByte() + n += k + if err != nil { + return n, err + } + if isOptional == 0 { + field.Set(reflect.Zero(v.Type())) + } else { + field.Set(reflect.New(field.Type().Elem())) + k, err := d.decode(field.Elem()) + n += k + if err != nil { + return n, err + } + } + default: + k, err := d.decode(field) + n += k + if err != nil { + return n, err + } + } + } + + return n, nil +} + +func (d *Decoder) decodeEnum(v reflect.Value) (int, error) { + if v.Kind() != reflect.Struct { + return 0, fmt.Errorf("only support struct for Enum, got %s", v.Kind().String()) + } + enumId, n, err := ULEB128Decode[int](d.reader) + if err != nil { + return n, err + } + + field := v.Field(enumId) + + k, err := d.decode(field) + n += k + + return n, err +} + +func (d *Decoder) decodeByteSlice(v reflect.Value) (int, error) { + size, n, err := ULEB128Decode[int](d.reader) + if err != nil { + return n, err + } + + if size == 0 { + return n, nil + } + + tmp := make([]byte, size) + + read, err := d.reader.Read(tmp) + n += read + if err != nil { + return n, err + } + + if size != read { + return n, fmt.Errorf("wrong number of bytes read for []byte, want: %d, got %d", size, read) + } + + v.Set(reflect.ValueOf(tmp)) + + return n, nil +} + +func (d *Decoder) decodeArray(v reflect.Value) (int, error) { + size := v.Len() + t := v.Type() + elementType := t.Elem() + + var n int + if elementType.Kind() == reflect.Pointer { + for i := 0; i < size; i++ { + idx := reflect.New(elementType.Elem()) + k, err := d.decode(idx.Elem()) + n += k + if err != nil { + return n, err + } + v.Index(i).Set(idx) + } + } else { + for i := 0; i < size; i++ { + idx := reflect.New(elementType) + k, err := d.decode(idx.Elem()) + n += k + if err != nil { + return n, err + } + v.Index(i).Set(idx.Elem()) + } + } + + return n, nil +} + +func (d *Decoder) decodeSlice(v reflect.Value) (int, error) { + // get the length of the slice. + size, n, err := ULEB128Decode[int](d.reader) + if err != nil { + return n, err + } + + // element type of the slice + elementType := v.Type().Elem() + // make a new slice + tmp := reflect.MakeSlice(v.Type(), 0, size) + + if elementType.Kind() == reflect.Pointer { + for i := 0; i < size; i++ { + ind := reflect.New(elementType.Elem()) + k, err := d.decode(ind.Elem()) + n += k + if err != nil { + return n, err + } + tmp = reflect.Append(tmp, ind) + } + } else { + for i := 0; i < size; i++ { + ind := reflect.New(elementType) + k, err := d.decode(ind.Elem()) + n += k + if err != nil { + return n, err + } + tmp = reflect.Append(tmp, ind.Elem()) + } + } + + v.Set(tmp) + + return n, nil +} diff --git a/mystenbcs/encode.go b/mystenbcs/encode.go new file mode 100644 index 0000000..2514c48 --- /dev/null +++ b/mystenbcs/encode.go @@ -0,0 +1,319 @@ +package mystenbcs + +// https://github.com/fardream/go-bcs/blob/main/bcs/encode.go + +import ( + "bytes" + "encoding/binary" + "fmt" + "io" + "reflect" +) + +// Marshaler customizes the marshalling behavior for a type +type Marshaler interface { + MarshalBCS() ([]byte, error) +} + +// Unmarshaler customizes the unmarshalling behavior for a type. +// +// Compared with other Unmarshalers in golang, the Unmarshaler here takes +// a [io.Reader] instead of []byte, since it is difficult to delimit the byte streams without unmarshalling. +// Method [UnmarshalBCS] returns the number of bytes read, and potentially an error. +type Unmarshaler interface { + UnmarshalBCS(io.Reader) (int, error) +} + +type Enum interface { + // IsBcsEnum doesn't do anything. Its function is to indicate this is an enum for bcs de/serialization. + IsBcsEnum() +} + +// Encoder takes an [io.Writer] and encodes value into it. +type Encoder struct { + w io.Writer +} + +// NewEncoder creates a new [Encoder] from an [io.Writer] +func NewEncoder(w io.Writer) *Encoder { + return &Encoder{ + w: w, + } +} + +// Encode a value v into the encoder. +// +// - If the value is [Marshaler], the corresponding +// MarshalBCS implementation will be called. +// - If the value is [Enum], it will be special handled for [Enum]. +func (e *Encoder) Encode(v any) error { + return e.encode(reflect.ValueOf(v)) +} + +// encode a value +func (e *Encoder) encode(v reflect.Value) error { + // if v not CanInterface, + // this value is an unexported value, skip it. + if !v.CanInterface() { + return nil + } + + // test for the two interfaces we defined. + // 1. Marshaler + // 2. Enum. + i := v.Interface() + if m, ismarshaler := i.(Marshaler); ismarshaler { + bytes, err := m.MarshalBCS() + if err != nil { + return err + } + + _, err = e.w.Write(bytes) + + return err + } + if _, isenum := i.(Enum); isenum { + return e.encodeEnum(reflect.Indirect(v)) + } + + kind := v.Kind() + + switch kind { + case reflect.Bool, // boolean + reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, // all the ints + reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: // all the uints + // use little endian to encode those. + return binary.Write(e.w, binary.LittleEndian, v.Interface()) + + case reflect.Pointer: // pointer + // if v is nil pointer, use the zero value for v. + // we don't check for optional flag here. + // that should be checked when the container struct is encoded + // if this pointer is contained in a struct. + return e.encode(reflect.Indirect(v)) + + case reflect.Interface: + return e.encode(v.Elem()) + + case reflect.Slice: // slices + // check if the element is uint8 or byteslice + if byteSlice, ok := (v.Interface()).([]byte); ok { + return e.encodeByteSlice(byteSlice) + } + return e.encodeSlice(v) + + case reflect.Array: // encode array + return e.encodeArray(v) + + case reflect.String: + str := []byte(v.String()) + return e.encodeByteSlice(str) + + case reflect.Struct: + return e.encodeStruct(v) + + case reflect.Chan, reflect.Func, reflect.Uintptr, reflect.UnsafePointer: // channel, func, pointers + return nil + + default: + return fmt.Errorf("unsupported kind: %s, consider make the field ignored by using - tag or provide a customized Marshaler implementation", kind.String()) + } +} + +// encodeEnum encodes an [Enum] +func (e *Encoder) encodeEnum(v reflect.Value) error { + t := v.Type() + + for i := 0; i < v.NumField(); i++ { + field := v.Field(i) + // ignore fields that are not exported + if !field.CanInterface() { + continue + } + + fieldType := t.Field(i) + // check the tag + tag, err := parseTagValue(fieldType.Tag.Get(tagName)) + if err != nil { + return err + } + if tag.isIgnored() { + continue + } + fieldKind := field.Kind() + if fieldKind != reflect.Pointer && fieldKind != reflect.Interface { + return fmt.Errorf("enum only supports fields that are either pointers or interfaces, unless they are ignored") + } + if !field.IsNil() { + if _, err := e.w.Write(ULEB128Encode(i)); err != nil { + return err + } + if fieldKind == reflect.Pointer { + return e.encode(reflect.Indirect(field)) + } else { + return e.encode(v) + } + } + } + + return fmt.Errorf("no field is set in the enum") +} + +// encodeByteSlice is specialized since bytes those can be simply put into the output. +func (e *Encoder) encodeByteSlice(b []byte) error { + l := len(b) + if _, err := e.w.Write(ULEB128Encode(l)); err != nil { + return err + } + + if _, err := e.w.Write(b); err != nil { + return err + } + + return nil +} + +func (e *Encoder) encodeArray(v reflect.Value) error { + length := v.Len() + for i := 0; i < length; i++ { + if err := e.encode(v.Index(i)); err != nil { + return err + } + } + + return nil +} + +func (e *Encoder) encodeSlice(v reflect.Value) error { + length := v.Len() + if _, err := e.w.Write(ULEB128Encode(length)); err != nil { + return err + } + + for i := 0; i < length; i++ { + if err := e.encode(v.Index(i)); err != nil { + return err + } + } + + return nil +} + +func (e *Encoder) encodeStruct(v reflect.Value) error { + t := v.Type() + + for i := 0; i < v.NumField(); i++ { + field := v.Field(i) + // if a field is not exported, ignore + if !field.CanInterface() { + continue + } + tag, err := parseTagValue(t.Field(i).Tag.Get(tagName)) + if err != nil { + return err + } + switch { + case tag.isIgnored(): + continue + case tag.isOptional(): + if field.Kind() != reflect.Pointer && field.Kind() != reflect.Interface { + return fmt.Errorf("optional field can only be pointer or interface") + } + if field.IsNil() { + _, err := e.w.Write([]byte{0}) + if err != nil { + return err + } + } else { + if _, err := e.w.Write([]byte{1}); err != nil { + return err + } + if err := e.encode(field.Elem()); err != nil { + return err + } + } + continue + default: + // finally + if err := e.encode(field); err != nil { + return err + } + } + } + + return nil +} + +// Marshal a value into bcs bytes. +// +// Many constructs supported by bcs don't exist in golang or move-lang. +// +// - [Enum] is used to simulate the effects of rust enum. +// - Use tag `optional` to indicate an optional value in rust. +// the field must be pointer or interface. +// - Use tag `-` to ignore fields. +// - Unexported fields are ignored. +// +// Note that bcs doesn't have schema, and field names are irrelevant. The fields +// of struct are serialized in the order that they are defined. +// +// Pointers are serialized as the type they point to. Nil pointers will be serialized +// as zero value of the type they point to unless it's marked as `optional`. +// +// Arrays are serialized as fixed length vector (or serialize the each object individually without prefixing +// the length of the array). +// +// Vanilla maps are not supported, however, the code will error if map is encountered to call out they are +// not supported and either ignore or provide a customized marshal function. +// +// Channels, functions are silently ignored. +// +// During marshalling process, how v is marshalled depends on if v implemented [Marshaler] or [Enum] +// 1. if [Marshaler], use "MarshalBCS" method. +// 2. if not [Marshaler] but [Enum], use specialization for [Enum]. +// 3. otherwise standard process. +func Marshal(v any) ([]byte, error) { + var b bytes.Buffer + e := NewEncoder(&b) + + if err := e.Encode(v); err != nil { + return nil, err + } + + return b.Bytes(), nil +} + +type Option[T any] struct { + Some T + None bool +} + +func (p *Option[T]) MarshalBCS() ([]byte, error) { + if p.None { + return []byte{0}, nil + } + b, err := Marshal(p.Some) + return append([]byte{1}, b...), err +} + +func (p *Option[T]) UnmarshalBCS(r io.Reader) (int, error) { + buf := new(bytes.Buffer) + io.Copy(buf, r) + tmp := buf.Bytes() + if len(tmp) == 1 { + p.None = true + return 1, nil + } + b := tmp[1:] + return Unmarshal(b, &p.Some) +} + +// MustMarshal [Marshal] v, and panics if error. +func MustMarshal(v any) []byte { + result, err := Marshal(v) + if err != nil { + panic(err) + } + + return result +} diff --git a/mystenbcs/tag.go b/mystenbcs/tag.go new file mode 100644 index 0000000..e29666e --- /dev/null +++ b/mystenbcs/tag.go @@ -0,0 +1,45 @@ +package mystenbcs + +// https://github.com/fardream/go-bcs/blob/main/bcs/tag.go +import ( + "fmt" + "strings" +) + +const tagName = "bcs" + +type tagValue int64 + +const ( + tagValue_Optional tagValue = 1 << iota // optional + tagValue_Ignore // - +) + +func parseTagValue(tag string) (tagValue, error) { + var r tagValue + tagSegs := strings.Split(tag, ",") + for _, seg := range tagSegs { + seg := strings.TrimSpace(seg) + if seg == "" { + continue + } + switch seg { + case "optional": + r |= tagValue_Optional + case "-": + return tagValue_Ignore, nil + default: + return 0, fmt.Errorf("unknown tag: %s in %s", seg, tag) + } + } + + return r, nil +} + +func (t tagValue) isOptional() bool { + return t&tagValue_Optional != 0 +} + +func (t tagValue) isIgnored() bool { + return t&tagValue_Ignore != 0 +} diff --git a/mystenbcs/uleb128.go b/mystenbcs/uleb128.go new file mode 100644 index 0000000..826515a --- /dev/null +++ b/mystenbcs/uleb128.go @@ -0,0 +1,72 @@ +package mystenbcs + +// https://github.com/fardream/go-bcs/blob/main/bcs/uleb128.go + +import ( + "encoding/binary" + "fmt" + "io" +) + +// MaxUleb128Length is the max possible number of bytes for an ULEB128 encoded integer. +// Go's widest integer is uint64, so the length is 10. +const MaxUleb128Length = 10 + +// ULEB128SupportedTypes is a contraint interface that limits the input to +// [ULEB128Encode] and [ULEB128Decode] to signed and unsigned integers. +type ULEB128SupportedTypes interface { + ~uint8 | ~uint16 | ~uint32 | ~uint64 | ~uint | ~int8 | ~int16 | ~int32 | ~int64 | ~int +} + +// ULEB128Encode converts an integer into []byte (see [wikipedia] and [bcs]) +// +// This reuses [binary.PutUvarint] in standard library. +// +// [wikipedia]: https://en.wikipedia.org/wiki/LEB128 +// [bcs]: https://github.com/diem/bcs#uleb128-encoded-integers +func ULEB128Encode[T ULEB128SupportedTypes](input T) []byte { + result := make([]byte, 10) + i := binary.PutUvarint(result, uint64(input)) + return result[:i] +} + +// ULEB128Decode decodes [io.Reader] into an integer, returns the resulted value, the number of byte read, and a possible error. +// +// [binary.ReadUvarint] is not used here because +// - it doesn't support returning the number of bytes read. +// - it accepts only [io.ByteReader], but the recommended way of creating one from [bufio.NewReader] will read more than 1 byte at the +// to fill the buffer. +func ULEB128Decode[T ULEB128SupportedTypes](r io.Reader) (T, int, error) { + buf := make([]byte, 1) + var v, shift T + var n int + for n < 10 { + i, err := r.Read(buf) + if i == 0 { + return 0, n, fmt.Errorf("zero read in. possible EOF") + } + if err != nil { + return 0, n, err + } + n += i + + d := T(buf[0]) + ld := d & 127 + if (ld<>shift != ld { + return v, n, fmt.Errorf("overflow at index %d: %v", n-1, ld) + } + + ld <<= shift + v = ld + v + if v < ld { + return v, n, fmt.Errorf("overflow after adding index %d: %v %v", n-1, ld, v) + } + if d <= 127 { + return v, n, nil + } + + shift += 7 + } + + return 0, n, fmt.Errorf("failed to find most significant bytes after reading %d bytes", n) +}