Files
sui-go-sdk/mystenbcs/decode.go
2025-05-21 10:34:45 +08:00

388 lines
8.2 KiB
Go

package mystenbcs
// https://github.com/fardream/go-bcs/blob/main/bcs/decode.go#L19
import (
"bytes"
"encoding/binary"
"fmt"
"io"
"reflect"
)
// Unmarshal unmarshals the bcs serialized data into v.
//
// Refer to notes in [Marshal] for details how data serialized/deserialized.
//
// During the unmarshalling process
// 1. if [Unmarshaler], use "UnmarshalBCS" method.
// 2. if not [Unmarshaler] but [Enum], use the specialization for [Enum].
// 3. otherwise standard process.
func Unmarshal(data []byte, v any) (int, error) {
return NewDecoder(bytes.NewReader(data)).Decode(v)
}
// Decoder takes an [io.Reader] and decodes value from it.
type Decoder struct {
reader io.Reader
byteBuffer [1]byte
}
// NewDecoder creates a new [Decoder] from an [io.Reader]
func NewDecoder(r io.Reader) *Decoder {
return &Decoder{
reader: r,
}
}
// DecodeWithSize decodes a value from the decoder, and returns the number of bytes it consumed from the decoder.
//
// - If the value is [Unmarshaler], the corresponding UnmarshalBCS will be called.
// - If the value is [Enum], it will be special handled for [Enum]
func (d *Decoder) Decode(v any) (int, error) {
reflectValue := reflect.ValueOf(v)
if reflectValue.Kind() != reflect.Pointer || reflectValue.IsNil() {
return 0, fmt.Errorf("not a pointer or nil pointer")
}
return d.decode(reflectValue)
}
// decode is the main lifter, it first checks if a value can be [reflect.Value.CanInterface],
// then checks if the value implements [Unmarshaler] or [Enum], and then switch on the kind of the value:
// - pointer, create a new one and decode into its element.
// - interface, decode into element.
// - function, channel, unsafe pointers, ignore
// - otherwise call [decodeVanilla].
func (d *Decoder) decode(v reflect.Value) (int, error) {
// if v cannot interface, ignore
if !v.CanInterface() {
return 0, nil
}
// Unmarshaler
if i, isUnmarshaler := v.Interface().(Unmarshaler); isUnmarshaler {
return i.UnmarshalBCS(d.reader)
}
// Enum
if _, isEnum := v.Interface().(Enum); isEnum {
switch v.Kind() {
case reflect.Pointer, reflect.Interface:
if v.IsNil() {
v.Set(reflect.New(v.Type().Elem()))
}
return d.decodeEnum(v.Elem())
default:
return d.decodeEnum(v)
}
}
// switch kind
switch v.Kind() {
case reflect.Pointer:
if v.IsNil() {
v.Set(reflect.New(v.Type().Elem()))
}
return d.decode(v.Elem())
case reflect.Interface:
if v.IsNil() {
return 0, fmt.Errorf("cannot decode into nil interface")
}
return d.decode(v.Elem())
case reflect.Chan, reflect.Func, reflect.Uintptr, reflect.UnsafePointer:
// silently ignore
return 0, nil
default:
return d.decodeVanilla(v)
}
}
// decodeVanilla decodes bool, ints, slice, struct, array, and string.
func (d *Decoder) decodeVanilla(v reflect.Value) (int, error) {
kind := v.Kind()
if !v.CanSet() {
return 0, fmt.Errorf("cannot change value of kind %s", kind.String())
}
switch kind {
case reflect.Bool:
t, n, err := d.readByte()
if err != nil {
return n, err
}
if t == 0 {
v.SetBool(false)
} else {
v.SetBool(true)
}
return n, nil
case reflect.Int8, reflect.Uint8:
return 1, binary.Read(d.reader, binary.LittleEndian, v.Addr().Interface())
case reflect.Int16, reflect.Uint16:
return 2, binary.Read(d.reader, binary.LittleEndian, v.Addr().Interface())
case reflect.Int32, reflect.Uint32:
return 4, binary.Read(d.reader, binary.LittleEndian, v.Addr().Interface())
case reflect.Int64, reflect.Uint64:
return 8, binary.Read(d.reader, binary.LittleEndian, v.Addr().Interface())
case reflect.Struct:
return d.decodeStruct(v)
case reflect.Slice:
sliceType := v.Type().Elem()
if sliceType.Kind() == reflect.Uint8 {
return d.decodeByteSlice(v)
}
return d.decodeSlice(v)
case reflect.Array:
arrayType := v.Type().Elem()
if arrayType.Kind() == reflect.Uint8 {
return d.decodeByteArray(v)
}
return d.decodeArray(v)
case reflect.String:
return d.decodeString(v)
default:
return 0, fmt.Errorf("unsupported vanilla decoding type: %s", kind.String())
}
}
// decodeString
func (d *Decoder) decodeString(v reflect.Value) (int, error) {
size, n, err := ULEB128Decode[int](d.reader)
if err != nil {
return n, err
}
if size == 0 {
v.SetString("")
return n, nil
}
tmp := make([]byte, size)
read, err := d.reader.Read(tmp)
n += read
if err != nil {
return n, err
}
if size != read {
return n, fmt.Errorf("wrong number of bytes read for string, want: %d, got %d", size, read)
}
v.SetString(string(tmp))
return n, nil
}
// readByte reads one byte from the input, error if no byte is read.
func (d *Decoder) readByte() (byte, int, error) {
b := d.byteBuffer[:]
n, err := d.reader.Read(b)
if err != nil {
return 0, n, err
}
if n == 0 {
return 0, n, io.ErrUnexpectedEOF
}
return b[0], n, nil
}
func (d *Decoder) decodeStruct(v reflect.Value) (int, error) {
t := v.Type()
var n int
fieldLoop:
for i := 0; i < v.NumField(); i++ {
field := v.Field(i)
if !field.CanInterface() {
continue fieldLoop
}
tag, err := parseTagValue(t.Field(i).Tag.Get(tagName))
if err != nil {
return n, err
}
switch {
case tag.isIgnored(): // ignored
continue fieldLoop
case tag.isOptional(): // optional
isOptional, k, err := d.readByte()
n += k
if err != nil {
return n, err
}
if isOptional == 0 {
field.Set(reflect.Zero(field.Type()))
} else {
field.Set(reflect.New(field.Type().Elem()))
k, err := d.decode(field.Elem())
n += k
if err != nil {
return n, err
}
}
default:
k, err := d.decode(field)
n += k
if err != nil {
return n, err
}
}
}
return n, nil
}
func (d *Decoder) decodeEnum(v reflect.Value) (int, error) {
if v.Kind() != reflect.Struct {
return 0, fmt.Errorf("only support struct for Enum, got %s", v.Kind().String())
}
enumId, n, err := ULEB128Decode[int](d.reader)
if err != nil {
return n, err
}
field := v.Field(enumId)
k, err := d.decode(field)
n += k
return n, err
}
func (d *Decoder) decodeByteSlice(v reflect.Value) (int, error) {
size, n, err := ULEB128Decode[int](d.reader)
if err != nil {
return n, err
}
if size == 0 {
return n, nil
}
tmp := make([]byte, size)
read, err := d.reader.Read(tmp)
n += read
if err != nil {
return n, err
}
if size != read {
return n, fmt.Errorf("wrong number of bytes read for []byte, want: %d, got %d", size, read)
}
v.Set(reflect.ValueOf(tmp))
return n, nil
}
func (d *Decoder) decodeByteArray(v reflect.Value) (int, error) {
arraySize := v.Len()
if arraySize == 0 {
return 0, nil
}
tmp := make([]byte, arraySize)
read, err := d.reader.Read(tmp)
if err != nil {
return read, err
}
if arraySize != read {
return read, fmt.Errorf("wrong number of bytes read for [%d]byte, want: %d, got %d", arraySize, arraySize, read)
}
for i := 0; i < arraySize; i++ {
v.Index(i).SetUint(uint64(tmp[i]))
}
return read, nil
}
func (d *Decoder) decodeArray(v reflect.Value) (int, error) {
size := v.Len()
t := v.Type()
elementType := t.Elem()
var n int
if elementType.Kind() == reflect.Pointer {
for i := 0; i < size; i++ {
idx := reflect.New(elementType.Elem())
k, err := d.decode(idx.Elem())
n += k
if err != nil {
return n, err
}
v.Index(i).Set(idx)
}
} else {
for i := 0; i < size; i++ {
idx := reflect.New(elementType)
k, err := d.decode(idx.Elem())
n += k
if err != nil {
return n, err
}
v.Index(i).Set(idx.Elem())
}
}
return n, nil
}
func (d *Decoder) decodeSlice(v reflect.Value) (int, error) {
// get the length of the slice.
size, n, err := ULEB128Decode[int](d.reader)
if err != nil {
return n, err
}
// element type of the slice
elementType := v.Type().Elem()
// make a new slice
tmp := reflect.MakeSlice(v.Type(), 0, size)
if elementType.Kind() == reflect.Pointer {
for i := 0; i < size; i++ {
ind := reflect.New(elementType.Elem())
k, err := d.decode(ind)
n += k
if err != nil {
return n, err
}
tmp = reflect.Append(tmp, ind)
}
} else {
for i := 0; i < size; i++ {
ind := reflect.New(elementType)
k, err := d.decode(ind.Elem())
n += k
if err != nil {
return n, err
}
tmp = reflect.Append(tmp, ind.Elem())
}
}
v.Set(tmp)
return n, nil
}