From db6f005bef7c4c8253501229da3b5c4a7fa6b7f5 Mon Sep 17 00:00:00 2001 From: Ian Gulliver Date: Fri, 3 Apr 2026 16:59:11 +0900 Subject: [PATCH] Add msgpack wire protocol with halfsiphash checksums --- CMakeLists.txt | 7 +- cmd/load/main.go | 17 +- lib/halfsiphash/halfsiphash.go | 74 +++ lib/halfsiphash/halfsiphash_test.go | 31 ++ lib/msgpack/codes.go | 88 ++++ lib/msgpack/decode.go | 707 +++++++++++++++++++++++++ lib/msgpack/decode_map.go | 355 +++++++++++++ lib/msgpack/decode_number.go | 294 +++++++++++ lib/msgpack/decode_query.go | 156 ++++++ lib/msgpack/decode_slice.go | 197 +++++++ lib/msgpack/decode_string.go | 191 +++++++ lib/msgpack/decode_typgen.go | 46 ++ lib/msgpack/decode_value.go | 251 +++++++++ lib/msgpack/encode.go | 269 ++++++++++ lib/msgpack/encode_map.go | 224 ++++++++ lib/msgpack/encode_number.go | 251 +++++++++ lib/msgpack/encode_slice.go | 138 +++++ lib/msgpack/encode_value.go | 254 +++++++++ lib/msgpack/ext.go | 343 ++++++++++++ lib/msgpack/intern.go | 235 +++++++++ lib/msgpack/msgpack.go | 52 ++ lib/msgpack/safe.go | 13 + lib/msgpack/time.go | 150 ++++++ lib/msgpack/types.go | 413 +++++++++++++++ lib/msgpack/unsafe.go | 22 + lib/msgpack/version.go | 6 + lib/picoserial/picoserial.go | 36 +- lib/tagparser/parser.go | 76 +++ lib/tagparser/tagparser.go | 162 ++++++ lib/wire/wire.go | 40 ++ picomap.cpp | 37 ++ third_party/halfsiphash/halfsiphash.h | 79 +++ third_party/msgpackpp/msgpackpp.h | 717 ++++++++++++++++++++++++++ 33 files changed, 5928 insertions(+), 3 deletions(-) create mode 100644 lib/halfsiphash/halfsiphash.go create mode 100644 lib/halfsiphash/halfsiphash_test.go create mode 100644 lib/msgpack/codes.go create mode 100644 lib/msgpack/decode.go create mode 100644 lib/msgpack/decode_map.go create mode 100644 lib/msgpack/decode_number.go create mode 100644 lib/msgpack/decode_query.go create mode 100644 lib/msgpack/decode_slice.go create mode 100644 lib/msgpack/decode_string.go create mode 100644 lib/msgpack/decode_typgen.go create mode 100644 lib/msgpack/decode_value.go create mode 100644 lib/msgpack/encode.go create mode 100644 lib/msgpack/encode_map.go create mode 100644 lib/msgpack/encode_number.go create mode 100644 lib/msgpack/encode_slice.go create mode 100644 lib/msgpack/encode_value.go create mode 100644 lib/msgpack/ext.go create mode 100644 lib/msgpack/intern.go create mode 100644 lib/msgpack/msgpack.go create mode 100644 lib/msgpack/safe.go create mode 100644 lib/msgpack/time.go create mode 100644 lib/msgpack/types.go create mode 100644 lib/msgpack/unsafe.go create mode 100644 lib/msgpack/version.go create mode 100644 lib/tagparser/parser.go create mode 100644 lib/tagparser/tagparser.go create mode 100644 lib/wire/wire.go create mode 100644 third_party/halfsiphash/halfsiphash.h create mode 100644 third_party/msgpackpp/msgpackpp.h diff --git a/CMakeLists.txt b/CMakeLists.txt index 53bffee..4e6e640 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -6,13 +6,18 @@ include(pico_sdk_import.cmake) project(picomap C CXX ASM) set(CMAKE_C_STANDARD 11) -set(CMAKE_CXX_STANDARD 17) +set(CMAKE_CXX_STANDARD 23) pico_sdk_init() add_executable(picomap picomap.cpp ) +target_include_directories(picomap PRIVATE + third_party/msgpackpp + third_party/halfsiphash +) + pico_enable_stdio_usb(picomap 1) pico_enable_stdio_uart(picomap 0) diff --git a/cmd/load/main.go b/cmd/load/main.go index 4204e88..5db160d 100644 --- a/cmd/load/main.go +++ b/cmd/load/main.go @@ -9,6 +9,7 @@ import ( "github.com/theater/picomap/lib/picoserial" "github.com/theater/picomap/lib/picotool" + "github.com/theater/picomap/lib/wire" ) func main() { @@ -40,9 +41,23 @@ func run(buildDir string) error { } if dev != "" { fmt.Printf("Sending 'b' to %s to enter BOOTSEL mode...\n", dev) - if err := picoserial.SendByte(dev, 'b'); err != nil { + resp, err := picoserial.SendByteAndRead(dev, 'b', 2*time.Second) + if err != nil { return err } + if len(resp) > 0 { + msg, err := wire.DecodeMessage(resp) + if err != nil { + fmt.Fprintf(os.Stderr, "warning: failed to decode response: %v\n", err) + } else { + switch msg.(type) { + case *wire.RebootingBootsel: + fmt.Println("Device confirmed reboot into BOOTSEL mode.") + default: + fmt.Printf("Unexpected response type: %T\n", msg) + } + } + } time.Sleep(2 * time.Second) } diff --git a/lib/halfsiphash/halfsiphash.go b/lib/halfsiphash/halfsiphash.go new file mode 100644 index 0000000..efc10d6 --- /dev/null +++ b/lib/halfsiphash/halfsiphash.go @@ -0,0 +1,74 @@ +package halfsiphash + +import "encoding/binary" + +func rotl(x uint32, b uint) uint32 { + return (x << b) | (x >> (32 - b)) +} + +func sipround(v0, v1, v2, v3 *uint32) { + *v0 += *v1 + *v1 = rotl(*v1, 5) + *v1 ^= *v0 + *v0 = rotl(*v0, 16) + *v2 += *v3 + *v3 = rotl(*v3, 8) + *v3 ^= *v2 + *v0 += *v3 + *v3 = rotl(*v3, 7) + *v3 ^= *v0 + *v2 += *v1 + *v1 = rotl(*v1, 13) + *v1 ^= *v2 + *v2 = rotl(*v2, 16) +} + +// Sum32 computes HalfSipHash-2-4 with an 8-byte key and returns a 4-byte hash. +func Sum32(data []byte, key [8]byte) uint32 { + k0 := binary.LittleEndian.Uint32(key[0:4]) + k1 := binary.LittleEndian.Uint32(key[4:8]) + + v0 := uint32(0) ^ k0 + v1 := uint32(0) ^ k1 + v2 := uint32(0x6c796765) ^ k0 + v3 := uint32(0x74656462) ^ k1 + + // Process full 4-byte blocks. + nblocks := len(data) / 4 + for i := 0; i < nblocks; i++ { + m := binary.LittleEndian.Uint32(data[i*4:]) + v3 ^= m + for j := 0; j < 2; j++ { + sipround(&v0, &v1, &v2, &v3) + } + v0 ^= m + } + + // Process remaining bytes. + b := uint32(len(data)) << 24 + tail := data[nblocks*4:] + switch len(tail) { + case 3: + b |= uint32(tail[2]) << 16 + fallthrough + case 2: + b |= uint32(tail[1]) << 8 + fallthrough + case 1: + b |= uint32(tail[0]) + } + + v3 ^= b + for i := 0; i < 2; i++ { + sipround(&v0, &v1, &v2, &v3) + } + v0 ^= b + + v2 ^= 0xff + + for i := 0; i < 4; i++ { + sipround(&v0, &v1, &v2, &v3) + } + + return v1 ^ v3 +} diff --git a/lib/halfsiphash/halfsiphash_test.go b/lib/halfsiphash/halfsiphash_test.go new file mode 100644 index 0000000..e59c8c3 --- /dev/null +++ b/lib/halfsiphash/halfsiphash_test.go @@ -0,0 +1,31 @@ +package halfsiphash + +import "testing" + +// Test vectors from the reference implementation. +// Key: 00 01 02 03 04 05 06 07 +// Input: sequence 00, 00 01, 00 01 02, ... +var vectors32 = []uint32{ + 0x5b9f35a9, + 0xb85a4727, + 0x03a662fa, + 0x04e7fe8a, + 0x89466e2a, + 0x69b6fac5, + 0x23fc6358, + 0xc563cf8b, +} + +func TestSum32(t *testing.T) { + key := [8]byte{0, 1, 2, 3, 4, 5, 6, 7} + for i, want := range vectors32 { + data := make([]byte, i) + for j := range data { + data[j] = byte(j) + } + got := Sum32(data, key) + if got != want { + t.Errorf("Sum32(len=%d) = %08x, want %08x", i, got, want) + } + } +} diff --git a/lib/msgpack/codes.go b/lib/msgpack/codes.go new file mode 100644 index 0000000..4180815 --- /dev/null +++ b/lib/msgpack/codes.go @@ -0,0 +1,88 @@ +package msgpack + +var ( + PosFixedNumHigh byte = 0x7f + NegFixedNumLow byte = 0xe0 + + Nil byte = 0xc0 + + False byte = 0xc2 + True byte = 0xc3 + + Float byte = 0xca + Double byte = 0xcb + + Uint8 byte = 0xcc + Uint16 byte = 0xcd + Uint32 byte = 0xce + Uint64 byte = 0xcf + + Int8 byte = 0xd0 + Int16 byte = 0xd1 + Int32 byte = 0xd2 + Int64 byte = 0xd3 + + FixedStrLow byte = 0xa0 + FixedStrHigh byte = 0xbf + FixedStrMask byte = 0x1f + Str8 byte = 0xd9 + Str16 byte = 0xda + Str32 byte = 0xdb + + Bin8 byte = 0xc4 + Bin16 byte = 0xc5 + Bin32 byte = 0xc6 + + FixedArrayLow byte = 0x90 + FixedArrayHigh byte = 0x9f + FixedArrayMask byte = 0xf + Array16 byte = 0xdc + Array32 byte = 0xdd + + FixedMapLow byte = 0x80 + FixedMapHigh byte = 0x8f + FixedMapMask byte = 0xf + Map16 byte = 0xde + Map32 byte = 0xdf + + FixExt1 byte = 0xd4 + FixExt2 byte = 0xd5 + FixExt4 byte = 0xd6 + FixExt8 byte = 0xd7 + FixExt16 byte = 0xd8 + Ext8 byte = 0xc7 + Ext16 byte = 0xc8 + Ext32 byte = 0xc9 +) + +func IsFixedNum(c byte) bool { + return c <= PosFixedNumHigh || c >= NegFixedNumLow +} + +func IsFixedMap(c byte) bool { + return c >= FixedMapLow && c <= FixedMapHigh +} + +func IsFixedArray(c byte) bool { + return c >= FixedArrayLow && c <= FixedArrayHigh +} + +func IsFixedString(c byte) bool { + return c >= FixedStrLow && c <= FixedStrHigh +} + +func IsString(c byte) bool { + return IsFixedString(c) || c == Str8 || c == Str16 || c == Str32 +} + +func IsBin(c byte) bool { + return c == Bin8 || c == Bin16 || c == Bin32 +} + +func IsFixedExt(c byte) bool { + return c >= FixExt1 && c <= FixExt16 +} + +func IsExt(c byte) bool { + return IsFixedExt(c) || c == Ext8 || c == Ext16 || c == Ext32 +} diff --git a/lib/msgpack/decode.go b/lib/msgpack/decode.go new file mode 100644 index 0000000..a4afbc7 --- /dev/null +++ b/lib/msgpack/decode.go @@ -0,0 +1,707 @@ +package msgpack + +import ( + "bufio" + "bytes" + "errors" + "fmt" + "io" + "reflect" + "sync" + "time" + +) + +const ( + bytesAllocLimit = 1 << 20 // 1mb + sliceAllocLimit = 1e6 // 1m elements + maxMapSize = 1e6 // 1m elements +) + +const ( + looseInterfaceDecodingFlag uint32 = 1 << iota + disallowUnknownFieldsFlag + usePreallocateValues + disableAllocLimitFlag +) + +type bufReader interface { + io.Reader + io.ByteScanner +} + +//------------------------------------------------------------------------------ + +var decPool = sync.Pool{ + New: func() interface{} { + return NewDecoder(nil) + }, +} + +func GetDecoder() *Decoder { + return decPool.Get().(*Decoder) +} + +func PutDecoder(dec *Decoder) { + dec.r = nil + dec.s = nil + decPool.Put(dec) +} + +//------------------------------------------------------------------------------ + +// Unmarshal decodes the MessagePack-encoded data and stores the result +// in the value pointed to by v. +func Unmarshal(data []byte, v interface{}) error { + dec := GetDecoder() + dec.UsePreallocateValues(true) + dec.Reset(bytes.NewReader(data)) + err := dec.Decode(v) + + PutDecoder(dec) + + return err +} + +// A Decoder reads and decodes MessagePack values from an input stream. +type Decoder struct { + r io.Reader + s io.ByteScanner + mapDecoder func(*Decoder) (interface{}, error) + structTag string + buf []byte + rec []byte + dict []string + flags uint32 +} + +// NewDecoder returns a new decoder that reads from r. +// +// The decoder introduces its own buffering and may read data from r +// beyond the requested msgpack values. Buffering can be disabled +// by passing a reader that implements io.ByteScanner interface. +func NewDecoder(r io.Reader) *Decoder { + d := new(Decoder) + d.Reset(r) + return d +} + +// Reset discards any buffered data, resets all state, and switches the buffered +// reader to read from r. +func (d *Decoder) Reset(r io.Reader) { + d.ResetDict(r, nil) +} + +// ResetDict is like Reset, but also resets the dict. +func (d *Decoder) ResetDict(r io.Reader, dict []string) { + d.ResetReader(r) + d.flags = 0 + d.structTag = "" + d.dict = dict +} + +func (d *Decoder) WithDict(dict []string, fn func(*Decoder) error) error { + oldDict := d.dict + d.dict = dict + err := fn(d) + d.dict = oldDict + return err +} + +func (d *Decoder) ResetReader(r io.Reader) { + d.mapDecoder = nil + d.dict = nil + + if br, ok := r.(bufReader); ok { + d.r = br + d.s = br + } else if r == nil { + d.r = nil + d.s = nil + } else { + br := bufio.NewReader(r) + d.r = br + d.s = br + } +} + +func (d *Decoder) SetMapDecoder(fn func(*Decoder) (interface{}, error)) { + d.mapDecoder = fn +} + +// UseLooseInterfaceDecoding causes decoder to use DecodeInterfaceLoose +// to decode msgpack value into Go interface{}. +func (d *Decoder) UseLooseInterfaceDecoding(on bool) { + if on { + d.flags |= looseInterfaceDecodingFlag + } else { + d.flags &= ^looseInterfaceDecodingFlag + } +} + +// SetCustomStructTag causes the decoder to use the supplied tag as a fallback option +// if there is no msgpack tag. +func (d *Decoder) SetCustomStructTag(tag string) { + d.structTag = tag +} + +// DisallowUnknownFields causes the Decoder to return an error when the destination +// is a struct and the input contains object keys which do not match any +// non-ignored, exported fields in the destination. +func (d *Decoder) DisallowUnknownFields(on bool) { + if on { + d.flags |= disallowUnknownFieldsFlag + } else { + d.flags &= ^disallowUnknownFieldsFlag + } +} + +// UseInternedStrings enables support for decoding interned strings. +func (d *Decoder) UseInternedStrings(on bool) { + if on { + d.flags |= useInternedStringsFlag + } else { + d.flags &= ^useInternedStringsFlag + } +} + +// UsePreallocateValues enables preallocating values in chunks +func (d *Decoder) UsePreallocateValues(on bool) { + if on { + d.flags |= usePreallocateValues + } else { + d.flags &= ^usePreallocateValues + } +} + +// DisableAllocLimit enables fully allocating slices/maps when the size is known +func (d *Decoder) DisableAllocLimit(on bool) { + if on { + d.flags |= disableAllocLimitFlag + } else { + d.flags &= ^disableAllocLimitFlag + } +} + +// Buffered returns a reader of the data remaining in the Decoder's buffer. +// The reader is valid until the next call to Decode. +func (d *Decoder) Buffered() io.Reader { + return d.r +} + +//nolint:gocyclo +func (d *Decoder) Decode(v interface{}) error { + var err error + switch v := v.(type) { + case *string: + if v != nil { + *v, err = d.DecodeString() + return err + } + case *[]byte: + if v != nil { + return d.decodeBytesPtr(v) + } + case *int: + if v != nil { + *v, err = d.DecodeInt() + return err + } + case *int8: + if v != nil { + *v, err = d.DecodeInt8() + return err + } + case *int16: + if v != nil { + *v, err = d.DecodeInt16() + return err + } + case *int32: + if v != nil { + *v, err = d.DecodeInt32() + return err + } + case *int64: + if v != nil { + *v, err = d.DecodeInt64() + return err + } + case *uint: + if v != nil { + *v, err = d.DecodeUint() + return err + } + case *uint8: + if v != nil { + *v, err = d.DecodeUint8() + return err + } + case *uint16: + if v != nil { + *v, err = d.DecodeUint16() + return err + } + case *uint32: + if v != nil { + *v, err = d.DecodeUint32() + return err + } + case *uint64: + if v != nil { + *v, err = d.DecodeUint64() + return err + } + case *bool: + if v != nil { + *v, err = d.DecodeBool() + return err + } + case *float32: + if v != nil { + *v, err = d.DecodeFloat32() + return err + } + case *float64: + if v != nil { + *v, err = d.DecodeFloat64() + return err + } + case *[]string: + return d.decodeStringSlicePtr(v) + case *map[string]string: + return d.decodeMapStringStringPtr(v) + case *map[string]interface{}: + return d.decodeMapStringInterfacePtr(v) + case *time.Duration: + if v != nil { + vv, err := d.DecodeInt64() + *v = time.Duration(vv) + return err + } + case *time.Time: + if v != nil { + *v, err = d.DecodeTime() + return err + } + } + + vv := reflect.ValueOf(v) + if !vv.IsValid() { + return errors.New("msgpack: Decode(nil)") + } + if vv.Kind() != reflect.Ptr { + return fmt.Errorf("msgpack: Decode(non-pointer %T)", v) + } + if vv.IsNil() { + return fmt.Errorf("msgpack: Decode(non-settable %T)", v) + } + + vv = vv.Elem() + if vv.Kind() == reflect.Interface { + if !vv.IsNil() { + vv = vv.Elem() + if vv.Kind() != reflect.Ptr { + return fmt.Errorf("msgpack: Decode(non-pointer %s)", vv.Type().String()) + } + } + } + + return d.DecodeValue(vv) +} + +func (d *Decoder) DecodeMulti(v ...interface{}) error { + for _, vv := range v { + if err := d.Decode(vv); err != nil { + return err + } + } + return nil +} + +func (d *Decoder) decodeInterfaceCond() (interface{}, error) { + if d.flags&looseInterfaceDecodingFlag != 0 { + return d.DecodeInterfaceLoose() + } + return d.DecodeInterface() +} + +func (d *Decoder) DecodeValue(v reflect.Value) error { + decode := getDecoder(v.Type()) + return decode(d, v) +} + +func (d *Decoder) DecodeNil() error { + c, err := d.readCode() + if err != nil { + return err + } + if c != Nil { + return fmt.Errorf("msgpack: invalid code=%x decoding nil", c) + } + return nil +} + +func (d *Decoder) decodeNilValue(v reflect.Value) error { + err := d.DecodeNil() + if v.IsNil() { + return err + } + if v.Kind() == reflect.Ptr { + v = v.Elem() + } + v.Set(reflect.Zero(v.Type())) + return err +} + +func (d *Decoder) DecodeBool() (bool, error) { + c, err := d.readCode() + if err != nil { + return false, err + } + return d.bool(c) +} + +func (d *Decoder) bool(c byte) (bool, error) { + if c == Nil { + return false, nil + } + if c == False { + return false, nil + } + if c == True { + return true, nil + } + return false, fmt.Errorf("msgpack: invalid code=%x decoding bool", c) +} + +func (d *Decoder) DecodeDuration() (time.Duration, error) { + n, err := d.DecodeInt64() + if err != nil { + return 0, err + } + return time.Duration(n), nil +} + +// DecodeInterface decodes value into interface. It returns following types: +// - nil, +// - bool, +// - int8, int16, int32, int64, +// - uint8, uint16, uint32, uint64, +// - float32 and float64, +// - string, +// - []byte, +// - slices of any of the above, +// - maps of any of the above. +// +// DecodeInterface should be used only when you don't know the type of value +// you are decoding. For example, if you are decoding number it is better to use +// DecodeInt64 for negative numbers and DecodeUint64 for positive numbers. +func (d *Decoder) DecodeInterface() (interface{}, error) { + c, err := d.readCode() + if err != nil { + return nil, err + } + + if IsFixedNum(c) { + return int8(c), nil + } + if IsFixedMap(c) { + err = d.s.UnreadByte() + if err != nil { + return nil, err + } + return d.decodeMapDefault() + } + if IsFixedArray(c) { + return d.decodeSlice(c) + } + if IsFixedString(c) { + return d.string(c) + } + + switch c { + case Nil: + return nil, nil + case False, True: + return d.bool(c) + case Float: + return d.float32(c) + case Double: + return d.float64(c) + case Uint8: + return d.uint8() + case Uint16: + return d.uint16() + case Uint32: + return d.uint32() + case Uint64: + return d.uint64() + case Int8: + return d.int8() + case Int16: + return d.int16() + case Int32: + return d.int32() + case Int64: + return d.int64() + case Bin8, Bin16, Bin32: + return d.bytes(c, nil) + case Str8, Str16, Str32: + return d.string(c) + case Array16, Array32: + return d.decodeSlice(c) + case Map16, Map32: + err = d.s.UnreadByte() + if err != nil { + return nil, err + } + return d.decodeMapDefault() + case FixExt1, FixExt2, FixExt4, FixExt8, FixExt16, + Ext8, Ext16, Ext32: + return d.decodeInterfaceExt(c) + } + + return 0, fmt.Errorf("msgpack: unknown code %x decoding interface{}", c) +} + +// DecodeInterfaceLoose is like DecodeInterface except that: +// - int8, int16, and int32 are converted to int64, +// - uint8, uint16, and uint32 are converted to uint64, +// - float32 is converted to float64. +// - []byte is converted to string. +func (d *Decoder) DecodeInterfaceLoose() (interface{}, error) { + c, err := d.readCode() + if err != nil { + return nil, err + } + + if IsFixedNum(c) { + return int64(int8(c)), nil + } + if IsFixedMap(c) { + err = d.s.UnreadByte() + if err != nil { + return nil, err + } + return d.decodeMapDefault() + } + if IsFixedArray(c) { + return d.decodeSlice(c) + } + if IsFixedString(c) { + return d.string(c) + } + + switch c { + case Nil: + return nil, nil + case False, True: + return d.bool(c) + case Float, Double: + return d.float64(c) + case Uint8, Uint16, Uint32, Uint64: + return d.uint(c) + case Int8, Int16, Int32, Int64: + return d.int(c) + case Str8, Str16, Str32, + Bin8, Bin16, Bin32: + return d.string(c) + case Array16, Array32: + return d.decodeSlice(c) + case Map16, Map32: + err = d.s.UnreadByte() + if err != nil { + return nil, err + } + return d.decodeMapDefault() + case FixExt1, FixExt2, FixExt4, FixExt8, FixExt16, + Ext8, Ext16, Ext32: + return d.decodeInterfaceExt(c) + } + + return 0, fmt.Errorf("msgpack: unknown code %x decoding interface{}", c) +} + +// Skip skips next value. +func (d *Decoder) Skip() error { + c, err := d.readCode() + if err != nil { + return err + } + + if IsFixedNum(c) { + return nil + } + if IsFixedMap(c) { + return d.skipMap(c) + } + if IsFixedArray(c) { + return d.skipSlice(c) + } + if IsFixedString(c) { + return d.skipBytes(c) + } + + switch c { + case Nil, False, True: + return nil + case Uint8, Int8: + return d.skipN(1) + case Uint16, Int16: + return d.skipN(2) + case Uint32, Int32, Float: + return d.skipN(4) + case Uint64, Int64, Double: + return d.skipN(8) + case Bin8, Bin16, Bin32: + return d.skipBytes(c) + case Str8, Str16, Str32: + return d.skipBytes(c) + case Array16, Array32: + return d.skipSlice(c) + case Map16, Map32: + return d.skipMap(c) + case FixExt1, FixExt2, FixExt4, FixExt8, FixExt16, + Ext8, Ext16, Ext32: + return d.skipExt(c) + } + + return fmt.Errorf("msgpack: unknown code %x", c) +} + +func (d *Decoder) DecodeRaw() (RawMessage, error) { + d.rec = make([]byte, 0) + if err := d.Skip(); err != nil { + return nil, err + } + msg := RawMessage(d.rec) + d.rec = nil + return msg, nil +} + +// PeekCode returns the next MessagePack code without advancing the reader. +// Subpackage msgpack/codes defines the list of available +func (d *Decoder) PeekCode() (byte, error) { + c, err := d.s.ReadByte() + if err != nil { + return 0, err + } + return c, d.s.UnreadByte() +} + +// ReadFull reads exactly len(buf) bytes into the buf. +func (d *Decoder) ReadFull(buf []byte) error { + _, err := readN(d.r, buf, len(buf)) + return err +} + +func (d *Decoder) hasNilCode() bool { + code, err := d.PeekCode() + return err == nil && code == Nil +} + +func (d *Decoder) readCode() (byte, error) { + c, err := d.s.ReadByte() + if err != nil { + return 0, err + } + if d.rec != nil { + d.rec = append(d.rec, c) + } + return c, nil +} + +func (d *Decoder) readFull(b []byte) error { + _, err := io.ReadFull(d.r, b) + if err != nil { + return err + } + if d.rec != nil { + d.rec = append(d.rec, b...) + } + return nil +} + +func (d *Decoder) readN(n int) ([]byte, error) { + var err error + if d.flags&disableAllocLimitFlag != 0 { + d.buf, err = readN(d.r, d.buf, n) + } else { + d.buf, err = readNGrow(d.r, d.buf, n) + } + if err != nil { + return nil, err + } + if d.rec != nil { + // TODO: read directly into d.rec? + d.rec = append(d.rec, d.buf...) + } + return d.buf, nil +} + +func readN(r io.Reader, b []byte, n int) ([]byte, error) { + if b == nil { + if n == 0 { + return make([]byte, 0), nil + } + b = make([]byte, 0, n) + } + + if n > cap(b) { + b = append(b, make([]byte, n-len(b))...) + } else if n <= cap(b) { + b = b[:n] + } + + _, err := io.ReadFull(r, b) + return b, err +} + +func readNGrow(r io.Reader, b []byte, n int) ([]byte, error) { + if b == nil { + if n == 0 { + return make([]byte, 0), nil + } + switch { + case n < 64: + b = make([]byte, 0, 64) + case n <= bytesAllocLimit: + b = make([]byte, 0, n) + default: + b = make([]byte, 0, bytesAllocLimit) + } + } + + if n <= cap(b) { + b = b[:n] + _, err := io.ReadFull(r, b) + return b, err + } + b = b[:cap(b)] + + var pos int + for { + alloc := min(n-len(b), bytesAllocLimit) + b = append(b, make([]byte, alloc)...) + + _, err := io.ReadFull(r, b[pos:]) + if err != nil { + return b, err + } + + if len(b) == n { + break + } + pos = len(b) + } + + return b, nil +} + +func min(a, b int) int { //nolint:unparam + if a <= b { + return a + } + return b +} diff --git a/lib/msgpack/decode_map.go b/lib/msgpack/decode_map.go new file mode 100644 index 0000000..ec43a22 --- /dev/null +++ b/lib/msgpack/decode_map.go @@ -0,0 +1,355 @@ +package msgpack + +import ( + "errors" + "fmt" + "reflect" + +) + +var errArrayStruct = errors.New("msgpack: number of fields in array-encoded struct has changed") + +var ( + mapStringStringPtrType = reflect.TypeOf((*map[string]string)(nil)) + mapStringStringType = mapStringStringPtrType.Elem() + mapStringBoolPtrType = reflect.TypeOf((*map[string]bool)(nil)) + mapStringBoolType = mapStringBoolPtrType.Elem() +) + +var ( + mapStringInterfacePtrType = reflect.TypeOf((*map[string]interface{})(nil)) + mapStringInterfaceType = mapStringInterfacePtrType.Elem() +) + +func decodeMapValue(d *Decoder, v reflect.Value) error { + n, err := d.DecodeMapLen() + if err != nil { + return err + } + + typ := v.Type() + if n == -1 { + v.Set(reflect.Zero(typ)) + return nil + } + + if v.IsNil() { + ln := n + if d.flags&disableAllocLimitFlag == 0 { + ln = min(ln, maxMapSize) + } + v.Set(reflect.MakeMapWithSize(typ, ln)) + } + if n == 0 { + return nil + } + + return d.decodeTypedMapValue(v, n) +} + +func (d *Decoder) decodeMapDefault() (interface{}, error) { + if d.mapDecoder != nil { + return d.mapDecoder(d) + } + return d.DecodeMap() +} + +// DecodeMapLen decodes map length. Length is -1 when map is nil. +func (d *Decoder) DecodeMapLen() (int, error) { + c, err := d.readCode() + if err != nil { + return 0, err + } + + if IsExt(c) { + if err = d.skipExtHeader(c); err != nil { + return 0, err + } + + c, err = d.readCode() + if err != nil { + return 0, err + } + } + return d.mapLen(c) +} + +func (d *Decoder) mapLen(c byte) (int, error) { + if c == Nil { + return -1, nil + } + if c >= FixedMapLow && c <= FixedMapHigh { + return int(c & FixedMapMask), nil + } + if c == Map16 { + size, err := d.uint16() + return int(size), err + } + if c == Map32 { + size, err := d.uint32() + return int(size), err + } + return 0, unexpectedCodeError{code: c, hint: "map length"} +} + +func decodeMapStringStringValue(d *Decoder, v reflect.Value) error { + mptr := v.Addr().Convert(mapStringStringPtrType).Interface().(*map[string]string) + return d.decodeMapStringStringPtr(mptr) +} + +func (d *Decoder) decodeMapStringStringPtr(ptr *map[string]string) error { + size, err := d.DecodeMapLen() + if err != nil { + return err + } + if size == -1 { + *ptr = nil + return nil + } + + m := *ptr + if m == nil { + ln := size + if d.flags&disableAllocLimitFlag == 0 { + ln = min(size, maxMapSize) + } + *ptr = make(map[string]string, ln) + m = *ptr + } + + for i := 0; i < size; i++ { + mk, err := d.DecodeString() + if err != nil { + return err + } + mv, err := d.DecodeString() + if err != nil { + return err + } + m[mk] = mv + } + + return nil +} + +func decodeMapStringInterfaceValue(d *Decoder, v reflect.Value) error { + ptr := v.Addr().Convert(mapStringInterfacePtrType).Interface().(*map[string]interface{}) + return d.decodeMapStringInterfacePtr(ptr) +} + +func (d *Decoder) decodeMapStringInterfacePtr(ptr *map[string]interface{}) error { + m, err := d.DecodeMap() + if err != nil { + return err + } + *ptr = m + return nil +} + +func (d *Decoder) DecodeMap() (map[string]interface{}, error) { + n, err := d.DecodeMapLen() + if err != nil { + return nil, err + } + + if n == -1 { + return nil, nil + } + + m := make(map[string]interface{}, n) + + for i := 0; i < n; i++ { + mk, err := d.DecodeString() + if err != nil { + return nil, err + } + mv, err := d.decodeInterfaceCond() + if err != nil { + return nil, err + } + m[mk] = mv + } + + return m, nil +} + +func (d *Decoder) DecodeUntypedMap() (map[interface{}]interface{}, error) { + n, err := d.DecodeMapLen() + if err != nil { + return nil, err + } + + if n == -1 { + return nil, nil + } + + m := make(map[interface{}]interface{}, n) + + for i := 0; i < n; i++ { + mk, err := d.decodeInterfaceCond() + if err != nil { + return nil, err + } + + mv, err := d.decodeInterfaceCond() + if err != nil { + return nil, err + } + + m[mk] = mv + } + + return m, nil +} + +// DecodeTypedMap decodes a typed map. Typed map is a map that has a fixed type for keys and values. +// Key and value types may be different. +func (d *Decoder) DecodeTypedMap() (interface{}, error) { + n, err := d.DecodeMapLen() + if err != nil { + return nil, err + } + if n <= 0 { + return nil, nil + } + + key, err := d.decodeInterfaceCond() + if err != nil { + return nil, err + } + + value, err := d.decodeInterfaceCond() + if err != nil { + return nil, err + } + + keyType := reflect.TypeOf(key) + valueType := reflect.TypeOf(value) + + if !keyType.Comparable() { + return nil, fmt.Errorf("msgpack: unsupported map key: %s", keyType.String()) + } + + mapType := reflect.MapOf(keyType, valueType) + + ln := n + if d.flags&disableAllocLimitFlag == 0 { + ln = min(ln, maxMapSize) + } + + mapValue := reflect.MakeMapWithSize(mapType, ln) + mapValue.SetMapIndex(reflect.ValueOf(key), reflect.ValueOf(value)) + + n-- + if err := d.decodeTypedMapValue(mapValue, n); err != nil { + return nil, err + } + + return mapValue.Interface(), nil +} + +func (d *Decoder) decodeTypedMapValue(v reflect.Value, n int) error { + var ( + typ = v.Type() + keyType = typ.Key() + valueType = typ.Elem() + ) + for i := 0; i < n; i++ { + mk := d.newValue(keyType).Elem() + if err := d.DecodeValue(mk); err != nil { + return err + } + + mv := d.newValue(valueType).Elem() + if err := d.DecodeValue(mv); err != nil { + return err + } + + v.SetMapIndex(mk, mv) + } + + return nil +} + +func (d *Decoder) skipMap(c byte) error { + n, err := d.mapLen(c) + if err != nil { + return err + } + for i := 0; i < n; i++ { + if err := d.Skip(); err != nil { + return err + } + if err := d.Skip(); err != nil { + return err + } + } + return nil +} + +func decodeStructValue(d *Decoder, v reflect.Value) error { + c, err := d.readCode() + if err != nil { + return err + } + + n, err := d.mapLen(c) + if err == nil { + return d.decodeStruct(v, n) + } + + var err2 error + n, err2 = d.arrayLen(c) + if err2 != nil { + return err + } + + if n <= 0 { + v.Set(reflect.Zero(v.Type())) + return nil + } + + fields := structs.Fields(v.Type(), d.structTag) + if n != len(fields.List) { + return errArrayStruct + } + + for _, f := range fields.List { + if err := f.DecodeValue(d, v); err != nil { + return err + } + } + + return nil +} + +func (d *Decoder) decodeStruct(v reflect.Value, n int) error { + if n == -1 { + v.Set(reflect.Zero(v.Type())) + return nil + } + + fields := structs.Fields(v.Type(), d.structTag) + for i := 0; i < n; i++ { + name, err := d.decodeStringTemp() + if err != nil { + return err + } + + if f := fields.Map[name]; f != nil { + if err := f.DecodeValue(d, v); err != nil { + return err + } + continue + } + + if d.flags&disallowUnknownFieldsFlag != 0 { + return fmt.Errorf("msgpack: unknown field %q", name) + } + if err := d.Skip(); err != nil { + return err + } + } + + return nil +} diff --git a/lib/msgpack/decode_number.go b/lib/msgpack/decode_number.go new file mode 100644 index 0000000..8c9fa79 --- /dev/null +++ b/lib/msgpack/decode_number.go @@ -0,0 +1,294 @@ +package msgpack + +import ( + "fmt" + "math" + "reflect" + +) + +func (d *Decoder) skipN(n int) error { + _, err := d.readN(n) + return err +} + +func (d *Decoder) uint8() (uint8, error) { + c, err := d.readCode() + if err != nil { + return 0, err + } + return c, nil +} + +func (d *Decoder) int8() (int8, error) { + n, err := d.uint8() + return int8(n), err +} + +func (d *Decoder) uint16() (uint16, error) { + b, err := d.readN(2) + if err != nil { + return 0, err + } + return (uint16(b[0]) << 8) | uint16(b[1]), nil +} + +func (d *Decoder) int16() (int16, error) { + n, err := d.uint16() + return int16(n), err +} + +func (d *Decoder) uint32() (uint32, error) { + b, err := d.readN(4) + if err != nil { + return 0, err + } + n := (uint32(b[0]) << 24) | + (uint32(b[1]) << 16) | + (uint32(b[2]) << 8) | + uint32(b[3]) + return n, nil +} + +func (d *Decoder) int32() (int32, error) { + n, err := d.uint32() + return int32(n), err +} + +func (d *Decoder) uint64() (uint64, error) { + b, err := d.readN(8) + if err != nil { + return 0, err + } + n := (uint64(b[0]) << 56) | + (uint64(b[1]) << 48) | + (uint64(b[2]) << 40) | + (uint64(b[3]) << 32) | + (uint64(b[4]) << 24) | + (uint64(b[5]) << 16) | + (uint64(b[6]) << 8) | + uint64(b[7]) + return n, nil +} + +func (d *Decoder) int64() (int64, error) { + n, err := d.uint64() + return int64(n), err +} + +// DecodeUint64 decodes msgpack int8/16/32/64 and uint8/16/32/64 +// into Go uint64. +func (d *Decoder) DecodeUint64() (uint64, error) { + c, err := d.readCode() + if err != nil { + return 0, err + } + return d.uint(c) +} + +func (d *Decoder) uint(c byte) (uint64, error) { + if c == Nil { + return 0, nil + } + if IsFixedNum(c) { + return uint64(int8(c)), nil + } + switch c { + case Uint8: + n, err := d.uint8() + return uint64(n), err + case Int8: + n, err := d.int8() + return uint64(n), err + case Uint16: + n, err := d.uint16() + return uint64(n), err + case Int16: + n, err := d.int16() + return uint64(n), err + case Uint32: + n, err := d.uint32() + return uint64(n), err + case Int32: + n, err := d.int32() + return uint64(n), err + case Uint64, Int64: + return d.uint64() + } + return 0, fmt.Errorf("msgpack: invalid code=%x decoding uint64", c) +} + +// DecodeInt64 decodes msgpack int8/16/32/64 and uint8/16/32/64 +// into Go int64. +func (d *Decoder) DecodeInt64() (int64, error) { + c, err := d.readCode() + if err != nil { + return 0, err + } + return d.int(c) +} + +func (d *Decoder) int(c byte) (int64, error) { + if c == Nil { + return 0, nil + } + if IsFixedNum(c) { + return int64(int8(c)), nil + } + switch c { + case Uint8: + n, err := d.uint8() + return int64(n), err + case Int8: + n, err := d.uint8() + return int64(int8(n)), err + case Uint16: + n, err := d.uint16() + return int64(n), err + case Int16: + n, err := d.uint16() + return int64(int16(n)), err + case Uint32: + n, err := d.uint32() + return int64(n), err + case Int32: + n, err := d.uint32() + return int64(int32(n)), err + case Uint64, Int64: + n, err := d.uint64() + return int64(n), err + } + return 0, fmt.Errorf("msgpack: invalid code=%x decoding int64", c) +} + +func (d *Decoder) DecodeFloat32() (float32, error) { + c, err := d.readCode() + if err != nil { + return 0, err + } + return d.float32(c) +} + +func (d *Decoder) float32(c byte) (float32, error) { + if c == Float { + n, err := d.uint32() + if err != nil { + return 0, err + } + return math.Float32frombits(n), nil + } + + n, err := d.int(c) + if err != nil { + return 0, fmt.Errorf("msgpack: invalid code=%x decoding float32", c) + } + return float32(n), nil +} + +// DecodeFloat64 decodes msgpack float32/64 into Go float64. +func (d *Decoder) DecodeFloat64() (float64, error) { + c, err := d.readCode() + if err != nil { + return 0, err + } + return d.float64(c) +} + +func (d *Decoder) float64(c byte) (float64, error) { + switch c { + case Float: + n, err := d.float32(c) + if err != nil { + return 0, err + } + return float64(n), nil + case Double: + n, err := d.uint64() + if err != nil { + return 0, err + } + return math.Float64frombits(n), nil + } + + n, err := d.int(c) + if err != nil { + return 0, fmt.Errorf("msgpack: invalid code=%x decoding float32", c) + } + return float64(n), nil +} + +func (d *Decoder) DecodeUint() (uint, error) { + n, err := d.DecodeUint64() + return uint(n), err +} + +func (d *Decoder) DecodeUint8() (uint8, error) { + n, err := d.DecodeUint64() + return uint8(n), err +} + +func (d *Decoder) DecodeUint16() (uint16, error) { + n, err := d.DecodeUint64() + return uint16(n), err +} + +func (d *Decoder) DecodeUint32() (uint32, error) { + n, err := d.DecodeUint64() + return uint32(n), err +} + +func (d *Decoder) DecodeInt() (int, error) { + n, err := d.DecodeInt64() + return int(n), err +} + +func (d *Decoder) DecodeInt8() (int8, error) { + n, err := d.DecodeInt64() + return int8(n), err +} + +func (d *Decoder) DecodeInt16() (int16, error) { + n, err := d.DecodeInt64() + return int16(n), err +} + +func (d *Decoder) DecodeInt32() (int32, error) { + n, err := d.DecodeInt64() + return int32(n), err +} + +func decodeFloat32Value(d *Decoder, v reflect.Value) error { + f, err := d.DecodeFloat32() + if err != nil { + return err + } + v.SetFloat(float64(f)) + return nil +} + +func decodeFloat64Value(d *Decoder, v reflect.Value) error { + f, err := d.DecodeFloat64() + if err != nil { + return err + } + v.SetFloat(f) + return nil +} + +func decodeInt64Value(d *Decoder, v reflect.Value) error { + n, err := d.DecodeInt64() + if err != nil { + return err + } + v.SetInt(n) + return nil +} + +func decodeUint64Value(d *Decoder, v reflect.Value) error { + n, err := d.DecodeUint64() + if err != nil { + return err + } + v.SetUint(n) + return nil +} diff --git a/lib/msgpack/decode_query.go b/lib/msgpack/decode_query.go new file mode 100644 index 0000000..633df80 --- /dev/null +++ b/lib/msgpack/decode_query.go @@ -0,0 +1,156 @@ +package msgpack + +import ( + "fmt" + "strconv" + "strings" + +) + +type queryResult struct { + query string + key string + values []interface{} + hasAsterisk bool +} + +func (q *queryResult) nextKey() { + ind := strings.IndexByte(q.query, '.') + if ind == -1 { + q.key = q.query + q.query = "" + return + } + q.key = q.query[:ind] + q.query = q.query[ind+1:] +} + +// Query extracts data specified by the query from the msgpack stream skipping +// any other data. Query consists of map keys and array indexes separated with dot, +// e.g. key1.0.key2. +func (d *Decoder) Query(query string) ([]interface{}, error) { + res := queryResult{ + query: query, + } + if err := d.query(&res); err != nil { + return nil, err + } + return res.values, nil +} + +func (d *Decoder) query(q *queryResult) error { + q.nextKey() + if q.key == "" { + v, err := d.decodeInterfaceCond() + if err != nil { + return err + } + q.values = append(q.values, v) + return nil + } + + code, err := d.PeekCode() + if err != nil { + return err + } + + switch { + case code == Map16 || code == Map32 || IsFixedMap(code): + err = d.queryMapKey(q) + case code == Array16 || code == Array32 || IsFixedArray(code): + err = d.queryArrayIndex(q) + default: + err = fmt.Errorf("msgpack: unsupported code=%x decoding key=%q", code, q.key) + } + return err +} + +func (d *Decoder) queryMapKey(q *queryResult) error { + n, err := d.DecodeMapLen() + if err != nil { + return err + } + if n == -1 { + return nil + } + + for i := 0; i < n; i++ { + key, err := d.decodeStringTemp() + if err != nil { + return err + } + + if key == q.key { + if err := d.query(q); err != nil { + return err + } + if q.hasAsterisk { + return d.skipNext((n - i - 1) * 2) + } + return nil + } + + if err := d.Skip(); err != nil { + return err + } + } + + return nil +} + +func (d *Decoder) queryArrayIndex(q *queryResult) error { + n, err := d.DecodeArrayLen() + if err != nil { + return err + } + if n == -1 { + return nil + } + + if q.key == "*" { + q.hasAsterisk = true + + query := q.query + for i := 0; i < n; i++ { + q.query = query + if err := d.query(q); err != nil { + return err + } + } + + q.hasAsterisk = false + return nil + } + + ind, err := strconv.Atoi(q.key) + if err != nil { + return err + } + + for i := 0; i < n; i++ { + if i == ind { + if err := d.query(q); err != nil { + return err + } + if q.hasAsterisk { + return d.skipNext(n - i - 1) + } + return nil + } + + if err := d.Skip(); err != nil { + return err + } + } + + return nil +} + +func (d *Decoder) skipNext(n int) error { + for i := 0; i < n; i++ { + if err := d.Skip(); err != nil { + return err + } + } + return nil +} diff --git a/lib/msgpack/decode_slice.go b/lib/msgpack/decode_slice.go new file mode 100644 index 0000000..b87e25b --- /dev/null +++ b/lib/msgpack/decode_slice.go @@ -0,0 +1,197 @@ +package msgpack + +import ( + "fmt" + "reflect" + +) + +var sliceStringPtrType = reflect.TypeOf((*[]string)(nil)) + +// DecodeArrayLen decodes array length. Length is -1 when array is nil. +func (d *Decoder) DecodeArrayLen() (int, error) { + c, err := d.readCode() + if err != nil { + return 0, err + } + return d.arrayLen(c) +} + +func (d *Decoder) arrayLen(c byte) (int, error) { + if c == Nil { + return -1, nil + } else if c >= FixedArrayLow && c <= FixedArrayHigh { + return int(c & FixedArrayMask), nil + } + switch c { + case Array16: + n, err := d.uint16() + return int(n), err + case Array32: + n, err := d.uint32() + return int(n), err + } + return 0, fmt.Errorf("msgpack: invalid code=%x decoding array length", c) +} + +func decodeStringSliceValue(d *Decoder, v reflect.Value) error { + ptr := v.Addr().Convert(sliceStringPtrType).Interface().(*[]string) + return d.decodeStringSlicePtr(ptr) +} + +func (d *Decoder) decodeStringSlicePtr(ptr *[]string) error { + n, err := d.DecodeArrayLen() + if err != nil { + return err + } + if n == -1 { + return nil + } + + ss := makeStrings(*ptr, n, d.flags&disableAllocLimitFlag != 0) + for i := 0; i < n; i++ { + s, err := d.DecodeString() + if err != nil { + return err + } + ss = append(ss, s) + } + *ptr = ss + + return nil +} + +func makeStrings(s []string, n int, noLimit bool) []string { + if !noLimit && n > sliceAllocLimit { + n = sliceAllocLimit + } + + if s == nil { + return make([]string, 0, n) + } + + if cap(s) >= n { + return s[:0] + } + + s = s[:cap(s)] + s = append(s, make([]string, n-len(s))...) + return s[:0] +} + +func decodeSliceValue(d *Decoder, v reflect.Value) error { + n, err := d.DecodeArrayLen() + if err != nil { + return err + } + + if n == -1 { + v.Set(reflect.Zero(v.Type())) + return nil + } + if n == 0 && v.IsNil() { + v.Set(reflect.MakeSlice(v.Type(), 0, 0)) + return nil + } + + if v.Cap() >= n { + v.Set(v.Slice(0, n)) + } else if v.Len() < v.Cap() { + v.Set(v.Slice(0, v.Cap())) + } + + noLimit := d.flags&disableAllocLimitFlag != 1 + + if noLimit && n > v.Len() { + v.Set(growSliceValue(v, n, noLimit)) + } + + for i := 0; i < n; i++ { + if !noLimit && i >= v.Len() { + v.Set(growSliceValue(v, n, noLimit)) + } + + elem := v.Index(i) + if err := d.DecodeValue(elem); err != nil { + return err + } + } + + return nil +} + +func growSliceValue(v reflect.Value, n int, noLimit bool) reflect.Value { + diff := n - v.Len() + if !noLimit && diff > sliceAllocLimit { + diff = sliceAllocLimit + } + v = reflect.AppendSlice(v, reflect.MakeSlice(v.Type(), diff, diff)) + return v +} + +func decodeArrayValue(d *Decoder, v reflect.Value) error { + n, err := d.DecodeArrayLen() + if err != nil { + return err + } + + if n == -1 { + return nil + } + if n > v.Len() { + return fmt.Errorf("%s len is %d, but msgpack has %d elements", v.Type(), v.Len(), n) + } + + for i := 0; i < n; i++ { + sv := v.Index(i) + if err := d.DecodeValue(sv); err != nil { + return err + } + } + + return nil +} + +func (d *Decoder) DecodeSlice() ([]interface{}, error) { + c, err := d.readCode() + if err != nil { + return nil, err + } + return d.decodeSlice(c) +} + +func (d *Decoder) decodeSlice(c byte) ([]interface{}, error) { + n, err := d.arrayLen(c) + if err != nil { + return nil, err + } + if n == -1 { + return nil, nil + } + + s := make([]interface{}, 0, n) + for i := 0; i < n; i++ { + v, err := d.decodeInterfaceCond() + if err != nil { + return nil, err + } + s = append(s, v) + } + + return s, nil +} + +func (d *Decoder) skipSlice(c byte) error { + n, err := d.arrayLen(c) + if err != nil { + return err + } + + for i := 0; i < n; i++ { + if err := d.Skip(); err != nil { + return err + } + } + + return nil +} diff --git a/lib/msgpack/decode_string.go b/lib/msgpack/decode_string.go new file mode 100644 index 0000000..d0ed366 --- /dev/null +++ b/lib/msgpack/decode_string.go @@ -0,0 +1,191 @@ +package msgpack + +import ( + "fmt" + "reflect" + +) + +func (d *Decoder) bytesLen(c byte) (int, error) { + if c == Nil { + return -1, nil + } + + if IsFixedString(c) { + return int(c & FixedStrMask), nil + } + + switch c { + case Str8, Bin8: + n, err := d.uint8() + return int(n), err + case Str16, Bin16: + n, err := d.uint16() + return int(n), err + case Str32, Bin32: + n, err := d.uint32() + return int(n), err + } + + return 0, fmt.Errorf("msgpack: invalid code=%x decoding string/bytes length", c) +} + +func (d *Decoder) DecodeString() (string, error) { + if intern := d.flags&useInternedStringsFlag != 0; intern || len(d.dict) > 0 { + return d.decodeInternedString(intern) + } + + c, err := d.readCode() + if err != nil { + return "", err + } + return d.string(c) +} + +func (d *Decoder) string(c byte) (string, error) { + n, err := d.bytesLen(c) + if err != nil { + return "", err + } + return d.stringWithLen(n) +} + +func (d *Decoder) stringWithLen(n int) (string, error) { + if n <= 0 { + return "", nil + } + b, err := d.readN(n) + return string(b), err +} + +func decodeStringValue(d *Decoder, v reflect.Value) error { + s, err := d.DecodeString() + if err != nil { + return err + } + v.SetString(s) + return nil +} + +func (d *Decoder) DecodeBytesLen() (int, error) { + c, err := d.readCode() + if err != nil { + return 0, err + } + return d.bytesLen(c) +} + +func (d *Decoder) DecodeBytes() ([]byte, error) { + c, err := d.readCode() + if err != nil { + return nil, err + } + return d.bytes(c, nil) +} + +func (d *Decoder) bytes(c byte, b []byte) ([]byte, error) { + n, err := d.bytesLen(c) + if err != nil { + return nil, err + } + if n == -1 { + return nil, nil + } + return readN(d.r, b, n) +} + +func (d *Decoder) decodeStringTemp() (string, error) { + if intern := d.flags&useInternedStringsFlag != 0; intern || len(d.dict) > 0 { + return d.decodeInternedString(intern) + } + + c, err := d.readCode() + if err != nil { + return "", err + } + + n, err := d.bytesLen(c) + if err != nil { + return "", err + } + if n == -1 { + return "", nil + } + + b, err := d.readN(n) + if err != nil { + return "", err + } + + return bytesToString(b), nil +} + +func (d *Decoder) decodeBytesPtr(ptr *[]byte) error { + c, err := d.readCode() + if err != nil { + return err + } + return d.bytesPtr(c, ptr) +} + +func (d *Decoder) bytesPtr(c byte, ptr *[]byte) error { + n, err := d.bytesLen(c) + if err != nil { + return err + } + if n == -1 { + *ptr = nil + return nil + } + + *ptr, err = readN(d.r, *ptr, n) + return err +} + +func (d *Decoder) skipBytes(c byte) error { + n, err := d.bytesLen(c) + if err != nil { + return err + } + if n <= 0 { + return nil + } + return d.skipN(n) +} + +func decodeBytesValue(d *Decoder, v reflect.Value) error { + c, err := d.readCode() + if err != nil { + return err + } + + b, err := d.bytes(c, v.Bytes()) + if err != nil { + return err + } + + v.SetBytes(b) + + return nil +} + +func decodeByteArrayValue(d *Decoder, v reflect.Value) error { + c, err := d.readCode() + if err != nil { + return err + } + + n, err := d.bytesLen(c) + if err != nil { + return err + } + if n == -1 { + return nil + } + if n > v.Len() { + return fmt.Errorf("%s len is %d, but msgpack has %d elements", v.Type(), v.Len(), n) + } + + b := v.Slice(0, n).Bytes() + return d.readFull(b) +} diff --git a/lib/msgpack/decode_typgen.go b/lib/msgpack/decode_typgen.go new file mode 100644 index 0000000..0b4c1d0 --- /dev/null +++ b/lib/msgpack/decode_typgen.go @@ -0,0 +1,46 @@ +package msgpack + +import ( + "reflect" + "sync" +) + +var cachedValues struct { + m map[reflect.Type]chan reflect.Value + sync.RWMutex +} + +func cachedValue(t reflect.Type) reflect.Value { + cachedValues.RLock() + ch := cachedValues.m[t] + cachedValues.RUnlock() + if ch != nil { + return <-ch + } + + cachedValues.Lock() + defer cachedValues.Unlock() + if ch = cachedValues.m[t]; ch != nil { + return <-ch + } + + ch = make(chan reflect.Value, 256) + go func() { + for { + ch <- reflect.New(t) + } + }() + if cachedValues.m == nil { + cachedValues.m = make(map[reflect.Type]chan reflect.Value, 8) + } + cachedValues.m[t] = ch + return <-ch +} + +func (d *Decoder) newValue(t reflect.Type) reflect.Value { + if d.flags&usePreallocateValues == 0 { + return reflect.New(t) + } + + return cachedValue(t) +} diff --git a/lib/msgpack/decode_value.go b/lib/msgpack/decode_value.go new file mode 100644 index 0000000..c44a674 --- /dev/null +++ b/lib/msgpack/decode_value.go @@ -0,0 +1,251 @@ +package msgpack + +import ( + "encoding" + "errors" + "fmt" + "reflect" +) + +var ( + interfaceType = reflect.TypeOf((*interface{})(nil)).Elem() + stringType = reflect.TypeOf((*string)(nil)).Elem() + boolType = reflect.TypeOf((*bool)(nil)).Elem() +) + +var valueDecoders []decoderFunc + +//nolint:gochecknoinits +func init() { + valueDecoders = []decoderFunc{ + reflect.Bool: decodeBoolValue, + reflect.Int: decodeInt64Value, + reflect.Int8: decodeInt64Value, + reflect.Int16: decodeInt64Value, + reflect.Int32: decodeInt64Value, + reflect.Int64: decodeInt64Value, + reflect.Uint: decodeUint64Value, + reflect.Uint8: decodeUint64Value, + reflect.Uint16: decodeUint64Value, + reflect.Uint32: decodeUint64Value, + reflect.Uint64: decodeUint64Value, + reflect.Float32: decodeFloat32Value, + reflect.Float64: decodeFloat64Value, + reflect.Complex64: decodeUnsupportedValue, + reflect.Complex128: decodeUnsupportedValue, + reflect.Array: decodeArrayValue, + reflect.Chan: decodeUnsupportedValue, + reflect.Func: decodeUnsupportedValue, + reflect.Interface: decodeInterfaceValue, + reflect.Map: decodeMapValue, + reflect.Ptr: decodeUnsupportedValue, + reflect.Slice: decodeSliceValue, + reflect.String: decodeStringValue, + reflect.Struct: decodeStructValue, + reflect.UnsafePointer: decodeUnsupportedValue, + } +} + +func getDecoder(typ reflect.Type) decoderFunc { + if v, ok := typeDecMap.Load(typ); ok { + return v.(decoderFunc) + } + fn := _getDecoder(typ) + typeDecMap.Store(typ, fn) + return fn +} + +func _getDecoder(typ reflect.Type) decoderFunc { + kind := typ.Kind() + + if kind == reflect.Ptr { + if _, ok := typeDecMap.Load(typ.Elem()); ok { + return ptrValueDecoder(typ) + } + } + + if typ.Implements(customDecoderType) { + return nilAwareDecoder(typ, decodeCustomValue) + } + if typ.Implements(unmarshalerType) { + return nilAwareDecoder(typ, unmarshalValue) + } + if typ.Implements(binaryUnmarshalerType) { + return nilAwareDecoder(typ, unmarshalBinaryValue) + } + if typ.Implements(textUnmarshalerType) { + return nilAwareDecoder(typ, unmarshalTextValue) + } + + // Addressable struct field value. + if kind != reflect.Ptr { + ptr := reflect.PtrTo(typ) + if ptr.Implements(customDecoderType) { + return addrDecoder(nilAwareDecoder(typ, decodeCustomValue)) + } + if ptr.Implements(unmarshalerType) { + return addrDecoder(nilAwareDecoder(typ, unmarshalValue)) + } + if ptr.Implements(binaryUnmarshalerType) { + return addrDecoder(nilAwareDecoder(typ, unmarshalBinaryValue)) + } + if ptr.Implements(textUnmarshalerType) { + return addrDecoder(nilAwareDecoder(typ, unmarshalTextValue)) + } + } + + switch kind { + case reflect.Ptr: + return ptrValueDecoder(typ) + case reflect.Slice: + elem := typ.Elem() + if elem.Kind() == reflect.Uint8 { + return decodeBytesValue + } + if elem == stringType { + return decodeStringSliceValue + } + case reflect.Array: + if typ.Elem().Kind() == reflect.Uint8 { + return decodeByteArrayValue + } + case reflect.Map: + if typ.Key() == stringType { + switch typ.Elem() { + case stringType: + return decodeMapStringStringValue + case interfaceType: + return decodeMapStringInterfaceValue + } + } + } + + return valueDecoders[kind] +} + +func ptrValueDecoder(typ reflect.Type) decoderFunc { + decoder := getDecoder(typ.Elem()) + return func(d *Decoder, v reflect.Value) error { + if d.hasNilCode() { + if !v.IsNil() { + v.Set(d.newValue(typ).Elem()) + } + return d.DecodeNil() + } + if v.IsNil() { + v.Set(d.newValue(typ.Elem())) + } + return decoder(d, v.Elem()) + } +} + +func addrDecoder(fn decoderFunc) decoderFunc { + return func(d *Decoder, v reflect.Value) error { + if !v.CanAddr() { + return fmt.Errorf("msgpack: Decode(nonaddressable %T)", v.Interface()) + } + return fn(d, v.Addr()) + } +} + +func nilAwareDecoder(typ reflect.Type, fn decoderFunc) decoderFunc { + if nilable(typ.Kind()) { + return func(d *Decoder, v reflect.Value) error { + if d.hasNilCode() { + return d.decodeNilValue(v) + } + if v.IsNil() { + v.Set(d.newValue(typ.Elem())) + } + return fn(d, v) + } + } + + return func(d *Decoder, v reflect.Value) error { + if d.hasNilCode() { + return d.decodeNilValue(v) + } + return fn(d, v) + } +} + +func decodeBoolValue(d *Decoder, v reflect.Value) error { + flag, err := d.DecodeBool() + if err != nil { + return err + } + v.SetBool(flag) + return nil +} + +func decodeInterfaceValue(d *Decoder, v reflect.Value) error { + if v.IsNil() { + return d.interfaceValue(v) + } + return d.DecodeValue(v.Elem()) +} + +func (d *Decoder) interfaceValue(v reflect.Value) error { + vv, err := d.decodeInterfaceCond() + if err != nil { + return err + } + + if vv != nil { + if v.Type() == errorType { + if vv, ok := vv.(string); ok { + v.Set(reflect.ValueOf(errors.New(vv))) + return nil + } + } + + v.Set(reflect.ValueOf(vv)) + } + + return nil +} + +func decodeUnsupportedValue(d *Decoder, v reflect.Value) error { + return fmt.Errorf("msgpack: Decode(unsupported %s)", v.Type()) +} + +//------------------------------------------------------------------------------ + +func decodeCustomValue(d *Decoder, v reflect.Value) error { + decoder := v.Interface().(CustomDecoder) + return decoder.DecodeMsgpack(d) +} + +func unmarshalValue(d *Decoder, v reflect.Value) error { + var b []byte + + d.rec = make([]byte, 0, 64) + if err := d.Skip(); err != nil { + return err + } + b = d.rec + d.rec = nil + + unmarshaler := v.Interface().(Unmarshaler) + return unmarshaler.UnmarshalMsgpack(b) +} + +func unmarshalBinaryValue(d *Decoder, v reflect.Value) error { + data, err := d.DecodeBytes() + if err != nil { + return err + } + + unmarshaler := v.Interface().(encoding.BinaryUnmarshaler) + return unmarshaler.UnmarshalBinary(data) +} + +func unmarshalTextValue(d *Decoder, v reflect.Value) error { + data, err := d.DecodeBytes() + if err != nil { + return err + } + + unmarshaler := v.Interface().(encoding.TextUnmarshaler) + return unmarshaler.UnmarshalText(data) +} diff --git a/lib/msgpack/encode.go b/lib/msgpack/encode.go new file mode 100644 index 0000000..1c99f23 --- /dev/null +++ b/lib/msgpack/encode.go @@ -0,0 +1,269 @@ +package msgpack + +import ( + "bytes" + "io" + "reflect" + "sync" + "time" + +) + +const ( + sortMapKeysFlag uint32 = 1 << iota + arrayEncodedStructsFlag + useCompactIntsFlag + useCompactFloatsFlag + useInternedStringsFlag + omitEmptyFlag +) + +type writer interface { + io.Writer + WriteByte(byte) error +} + +type byteWriter struct { + io.Writer +} + +func newByteWriter(w io.Writer) byteWriter { + return byteWriter{ + Writer: w, + } +} + +func (bw byteWriter) WriteByte(c byte) error { + _, err := bw.Write([]byte{c}) + return err +} + +//------------------------------------------------------------------------------ + +var encPool = sync.Pool{ + New: func() interface{} { + return NewEncoder(nil) + }, +} + +func GetEncoder() *Encoder { + return encPool.Get().(*Encoder) +} + +func PutEncoder(enc *Encoder) { + enc.w = nil + encPool.Put(enc) +} + +// Marshal returns the MessagePack encoding of v. +func Marshal(v interface{}) ([]byte, error) { + enc := GetEncoder() + + var buf bytes.Buffer + enc.Reset(&buf) + + err := enc.Encode(v) + b := buf.Bytes() + + PutEncoder(enc) + + if err != nil { + return nil, err + } + return b, err +} + +type Encoder struct { + w writer + dict map[string]int + structTag string + buf []byte + timeBuf []byte + flags uint32 +} + +// NewEncoder returns a new encoder that writes to w. +func NewEncoder(w io.Writer) *Encoder { + e := &Encoder{ + buf: make([]byte, 9), + } + e.Reset(w) + return e +} + +// Writer returns the Encoder's writer. +func (e *Encoder) Writer() io.Writer { + return e.w +} + +// Reset discards any buffered data, resets all state, and switches the writer to write to w. +func (e *Encoder) Reset(w io.Writer) { + e.ResetDict(w, nil) +} + +// ResetDict is like Reset, but also resets the dict. +func (e *Encoder) ResetDict(w io.Writer, dict map[string]int) { + e.ResetWriter(w) + e.flags = 0 + e.structTag = "" + e.dict = dict +} + +func (e *Encoder) WithDict(dict map[string]int, fn func(*Encoder) error) error { + oldDict := e.dict + e.dict = dict + err := fn(e) + e.dict = oldDict + return err +} + +func (e *Encoder) ResetWriter(w io.Writer) { + e.dict = nil + if bw, ok := w.(writer); ok { + e.w = bw + } else if w == nil { + e.w = nil + } else { + e.w = newByteWriter(w) + } +} + +// SetSortMapKeys causes the Encoder to encode map keys in increasing order. +// Supported map types are: +// - map[string]string +// - map[string]bool +// - map[string]interface{} +func (e *Encoder) SetSortMapKeys(on bool) *Encoder { + if on { + e.flags |= sortMapKeysFlag + } else { + e.flags &= ^sortMapKeysFlag + } + return e +} + +// SetCustomStructTag causes the Encoder to use a custom struct tag as +// fallback option if there is no msgpack tag. +func (e *Encoder) SetCustomStructTag(tag string) { + e.structTag = tag +} + +// SetOmitEmpty causes the Encoder to omit empty values by default. +func (e *Encoder) SetOmitEmpty(on bool) { + if on { + e.flags |= omitEmptyFlag + } else { + e.flags &= ^omitEmptyFlag + } +} + +// UseArrayEncodedStructs causes the Encoder to encode Go structs as msgpack arrays. +func (e *Encoder) UseArrayEncodedStructs(on bool) { + if on { + e.flags |= arrayEncodedStructsFlag + } else { + e.flags &= ^arrayEncodedStructsFlag + } +} + +// UseCompactEncoding causes the Encoder to chose the most compact encoding. +// For example, it allows to encode small Go int64 as msgpack int8 saving 7 bytes. +func (e *Encoder) UseCompactInts(on bool) { + if on { + e.flags |= useCompactIntsFlag + } else { + e.flags &= ^useCompactIntsFlag + } +} + +// UseCompactFloats causes the Encoder to chose a compact integer encoding +// for floats that can be represented as integers. +func (e *Encoder) UseCompactFloats(on bool) { + if on { + e.flags |= useCompactFloatsFlag + } else { + e.flags &= ^useCompactFloatsFlag + } +} + +// UseInternedStrings causes the Encoder to intern strings. +func (e *Encoder) UseInternedStrings(on bool) { + if on { + e.flags |= useInternedStringsFlag + } else { + e.flags &= ^useInternedStringsFlag + } +} + +func (e *Encoder) Encode(v interface{}) error { + switch v := v.(type) { + case nil: + return e.EncodeNil() + case string: + return e.EncodeString(v) + case []byte: + return e.EncodeBytes(v) + case int: + return e.EncodeInt(int64(v)) + case int64: + return e.encodeInt64Cond(v) + case uint: + return e.EncodeUint(uint64(v)) + case uint64: + return e.encodeUint64Cond(v) + case bool: + return e.EncodeBool(v) + case float32: + return e.EncodeFloat32(v) + case float64: + return e.EncodeFloat64(v) + case time.Duration: + return e.encodeInt64Cond(int64(v)) + case time.Time: + return e.EncodeTime(v) + } + return e.EncodeValue(reflect.ValueOf(v)) +} + +func (e *Encoder) EncodeMulti(v ...interface{}) error { + for _, vv := range v { + if err := e.Encode(vv); err != nil { + return err + } + } + return nil +} + +func (e *Encoder) EncodeValue(v reflect.Value) error { + fn := getEncoder(v.Type()) + return fn(e, v) +} + +func (e *Encoder) EncodeNil() error { + return e.writeCode(Nil) +} + +func (e *Encoder) EncodeBool(value bool) error { + if value { + return e.writeCode(True) + } + return e.writeCode(False) +} + +func (e *Encoder) EncodeDuration(d time.Duration) error { + return e.EncodeInt(int64(d)) +} + +func (e *Encoder) writeCode(c byte) error { + return e.w.WriteByte(c) +} + +func (e *Encoder) write(b []byte) error { + _, err := e.w.Write(b) + return err +} + +func (e *Encoder) writeString(s string) error { + _, err := e.w.Write(stringToBytes(s)) + return err +} diff --git a/lib/msgpack/encode_map.go b/lib/msgpack/encode_map.go new file mode 100644 index 0000000..1dcb650 --- /dev/null +++ b/lib/msgpack/encode_map.go @@ -0,0 +1,224 @@ +package msgpack + +import ( + "math" + "reflect" + "sort" + +) + +func encodeMapValue(e *Encoder, v reflect.Value) error { + if v.IsNil() { + return e.EncodeNil() + } + + if err := e.EncodeMapLen(v.Len()); err != nil { + return err + } + + iter := v.MapRange() + for iter.Next() { + if err := e.EncodeValue(iter.Key()); err != nil { + return err + } + if err := e.EncodeValue(iter.Value()); err != nil { + return err + } + } + + return nil +} + +func encodeMapStringBoolValue(e *Encoder, v reflect.Value) error { + if v.IsNil() { + return e.EncodeNil() + } + + if err := e.EncodeMapLen(v.Len()); err != nil { + return err + } + + m := v.Convert(mapStringBoolType).Interface().(map[string]bool) + if e.flags&sortMapKeysFlag != 0 { + return e.encodeSortedMapStringBool(m) + } + + for mk, mv := range m { + if err := e.EncodeString(mk); err != nil { + return err + } + if err := e.EncodeBool(mv); err != nil { + return err + } + } + + return nil +} + +func encodeMapStringStringValue(e *Encoder, v reflect.Value) error { + if v.IsNil() { + return e.EncodeNil() + } + + if err := e.EncodeMapLen(v.Len()); err != nil { + return err + } + + m := v.Convert(mapStringStringType).Interface().(map[string]string) + if e.flags&sortMapKeysFlag != 0 { + return e.encodeSortedMapStringString(m) + } + + for mk, mv := range m { + if err := e.EncodeString(mk); err != nil { + return err + } + if err := e.EncodeString(mv); err != nil { + return err + } + } + + return nil +} + +func encodeMapStringInterfaceValue(e *Encoder, v reflect.Value) error { + if v.IsNil() { + return e.EncodeNil() + } + m := v.Convert(mapStringInterfaceType).Interface().(map[string]interface{}) + if e.flags&sortMapKeysFlag != 0 { + return e.EncodeMapSorted(m) + } + return e.EncodeMap(m) +} + +func (e *Encoder) EncodeMap(m map[string]interface{}) error { + if m == nil { + return e.EncodeNil() + } + if err := e.EncodeMapLen(len(m)); err != nil { + return err + } + for mk, mv := range m { + if err := e.EncodeString(mk); err != nil { + return err + } + if err := e.Encode(mv); err != nil { + return err + } + } + return nil +} + +func (e *Encoder) EncodeMapSorted(m map[string]interface{}) error { + if m == nil { + return e.EncodeNil() + } + if err := e.EncodeMapLen(len(m)); err != nil { + return err + } + + keys := make([]string, 0, len(m)) + + for k := range m { + keys = append(keys, k) + } + + sort.Strings(keys) + + for _, k := range keys { + if err := e.EncodeString(k); err != nil { + return err + } + if err := e.Encode(m[k]); err != nil { + return err + } + } + + return nil +} + +func (e *Encoder) encodeSortedMapStringBool(m map[string]bool) error { + keys := make([]string, 0, len(m)) + for k := range m { + keys = append(keys, k) + } + sort.Strings(keys) + + for _, k := range keys { + err := e.EncodeString(k) + if err != nil { + return err + } + if err = e.EncodeBool(m[k]); err != nil { + return err + } + } + + return nil +} + +func (e *Encoder) encodeSortedMapStringString(m map[string]string) error { + keys := make([]string, 0, len(m)) + for k := range m { + keys = append(keys, k) + } + sort.Strings(keys) + + for _, k := range keys { + err := e.EncodeString(k) + if err != nil { + return err + } + if err = e.EncodeString(m[k]); err != nil { + return err + } + } + + return nil +} + +func (e *Encoder) EncodeMapLen(l int) error { + if l < 16 { + return e.writeCode(FixedMapLow | byte(l)) + } + if l <= math.MaxUint16 { + return e.write2(Map16, uint16(l)) + } + return e.write4(Map32, uint32(l)) +} + +func encodeStructValue(e *Encoder, strct reflect.Value) error { + structFields := structs.Fields(strct.Type(), e.structTag) + if e.flags&arrayEncodedStructsFlag != 0 || structFields.AsArray { + return encodeStructValueAsArray(e, strct, structFields.List) + } + fields := structFields.OmitEmpty(e, strct) + + if err := e.EncodeMapLen(len(fields)); err != nil { + return err + } + + for _, f := range fields { + if err := e.EncodeString(f.name); err != nil { + return err + } + if err := f.EncodeValue(e, strct); err != nil { + return err + } + } + + return nil +} + +func encodeStructValueAsArray(e *Encoder, strct reflect.Value, fields []*field) error { + if err := e.EncodeArrayLen(len(fields)); err != nil { + return err + } + for _, f := range fields { + if err := f.EncodeValue(e, strct); err != nil { + return err + } + } + return nil +} diff --git a/lib/msgpack/encode_number.go b/lib/msgpack/encode_number.go new file mode 100644 index 0000000..7914e80 --- /dev/null +++ b/lib/msgpack/encode_number.go @@ -0,0 +1,251 @@ +package msgpack + +import ( + "math" + "reflect" + +) + +// EncodeUint8 encodes an uint8 in 2 bytes preserving type of the number. +func (e *Encoder) EncodeUint8(n uint8) error { + return e.write1(Uint8, n) +} + +func (e *Encoder) encodeUint8Cond(n uint8) error { + if e.flags&useCompactIntsFlag != 0 { + return e.EncodeUint(uint64(n)) + } + return e.EncodeUint8(n) +} + +// EncodeUint16 encodes an uint16 in 3 bytes preserving type of the number. +func (e *Encoder) EncodeUint16(n uint16) error { + return e.write2(Uint16, n) +} + +func (e *Encoder) encodeUint16Cond(n uint16) error { + if e.flags&useCompactIntsFlag != 0 { + return e.EncodeUint(uint64(n)) + } + return e.EncodeUint16(n) +} + +// EncodeUint32 encodes an uint16 in 5 bytes preserving type of the number. +func (e *Encoder) EncodeUint32(n uint32) error { + return e.write4(Uint32, n) +} + +func (e *Encoder) encodeUint32Cond(n uint32) error { + if e.flags&useCompactIntsFlag != 0 { + return e.EncodeUint(uint64(n)) + } + return e.EncodeUint32(n) +} + +// EncodeUint64 encodes an uint16 in 9 bytes preserving type of the number. +func (e *Encoder) EncodeUint64(n uint64) error { + return e.write8(Uint64, n) +} + +func (e *Encoder) encodeUint64Cond(n uint64) error { + if e.flags&useCompactIntsFlag != 0 { + return e.EncodeUint(n) + } + return e.EncodeUint64(n) +} + +// EncodeInt8 encodes an int8 in 2 bytes preserving type of the number. +func (e *Encoder) EncodeInt8(n int8) error { + return e.write1(Int8, uint8(n)) +} + +func (e *Encoder) encodeInt8Cond(n int8) error { + if e.flags&useCompactIntsFlag != 0 { + return e.EncodeInt(int64(n)) + } + return e.EncodeInt8(n) +} + +// EncodeInt16 encodes an int16 in 3 bytes preserving type of the number. +func (e *Encoder) EncodeInt16(n int16) error { + return e.write2(Int16, uint16(n)) +} + +func (e *Encoder) encodeInt16Cond(n int16) error { + if e.flags&useCompactIntsFlag != 0 { + return e.EncodeInt(int64(n)) + } + return e.EncodeInt16(n) +} + +// EncodeInt32 encodes an int32 in 5 bytes preserving type of the number. +func (e *Encoder) EncodeInt32(n int32) error { + return e.write4(Int32, uint32(n)) +} + +func (e *Encoder) encodeInt32Cond(n int32) error { + if e.flags&useCompactIntsFlag != 0 { + return e.EncodeInt(int64(n)) + } + return e.EncodeInt32(n) +} + +// EncodeInt64 encodes an int64 in 9 bytes preserving type of the number. +func (e *Encoder) EncodeInt64(n int64) error { + return e.write8(Int64, uint64(n)) +} + +func (e *Encoder) encodeInt64Cond(n int64) error { + if e.flags&useCompactIntsFlag != 0 { + return e.EncodeInt(n) + } + return e.EncodeInt64(n) +} + +// EncodeUnsignedNumber encodes an uint64 in 1, 2, 3, 5, or 9 bytes. +// Type of the number is lost during encoding. +func (e *Encoder) EncodeUint(n uint64) error { + if n <= math.MaxInt8 { + return e.w.WriteByte(byte(n)) + } + if n <= math.MaxUint8 { + return e.EncodeUint8(uint8(n)) + } + if n <= math.MaxUint16 { + return e.EncodeUint16(uint16(n)) + } + if n <= math.MaxUint32 { + return e.EncodeUint32(uint32(n)) + } + return e.EncodeUint64(n) +} + +// EncodeNumber encodes an int64 in 1, 2, 3, 5, or 9 bytes. +// Type of the number is lost during encoding. +func (e *Encoder) EncodeInt(n int64) error { + if n >= 0 { + return e.EncodeUint(uint64(n)) + } + if n >= int64(int8(NegFixedNumLow)) { + return e.w.WriteByte(byte(n)) + } + if n >= math.MinInt8 { + return e.EncodeInt8(int8(n)) + } + if n >= math.MinInt16 { + return e.EncodeInt16(int16(n)) + } + if n >= math.MinInt32 { + return e.EncodeInt32(int32(n)) + } + return e.EncodeInt64(n) +} + +func (e *Encoder) EncodeFloat32(n float32) error { + if e.flags&useCompactFloatsFlag != 0 { + if float32(int64(n)) == n { + return e.EncodeInt(int64(n)) + } + } + return e.write4(Float, math.Float32bits(n)) +} + +func (e *Encoder) EncodeFloat64(n float64) error { + if e.flags&useCompactFloatsFlag != 0 { + // Both NaN and Inf convert to int64(-0x8000000000000000) + // If n is NaN then it never compares true with any other value + // If n is Inf then it doesn't convert from int64 back to +/-Inf + // In both cases the comparison works. + if float64(int64(n)) == n { + return e.EncodeInt(int64(n)) + } + } + return e.write8(Double, math.Float64bits(n)) +} + +func (e *Encoder) write1(code byte, n uint8) error { + e.buf = e.buf[:2] + e.buf[0] = code + e.buf[1] = n + return e.write(e.buf) +} + +func (e *Encoder) write2(code byte, n uint16) error { + e.buf = e.buf[:3] + e.buf[0] = code + e.buf[1] = byte(n >> 8) + e.buf[2] = byte(n) + return e.write(e.buf) +} + +func (e *Encoder) write4(code byte, n uint32) error { + e.buf = e.buf[:5] + e.buf[0] = code + e.buf[1] = byte(n >> 24) + e.buf[2] = byte(n >> 16) + e.buf[3] = byte(n >> 8) + e.buf[4] = byte(n) + return e.write(e.buf) +} + +func (e *Encoder) write8(code byte, n uint64) error { + e.buf = e.buf[:9] + e.buf[0] = code + e.buf[1] = byte(n >> 56) + e.buf[2] = byte(n >> 48) + e.buf[3] = byte(n >> 40) + e.buf[4] = byte(n >> 32) + e.buf[5] = byte(n >> 24) + e.buf[6] = byte(n >> 16) + e.buf[7] = byte(n >> 8) + e.buf[8] = byte(n) + return e.write(e.buf) +} + +func encodeUintValue(e *Encoder, v reflect.Value) error { + return e.EncodeUint(v.Uint()) +} + +func encodeIntValue(e *Encoder, v reflect.Value) error { + return e.EncodeInt(v.Int()) +} + +func encodeUint8CondValue(e *Encoder, v reflect.Value) error { + return e.encodeUint8Cond(uint8(v.Uint())) +} + +func encodeUint16CondValue(e *Encoder, v reflect.Value) error { + return e.encodeUint16Cond(uint16(v.Uint())) +} + +func encodeUint32CondValue(e *Encoder, v reflect.Value) error { + return e.encodeUint32Cond(uint32(v.Uint())) +} + +func encodeUint64CondValue(e *Encoder, v reflect.Value) error { + return e.encodeUint64Cond(v.Uint()) +} + +func encodeInt8CondValue(e *Encoder, v reflect.Value) error { + return e.encodeInt8Cond(int8(v.Int())) +} + +func encodeInt16CondValue(e *Encoder, v reflect.Value) error { + return e.encodeInt16Cond(int16(v.Int())) +} + +func encodeInt32CondValue(e *Encoder, v reflect.Value) error { + return e.encodeInt32Cond(int32(v.Int())) +} + +func encodeInt64CondValue(e *Encoder, v reflect.Value) error { + return e.encodeInt64Cond(v.Int()) +} + +func encodeFloat32Value(e *Encoder, v reflect.Value) error { + return e.EncodeFloat32(float32(v.Float())) +} + +func encodeFloat64Value(e *Encoder, v reflect.Value) error { + return e.EncodeFloat64(v.Float()) +} diff --git a/lib/msgpack/encode_slice.go b/lib/msgpack/encode_slice.go new file mode 100644 index 0000000..98fbb51 --- /dev/null +++ b/lib/msgpack/encode_slice.go @@ -0,0 +1,138 @@ +package msgpack + +import ( + "math" + "reflect" + +) + +var stringSliceType = reflect.TypeOf(([]string)(nil)) + +func encodeStringValue(e *Encoder, v reflect.Value) error { + return e.EncodeString(v.String()) +} + +func encodeByteSliceValue(e *Encoder, v reflect.Value) error { + return e.EncodeBytes(v.Bytes()) +} + +func encodeByteArrayValue(e *Encoder, v reflect.Value) error { + if err := e.EncodeBytesLen(v.Len()); err != nil { + return err + } + + if v.CanAddr() { + b := v.Slice(0, v.Len()).Bytes() + return e.write(b) + } + + e.buf = grow(e.buf, v.Len()) + reflect.Copy(reflect.ValueOf(e.buf), v) + return e.write(e.buf) +} + +func grow(b []byte, n int) []byte { + if cap(b) >= n { + return b[:n] + } + b = b[:cap(b)] + b = append(b, make([]byte, n-len(b))...) + return b +} + +func (e *Encoder) EncodeBytesLen(l int) error { + if l < 256 { + return e.write1(Bin8, uint8(l)) + } + if l <= math.MaxUint16 { + return e.write2(Bin16, uint16(l)) + } + return e.write4(Bin32, uint32(l)) +} + +func (e *Encoder) encodeStringLen(l int) error { + if l < 32 { + return e.writeCode(FixedStrLow | byte(l)) + } + if l < 256 { + return e.write1(Str8, uint8(l)) + } + if l <= math.MaxUint16 { + return e.write2(Str16, uint16(l)) + } + return e.write4(Str32, uint32(l)) +} + +func (e *Encoder) EncodeString(v string) error { + if intern := e.flags&useInternedStringsFlag != 0; intern || len(e.dict) > 0 { + return e.encodeInternedString(v, intern) + } + return e.encodeNormalString(v) +} + +func (e *Encoder) encodeNormalString(v string) error { + if err := e.encodeStringLen(len(v)); err != nil { + return err + } + return e.writeString(v) +} + +func (e *Encoder) EncodeBytes(v []byte) error { + if v == nil { + return e.EncodeNil() + } + if err := e.EncodeBytesLen(len(v)); err != nil { + return err + } + return e.write(v) +} + +func (e *Encoder) EncodeArrayLen(l int) error { + if l < 16 { + return e.writeCode(FixedArrayLow | byte(l)) + } + if l <= math.MaxUint16 { + return e.write2(Array16, uint16(l)) + } + return e.write4(Array32, uint32(l)) +} + +func encodeStringSliceValue(e *Encoder, v reflect.Value) error { + ss := v.Convert(stringSliceType).Interface().([]string) + return e.encodeStringSlice(ss) +} + +func (e *Encoder) encodeStringSlice(s []string) error { + if s == nil { + return e.EncodeNil() + } + if err := e.EncodeArrayLen(len(s)); err != nil { + return err + } + for _, v := range s { + if err := e.EncodeString(v); err != nil { + return err + } + } + return nil +} + +func encodeSliceValue(e *Encoder, v reflect.Value) error { + if v.IsNil() { + return e.EncodeNil() + } + return encodeArrayValue(e, v) +} + +func encodeArrayValue(e *Encoder, v reflect.Value) error { + l := v.Len() + if err := e.EncodeArrayLen(l); err != nil { + return err + } + for i := 0; i < l; i++ { + if err := e.EncodeValue(v.Index(i)); err != nil { + return err + } + } + return nil +} diff --git a/lib/msgpack/encode_value.go b/lib/msgpack/encode_value.go new file mode 100644 index 0000000..1d6303a --- /dev/null +++ b/lib/msgpack/encode_value.go @@ -0,0 +1,254 @@ +package msgpack + +import ( + "encoding" + "fmt" + "reflect" +) + +var valueEncoders []encoderFunc + +//nolint:gochecknoinits +func init() { + valueEncoders = []encoderFunc{ + reflect.Bool: encodeBoolValue, + reflect.Int: encodeIntValue, + reflect.Int8: encodeInt8CondValue, + reflect.Int16: encodeInt16CondValue, + reflect.Int32: encodeInt32CondValue, + reflect.Int64: encodeInt64CondValue, + reflect.Uint: encodeUintValue, + reflect.Uint8: encodeUint8CondValue, + reflect.Uint16: encodeUint16CondValue, + reflect.Uint32: encodeUint32CondValue, + reflect.Uint64: encodeUint64CondValue, + reflect.Float32: encodeFloat32Value, + reflect.Float64: encodeFloat64Value, + reflect.Complex64: encodeUnsupportedValue, + reflect.Complex128: encodeUnsupportedValue, + reflect.Array: encodeArrayValue, + reflect.Chan: encodeUnsupportedValue, + reflect.Func: encodeUnsupportedValue, + reflect.Interface: encodeInterfaceValue, + reflect.Map: encodeMapValue, + reflect.Ptr: encodeUnsupportedValue, + reflect.Slice: encodeSliceValue, + reflect.String: encodeStringValue, + reflect.Struct: encodeStructValue, + reflect.UnsafePointer: encodeUnsupportedValue, + } +} + +func getEncoder(typ reflect.Type) encoderFunc { + if v, ok := typeEncMap.Load(typ); ok { + return v.(encoderFunc) + } + fn := _getEncoder(typ) + typeEncMap.Store(typ, fn) + return fn +} + +func _getEncoder(typ reflect.Type) encoderFunc { + kind := typ.Kind() + + if kind == reflect.Ptr { + if _, ok := typeEncMap.Load(typ.Elem()); ok { + return ptrEncoderFunc(typ) + } + } + + if typ.Implements(customEncoderType) { + return encodeCustomValue + } + if typ.Implements(marshalerType) { + return marshalValue + } + if typ.Implements(binaryMarshalerType) { + return marshalBinaryValue + } + if typ.Implements(textMarshalerType) { + return marshalTextValue + } + + // Addressable struct field value. + if kind != reflect.Ptr { + ptr := reflect.PtrTo(typ) + if ptr.Implements(customEncoderType) { + return encodeCustomValuePtr + } + if ptr.Implements(marshalerType) { + return marshalValuePtr + } + if ptr.Implements(binaryMarshalerType) { + return marshalBinaryValueAddr + } + if ptr.Implements(textMarshalerType) { + return marshalTextValueAddr + } + } + + if typ == errorType { + return encodeErrorValue + } + + switch kind { + case reflect.Ptr: + return ptrEncoderFunc(typ) + case reflect.Slice: + elem := typ.Elem() + if elem.Kind() == reflect.Uint8 { + return encodeByteSliceValue + } + if elem == stringType { + return encodeStringSliceValue + } + case reflect.Array: + if typ.Elem().Kind() == reflect.Uint8 { + return encodeByteArrayValue + } + case reflect.Map: + if typ.Key() == stringType { + switch typ.Elem() { + case stringType: + return encodeMapStringStringValue + case boolType: + return encodeMapStringBoolValue + case interfaceType: + return encodeMapStringInterfaceValue + } + } + } + + return valueEncoders[kind] +} + +func ptrEncoderFunc(typ reflect.Type) encoderFunc { + encoder := getEncoder(typ.Elem()) + return func(e *Encoder, v reflect.Value) error { + if v.IsNil() { + return e.EncodeNil() + } + return encoder(e, v.Elem()) + } +} + +func encodeCustomValuePtr(e *Encoder, v reflect.Value) error { + if !v.CanAddr() { + return fmt.Errorf("msgpack: Encode(non-addressable %T)", v.Interface()) + } + encoder := v.Addr().Interface().(CustomEncoder) + return encoder.EncodeMsgpack(e) +} + +func encodeCustomValue(e *Encoder, v reflect.Value) error { + if nilable(v.Kind()) && v.IsNil() { + return e.EncodeNil() + } + + encoder := v.Interface().(CustomEncoder) + return encoder.EncodeMsgpack(e) +} + +func marshalValuePtr(e *Encoder, v reflect.Value) error { + if !v.CanAddr() { + return fmt.Errorf("msgpack: Encode(non-addressable %T)", v.Interface()) + } + return marshalValue(e, v.Addr()) +} + +func marshalValue(e *Encoder, v reflect.Value) error { + if nilable(v.Kind()) && v.IsNil() { + return e.EncodeNil() + } + + marshaler := v.Interface().(Marshaler) + b, err := marshaler.MarshalMsgpack() + if err != nil { + return err + } + _, err = e.w.Write(b) + return err +} + +func encodeBoolValue(e *Encoder, v reflect.Value) error { + return e.EncodeBool(v.Bool()) +} + +func encodeInterfaceValue(e *Encoder, v reflect.Value) error { + if v.IsNil() { + return e.EncodeNil() + } + return e.EncodeValue(v.Elem()) +} + +func encodeErrorValue(e *Encoder, v reflect.Value) error { + if v.IsNil() { + return e.EncodeNil() + } + return e.EncodeString(v.Interface().(error).Error()) +} + +func encodeUnsupportedValue(e *Encoder, v reflect.Value) error { + return fmt.Errorf("msgpack: Encode(unsupported %s)", v.Type()) +} + +func nilable(kind reflect.Kind) bool { + switch kind { + case reflect.Chan, reflect.Func, reflect.Interface, reflect.Map, reflect.Ptr, reflect.Slice: + return true + } + return false +} + +func nilableType(t reflect.Type) bool { + if t.Kind() == reflect.Ptr { + t = t.Elem() + } + return nilable(t.Kind()) +} + +//------------------------------------------------------------------------------ + +func marshalBinaryValueAddr(e *Encoder, v reflect.Value) error { + if !v.CanAddr() { + return fmt.Errorf("msgpack: Encode(non-addressable %T)", v.Interface()) + } + return marshalBinaryValue(e, v.Addr()) +} + +func marshalBinaryValue(e *Encoder, v reflect.Value) error { + if nilable(v.Kind()) && v.IsNil() { + return e.EncodeNil() + } + + marshaler := v.Interface().(encoding.BinaryMarshaler) + data, err := marshaler.MarshalBinary() + if err != nil { + return err + } + + return e.EncodeBytes(data) +} + +//------------------------------------------------------------------------------ + +func marshalTextValueAddr(e *Encoder, v reflect.Value) error { + if !v.CanAddr() { + return fmt.Errorf("msgpack: Encode(non-addressable %T)", v.Interface()) + } + return marshalTextValue(e, v.Addr()) +} + +func marshalTextValue(e *Encoder, v reflect.Value) error { + if nilable(v.Kind()) && v.IsNil() { + return e.EncodeNil() + } + + marshaler := v.Interface().(encoding.TextMarshaler) + data, err := marshaler.MarshalText() + if err != nil { + return err + } + + return e.EncodeBytes(data) +} diff --git a/lib/msgpack/ext.go b/lib/msgpack/ext.go new file mode 100644 index 0000000..42e0fb5 --- /dev/null +++ b/lib/msgpack/ext.go @@ -0,0 +1,343 @@ +package msgpack + +import ( + "bytes" + "fmt" + "math" + "reflect" +) + +type extInfo struct { + Type reflect.Type + Decoder func(d *Decoder, v reflect.Value, extLen int) error +} + +var extTypes = make(map[int8]*extInfo) + +type MarshalerUnmarshaler interface { + Marshaler + Unmarshaler +} + +func RegisterExt(extID int8, value interface{}) { + typ := reflect.TypeOf(value) + + marshalerType := reflect.TypeOf((*Marshaler)(nil)).Elem() + unmarshalerType := reflect.TypeOf((*Unmarshaler)(nil)).Elem() + + // Encoder: use MarshalMsgpack if available, otherwise use default struct encoding. + implMarshal := typ.Implements(marshalerType) || + (typ.Kind() == reflect.Ptr && typ.Elem().Implements(marshalerType)) + if implMarshal { + RegisterExtEncoder(extID, value, func(e *Encoder, v reflect.Value) ([]byte, error) { + return v.Interface().(Marshaler).MarshalMsgpack() + }) + } else { + RegisterExtEncoder(extID, value, func(e *Encoder, v reflect.Value) ([]byte, error) { + var buf bytes.Buffer + enc := NewEncoder(&buf) + enc.UseArrayEncodedStructs(true) + if err := enc.Encode(v.Interface()); err != nil { + return nil, err + } + return buf.Bytes(), nil + }) + } + + // Decoder: use UnmarshalMsgpack if available, otherwise use default struct decoding. + implUnmarshal := typ.Implements(unmarshalerType) || + (typ.Kind() == reflect.Ptr && typ.Elem().Implements(unmarshalerType)) + if implUnmarshal { + RegisterExtDecoder(extID, value, func(d *Decoder, v reflect.Value, extLen int) error { + b, err := d.readN(extLen) + if err != nil { + return err + } + return v.Interface().(Unmarshaler).UnmarshalMsgpack(b) + }) + } else { + structDecoder := _getDecoder(typ) + if typ.Kind() == reflect.Ptr { + structDecoder = _getDecoder(typ.Elem()) + } + RegisterExtDecoder(extID, value, func(d *Decoder, v reflect.Value, extLen int) error { + b, err := d.readN(extLen) + if err != nil { + return err + } + dec := NewDecoder(bytes.NewReader(b)) + if v.Kind() == reflect.Ptr { + return structDecoder(dec, v.Elem()) + } + return structDecoder(dec, v) + }) + } +} + +func UnregisterExt(extID int8) { + unregisterExtEncoder(extID) + unregisterExtDecoder(extID) +} + +func RegisterExtEncoder( + extID int8, + value interface{}, + encoder func(enc *Encoder, v reflect.Value) ([]byte, error), +) { + unregisterExtEncoder(extID) + + typ := reflect.TypeOf(value) + extEncoder := makeExtEncoder(extID, typ, encoder) + typeEncMap.Store(extID, typ) + typeEncMap.Store(typ, extEncoder) + if typ.Kind() == reflect.Ptr { + typeEncMap.Store(typ.Elem(), makeExtEncoderAddr(extEncoder)) + } +} + +func unregisterExtEncoder(extID int8) { + t, ok := typeEncMap.Load(extID) + if !ok { + return + } + typeEncMap.Delete(extID) + typ := t.(reflect.Type) + typeEncMap.Delete(typ) + if typ.Kind() == reflect.Ptr { + typeEncMap.Delete(typ.Elem()) + } +} + +func makeExtEncoder( + extID int8, + typ reflect.Type, + encoder func(enc *Encoder, v reflect.Value) ([]byte, error), +) encoderFunc { + nilable := typ.Kind() == reflect.Ptr + + return func(e *Encoder, v reflect.Value) error { + if nilable && v.IsNil() { + return e.EncodeNil() + } + + b, err := encoder(e, v) + if err != nil { + return err + } + + if err := e.EncodeExtHeader(extID, len(b)); err != nil { + return err + } + + return e.write(b) + } +} + +func makeExtEncoderAddr(extEncoder encoderFunc) encoderFunc { + return func(e *Encoder, v reflect.Value) error { + if !v.CanAddr() { + return fmt.Errorf("msgpack: EncodeExt(nonaddressable %T)", v.Interface()) + } + return extEncoder(e, v.Addr()) + } +} + +func RegisterExtDecoder( + extID int8, + value interface{}, + decoder func(dec *Decoder, v reflect.Value, extLen int) error, +) { + unregisterExtDecoder(extID) + + typ := reflect.TypeOf(value) + extDecoder := makeExtDecoder(extID, typ, decoder) + extTypes[extID] = &extInfo{ + Type: typ, + Decoder: decoder, + } + + typeDecMap.Store(extID, typ) + typeDecMap.Store(typ, extDecoder) + if typ.Kind() == reflect.Ptr { + typeDecMap.Store(typ.Elem(), makeExtDecoderAddr(extDecoder)) + } +} + +func unregisterExtDecoder(extID int8) { + t, ok := typeDecMap.Load(extID) + if !ok { + return + } + typeDecMap.Delete(extID) + delete(extTypes, extID) + typ := t.(reflect.Type) + typeDecMap.Delete(typ) + if typ.Kind() == reflect.Ptr { + typeDecMap.Delete(typ.Elem()) + } +} + +func makeExtDecoder( + wantedExtID int8, + typ reflect.Type, + decoder func(d *Decoder, v reflect.Value, extLen int) error, +) decoderFunc { + return nilAwareDecoder(typ, func(d *Decoder, v reflect.Value) error { + extID, extLen, err := d.DecodeExtHeader() + if err != nil { + return err + } + if extID != wantedExtID { + return fmt.Errorf("msgpack: got ext type=%d, wanted %d", extID, wantedExtID) + } + return decoder(d, v, extLen) + }) +} + +func makeExtDecoderAddr(extDecoder decoderFunc) decoderFunc { + return func(d *Decoder, v reflect.Value) error { + if !v.CanAddr() { + return fmt.Errorf("msgpack: DecodeExt(nonaddressable %T)", v.Interface()) + } + return extDecoder(d, v.Addr()) + } +} + +func (e *Encoder) EncodeExtHeader(extID int8, extLen int) error { + if err := e.encodeExtLen(extLen); err != nil { + return err + } + if err := e.w.WriteByte(byte(extID)); err != nil { + return err + } + return nil +} + +func (e *Encoder) encodeExtLen(l int) error { + switch l { + case 1: + return e.writeCode(FixExt1) + case 2: + return e.writeCode(FixExt2) + case 4: + return e.writeCode(FixExt4) + case 8: + return e.writeCode(FixExt8) + case 16: + return e.writeCode(FixExt16) + } + if l <= math.MaxUint8 { + return e.write1(Ext8, uint8(l)) + } + if l <= math.MaxUint16 { + return e.write2(Ext16, uint16(l)) + } + return e.write4(Ext32, uint32(l)) +} + +func (d *Decoder) DecodeExtHeader() (extID int8, extLen int, err error) { + c, err := d.readCode() + if err != nil { + return + } + return d.extHeader(c) +} + +func (d *Decoder) extHeader(c byte) (int8, int, error) { + extLen, err := d.parseExtLen(c) + if err != nil { + return 0, 0, err + } + + extID, err := d.readCode() + if err != nil { + return 0, 0, err + } + + return int8(extID), extLen, nil +} + +func (d *Decoder) parseExtLen(c byte) (int, error) { + switch c { + case FixExt1: + return 1, nil + case FixExt2: + return 2, nil + case FixExt4: + return 4, nil + case FixExt8: + return 8, nil + case FixExt16: + return 16, nil + case Ext8: + n, err := d.uint8() + return int(n), err + case Ext16: + n, err := d.uint16() + return int(n), err + case Ext32: + n, err := d.uint32() + return int(n), err + default: + return 0, fmt.Errorf("msgpack: invalid code=%x decoding ext len", c) + } +} + +func (d *Decoder) decodeInterfaceExt(c byte) (interface{}, error) { + extID, extLen, err := d.extHeader(c) + if err != nil { + return nil, err + } + + info, ok := extTypes[extID] + if !ok { + return nil, fmt.Errorf("msgpack: unknown ext id=%d", extID) + } + + v := d.newValue(info.Type).Elem() + if nilable(v.Kind()) && v.IsNil() { + v.Set(d.newValue(info.Type.Elem())) + } + + if err := info.Decoder(d, v, extLen); err != nil { + return nil, err + } + + return v.Interface(), nil +} + +func (d *Decoder) skipExt(c byte) error { + n, err := d.parseExtLen(c) + if err != nil { + return err + } + return d.skipN(n + 1) +} + +func (d *Decoder) skipExtHeader(c byte) error { + // Read ext type. + _, err := d.readCode() + if err != nil { + return err + } + // Read ext body len. + for i := 0; i < extHeaderLen(c); i++ { + _, err := d.readCode() + if err != nil { + return err + } + } + return nil +} + +func extHeaderLen(c byte) int { + switch c { + case Ext8: + return 1 + case Ext16: + return 2 + case Ext32: + return 4 + } + return 0 +} diff --git a/lib/msgpack/intern.go b/lib/msgpack/intern.go new file mode 100644 index 0000000..9af4218 --- /dev/null +++ b/lib/msgpack/intern.go @@ -0,0 +1,235 @@ +package msgpack + +import ( + "fmt" + "math" + "reflect" + +) + +const ( + minInternedStringLen = 3 + maxDictLen = math.MaxUint16 +) + +var internedStringExtID = int8(math.MinInt8) + +func init() { + extTypes[internedStringExtID] = &extInfo{ + Type: stringType, + Decoder: decodeInternedStringExt, + } +} + +func decodeInternedStringExt(d *Decoder, v reflect.Value, extLen int) error { + idx, err := d.decodeInternedStringIndex(extLen) + if err != nil { + return err + } + + s, err := d.internedStringAtIndex(idx) + if err != nil { + return err + } + + v.SetString(s) + return nil +} + +//------------------------------------------------------------------------------ + +func encodeInternedInterfaceValue(e *Encoder, v reflect.Value) error { + if v.IsNil() { + return e.EncodeNil() + } + + v = v.Elem() + if v.Kind() == reflect.String { + return e.encodeInternedString(v.String(), true) + } + return e.EncodeValue(v) +} + +func encodeInternedStringValue(e *Encoder, v reflect.Value) error { + return e.encodeInternedString(v.String(), true) +} + +func (e *Encoder) encodeInternedString(s string, intern bool) error { + // Interned string takes at least 3 bytes. Plain string 1 byte + string len. + if idx, ok := e.dict[s]; ok { + return e.encodeInternedStringIndex(idx) + } + + if intern && len(s) >= minInternedStringLen && len(e.dict) < maxDictLen { + if e.dict == nil { + e.dict = make(map[string]int) + } + idx := len(e.dict) + e.dict[s] = idx + } + + return e.encodeNormalString(s) +} + +func (e *Encoder) encodeInternedStringIndex(idx int) error { + if idx <= math.MaxUint8 { + if err := e.writeCode(FixExt1); err != nil { + return err + } + return e.write1(byte(internedStringExtID), uint8(idx)) + } + + if idx <= math.MaxUint16 { + if err := e.writeCode(FixExt2); err != nil { + return err + } + return e.write2(byte(internedStringExtID), uint16(idx)) + } + + if uint64(idx) <= math.MaxUint32 { + if err := e.writeCode(FixExt4); err != nil { + return err + } + return e.write4(byte(internedStringExtID), uint32(idx)) + } + + return fmt.Errorf("msgpack: interned string index=%d is too large", idx) +} + +//------------------------------------------------------------------------------ + +func decodeInternedInterfaceValue(d *Decoder, v reflect.Value) error { + s, err := d.decodeInternedString(true) + if err == nil { + v.Set(reflect.ValueOf(s)) + return nil + } + if err != nil { + if _, ok := err.(unexpectedCodeError); !ok { + return err + } + } + + if err := d.s.UnreadByte(); err != nil { + return err + } + return decodeInterfaceValue(d, v) +} + +func decodeInternedStringValue(d *Decoder, v reflect.Value) error { + s, err := d.decodeInternedString(true) + if err != nil { + return err + } + + v.SetString(s) + return nil +} + +func (d *Decoder) decodeInternedString(intern bool) (string, error) { + c, err := d.readCode() + if err != nil { + return "", err + } + + if IsFixedString(c) { + n := int(c & FixedStrMask) + return d.decodeInternedStringWithLen(n, intern) + } + + switch c { + case Nil: + return "", nil + case FixExt1, FixExt2, FixExt4: + typeID, extLen, err := d.extHeader(c) + if err != nil { + return "", err + } + if typeID != internedStringExtID { + err := fmt.Errorf("msgpack: got ext type=%d, wanted %d", + typeID, internedStringExtID) + return "", err + } + + idx, err := d.decodeInternedStringIndex(extLen) + if err != nil { + return "", err + } + + return d.internedStringAtIndex(idx) + case Str8, Bin8: + n, err := d.uint8() + if err != nil { + return "", err + } + return d.decodeInternedStringWithLen(int(n), intern) + case Str16, Bin16: + n, err := d.uint16() + if err != nil { + return "", err + } + return d.decodeInternedStringWithLen(int(n), intern) + case Str32, Bin32: + n, err := d.uint32() + if err != nil { + return "", err + } + return d.decodeInternedStringWithLen(int(n), intern) + } + + return "", unexpectedCodeError{ + code: c, + hint: "interned string", + } +} + +func (d *Decoder) decodeInternedStringIndex(extLen int) (int, error) { + switch extLen { + case 1: + n, err := d.uint8() + if err != nil { + return 0, err + } + return int(n), nil + case 2: + n, err := d.uint16() + if err != nil { + return 0, err + } + return int(n), nil + case 4: + n, err := d.uint32() + if err != nil { + return 0, err + } + return int(n), nil + } + + err := fmt.Errorf("msgpack: unsupported ext len=%d decoding interned string", extLen) + return 0, err +} + +func (d *Decoder) internedStringAtIndex(idx int) (string, error) { + if idx >= len(d.dict) { + err := fmt.Errorf("msgpack: interned string at index=%d does not exist", idx) + return "", err + } + return d.dict[idx], nil +} + +func (d *Decoder) decodeInternedStringWithLen(n int, intern bool) (string, error) { + if n <= 0 { + return "", nil + } + + s, err := d.stringWithLen(n) + if err != nil { + return "", err + } + + if intern && len(s) >= minInternedStringLen && len(d.dict) < maxDictLen { + d.dict = append(d.dict, s) + } + + return s, nil +} diff --git a/lib/msgpack/msgpack.go b/lib/msgpack/msgpack.go new file mode 100644 index 0000000..4fa000b --- /dev/null +++ b/lib/msgpack/msgpack.go @@ -0,0 +1,52 @@ +package msgpack + +import "fmt" + +type Marshaler interface { + MarshalMsgpack() ([]byte, error) +} + +type Unmarshaler interface { + UnmarshalMsgpack([]byte) error +} + +type CustomEncoder interface { + EncodeMsgpack(*Encoder) error +} + +type CustomDecoder interface { + DecodeMsgpack(*Decoder) error +} + +//------------------------------------------------------------------------------ + +type RawMessage []byte + +var ( + _ CustomEncoder = (RawMessage)(nil) + _ CustomDecoder = (*RawMessage)(nil) +) + +func (m RawMessage) EncodeMsgpack(enc *Encoder) error { + return enc.write(m) +} + +func (m *RawMessage) DecodeMsgpack(dec *Decoder) error { + msg, err := dec.DecodeRaw() + if err != nil { + return err + } + *m = msg + return nil +} + +//------------------------------------------------------------------------------ + +type unexpectedCodeError struct { + hint string + code byte +} + +func (err unexpectedCodeError) Error() string { + return fmt.Sprintf("msgpack: unexpected code=%x decoding %s", err.code, err.hint) +} diff --git a/lib/msgpack/safe.go b/lib/msgpack/safe.go new file mode 100644 index 0000000..8352c9d --- /dev/null +++ b/lib/msgpack/safe.go @@ -0,0 +1,13 @@ +// +build appengine + +package msgpack + +// bytesToString converts byte slice to string. +func bytesToString(b []byte) string { + return string(b) +} + +// stringToBytes converts string to byte slice. +func stringToBytes(s string) []byte { + return []byte(s) +} diff --git a/lib/msgpack/time.go b/lib/msgpack/time.go new file mode 100644 index 0000000..a97bf05 --- /dev/null +++ b/lib/msgpack/time.go @@ -0,0 +1,150 @@ +package msgpack + +import ( + "encoding/binary" + "fmt" + "reflect" + "time" + +) + +var timeExtID int8 = -1 + +func init() { + RegisterExtEncoder(timeExtID, time.Time{}, timeEncoder) + RegisterExtDecoder(timeExtID, time.Time{}, timeDecoder) +} + +func timeEncoder(e *Encoder, v reflect.Value) ([]byte, error) { + return e.encodeTime(v.Interface().(time.Time)), nil +} + +func timeDecoder(d *Decoder, v reflect.Value, extLen int) error { + tm, err := d.decodeTime(extLen) + if err != nil { + return err + } + + if tm.IsZero() { + // Zero time does not have timezone information. + tm = tm.UTC() + } + + ptr := v.Addr().Interface().(*time.Time) + *ptr = tm + + return nil +} + +func (e *Encoder) EncodeTime(tm time.Time) error { + b := e.encodeTime(tm) + if err := e.encodeExtLen(len(b)); err != nil { + return err + } + if err := e.w.WriteByte(byte(timeExtID)); err != nil { + return err + } + return e.write(b) +} + +func (e *Encoder) encodeTime(tm time.Time) []byte { + if e.timeBuf == nil { + e.timeBuf = make([]byte, 12) + } + + secs := uint64(tm.Unix()) + if secs>>34 == 0 { + data := uint64(tm.Nanosecond())<<34 | secs + + if data&0xffffffff00000000 == 0 { + b := e.timeBuf[:4] + binary.BigEndian.PutUint32(b, uint32(data)) + return b + } + + b := e.timeBuf[:8] + binary.BigEndian.PutUint64(b, data) + return b + } + + b := e.timeBuf[:12] + binary.BigEndian.PutUint32(b, uint32(tm.Nanosecond())) + binary.BigEndian.PutUint64(b[4:], secs) + return b +} + +func (d *Decoder) DecodeTime() (time.Time, error) { + c, err := d.readCode() + if err != nil { + return time.Time{}, err + } + + // Legacy format. + if c == FixedArrayLow|2 { + sec, err := d.DecodeInt64() + if err != nil { + return time.Time{}, err + } + + nsec, err := d.DecodeInt64() + if err != nil { + return time.Time{}, err + } + + return time.Unix(sec, nsec), nil + } + + if IsString(c) { + s, err := d.string(c) + if err != nil { + return time.Time{}, err + } + return time.Parse(time.RFC3339Nano, s) + } + + extID, extLen, err := d.extHeader(c) + if err != nil { + return time.Time{}, err + } + + // NodeJS seems to use extID 13. + if extID != timeExtID && extID != 13 { + return time.Time{}, fmt.Errorf("msgpack: invalid time ext id=%d", extID) + } + + tm, err := d.decodeTime(extLen) + if err != nil { + return tm, err + } + + if tm.IsZero() { + // Zero time does not have timezone information. + return tm.UTC(), nil + } + return tm, nil +} + +func (d *Decoder) decodeTime(extLen int) (time.Time, error) { + b, err := d.readN(extLen) + if err != nil { + return time.Time{}, err + } + + switch len(b) { + case 4: + sec := binary.BigEndian.Uint32(b) + return time.Unix(int64(sec), 0), nil + case 8: + sec := binary.BigEndian.Uint64(b) + nsec := int64(sec >> 34) + sec &= 0x00000003ffffffff + return time.Unix(int64(sec), nsec), nil + case 12: + nsec := binary.BigEndian.Uint32(b) + sec := binary.BigEndian.Uint64(b[4:]) + return time.Unix(int64(sec), int64(nsec)), nil + default: + err = fmt.Errorf("msgpack: invalid ext len=%d decoding time", extLen) + return time.Time{}, err + } +} diff --git a/lib/msgpack/types.go b/lib/msgpack/types.go new file mode 100644 index 0000000..63ba086 --- /dev/null +++ b/lib/msgpack/types.go @@ -0,0 +1,413 @@ +package msgpack + +import ( + "encoding" + "fmt" + "log" + "reflect" + "sync" + + "github.com/theater/picomap/lib/tagparser" +) + +var errorType = reflect.TypeOf((*error)(nil)).Elem() + +var ( + customEncoderType = reflect.TypeOf((*CustomEncoder)(nil)).Elem() + customDecoderType = reflect.TypeOf((*CustomDecoder)(nil)).Elem() +) + +var ( + marshalerType = reflect.TypeOf((*Marshaler)(nil)).Elem() + unmarshalerType = reflect.TypeOf((*Unmarshaler)(nil)).Elem() +) + +var ( + binaryMarshalerType = reflect.TypeOf((*encoding.BinaryMarshaler)(nil)).Elem() + binaryUnmarshalerType = reflect.TypeOf((*encoding.BinaryUnmarshaler)(nil)).Elem() +) + +var ( + textMarshalerType = reflect.TypeOf((*encoding.TextMarshaler)(nil)).Elem() + textUnmarshalerType = reflect.TypeOf((*encoding.TextUnmarshaler)(nil)).Elem() +) + +type ( + encoderFunc func(*Encoder, reflect.Value) error + decoderFunc func(*Decoder, reflect.Value) error +) + +var ( + typeEncMap sync.Map + typeDecMap sync.Map +) + +// Register registers encoder and decoder functions for a value. +// This is low level API and in most cases you should prefer implementing +// CustomEncoder/CustomDecoder or Marshaler/Unmarshaler interfaces. +func Register(value interface{}, enc encoderFunc, dec decoderFunc) { + typ := reflect.TypeOf(value) + if enc != nil { + typeEncMap.Store(typ, enc) + } + if dec != nil { + typeDecMap.Store(typ, dec) + } +} + +//------------------------------------------------------------------------------ + +const defaultStructTag = "msgpack" + +var structs = newStructCache() + +type structCache struct { + m sync.Map +} + +type structCacheKey struct { + typ reflect.Type + tag string +} + +func newStructCache() *structCache { + return new(structCache) +} + +func (m *structCache) Fields(typ reflect.Type, tag string) *fields { + key := structCacheKey{tag: tag, typ: typ} + + if v, ok := m.m.Load(key); ok { + return v.(*fields) + } + + fs := getFields(typ, tag) + m.m.Store(key, fs) + + return fs +} + +//------------------------------------------------------------------------------ + +type field struct { + encoder encoderFunc + decoder decoderFunc + name string + index []int + omitEmpty bool +} + +func (f *field) Omit(e *Encoder, strct reflect.Value) bool { + v, ok := fieldByIndex(strct, f.index) + if !ok { + return true + } + forced := e.flags&omitEmptyFlag != 0 + return (f.omitEmpty || forced) && e.isEmptyValue(v) +} + +func (f *field) EncodeValue(e *Encoder, strct reflect.Value) error { + v, ok := fieldByIndex(strct, f.index) + if !ok { + return e.EncodeNil() + } + return f.encoder(e, v) +} + +func (f *field) DecodeValue(d *Decoder, strct reflect.Value) error { + v := fieldByIndexAlloc(strct, f.index) + return f.decoder(d, v) +} + +//------------------------------------------------------------------------------ + +type fields struct { + Type reflect.Type + Map map[string]*field + List []*field + AsArray bool + + hasOmitEmpty bool +} + +func newFields(typ reflect.Type) *fields { + return &fields{ + Type: typ, + Map: make(map[string]*field, typ.NumField()), + List: make([]*field, 0, typ.NumField()), + } +} + +func (fs *fields) Add(field *field) { + fs.warnIfFieldExists(field.name) + fs.Map[field.name] = field + fs.List = append(fs.List, field) + if field.omitEmpty { + fs.hasOmitEmpty = true + } +} + +func (fs *fields) warnIfFieldExists(name string) { + if _, ok := fs.Map[name]; ok { + log.Printf("msgpack: %s already has field=%s", fs.Type, name) + } +} + +func (fs *fields) OmitEmpty(e *Encoder, strct reflect.Value) []*field { + forced := e.flags&omitEmptyFlag != 0 + if !fs.hasOmitEmpty && !forced { + return fs.List + } + + fields := make([]*field, 0, len(fs.List)) + + for _, f := range fs.List { + if !f.Omit(e, strct) { + fields = append(fields, f) + } + } + + return fields +} + +func getFields(typ reflect.Type, fallbackTag string) *fields { + fs := newFields(typ) + + var omitEmpty bool + for i := 0; i < typ.NumField(); i++ { + f := typ.Field(i) + + tagStr := f.Tag.Get(defaultStructTag) + if tagStr == "" && fallbackTag != "" { + tagStr = f.Tag.Get(fallbackTag) + } + + tag := tagparser.Parse(tagStr) + if tag.Name == "-" { + continue + } + + if f.Name == "_msgpack" { + fs.AsArray = tag.HasOption("as_array") || tag.HasOption("asArray") + if tag.HasOption("omitempty") { + omitEmpty = true + } + } + + if f.PkgPath != "" && !f.Anonymous { + continue + } + + field := &field{ + name: tag.Name, + index: f.Index, + omitEmpty: omitEmpty || tag.HasOption("omitempty"), + } + + if tag.HasOption("intern") { + switch f.Type.Kind() { + case reflect.Interface: + field.encoder = encodeInternedInterfaceValue + field.decoder = decodeInternedInterfaceValue + case reflect.String: + field.encoder = encodeInternedStringValue + field.decoder = decodeInternedStringValue + default: + err := fmt.Errorf("msgpack: intern strings are not supported on %s", f.Type) + panic(err) + } + } else { + field.encoder = getEncoder(f.Type) + field.decoder = getDecoder(f.Type) + } + + if field.name == "" { + field.name = f.Name + } + + if f.Anonymous && !tag.HasOption("noinline") { + inline := tag.HasOption("inline") + if inline { + inlineFields(fs, f.Type, field, fallbackTag) + } else { + inline = shouldInline(fs, f.Type, field, fallbackTag) + } + + if inline { + if _, ok := fs.Map[field.name]; ok { + log.Printf("msgpack: %s already has field=%s", fs.Type, field.name) + } + fs.Map[field.name] = field + continue + } + } + + fs.Add(field) + + if alias, ok := tag.Options["alias"]; ok { + fs.warnIfFieldExists(alias) + fs.Map[alias] = field + } + } + return fs +} + +var ( + encodeStructValuePtr uintptr + decodeStructValuePtr uintptr +) + +//nolint:gochecknoinits +func init() { + encodeStructValuePtr = reflect.ValueOf(encodeStructValue).Pointer() + decodeStructValuePtr = reflect.ValueOf(decodeStructValue).Pointer() +} + +func inlineFields(fs *fields, typ reflect.Type, f *field, tag string) { + inlinedFields := getFields(typ, tag).List + for _, field := range inlinedFields { + if _, ok := fs.Map[field.name]; ok { + // Don't inline shadowed fields. + continue + } + field.index = append(f.index, field.index...) + fs.Add(field) + } +} + +func shouldInline(fs *fields, typ reflect.Type, f *field, tag string) bool { + var encoder encoderFunc + var decoder decoderFunc + + if typ.Kind() == reflect.Struct { + encoder = f.encoder + decoder = f.decoder + } else { + for typ.Kind() == reflect.Ptr { + typ = typ.Elem() + encoder = getEncoder(typ) + decoder = getDecoder(typ) + } + if typ.Kind() != reflect.Struct { + return false + } + } + + if reflect.ValueOf(encoder).Pointer() != encodeStructValuePtr { + return false + } + if reflect.ValueOf(decoder).Pointer() != decodeStructValuePtr { + return false + } + + inlinedFields := getFields(typ, tag).List + for _, field := range inlinedFields { + if _, ok := fs.Map[field.name]; ok { + // Don't auto inline if there are shadowed fields. + return false + } + } + + for _, field := range inlinedFields { + field.index = append(f.index, field.index...) + fs.Add(field) + } + return true +} + +type isZeroer interface { + IsZero() bool +} + +func (e *Encoder) isEmptyValue(v reflect.Value) bool { + kind := v.Kind() + + for kind == reflect.Interface { + if v.IsNil() { + return true + } + v = v.Elem() + kind = v.Kind() + } + + if z, ok := v.Interface().(isZeroer); ok { + return nilable(kind) && v.IsNil() || z.IsZero() + } + + switch kind { + case reflect.Array, reflect.Map, reflect.Slice, reflect.String: + return v.Len() == 0 + case reflect.Struct: + structFields := structs.Fields(v.Type(), e.structTag) + fields := structFields.OmitEmpty(e, v) + return len(fields) == 0 + case reflect.Bool: + return !v.Bool() + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + return v.Int() == 0 + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: + return v.Uint() == 0 + case reflect.Float32, reflect.Float64: + return v.Float() == 0 + case reflect.Ptr: + return v.IsNil() + default: + return false + } +} + +func fieldByIndex(v reflect.Value, index []int) (_ reflect.Value, ok bool) { + if len(index) == 1 { + return v.Field(index[0]), true + } + + for i, idx := range index { + if i > 0 { + if v.Kind() == reflect.Ptr { + if v.IsNil() { + return v, false + } + v = v.Elem() + } + } + v = v.Field(idx) + } + + return v, true +} + +func fieldByIndexAlloc(v reflect.Value, index []int) reflect.Value { + if len(index) == 1 { + return v.Field(index[0]) + } + + for i, idx := range index { + if i > 0 { + var ok bool + v, ok = indirectNil(v) + if !ok { + return v + } + } + v = v.Field(idx) + } + + return v +} + +func indirectNil(v reflect.Value) (reflect.Value, bool) { + if v.Kind() == reflect.Ptr { + if v.IsNil() { + if !v.CanSet() { + return v, false + } + elemType := v.Type().Elem() + if elemType.Kind() != reflect.Struct { + return v, false + } + v.Set(cachedValue(elemType)) + } + v = v.Elem() + } + return v, true +} diff --git a/lib/msgpack/unsafe.go b/lib/msgpack/unsafe.go new file mode 100644 index 0000000..192ac47 --- /dev/null +++ b/lib/msgpack/unsafe.go @@ -0,0 +1,22 @@ +// +build !appengine + +package msgpack + +import ( + "unsafe" +) + +// bytesToString converts byte slice to string. +func bytesToString(b []byte) string { + return *(*string)(unsafe.Pointer(&b)) +} + +// stringToBytes converts string to byte slice. +func stringToBytes(s string) []byte { + return *(*[]byte)(unsafe.Pointer( + &struct { + string + Cap int + }{s, len(s)}, + )) +} diff --git a/lib/msgpack/version.go b/lib/msgpack/version.go new file mode 100644 index 0000000..ca10205 --- /dev/null +++ b/lib/msgpack/version.go @@ -0,0 +1,6 @@ +package msgpack + +// Version is the current release version. +func Version() string { + return "5.4.1" +} diff --git a/lib/picoserial/picoserial.go b/lib/picoserial/picoserial.go index 4bf7618..633a7d5 100644 --- a/lib/picoserial/picoserial.go +++ b/lib/picoserial/picoserial.go @@ -2,6 +2,7 @@ package picoserial import ( "fmt" + "time" "go.bug.st/serial" "go.bug.st/serial/enumerator" @@ -20,8 +21,12 @@ func FindDevice() (string, error) { return "", nil } +func Open(portName string) (serial.Port, error) { + return serial.Open(portName, &serial.Mode{BaudRate: 115200}) +} + func SendByte(portName string, b byte) error { - port, err := serial.Open(portName, &serial.Mode{BaudRate: 115200}) + port, err := Open(portName) if err != nil { return fmt.Errorf("opening %s: %w", portName, err) } @@ -33,3 +38,32 @@ func SendByte(portName string, b byte) error { } return nil } + +// SendByteAndRead sends a byte and reads the response with a timeout. +func SendByteAndRead(portName string, b byte, timeout time.Duration) ([]byte, error) { + port, err := Open(portName) + if err != nil { + return nil, fmt.Errorf("opening %s: %w", portName, err) + } + defer port.Close() + + port.SetReadTimeout(timeout) + + _, err = port.Write([]byte{b}) + if err != nil { + return nil, fmt.Errorf("writing to %s: %w", portName, err) + } + + var resp []byte + buf := make([]byte, 256) + for { + n, err := port.Read(buf) + if n > 0 { + resp = append(resp, buf[:n]...) + } + if err != nil || n == 0 { + break + } + } + return resp, nil +} diff --git a/lib/tagparser/parser.go b/lib/tagparser/parser.go new file mode 100644 index 0000000..a6b7a7c --- /dev/null +++ b/lib/tagparser/parser.go @@ -0,0 +1,76 @@ +package tagparser + +import "bytes" + +type Parser struct { + b []byte + i int +} + +func newParser(b []byte) *Parser { + return &Parser{b: b} +} + +func newStringParser(s string) *Parser { + return newParser([]byte(s)) +} + +func (p *Parser) Bytes() []byte { + return p.b[p.i:] +} + +func (p *Parser) Valid() bool { + return p.i < len(p.b) +} + +func (p *Parser) Read() byte { + if p.Valid() { + c := p.b[p.i] + p.Advance() + return c + } + return 0 +} + +func (p *Parser) Peek() byte { + if p.Valid() { + return p.b[p.i] + } + return 0 +} + +func (p *Parser) Advance() { + p.i++ +} + +func (p *Parser) Skip(skip byte) bool { + if p.Peek() == skip { + p.Advance() + return true + } + return false +} + +func (p *Parser) SkipBytes(skip []byte) bool { + if len(skip) > len(p.b[p.i:]) { + return false + } + if !bytes.Equal(p.b[p.i:p.i+len(skip)], skip) { + return false + } + p.i += len(skip) + return true +} + +func (p *Parser) ReadSep(sep byte) ([]byte, bool) { + ind := bytes.IndexByte(p.b[p.i:], sep) + if ind == -1 { + b := p.b[p.i:] + p.i = len(p.b) + return b, false + } + + b := p.b[p.i : p.i+ind] + p.i += ind + 1 + return b, true +} diff --git a/lib/tagparser/tagparser.go b/lib/tagparser/tagparser.go new file mode 100644 index 0000000..a0692c3 --- /dev/null +++ b/lib/tagparser/tagparser.go @@ -0,0 +1,162 @@ +package tagparser + +import "strings" + +type Tag struct { + Name string + Options map[string]string +} + +func (t *Tag) HasOption(name string) bool { + _, ok := t.Options[name] + return ok +} + +func Parse(s string) *Tag { + p := &tagParser{ + Parser: newStringParser(s), + } + p.parseKey() + return &p.Tag +} + +type tagParser struct { + *Parser + + Tag Tag + hasName bool + key string +} + +func (p *tagParser) setTagOption(key, value string) { + key = strings.TrimSpace(key) + value = strings.TrimSpace(value) + + if !p.hasName { + p.hasName = true + if key == "" { + p.Tag.Name = value + return + } + } + if p.Tag.Options == nil { + p.Tag.Options = make(map[string]string) + } + if key == "" { + p.Tag.Options[value] = "" + } else { + p.Tag.Options[key] = value + } +} + +func (p *tagParser) parseKey() { + p.key = "" + + var b []byte + for p.Valid() { + c := p.Read() + switch c { + case ',': + p.Skip(' ') + p.setTagOption("", string(b)) + p.parseKey() + return + case ':': + p.key = string(b) + p.parseValue() + return + case '\'': + p.parseQuotedValue() + return + default: + b = append(b, c) + } + } + + if len(b) > 0 { + p.setTagOption("", string(b)) + } +} + +func (p *tagParser) parseValue() { + const quote = '\'' + c := p.Peek() + if c == quote { + p.Skip(quote) + p.parseQuotedValue() + return + } + + var b []byte + for p.Valid() { + c = p.Read() + switch c { + case '\\': + b = append(b, p.Read()) + case '(': + b = append(b, c) + b = p.readBrackets(b) + case ',': + p.Skip(' ') + p.setTagOption(p.key, string(b)) + p.parseKey() + return + default: + b = append(b, c) + } + } + p.setTagOption(p.key, string(b)) +} + +func (p *tagParser) readBrackets(b []byte) []byte { + var lvl int +loop: + for p.Valid() { + c := p.Read() + switch c { + case '\\': + b = append(b, p.Read()) + case '(': + b = append(b, c) + lvl++ + case ')': + b = append(b, c) + lvl-- + if lvl < 0 { + break loop + } + default: + b = append(b, c) + } + } + return b +} + +func (p *tagParser) parseQuotedValue() { + const quote = '\'' + var b []byte + for p.Valid() { + bb, ok := p.ReadSep(quote) + if !ok { + b = append(b, bb...) + break + } + + // keep the escaped single-quote, and continue until we've found the + // one that isn't. + if len(bb) > 0 && bb[len(bb)-1] == '\\' { + b = append(b, bb[:len(bb)-1]...) + b = append(b, quote) + continue + } + + b = append(b, bb...) + break + } + + p.setTagOption(p.key, string(b)) + if p.Skip(',') { + p.Skip(' ') + } + p.parseKey() +} diff --git a/lib/wire/wire.go b/lib/wire/wire.go new file mode 100644 index 0000000..fead255 --- /dev/null +++ b/lib/wire/wire.go @@ -0,0 +1,40 @@ +package wire + +import ( + "fmt" + + "github.com/theater/picomap/lib/halfsiphash" + "github.com/theater/picomap/lib/msgpack" +) + +var HashKey = [8]byte{} + +type RebootingBootsel struct{} + +type Envelope struct { + Checksum uint32 + Payload []byte +} + +func init() { + msgpack.RegisterExt(0, (*Envelope)(nil)) + msgpack.RegisterExt(1, (*RebootingBootsel)(nil)) +} + +func DecodeMessage(data []byte) (any, error) { + var env Envelope + if err := msgpack.Unmarshal(data, &env); err != nil { + return nil, fmt.Errorf("decode envelope: %w", err) + } + + expected := halfsiphash.Sum32(env.Payload, HashKey) + if env.Checksum != expected { + return nil, fmt.Errorf("checksum mismatch: got %08x, want %08x", env.Checksum, expected) + } + + var inner any + if err := msgpack.Unmarshal(env.Payload, &inner); err != nil { + return nil, fmt.Errorf("decode inner: %w", err) + } + return inner, nil +} diff --git a/picomap.cpp b/picomap.cpp index 2c503e6..d58dbbf 100644 --- a/picomap.cpp +++ b/picomap.cpp @@ -1,6 +1,38 @@ #include +#include +#include #include "pico/stdlib.h" #include "pico/bootrom.h" +#include "msgpackpp.h" +#include "halfsiphash.h" + +static constexpr uint8_t hash_key[8] = {}; + +struct RebootingBootsel { + static constexpr int8_t ext_id = 1; + auto as_tuple() const { return std::tie(); } +}; + +struct Envelope { + static constexpr int8_t ext_id = 0; + uint32_t checksum; + std::vector payload; + auto as_tuple() const { return std::tie(checksum, payload); } +}; + +static std::vector pack_envelope(const std::vector &payload) { + uint32_t checksum = halfsiphash::hash32(payload.data(), payload.size(), hash_key); + msgpackpp::packer p; + p.pack(Envelope{checksum, payload}); + return p.get_payload(); +} + +static void send_bytes(const std::vector &data) { + for (auto b : data) { + putchar(b); + } + stdio_flush(); +} int main() { stdio_init_all(); @@ -10,6 +42,11 @@ int main() { if (c == 'p') { printf("p"); } else if (c == 'b') { + msgpackpp::packer inner; + inner.pack(RebootingBootsel{}); + auto msg = pack_envelope(inner.get_payload()); + send_bytes(msg); + sleep_ms(100); reset_usb_boot(0, 0); } } diff --git a/third_party/halfsiphash/halfsiphash.h b/third_party/halfsiphash/halfsiphash.h new file mode 100644 index 0000000..5a79b03 --- /dev/null +++ b/third_party/halfsiphash/halfsiphash.h @@ -0,0 +1,79 @@ +#pragma once +#include +#include + +namespace halfsiphash { + +namespace detail { + +constexpr uint32_t rotl(uint32_t x, int b) { + return (x << b) | (x >> (32 - b)); +} + +constexpr uint32_t load_le32(const uint8_t *p) { + return static_cast(p[0]) + | (static_cast(p[1]) << 8) + | (static_cast(p[2]) << 16) + | (static_cast(p[3]) << 24); +} + +inline void store_le32(uint8_t *p, uint32_t v) { + p[0] = static_cast(v); + p[1] = static_cast(v >> 8); + p[2] = static_cast(v >> 16); + p[3] = static_cast(v >> 24); +} + +inline void sipround(uint32_t &v0, uint32_t &v1, uint32_t &v2, uint32_t &v3) { + v0 += v1; v1 = rotl(v1, 5); v1 ^= v0; v0 = rotl(v0, 16); + v2 += v3; v3 = rotl(v3, 8); v3 ^= v2; + v0 += v3; v3 = rotl(v3, 7); v3 ^= v0; + v2 += v1; v1 = rotl(v1, 13); v1 ^= v2; v2 = rotl(v2, 16); +} + +} // namespace detail + +// Compute HalfSipHash-2-4 with an 8-byte key, returning a 32-bit hash. +inline uint32_t hash32(const uint8_t *data, size_t len, const uint8_t key[8]) { + using namespace detail; + + uint32_t k0 = load_le32(key); + uint32_t k1 = load_le32(key + 4); + + uint32_t v0 = 0 ^ k0; + uint32_t v1 = 0 ^ k1; + uint32_t v2 = UINT32_C(0x6c796765) ^ k0; + uint32_t v3 = UINT32_C(0x74656462) ^ k1; + + const uint8_t *end = data + len - (len % 4); + for (const uint8_t *p = data; p != end; p += 4) { + uint32_t m = load_le32(p); + v3 ^= m; + sipround(v0, v1, v2, v3); + sipround(v0, v1, v2, v3); + v0 ^= m; + } + + uint32_t b = static_cast(len) << 24; + switch (len & 3) { + case 3: b |= static_cast(end[2]) << 16; [[fallthrough]]; + case 2: b |= static_cast(end[1]) << 8; [[fallthrough]]; + case 1: b |= static_cast(end[0]); break; + case 0: break; + } + + v3 ^= b; + sipround(v0, v1, v2, v3); + sipround(v0, v1, v2, v3); + v0 ^= b; + + v2 ^= 0xff; + sipround(v0, v1, v2, v3); + sipround(v0, v1, v2, v3); + sipround(v0, v1, v2, v3); + sipround(v0, v1, v2, v3); + + return v1 ^ v3; +} + +} // namespace halfsiphash diff --git a/third_party/msgpackpp/msgpackpp.h b/third_party/msgpackpp/msgpackpp.h new file mode 100644 index 0000000..c56109b --- /dev/null +++ b/third_party/msgpackpp/msgpackpp.h @@ -0,0 +1,717 @@ +#pragma once +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace msgpackpp { + +enum class error_code { + overflow, + empty, + lack, + invalid, + type_error, +}; + +// MessagePack format byte constants. Ranges are handled via helper functions +// rather than enumerating every value. +namespace format { + // Fixed ranges (use is_* helpers below) + constexpr uint8_t POSITIVE_FIXINT_MIN = 0x00; + constexpr uint8_t POSITIVE_FIXINT_MAX = 0x7F; + constexpr uint8_t FIXMAP_MIN = 0x80; + constexpr uint8_t FIXMAP_MAX = 0x8F; + constexpr uint8_t FIXARRAY_MIN = 0x90; + constexpr uint8_t FIXARRAY_MAX = 0x9F; + constexpr uint8_t FIXSTR_MIN = 0xA0; + constexpr uint8_t FIXSTR_MAX = 0xBF; + constexpr uint8_t NEGATIVE_FIXINT_MIN = 0xE0; + constexpr uint8_t NEGATIVE_FIXINT_MAX = 0xFF; + + // Specific type bytes + constexpr uint8_t NIL = 0xC0; + constexpr uint8_t NEVER_USED = 0xC1; + constexpr uint8_t FALSE = 0xC2; + constexpr uint8_t TRUE = 0xC3; + constexpr uint8_t BIN8 = 0xC4; + constexpr uint8_t BIN16 = 0xC5; + constexpr uint8_t BIN32 = 0xC6; + constexpr uint8_t EXT8 = 0xC7; + constexpr uint8_t EXT16 = 0xC8; + constexpr uint8_t EXT32 = 0xC9; + constexpr uint8_t FLOAT32 = 0xCA; + constexpr uint8_t FLOAT64 = 0xCB; + constexpr uint8_t UINT8 = 0xCC; + constexpr uint8_t UINT16 = 0xCD; + constexpr uint8_t UINT32 = 0xCE; + constexpr uint8_t UINT64 = 0xCF; + constexpr uint8_t INT8 = 0xD0; + constexpr uint8_t INT16 = 0xD1; + constexpr uint8_t INT32 = 0xD2; + constexpr uint8_t INT64 = 0xD3; + constexpr uint8_t FIXEXT1 = 0xD4; + constexpr uint8_t FIXEXT2 = 0xD5; + constexpr uint8_t FIXEXT4 = 0xD6; + constexpr uint8_t FIXEXT8 = 0xD7; + constexpr uint8_t FIXEXT16 = 0xD8; + constexpr uint8_t STR8 = 0xD9; + constexpr uint8_t STR16 = 0xDA; + constexpr uint8_t STR32 = 0xDB; + constexpr uint8_t ARRAY16 = 0xDC; + constexpr uint8_t ARRAY32 = 0xDD; + constexpr uint8_t MAP16 = 0xDE; + constexpr uint8_t MAP32 = 0xDF; + + constexpr bool is_positive_fixint(uint8_t b) { return b <= POSITIVE_FIXINT_MAX; } + constexpr bool is_fixmap(uint8_t b) { return b >= FIXMAP_MIN && b <= FIXMAP_MAX; } + constexpr bool is_fixarray(uint8_t b) { return b >= FIXARRAY_MIN && b <= FIXARRAY_MAX; } + constexpr bool is_fixstr(uint8_t b) { return b >= FIXSTR_MIN && b <= FIXSTR_MAX; } + constexpr bool is_negative_fixint(uint8_t b) { return b >= NEGATIVE_FIXINT_MIN; } +} // namespace format + +template +using result = std::expected; + +// Read a big-endian number from position m_p+1. +template +result body_number(const uint8_t *p, int size) { + if (size < 1 + static_cast(sizeof(T))) { + return std::unexpected(error_code::lack); + } + if constexpr (sizeof(T) == 1) { + return static_cast(p[1]); + } else if constexpr (sizeof(T) == 2) { + return static_cast((p[1] << 8) | p[2]); + } else if constexpr (sizeof(T) == 4) { + uint8_t buf[] = {p[4], p[3], p[2], p[1]}; + T val; + __builtin_memcpy(&val, buf, sizeof(T)); + return val; + } else if constexpr (sizeof(T) == 8) { + uint8_t buf[] = {p[8], p[7], p[6], p[5], p[4], p[3], p[2], p[1]}; + T val; + __builtin_memcpy(&val, buf, sizeof(T)); + return val; + } else { + return std::unexpected(error_code::invalid); + } +} + +// Returns {header_bytes, body_bytes} for a given format byte. +// For container types (array/map), body_bytes is not meaningful (returns 0). +struct body_info { + int header; // bytes before the body (includes format byte + length fields + ext type byte) + uint32_t body; // body size in bytes (0 for containers, computed for variable-length) +}; + +inline result get_body_info(const uint8_t *p, int size) { + if (size < 1) return std::unexpected(error_code::empty); + uint8_t b = p[0]; + + using namespace format; + + if (is_positive_fixint(b)) return body_info{1, 0}; + if (is_negative_fixint(b)) return body_info{1, 0}; + if (is_fixmap(b)) return body_info{1, 0}; // container + if (is_fixarray(b)) return body_info{1, 0}; // container + if (is_fixstr(b)) return body_info{1, static_cast(b & 0x1F)}; + + switch (b) { + case NIL: case FALSE: case TRUE: + return body_info{1, 0}; + case NEVER_USED: + return std::unexpected(error_code::invalid); + + case BIN8: { auto n = body_number(p, size); if (!n) return std::unexpected(n.error()); return body_info{1+1, *n}; } + case BIN16: { auto n = body_number(p, size); if (!n) return std::unexpected(n.error()); return body_info{1+2, *n}; } + case BIN32: { auto n = body_number(p, size); if (!n) return std::unexpected(n.error()); return body_info{1+4, *n}; } + + case EXT8: { auto n = body_number(p, size); if (!n) return std::unexpected(n.error()); return body_info{1+1+1, *n}; } + case EXT16: { auto n = body_number(p, size); if (!n) return std::unexpected(n.error()); return body_info{1+2+1, *n}; } + case EXT32: { auto n = body_number(p, size); if (!n) return std::unexpected(n.error()); return body_info{1+4+1, *n}; } + + case FLOAT32: return body_info{1, 4}; + case FLOAT64: return body_info{1, 8}; + case UINT8: return body_info{1, 1}; + case UINT16: return body_info{1, 2}; + case UINT32: return body_info{1, 4}; + case UINT64: return body_info{1, 8}; + case INT8: return body_info{1, 1}; + case INT16: return body_info{1, 2}; + case INT32: return body_info{1, 4}; + case INT64: return body_info{1, 8}; + + case FIXEXT1: return body_info{1+1, 1}; + case FIXEXT2: return body_info{1+1, 2}; + case FIXEXT4: return body_info{1+1, 4}; + case FIXEXT8: return body_info{1+1, 8}; + case FIXEXT16: return body_info{1+1, 16}; + + case STR8: { auto n = body_number(p, size); if (!n) return std::unexpected(n.error()); return body_info{1+1, *n}; } + case STR16: { auto n = body_number(p, size); if (!n) return std::unexpected(n.error()); return body_info{1+2, *n}; } + case STR32: { auto n = body_number(p, size); if (!n) return std::unexpected(n.error()); return body_info{1+4, *n}; } + + case ARRAY16: case ARRAY32: + case MAP16: case MAP32: + return body_info{1 + (b == ARRAY16 || b == MAP16 ? 2 : 4), 0}; // container + + default: + return std::unexpected(error_code::invalid); + } +} + +class packer { +public: + using buffer = std::vector; + +private: + std::shared_ptr m_buffer; + + template void push_big_endian(T n) { + auto p = reinterpret_cast(&n) + (sizeof(T) - 1); + for (size_t i = 0; i < sizeof(T); ++i, --p) { + m_buffer->push_back(*p); + } + } + + template void push(const Range &r) { + m_buffer->insert(m_buffer->end(), std::begin(r), std::end(r)); + } + +public: + packer() : m_buffer(std::make_shared()) {} + packer(const std::shared_ptr &buf) : m_buffer(buf) {} + + packer(const packer &) = delete; + packer &operator=(const packer &) = delete; + + using pack_result = result>; + + pack_result pack_nil() { + m_buffer->push_back(format::NIL); + return *this; + } + + pack_result pack_bool(bool v) { + m_buffer->push_back(v ? format::TRUE : format::FALSE); + return *this; + } + + template + pack_result pack_integer(T n) { + if constexpr (std::is_signed_v) { + if (n >= 0 && n <= 0x7F) { + m_buffer->push_back(static_cast(n)); + } else if (n >= -32 && n < 0) { + m_buffer->push_back(static_cast(n)); // negative fixint + } else if (n >= std::numeric_limits::min() && n <= std::numeric_limits::max()) { + m_buffer->push_back(format::INT8); + m_buffer->push_back(static_cast(n)); + } else if (n >= std::numeric_limits::min() && n <= std::numeric_limits::max()) { + m_buffer->push_back(format::INT16); + push_big_endian(static_cast(n)); + } else if (n >= std::numeric_limits::min() && n <= std::numeric_limits::max()) { + m_buffer->push_back(format::INT32); + push_big_endian(static_cast(n)); + } else { + m_buffer->push_back(format::INT64); + push_big_endian(static_cast(n)); + } + } else { + if (n <= 0x7F) { + m_buffer->push_back(static_cast(n)); + } else if (n <= std::numeric_limits::max()) { + m_buffer->push_back(format::UINT8); + m_buffer->push_back(static_cast(n)); + } else if (n <= std::numeric_limits::max()) { + m_buffer->push_back(format::UINT16); + push_big_endian(static_cast(n)); + } else if (n <= std::numeric_limits::max()) { + m_buffer->push_back(format::UINT32); + push_big_endian(static_cast(n)); + } else { + m_buffer->push_back(format::UINT64); + push_big_endian(static_cast(n)); + } + } + return *this; + } + + pack_result pack_float(float n) { + m_buffer->push_back(format::FLOAT32); + push_big_endian(n); + return *this; + } + + pack_result pack_double(double n) { + m_buffer->push_back(format::FLOAT64); + push_big_endian(n); + return *this; + } + + template + pack_result pack_str(const Range &r) { + auto sz = static_cast(std::distance(std::begin(r), std::end(r))); + if (sz < 32) { + m_buffer->push_back(format::FIXSTR_MIN | static_cast(sz)); + } else if (sz <= std::numeric_limits::max()) { + m_buffer->push_back(format::STR8); + m_buffer->push_back(static_cast(sz)); + } else if (sz <= std::numeric_limits::max()) { + m_buffer->push_back(format::STR16); + push_big_endian(static_cast(sz)); + } else if (sz <= std::numeric_limits::max()) { + m_buffer->push_back(format::STR32); + push_big_endian(static_cast(sz)); + } else { + return std::unexpected(error_code::overflow); + } + push(r); + return *this; + } + + pack_result pack_str(const char *s) { + return pack_str(std::string_view(s)); + } + + template + pack_result pack_bin(const Range &r) { + auto sz = static_cast(std::distance(std::begin(r), std::end(r))); + if (sz <= std::numeric_limits::max()) { + m_buffer->push_back(format::BIN8); + m_buffer->push_back(static_cast(sz)); + } else if (sz <= std::numeric_limits::max()) { + m_buffer->push_back(format::BIN16); + push_big_endian(static_cast(sz)); + } else if (sz <= std::numeric_limits::max()) { + m_buffer->push_back(format::BIN32); + push_big_endian(static_cast(sz)); + } else { + return std::unexpected(error_code::overflow); + } + push(r); + return *this; + } + + pack_result pack_array(size_t n) { + if (n <= 15) { + m_buffer->push_back(format::FIXARRAY_MIN | static_cast(n)); + } else if (n <= std::numeric_limits::max()) { + m_buffer->push_back(format::ARRAY16); + push_big_endian(static_cast(n)); + } else if (n <= std::numeric_limits::max()) { + m_buffer->push_back(format::ARRAY32); + push_big_endian(static_cast(n)); + } else { + return std::unexpected(error_code::overflow); + } + return *this; + } + + pack_result pack_map(size_t n) { + if (n <= 15) { + m_buffer->push_back(format::FIXMAP_MIN | static_cast(n)); + } else if (n <= std::numeric_limits::max()) { + m_buffer->push_back(format::MAP16); + push_big_endian(static_cast(n)); + } else if (n <= std::numeric_limits::max()) { + m_buffer->push_back(format::MAP32); + push_big_endian(static_cast(n)); + } else { + return std::unexpected(error_code::overflow); + } + return *this; + } + + template + pack_result pack_ext(char type, const Range &r) { + auto sz = static_cast(std::distance(std::begin(r), std::end(r))); + + switch (sz) { + case 1: m_buffer->push_back(format::FIXEXT1); break; + case 2: m_buffer->push_back(format::FIXEXT2); break; + case 4: m_buffer->push_back(format::FIXEXT4); break; + case 8: m_buffer->push_back(format::FIXEXT8); break; + case 16: m_buffer->push_back(format::FIXEXT16); break; + default: + if (sz <= std::numeric_limits::max()) { + m_buffer->push_back(format::EXT8); + m_buffer->push_back(static_cast(sz)); + } else if (sz <= std::numeric_limits::max()) { + m_buffer->push_back(format::EXT16); + push_big_endian(static_cast(sz)); + } else if (sz <= std::numeric_limits::max()) { + m_buffer->push_back(format::EXT32); + push_big_endian(static_cast(sz)); + } else { + return std::unexpected(error_code::overflow); + } + } + m_buffer->push_back(static_cast(type)); + push(r); + return *this; + } + + // Generic pack: dispatches based on type. + // Integers + template + requires std::is_integral_v && (!std::is_same_v) + pack_result pack(T n) { return pack_integer(n); } + + pack_result pack(bool v) { return pack_bool(v); } + pack_result pack(float v) { return pack_float(v); } + pack_result pack(double v) { return pack_double(v); } + pack_result pack(const char *v) { return pack_str(v); } + pack_result pack(std::string_view v) { return pack_str(v); } + pack_result pack(const std::string &v) { return pack_str(v); } + + // Binary (vector) + pack_result pack(const std::vector &v) { return pack_bin(v); } + + // Tuples → msgpack array + template + pack_result pack(const std::tuple &t) { + auto r = pack_array(sizeof...(Ts)); + if (!r) return r; + return pack_tuple_elements(t, std::index_sequence_for{}); + } + + // Structs with ext_id and as_tuple() → ext wrapping msgpack array + template + requires requires(const T &v) { { T::ext_id } -> std::convertible_to; v.as_tuple(); } + pack_result pack(const T &v) { + packer inner; + auto r = inner.pack(v.as_tuple()); + if (!r) return r; + return pack_ext(T::ext_id, inner.get_payload()); + } + + // Structs with as_tuple() but no ext_id → plain msgpack array + template + requires (requires(const T &v) { v.as_tuple(); } && !requires { { T::ext_id } -> std::convertible_to; }) + pack_result pack(const T &v) { + return pack(v.as_tuple()); + } + +private: + template + pack_result pack_tuple_elements(const Tuple &t, std::index_sequence) { + pack_result r = *this; + ((r = r ? r->get().pack(std::get(t)) : r), ...); + return r; + } + +public: + const buffer &get_payload() const { return *m_buffer; } +}; + +class parser { + const uint8_t *m_p = nullptr; + int m_size = 0; + + result header_byte() const { + if (m_size < 1) return std::unexpected(error_code::empty); + return m_p[0]; + } + +public: + parser() = default; + + parser(const std::vector &v) + : m_p(v.data()), m_size(static_cast(v.size())) {} + + parser(const uint8_t *p, int size) + : m_p(p), m_size(size < 0 ? 0 : size) {} + + bool is_empty() const { return m_size == 0; } + const uint8_t *data() const { return m_p; } + int size() const { return m_size; } + + result advance(int n) const { + if (n > m_size) return std::unexpected(error_code::lack); + return parser(m_p + n, m_size - n); + } + + // Navigate to the next value in a sequence. + result next() const { + auto hdr = header_byte(); + if (!hdr) return std::unexpected(hdr.error()); + uint8_t b = *hdr; + + if (is_array()) { + auto info = get_body_info(m_p, m_size); + if (!info) return std::unexpected(info.error()); + auto cnt = count(); + if (!cnt) return std::unexpected(cnt.error()); + auto cur = advance(info->header); + if (!cur) return std::unexpected(cur.error()); + for (uint32_t i = 0; i < *cnt; ++i) { + auto n = cur->next(); + if (!n) return std::unexpected(n.error()); + cur = *n; + } + return *cur; + } else if (is_map()) { + auto info = get_body_info(m_p, m_size); + if (!info) return std::unexpected(info.error()); + auto cnt = count(); + if (!cnt) return std::unexpected(cnt.error()); + auto cur = advance(info->header); + if (!cur) return std::unexpected(cur.error()); + for (uint32_t i = 0; i < *cnt; ++i) { + // key + auto k = cur->next(); + if (!k) return std::unexpected(k.error()); + cur = *k; + // value + auto v = cur->next(); + if (!v) return std::unexpected(v.error()); + cur = *v; + } + return *cur; + } else { + auto info = get_body_info(m_p, m_size); + if (!info) return std::unexpected(info.error()); + auto total = info->header + static_cast(info->body); + return advance(total); + } + } + + // Type checks + bool is_nil() const { + auto h = header_byte(); + return h && *h == format::NIL; + } + + bool is_bool() const { + auto h = header_byte(); + return h && (*h == format::TRUE || *h == format::FALSE); + } + + bool is_number() const { + auto h = header_byte(); + if (!h) return false; + uint8_t b = *h; + if (format::is_positive_fixint(b)) return true; + if (format::is_negative_fixint(b)) return true; + return b >= format::FLOAT32 && b <= format::INT64; + } + + bool is_string() const { + auto h = header_byte(); + if (!h) return false; + uint8_t b = *h; + if (format::is_fixstr(b)) return true; + return b == format::STR8 || b == format::STR16 || b == format::STR32; + } + + bool is_binary() const { + auto h = header_byte(); + if (!h) return false; + uint8_t b = *h; + return b == format::BIN8 || b == format::BIN16 || b == format::BIN32; + } + + bool is_ext() const { + auto h = header_byte(); + if (!h) return false; + uint8_t b = *h; + return (b >= format::FIXEXT1 && b <= format::FIXEXT16) || + b == format::EXT8 || b == format::EXT16 || b == format::EXT32; + } + + bool is_array() const { + auto h = header_byte(); + if (!h) return false; + uint8_t b = *h; + if (format::is_fixarray(b)) return true; + return b == format::ARRAY16 || b == format::ARRAY32; + } + + bool is_map() const { + auto h = header_byte(); + if (!h) return false; + uint8_t b = *h; + if (format::is_fixmap(b)) return true; + return b == format::MAP16 || b == format::MAP32; + } + + // Value accessors + + result get_bool() const { + auto h = header_byte(); + if (!h) return std::unexpected(h.error()); + if (*h == format::TRUE) return true; + if (*h == format::FALSE) return false; + return std::unexpected(error_code::type_error); + } + + result get_string() const { + auto h = header_byte(); + if (!h) return std::unexpected(h.error()); + uint8_t b = *h; + size_t offset, len; + if (format::is_fixstr(b)) { + len = b & 0x1F; + offset = 1; + } else if (b == format::STR8) { + auto n = body_number(m_p, m_size); + if (!n) return std::unexpected(n.error()); + len = *n; offset = 1 + 1; + } else if (b == format::STR16) { + auto n = body_number(m_p, m_size); + if (!n) return std::unexpected(n.error()); + len = *n; offset = 1 + 2; + } else if (b == format::STR32) { + auto n = body_number(m_p, m_size); + if (!n) return std::unexpected(n.error()); + len = *n; offset = 1 + 4; + } else { + return std::unexpected(error_code::type_error); + } + if (static_cast(offset + len) > m_size) { + return std::unexpected(error_code::lack); + } + return std::string_view(reinterpret_cast(m_p + offset), len); + } + + result get_binary_view() const { + auto h = header_byte(); + if (!h) return std::unexpected(h.error()); + uint8_t b = *h; + size_t offset, len; + + if (b == format::BIN8) { + auto n = body_number(m_p, m_size); + if (!n) return std::unexpected(n.error()); + len = *n; offset = 1 + 1; + } else if (b == format::BIN16) { + auto n = body_number(m_p, m_size); + if (!n) return std::unexpected(n.error()); + len = *n; offset = 1 + 2; + } else if (b == format::BIN32) { + auto n = body_number(m_p, m_size); + if (!n) return std::unexpected(n.error()); + len = *n; offset = 1 + 4; + } else { + return std::unexpected(error_code::type_error); + } + if (static_cast(offset + len) > m_size) { + return std::unexpected(error_code::lack); + } + return std::string_view(reinterpret_cast(m_p + offset), len); + } + + // Returns {ext_type, data_view}. + result> get_ext() const { + auto h = header_byte(); + if (!h) return std::unexpected(h.error()); + uint8_t b = *h; + int8_t ext_type; + size_t data_offset, data_len; + + switch (b) { + case format::FIXEXT1: ext_type = m_p[1]; data_offset = 2; data_len = 1; break; + case format::FIXEXT2: ext_type = m_p[1]; data_offset = 2; data_len = 2; break; + case format::FIXEXT4: ext_type = m_p[1]; data_offset = 2; data_len = 4; break; + case format::FIXEXT8: ext_type = m_p[1]; data_offset = 2; data_len = 8; break; + case format::FIXEXT16: ext_type = m_p[1]; data_offset = 2; data_len = 16; break; + case format::EXT8: { + auto n = body_number(m_p, m_size); + if (!n) return std::unexpected(n.error()); + ext_type = m_p[2]; data_offset = 3; data_len = *n; + break; + } + case format::EXT16: { + auto n = body_number(m_p, m_size); + if (!n) return std::unexpected(n.error()); + ext_type = m_p[3]; data_offset = 4; data_len = *n; + break; + } + case format::EXT32: { + auto n = body_number(m_p, m_size); + if (!n) return std::unexpected(n.error()); + ext_type = m_p[5]; data_offset = 6; data_len = *n; + break; + } + default: + return std::unexpected(error_code::type_error); + } + if (static_cast(data_offset + data_len) > m_size) { + return std::unexpected(error_code::lack); + } + return std::tuple{ext_type, + std::string_view(reinterpret_cast(m_p + data_offset), data_len)}; + } + + template + result get_number() const { + auto h = header_byte(); + if (!h) return std::unexpected(h.error()); + uint8_t b = *h; + + if (format::is_positive_fixint(b)) return static_cast(b); + if (format::is_negative_fixint(b)) return static_cast(static_cast(b)); + + switch (b) { + case format::UINT8: { auto n = body_number(m_p, m_size); if (!n) return std::unexpected(n.error()); return static_cast(*n); } + case format::UINT16: { auto n = body_number(m_p, m_size); if (!n) return std::unexpected(n.error()); return static_cast(*n); } + case format::UINT32: { auto n = body_number(m_p, m_size); if (!n) return std::unexpected(n.error()); return static_cast(*n); } + case format::UINT64: { auto n = body_number(m_p, m_size); if (!n) return std::unexpected(n.error()); return static_cast(*n); } + case format::INT8: { auto n = body_number(m_p, m_size); if (!n) return std::unexpected(n.error()); return static_cast(*n); } + case format::INT16: { auto n = body_number(m_p, m_size); if (!n) return std::unexpected(n.error()); return static_cast(*n); } + case format::INT32: { auto n = body_number(m_p, m_size); if (!n) return std::unexpected(n.error()); return static_cast(*n); } + case format::INT64: { auto n = body_number(m_p, m_size); if (!n) return std::unexpected(n.error()); return static_cast(*n); } + case format::FLOAT32: { auto n = body_number(m_p, m_size); if (!n) return std::unexpected(n.error()); return static_cast(*n); } + case format::FLOAT64: { auto n = body_number(m_p, m_size); if (!n) return std::unexpected(n.error()); return static_cast(*n); } + default: + return std::unexpected(error_code::type_error); + } + } + + result count() const { + auto h = header_byte(); + if (!h) return std::unexpected(h.error()); + uint8_t b = *h; + + if (format::is_fixarray(b)) return static_cast(b & 0x0F); + if (format::is_fixmap(b)) return static_cast(b & 0x0F); + + switch (b) { + case format::ARRAY16: { auto n = body_number(m_p, m_size); if (!n) return std::unexpected(n.error()); return static_cast(*n); } + case format::ARRAY32: { auto n = body_number(m_p, m_size); if (!n) return std::unexpected(n.error()); return *n; } + case format::MAP16: { auto n = body_number(m_p, m_size); if (!n) return std::unexpected(n.error()); return static_cast(*n); } + case format::MAP32: { auto n = body_number(m_p, m_size); if (!n) return std::unexpected(n.error()); return *n; } + default: + return std::unexpected(error_code::type_error); + } + } + + result first_item() const { + if (!is_array() && !is_map()) return std::unexpected(error_code::type_error); + auto info = get_body_info(m_p, m_size); + if (!info) return std::unexpected(info.error()); + return advance(info->header); + } + + parser operator[](int index) const { + auto cur = first_item(); + if (!cur) return {}; + for (int i = 0; i < index; ++i) { + auto n = cur->next(); + if (!n) return {}; + cur = *n; + } + return *cur; + } +}; + +} // namespace msgpackpp