Add msgpack wire protocol with halfsiphash checksums
This commit is contained in:
343
lib/msgpack/ext.go
Normal file
343
lib/msgpack/ext.go
Normal file
@@ -0,0 +1,343 @@
|
||||
package msgpack
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"math"
|
||||
"reflect"
|
||||
)
|
||||
|
||||
type extInfo struct {
|
||||
Type reflect.Type
|
||||
Decoder func(d *Decoder, v reflect.Value, extLen int) error
|
||||
}
|
||||
|
||||
var extTypes = make(map[int8]*extInfo)
|
||||
|
||||
type MarshalerUnmarshaler interface {
|
||||
Marshaler
|
||||
Unmarshaler
|
||||
}
|
||||
|
||||
func RegisterExt(extID int8, value interface{}) {
|
||||
typ := reflect.TypeOf(value)
|
||||
|
||||
marshalerType := reflect.TypeOf((*Marshaler)(nil)).Elem()
|
||||
unmarshalerType := reflect.TypeOf((*Unmarshaler)(nil)).Elem()
|
||||
|
||||
// Encoder: use MarshalMsgpack if available, otherwise use default struct encoding.
|
||||
implMarshal := typ.Implements(marshalerType) ||
|
||||
(typ.Kind() == reflect.Ptr && typ.Elem().Implements(marshalerType))
|
||||
if implMarshal {
|
||||
RegisterExtEncoder(extID, value, func(e *Encoder, v reflect.Value) ([]byte, error) {
|
||||
return v.Interface().(Marshaler).MarshalMsgpack()
|
||||
})
|
||||
} else {
|
||||
RegisterExtEncoder(extID, value, func(e *Encoder, v reflect.Value) ([]byte, error) {
|
||||
var buf bytes.Buffer
|
||||
enc := NewEncoder(&buf)
|
||||
enc.UseArrayEncodedStructs(true)
|
||||
if err := enc.Encode(v.Interface()); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return buf.Bytes(), nil
|
||||
})
|
||||
}
|
||||
|
||||
// Decoder: use UnmarshalMsgpack if available, otherwise use default struct decoding.
|
||||
implUnmarshal := typ.Implements(unmarshalerType) ||
|
||||
(typ.Kind() == reflect.Ptr && typ.Elem().Implements(unmarshalerType))
|
||||
if implUnmarshal {
|
||||
RegisterExtDecoder(extID, value, func(d *Decoder, v reflect.Value, extLen int) error {
|
||||
b, err := d.readN(extLen)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return v.Interface().(Unmarshaler).UnmarshalMsgpack(b)
|
||||
})
|
||||
} else {
|
||||
structDecoder := _getDecoder(typ)
|
||||
if typ.Kind() == reflect.Ptr {
|
||||
structDecoder = _getDecoder(typ.Elem())
|
||||
}
|
||||
RegisterExtDecoder(extID, value, func(d *Decoder, v reflect.Value, extLen int) error {
|
||||
b, err := d.readN(extLen)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
dec := NewDecoder(bytes.NewReader(b))
|
||||
if v.Kind() == reflect.Ptr {
|
||||
return structDecoder(dec, v.Elem())
|
||||
}
|
||||
return structDecoder(dec, v)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func UnregisterExt(extID int8) {
|
||||
unregisterExtEncoder(extID)
|
||||
unregisterExtDecoder(extID)
|
||||
}
|
||||
|
||||
func RegisterExtEncoder(
|
||||
extID int8,
|
||||
value interface{},
|
||||
encoder func(enc *Encoder, v reflect.Value) ([]byte, error),
|
||||
) {
|
||||
unregisterExtEncoder(extID)
|
||||
|
||||
typ := reflect.TypeOf(value)
|
||||
extEncoder := makeExtEncoder(extID, typ, encoder)
|
||||
typeEncMap.Store(extID, typ)
|
||||
typeEncMap.Store(typ, extEncoder)
|
||||
if typ.Kind() == reflect.Ptr {
|
||||
typeEncMap.Store(typ.Elem(), makeExtEncoderAddr(extEncoder))
|
||||
}
|
||||
}
|
||||
|
||||
func unregisterExtEncoder(extID int8) {
|
||||
t, ok := typeEncMap.Load(extID)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
typeEncMap.Delete(extID)
|
||||
typ := t.(reflect.Type)
|
||||
typeEncMap.Delete(typ)
|
||||
if typ.Kind() == reflect.Ptr {
|
||||
typeEncMap.Delete(typ.Elem())
|
||||
}
|
||||
}
|
||||
|
||||
func makeExtEncoder(
|
||||
extID int8,
|
||||
typ reflect.Type,
|
||||
encoder func(enc *Encoder, v reflect.Value) ([]byte, error),
|
||||
) encoderFunc {
|
||||
nilable := typ.Kind() == reflect.Ptr
|
||||
|
||||
return func(e *Encoder, v reflect.Value) error {
|
||||
if nilable && v.IsNil() {
|
||||
return e.EncodeNil()
|
||||
}
|
||||
|
||||
b, err := encoder(e, v)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := e.EncodeExtHeader(extID, len(b)); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return e.write(b)
|
||||
}
|
||||
}
|
||||
|
||||
func makeExtEncoderAddr(extEncoder encoderFunc) encoderFunc {
|
||||
return func(e *Encoder, v reflect.Value) error {
|
||||
if !v.CanAddr() {
|
||||
return fmt.Errorf("msgpack: EncodeExt(nonaddressable %T)", v.Interface())
|
||||
}
|
||||
return extEncoder(e, v.Addr())
|
||||
}
|
||||
}
|
||||
|
||||
func RegisterExtDecoder(
|
||||
extID int8,
|
||||
value interface{},
|
||||
decoder func(dec *Decoder, v reflect.Value, extLen int) error,
|
||||
) {
|
||||
unregisterExtDecoder(extID)
|
||||
|
||||
typ := reflect.TypeOf(value)
|
||||
extDecoder := makeExtDecoder(extID, typ, decoder)
|
||||
extTypes[extID] = &extInfo{
|
||||
Type: typ,
|
||||
Decoder: decoder,
|
||||
}
|
||||
|
||||
typeDecMap.Store(extID, typ)
|
||||
typeDecMap.Store(typ, extDecoder)
|
||||
if typ.Kind() == reflect.Ptr {
|
||||
typeDecMap.Store(typ.Elem(), makeExtDecoderAddr(extDecoder))
|
||||
}
|
||||
}
|
||||
|
||||
func unregisterExtDecoder(extID int8) {
|
||||
t, ok := typeDecMap.Load(extID)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
typeDecMap.Delete(extID)
|
||||
delete(extTypes, extID)
|
||||
typ := t.(reflect.Type)
|
||||
typeDecMap.Delete(typ)
|
||||
if typ.Kind() == reflect.Ptr {
|
||||
typeDecMap.Delete(typ.Elem())
|
||||
}
|
||||
}
|
||||
|
||||
func makeExtDecoder(
|
||||
wantedExtID int8,
|
||||
typ reflect.Type,
|
||||
decoder func(d *Decoder, v reflect.Value, extLen int) error,
|
||||
) decoderFunc {
|
||||
return nilAwareDecoder(typ, func(d *Decoder, v reflect.Value) error {
|
||||
extID, extLen, err := d.DecodeExtHeader()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if extID != wantedExtID {
|
||||
return fmt.Errorf("msgpack: got ext type=%d, wanted %d", extID, wantedExtID)
|
||||
}
|
||||
return decoder(d, v, extLen)
|
||||
})
|
||||
}
|
||||
|
||||
func makeExtDecoderAddr(extDecoder decoderFunc) decoderFunc {
|
||||
return func(d *Decoder, v reflect.Value) error {
|
||||
if !v.CanAddr() {
|
||||
return fmt.Errorf("msgpack: DecodeExt(nonaddressable %T)", v.Interface())
|
||||
}
|
||||
return extDecoder(d, v.Addr())
|
||||
}
|
||||
}
|
||||
|
||||
func (e *Encoder) EncodeExtHeader(extID int8, extLen int) error {
|
||||
if err := e.encodeExtLen(extLen); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := e.w.WriteByte(byte(extID)); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (e *Encoder) encodeExtLen(l int) error {
|
||||
switch l {
|
||||
case 1:
|
||||
return e.writeCode(FixExt1)
|
||||
case 2:
|
||||
return e.writeCode(FixExt2)
|
||||
case 4:
|
||||
return e.writeCode(FixExt4)
|
||||
case 8:
|
||||
return e.writeCode(FixExt8)
|
||||
case 16:
|
||||
return e.writeCode(FixExt16)
|
||||
}
|
||||
if l <= math.MaxUint8 {
|
||||
return e.write1(Ext8, uint8(l))
|
||||
}
|
||||
if l <= math.MaxUint16 {
|
||||
return e.write2(Ext16, uint16(l))
|
||||
}
|
||||
return e.write4(Ext32, uint32(l))
|
||||
}
|
||||
|
||||
func (d *Decoder) DecodeExtHeader() (extID int8, extLen int, err error) {
|
||||
c, err := d.readCode()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
return d.extHeader(c)
|
||||
}
|
||||
|
||||
func (d *Decoder) extHeader(c byte) (int8, int, error) {
|
||||
extLen, err := d.parseExtLen(c)
|
||||
if err != nil {
|
||||
return 0, 0, err
|
||||
}
|
||||
|
||||
extID, err := d.readCode()
|
||||
if err != nil {
|
||||
return 0, 0, err
|
||||
}
|
||||
|
||||
return int8(extID), extLen, nil
|
||||
}
|
||||
|
||||
func (d *Decoder) parseExtLen(c byte) (int, error) {
|
||||
switch c {
|
||||
case FixExt1:
|
||||
return 1, nil
|
||||
case FixExt2:
|
||||
return 2, nil
|
||||
case FixExt4:
|
||||
return 4, nil
|
||||
case FixExt8:
|
||||
return 8, nil
|
||||
case FixExt16:
|
||||
return 16, nil
|
||||
case Ext8:
|
||||
n, err := d.uint8()
|
||||
return int(n), err
|
||||
case Ext16:
|
||||
n, err := d.uint16()
|
||||
return int(n), err
|
||||
case Ext32:
|
||||
n, err := d.uint32()
|
||||
return int(n), err
|
||||
default:
|
||||
return 0, fmt.Errorf("msgpack: invalid code=%x decoding ext len", c)
|
||||
}
|
||||
}
|
||||
|
||||
func (d *Decoder) decodeInterfaceExt(c byte) (interface{}, error) {
|
||||
extID, extLen, err := d.extHeader(c)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
info, ok := extTypes[extID]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("msgpack: unknown ext id=%d", extID)
|
||||
}
|
||||
|
||||
v := d.newValue(info.Type).Elem()
|
||||
if nilable(v.Kind()) && v.IsNil() {
|
||||
v.Set(d.newValue(info.Type.Elem()))
|
||||
}
|
||||
|
||||
if err := info.Decoder(d, v, extLen); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return v.Interface(), nil
|
||||
}
|
||||
|
||||
func (d *Decoder) skipExt(c byte) error {
|
||||
n, err := d.parseExtLen(c)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return d.skipN(n + 1)
|
||||
}
|
||||
|
||||
func (d *Decoder) skipExtHeader(c byte) error {
|
||||
// Read ext type.
|
||||
_, err := d.readCode()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// Read ext body len.
|
||||
for i := 0; i < extHeaderLen(c); i++ {
|
||||
_, err := d.readCode()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func extHeaderLen(c byte) int {
|
||||
switch c {
|
||||
case Ext8:
|
||||
return 1
|
||||
case Ext16:
|
||||
return 2
|
||||
case Ext32:
|
||||
return 4
|
||||
}
|
||||
return 0
|
||||
}
|
||||
Reference in New Issue
Block a user