feat: bcs encoding and decoding

This commit is contained in:
Ashwin Prasad
2024-12-25 01:14:34 +05:30
parent e02160cf4c
commit 5f1cda5572
4 changed files with 794 additions and 0 deletions

358
mystenbcs/decode.go Normal file
View File

@@ -0,0 +1,358 @@
package mystenbcs
// https://github.com/fardream/go-bcs/blob/main/bcs/decode.go#L19
import (
"bytes"
"encoding/binary"
"fmt"
"io"
"reflect"
)
// Unmarshal unmarshals the bcs serialized data into v.
//
// Refer to notes in [Marshal] for details how data serialized/deserialized.
//
// During the unmarshalling process
// 1. if [Unmarshaler], use "UnmarshalBCS" method.
// 2. if not [Unmarshaler] but [Enum], use the specialization for [Enum].
// 3. otherwise standard process.
func Unmarshal(data []byte, v any) (int, error) {
return NewDecoder(bytes.NewReader(data)).Decode(v)
}
// Decoder takes an [io.Reader] and decodes value from it.
type Decoder struct {
reader io.Reader
byteBuffer [1]byte
}
// NewDecoder creates a new [Decoder] from an [io.Reader]
func NewDecoder(r io.Reader) *Decoder {
return &Decoder{
reader: r,
}
}
// DecodeWithSize decodes a value from the decoder, and returns the number of bytes it consumed from the decoder.
//
// - If the value is [Unmarshaler], the corresponding UnmarshalBCS will be called.
// - If the value is [Enum], it will be special handled for [Enum]
func (d *Decoder) Decode(v any) (int, error) {
reflectValue := reflect.ValueOf(v)
if reflectValue.Kind() != reflect.Pointer || reflectValue.IsNil() {
return 0, fmt.Errorf("not a pointer or nil pointer")
}
return d.decode(reflectValue)
}
// decode is the main lifter, it first checks if a value can be [reflect.Value.CanInterface],
// then checks if the value implements [Unmarshaler] or [Enum], and then switch on the kind of the value:
// - pointer, create a new one and decode into its element.
// - interface, decode into element.
// - function, channel, unsafe pointers, ignore
// - otherwise call [decodeVanilla].
func (d *Decoder) decode(v reflect.Value) (int, error) {
// if v cannot interface, ignore
if !v.CanInterface() {
return 0, nil
}
// Unmarshaler
if i, isUnmarshaler := v.Interface().(Unmarshaler); isUnmarshaler {
return i.UnmarshalBCS(d.reader)
}
// Enum
if _, isEnum := v.Interface().(Enum); isEnum {
switch v.Kind() {
case reflect.Pointer, reflect.Interface:
if v.IsNil() {
return 0, fmt.Errorf("trying to decode into nil pointer/interface")
}
return d.decodeEnum(v.Elem())
default:
return d.decodeEnum(v)
}
}
// switch kind
switch v.Kind() {
case reflect.Pointer:
if v.IsNil() {
v.Set(reflect.New(v.Type().Elem()))
}
return d.decode(v.Elem())
case reflect.Interface:
if v.IsNil() {
return 0, fmt.Errorf("cannot decode into nil interface")
}
return d.decode(v.Elem())
case reflect.Chan, reflect.Func, reflect.Uintptr, reflect.UnsafePointer:
// silently ignore
return 0, nil
default:
return d.decodeVanilla(v)
}
}
// decodeVanilla decodes bool, ints, slice, struct, array, and string.
func (d *Decoder) decodeVanilla(v reflect.Value) (int, error) {
kind := v.Kind()
if !v.CanSet() {
return 0, fmt.Errorf("cannot change value of kind %s", kind.String())
}
switch kind {
case reflect.Bool:
t, n, err := d.readByte()
if err != nil {
return n, err
}
if t == 0 {
v.SetBool(false)
} else {
v.SetBool(true)
}
return n, nil
case reflect.Int8, reflect.Uint8:
return 1, binary.Read(d.reader, binary.LittleEndian, v.Addr().Interface())
case reflect.Int16, reflect.Uint16:
return 2, binary.Read(d.reader, binary.LittleEndian, v.Addr().Interface())
case reflect.Int32, reflect.Uint32:
return 4, binary.Read(d.reader, binary.LittleEndian, v.Addr().Interface())
case reflect.Int64, reflect.Uint64:
return 8, binary.Read(d.reader, binary.LittleEndian, v.Addr().Interface())
case reflect.Struct:
return d.decodeStruct(v)
case reflect.Slice:
sliceType := v.Type().Elem()
if sliceType.Kind() == reflect.Uint8 {
return d.decodeByteSlice(v)
}
return d.decodeSlice(v)
case reflect.Array:
return d.decodeArray(v)
case reflect.String:
return d.decodeString(v)
default:
return 0, fmt.Errorf("unsupported vanilla decoding type: %s", kind.String())
}
}
// decodeString
func (d *Decoder) decodeString(v reflect.Value) (int, error) {
size, n, err := ULEB128Decode[int](d.reader)
if err != nil {
return n, err
}
if size == 0 {
v.SetString("")
return n, nil
}
tmp := make([]byte, size)
read, err := d.reader.Read(tmp)
n += read
if err != nil {
return n, err
}
if size != read {
return n, fmt.Errorf("wrong number of bytes read for string, want: %d, got %d", size, read)
}
v.SetString(string(tmp))
return n, nil
}
// readByte reads one byte from the input, error if no byte is read.
func (d *Decoder) readByte() (byte, int, error) {
b := d.byteBuffer[:]
n, err := d.reader.Read(b)
if err != nil {
return 0, n, err
}
if n == 0 {
return 0, n, io.ErrUnexpectedEOF
}
return b[0], n, nil
}
func (d *Decoder) decodeStruct(v reflect.Value) (int, error) {
t := v.Type()
var n int
fieldLoop:
for i := 0; i < v.NumField(); i++ {
field := v.Field(i)
if !field.CanInterface() {
continue fieldLoop
}
tag, err := parseTagValue(t.Field(i).Tag.Get(tagName))
if err != nil {
return n, err
}
switch {
case tag.isIgnored(): // ignored
continue fieldLoop
case tag.isOptional(): // optional
isOptional, k, err := d.readByte()
n += k
if err != nil {
return n, err
}
if isOptional == 0 {
field.Set(reflect.Zero(v.Type()))
} else {
field.Set(reflect.New(field.Type().Elem()))
k, err := d.decode(field.Elem())
n += k
if err != nil {
return n, err
}
}
default:
k, err := d.decode(field)
n += k
if err != nil {
return n, err
}
}
}
return n, nil
}
func (d *Decoder) decodeEnum(v reflect.Value) (int, error) {
if v.Kind() != reflect.Struct {
return 0, fmt.Errorf("only support struct for Enum, got %s", v.Kind().String())
}
enumId, n, err := ULEB128Decode[int](d.reader)
if err != nil {
return n, err
}
field := v.Field(enumId)
k, err := d.decode(field)
n += k
return n, err
}
func (d *Decoder) decodeByteSlice(v reflect.Value) (int, error) {
size, n, err := ULEB128Decode[int](d.reader)
if err != nil {
return n, err
}
if size == 0 {
return n, nil
}
tmp := make([]byte, size)
read, err := d.reader.Read(tmp)
n += read
if err != nil {
return n, err
}
if size != read {
return n, fmt.Errorf("wrong number of bytes read for []byte, want: %d, got %d", size, read)
}
v.Set(reflect.ValueOf(tmp))
return n, nil
}
func (d *Decoder) decodeArray(v reflect.Value) (int, error) {
size := v.Len()
t := v.Type()
elementType := t.Elem()
var n int
if elementType.Kind() == reflect.Pointer {
for i := 0; i < size; i++ {
idx := reflect.New(elementType.Elem())
k, err := d.decode(idx.Elem())
n += k
if err != nil {
return n, err
}
v.Index(i).Set(idx)
}
} else {
for i := 0; i < size; i++ {
idx := reflect.New(elementType)
k, err := d.decode(idx.Elem())
n += k
if err != nil {
return n, err
}
v.Index(i).Set(idx.Elem())
}
}
return n, nil
}
func (d *Decoder) decodeSlice(v reflect.Value) (int, error) {
// get the length of the slice.
size, n, err := ULEB128Decode[int](d.reader)
if err != nil {
return n, err
}
// element type of the slice
elementType := v.Type().Elem()
// make a new slice
tmp := reflect.MakeSlice(v.Type(), 0, size)
if elementType.Kind() == reflect.Pointer {
for i := 0; i < size; i++ {
ind := reflect.New(elementType.Elem())
k, err := d.decode(ind.Elem())
n += k
if err != nil {
return n, err
}
tmp = reflect.Append(tmp, ind)
}
} else {
for i := 0; i < size; i++ {
ind := reflect.New(elementType)
k, err := d.decode(ind.Elem())
n += k
if err != nil {
return n, err
}
tmp = reflect.Append(tmp, ind.Elem())
}
}
v.Set(tmp)
return n, nil
}

319
mystenbcs/encode.go Normal file
View File

@@ -0,0 +1,319 @@
package mystenbcs
// https://github.com/fardream/go-bcs/blob/main/bcs/encode.go
import (
"bytes"
"encoding/binary"
"fmt"
"io"
"reflect"
)
// Marshaler customizes the marshalling behavior for a type
type Marshaler interface {
MarshalBCS() ([]byte, error)
}
// Unmarshaler customizes the unmarshalling behavior for a type.
//
// Compared with other Unmarshalers in golang, the Unmarshaler here takes
// a [io.Reader] instead of []byte, since it is difficult to delimit the byte streams without unmarshalling.
// Method [UnmarshalBCS] returns the number of bytes read, and potentially an error.
type Unmarshaler interface {
UnmarshalBCS(io.Reader) (int, error)
}
type Enum interface {
// IsBcsEnum doesn't do anything. Its function is to indicate this is an enum for bcs de/serialization.
IsBcsEnum()
}
// Encoder takes an [io.Writer] and encodes value into it.
type Encoder struct {
w io.Writer
}
// NewEncoder creates a new [Encoder] from an [io.Writer]
func NewEncoder(w io.Writer) *Encoder {
return &Encoder{
w: w,
}
}
// Encode a value v into the encoder.
//
// - If the value is [Marshaler], the corresponding
// MarshalBCS implementation will be called.
// - If the value is [Enum], it will be special handled for [Enum].
func (e *Encoder) Encode(v any) error {
return e.encode(reflect.ValueOf(v))
}
// encode a value
func (e *Encoder) encode(v reflect.Value) error {
// if v not CanInterface,
// this value is an unexported value, skip it.
if !v.CanInterface() {
return nil
}
// test for the two interfaces we defined.
// 1. Marshaler
// 2. Enum.
i := v.Interface()
if m, ismarshaler := i.(Marshaler); ismarshaler {
bytes, err := m.MarshalBCS()
if err != nil {
return err
}
_, err = e.w.Write(bytes)
return err
}
if _, isenum := i.(Enum); isenum {
return e.encodeEnum(reflect.Indirect(v))
}
kind := v.Kind()
switch kind {
case reflect.Bool, // boolean
reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, // all the ints
reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: // all the uints
// use little endian to encode those.
return binary.Write(e.w, binary.LittleEndian, v.Interface())
case reflect.Pointer: // pointer
// if v is nil pointer, use the zero value for v.
// we don't check for optional flag here.
// that should be checked when the container struct is encoded
// if this pointer is contained in a struct.
return e.encode(reflect.Indirect(v))
case reflect.Interface:
return e.encode(v.Elem())
case reflect.Slice: // slices
// check if the element is uint8 or byteslice
if byteSlice, ok := (v.Interface()).([]byte); ok {
return e.encodeByteSlice(byteSlice)
}
return e.encodeSlice(v)
case reflect.Array: // encode array
return e.encodeArray(v)
case reflect.String:
str := []byte(v.String())
return e.encodeByteSlice(str)
case reflect.Struct:
return e.encodeStruct(v)
case reflect.Chan, reflect.Func, reflect.Uintptr, reflect.UnsafePointer: // channel, func, pointers
return nil
default:
return fmt.Errorf("unsupported kind: %s, consider make the field ignored by using - tag or provide a customized Marshaler implementation", kind.String())
}
}
// encodeEnum encodes an [Enum]
func (e *Encoder) encodeEnum(v reflect.Value) error {
t := v.Type()
for i := 0; i < v.NumField(); i++ {
field := v.Field(i)
// ignore fields that are not exported
if !field.CanInterface() {
continue
}
fieldType := t.Field(i)
// check the tag
tag, err := parseTagValue(fieldType.Tag.Get(tagName))
if err != nil {
return err
}
if tag.isIgnored() {
continue
}
fieldKind := field.Kind()
if fieldKind != reflect.Pointer && fieldKind != reflect.Interface {
return fmt.Errorf("enum only supports fields that are either pointers or interfaces, unless they are ignored")
}
if !field.IsNil() {
if _, err := e.w.Write(ULEB128Encode(i)); err != nil {
return err
}
if fieldKind == reflect.Pointer {
return e.encode(reflect.Indirect(field))
} else {
return e.encode(v)
}
}
}
return fmt.Errorf("no field is set in the enum")
}
// encodeByteSlice is specialized since bytes those can be simply put into the output.
func (e *Encoder) encodeByteSlice(b []byte) error {
l := len(b)
if _, err := e.w.Write(ULEB128Encode(l)); err != nil {
return err
}
if _, err := e.w.Write(b); err != nil {
return err
}
return nil
}
func (e *Encoder) encodeArray(v reflect.Value) error {
length := v.Len()
for i := 0; i < length; i++ {
if err := e.encode(v.Index(i)); err != nil {
return err
}
}
return nil
}
func (e *Encoder) encodeSlice(v reflect.Value) error {
length := v.Len()
if _, err := e.w.Write(ULEB128Encode(length)); err != nil {
return err
}
for i := 0; i < length; i++ {
if err := e.encode(v.Index(i)); err != nil {
return err
}
}
return nil
}
func (e *Encoder) encodeStruct(v reflect.Value) error {
t := v.Type()
for i := 0; i < v.NumField(); i++ {
field := v.Field(i)
// if a field is not exported, ignore
if !field.CanInterface() {
continue
}
tag, err := parseTagValue(t.Field(i).Tag.Get(tagName))
if err != nil {
return err
}
switch {
case tag.isIgnored():
continue
case tag.isOptional():
if field.Kind() != reflect.Pointer && field.Kind() != reflect.Interface {
return fmt.Errorf("optional field can only be pointer or interface")
}
if field.IsNil() {
_, err := e.w.Write([]byte{0})
if err != nil {
return err
}
} else {
if _, err := e.w.Write([]byte{1}); err != nil {
return err
}
if err := e.encode(field.Elem()); err != nil {
return err
}
}
continue
default:
// finally
if err := e.encode(field); err != nil {
return err
}
}
}
return nil
}
// Marshal a value into bcs bytes.
//
// Many constructs supported by bcs don't exist in golang or move-lang.
//
// - [Enum] is used to simulate the effects of rust enum.
// - Use tag `optional` to indicate an optional value in rust.
// the field must be pointer or interface.
// - Use tag `-` to ignore fields.
// - Unexported fields are ignored.
//
// Note that bcs doesn't have schema, and field names are irrelevant. The fields
// of struct are serialized in the order that they are defined.
//
// Pointers are serialized as the type they point to. Nil pointers will be serialized
// as zero value of the type they point to unless it's marked as `optional`.
//
// Arrays are serialized as fixed length vector (or serialize the each object individually without prefixing
// the length of the array).
//
// Vanilla maps are not supported, however, the code will error if map is encountered to call out they are
// not supported and either ignore or provide a customized marshal function.
//
// Channels, functions are silently ignored.
//
// During marshalling process, how v is marshalled depends on if v implemented [Marshaler] or [Enum]
// 1. if [Marshaler], use "MarshalBCS" method.
// 2. if not [Marshaler] but [Enum], use specialization for [Enum].
// 3. otherwise standard process.
func Marshal(v any) ([]byte, error) {
var b bytes.Buffer
e := NewEncoder(&b)
if err := e.Encode(v); err != nil {
return nil, err
}
return b.Bytes(), nil
}
type Option[T any] struct {
Some T
None bool
}
func (p *Option[T]) MarshalBCS() ([]byte, error) {
if p.None {
return []byte{0}, nil
}
b, err := Marshal(p.Some)
return append([]byte{1}, b...), err
}
func (p *Option[T]) UnmarshalBCS(r io.Reader) (int, error) {
buf := new(bytes.Buffer)
io.Copy(buf, r)
tmp := buf.Bytes()
if len(tmp) == 1 {
p.None = true
return 1, nil
}
b := tmp[1:]
return Unmarshal(b, &p.Some)
}
// MustMarshal [Marshal] v, and panics if error.
func MustMarshal(v any) []byte {
result, err := Marshal(v)
if err != nil {
panic(err)
}
return result
}

45
mystenbcs/tag.go Normal file
View File

@@ -0,0 +1,45 @@
package mystenbcs
// https://github.com/fardream/go-bcs/blob/main/bcs/tag.go
import (
"fmt"
"strings"
)
const tagName = "bcs"
type tagValue int64
const (
tagValue_Optional tagValue = 1 << iota // optional
tagValue_Ignore // -
)
func parseTagValue(tag string) (tagValue, error) {
var r tagValue
tagSegs := strings.Split(tag, ",")
for _, seg := range tagSegs {
seg := strings.TrimSpace(seg)
if seg == "" {
continue
}
switch seg {
case "optional":
r |= tagValue_Optional
case "-":
return tagValue_Ignore, nil
default:
return 0, fmt.Errorf("unknown tag: %s in %s", seg, tag)
}
}
return r, nil
}
func (t tagValue) isOptional() bool {
return t&tagValue_Optional != 0
}
func (t tagValue) isIgnored() bool {
return t&tagValue_Ignore != 0
}

72
mystenbcs/uleb128.go Normal file
View File

@@ -0,0 +1,72 @@
package mystenbcs
// https://github.com/fardream/go-bcs/blob/main/bcs/uleb128.go
import (
"encoding/binary"
"fmt"
"io"
)
// MaxUleb128Length is the max possible number of bytes for an ULEB128 encoded integer.
// Go's widest integer is uint64, so the length is 10.
const MaxUleb128Length = 10
// ULEB128SupportedTypes is a contraint interface that limits the input to
// [ULEB128Encode] and [ULEB128Decode] to signed and unsigned integers.
type ULEB128SupportedTypes interface {
~uint8 | ~uint16 | ~uint32 | ~uint64 | ~uint | ~int8 | ~int16 | ~int32 | ~int64 | ~int
}
// ULEB128Encode converts an integer into []byte (see [wikipedia] and [bcs])
//
// This reuses [binary.PutUvarint] in standard library.
//
// [wikipedia]: https://en.wikipedia.org/wiki/LEB128
// [bcs]: https://github.com/diem/bcs#uleb128-encoded-integers
func ULEB128Encode[T ULEB128SupportedTypes](input T) []byte {
result := make([]byte, 10)
i := binary.PutUvarint(result, uint64(input))
return result[:i]
}
// ULEB128Decode decodes [io.Reader] into an integer, returns the resulted value, the number of byte read, and a possible error.
//
// [binary.ReadUvarint] is not used here because
// - it doesn't support returning the number of bytes read.
// - it accepts only [io.ByteReader], but the recommended way of creating one from [bufio.NewReader] will read more than 1 byte at the
// to fill the buffer.
func ULEB128Decode[T ULEB128SupportedTypes](r io.Reader) (T, int, error) {
buf := make([]byte, 1)
var v, shift T
var n int
for n < 10 {
i, err := r.Read(buf)
if i == 0 {
return 0, n, fmt.Errorf("zero read in. possible EOF")
}
if err != nil {
return 0, n, err
}
n += i
d := T(buf[0])
ld := d & 127
if (ld<<shift)>>shift != ld {
return v, n, fmt.Errorf("overflow at index %d: %v", n-1, ld)
}
ld <<= shift
v = ld + v
if v < ld {
return v, n, fmt.Errorf("overflow after adding index %d: %v %v", n-1, ld, v)
}
if d <= 127 {
return v, n, nil
}
shift += 7
}
return 0, n, fmt.Errorf("failed to find most significant bytes after reading %d bytes", n)
}