Bidirectional msgpack wire protocol with unpack support

This commit is contained in:
Ian Gulliver
2026-04-03 17:32:14 +09:00
parent d06d8b595e
commit 302f7fdb6a
8 changed files with 202 additions and 50 deletions

View File

@@ -7,3 +7,5 @@ user-invocable: true
Run `go run ./cmd/load/` from the project root. This builds the firmware, loads it onto the Pico, and reboots. Run `go run ./cmd/load/` from the project root. This builds the firmware, loads it onto the Pico, and reboots.
If cmake needs reconfiguring (e.g. after CMakeLists.txt changes), run `cmake -B build` first. If cmake needs reconfiguring (e.g. after CMakeLists.txt changes), run `cmake -B build` first.
After modifying the load command itself (cmd/load/, lib/wire/, lib/picoserial/), run it twice: once to load the firmware, once to verify the load process still works end-to-end.

View File

@@ -40,8 +40,12 @@ func run(buildDir string) error {
return err return err
} }
if dev != "" { if dev != "" {
fmt.Printf("Sending 'b' to %s to enter BOOTSEL mode...\n", dev) fmt.Printf("Sending bootsel request to %s...\n", dev)
resp, err := picoserial.SendByteAndRead(dev, 'b', 2*time.Second) req, err := wire.EncodeMessage(&wire.RequestBOOTSEL{})
if err != nil {
return fmt.Errorf("encoding request: %w", err)
}
resp, err := picoserial.SendAndReceive(dev, req, 2*time.Second)
if err != nil { if err != nil {
return err return err
} }
@@ -51,7 +55,7 @@ func run(buildDir string) error {
fmt.Fprintf(os.Stderr, "warning: failed to decode response: %v\n", err) fmt.Fprintf(os.Stderr, "warning: failed to decode response: %v\n", err)
} else { } else {
switch msg.(type) { switch msg.(type) {
case *wire.RebootingBootsel: case *wire.ResponseBOOTSEL:
fmt.Println("Device confirmed reboot into BOOTSEL mode.") fmt.Println("Device confirmed reboot into BOOTSEL mode.")
default: default:
fmt.Printf("Unexpected response type: %T\n", msg) fmt.Printf("Unexpected response type: %T\n", msg)

View File

@@ -21,10 +21,7 @@ enum class error_code {
type_error, type_error,
}; };
// MessagePack format byte constants. Ranges are handled via helper functions
// rather than enumerating every value.
namespace format { namespace format {
// Fixed ranges (use is_* helpers below)
constexpr uint8_t POSITIVE_FIXINT_MIN = 0x00; constexpr uint8_t POSITIVE_FIXINT_MIN = 0x00;
constexpr uint8_t POSITIVE_FIXINT_MAX = 0x7F; constexpr uint8_t POSITIVE_FIXINT_MAX = 0x7F;
constexpr uint8_t FIXMAP_MIN = 0x80; constexpr uint8_t FIXMAP_MIN = 0x80;
@@ -36,7 +33,6 @@ namespace format {
constexpr uint8_t NEGATIVE_FIXINT_MIN = 0xE0; constexpr uint8_t NEGATIVE_FIXINT_MIN = 0xE0;
constexpr uint8_t NEGATIVE_FIXINT_MAX = 0xFF; constexpr uint8_t NEGATIVE_FIXINT_MAX = 0xFF;
// Specific type bytes
constexpr uint8_t NIL = 0xC0; constexpr uint8_t NIL = 0xC0;
constexpr uint8_t NEVER_USED = 0xC1; constexpr uint8_t NEVER_USED = 0xC1;
constexpr uint8_t FALSE = 0xC2; constexpr uint8_t FALSE = 0xC2;
@@ -80,7 +76,6 @@ namespace format {
template <typename T> template <typename T>
using result = std::expected<T, error_code>; using result = std::expected<T, error_code>;
// Read a big-endian number from position m_p+1.
template <typename T> template <typename T>
result<T> body_number(const uint8_t *p, int size) { result<T> body_number(const uint8_t *p, int size) {
if (size < 1 + static_cast<int>(sizeof(T))) { if (size < 1 + static_cast<int>(sizeof(T))) {
@@ -105,8 +100,6 @@ result<T> body_number(const uint8_t *p, int size) {
} }
} }
// Returns {header_bytes, body_bytes} for a given format byte.
// For container types (array/map), body_bytes is not meaningful (returns 0).
struct body_info { struct body_info {
int header; // bytes before the body (includes format byte + length fields + ext type byte) int header; // bytes before the body (includes format byte + length fields + ext type byte)
uint32_t body; // body size in bytes (0 for containers, computed for variable-length) uint32_t body; // body size in bytes (0 for containers, computed for variable-length)
@@ -360,8 +353,6 @@ public:
return *this; return *this;
} }
// Generic pack: dispatches based on type.
// Integers
template <typename T> template <typename T>
requires std::is_integral_v<T> && (!std::is_same_v<T, bool>) requires std::is_integral_v<T> && (!std::is_same_v<T, bool>)
pack_result pack(T n) { return pack_integer(n); } pack_result pack(T n) { return pack_integer(n); }
@@ -373,10 +364,8 @@ public:
pack_result pack(std::string_view v) { return pack_str(v); } pack_result pack(std::string_view v) { return pack_str(v); }
pack_result pack(const std::string &v) { return pack_str(v); } pack_result pack(const std::string &v) { return pack_str(v); }
// Binary (vector<uint8_t>)
pack_result pack(const std::vector<uint8_t> &v) { return pack_bin(v); } pack_result pack(const std::vector<uint8_t> &v) { return pack_bin(v); }
// Tuples → msgpack array
template <typename... Ts> template <typename... Ts>
pack_result pack(const std::tuple<Ts...> &t) { pack_result pack(const std::tuple<Ts...> &t) {
auto r = pack_array(sizeof...(Ts)); auto r = pack_array(sizeof...(Ts));
@@ -384,7 +373,6 @@ public:
return pack_tuple_elements(t, std::index_sequence_for<Ts...>{}); return pack_tuple_elements(t, std::index_sequence_for<Ts...>{});
} }
// Structs with ext_id and as_tuple() → ext wrapping msgpack array
template <typename T> template <typename T>
requires requires(const T &v) { { T::ext_id } -> std::convertible_to<int8_t>; v.as_tuple(); } requires requires(const T &v) { { T::ext_id } -> std::convertible_to<int8_t>; v.as_tuple(); }
pack_result pack(const T &v) { pack_result pack(const T &v) {
@@ -394,7 +382,6 @@ public:
return pack_ext(T::ext_id, inner.get_payload()); return pack_ext(T::ext_id, inner.get_payload());
} }
// Structs with as_tuple() but no ext_id → plain msgpack array
template <typename T> template <typename T>
requires (requires(const T &v) { v.as_tuple(); } && !requires { { T::ext_id } -> std::convertible_to<int8_t>; }) requires (requires(const T &v) { v.as_tuple(); } && !requires { { T::ext_id } -> std::convertible_to<int8_t>; })
pack_result pack(const T &v) { pack_result pack(const T &v) {
@@ -440,7 +427,6 @@ public:
return parser(m_p + n, m_size - n); return parser(m_p + n, m_size - n);
} }
// Navigate to the next value in a sequence.
result<parser> next() const { result<parser> next() const {
auto hdr = header_byte(); auto hdr = header_byte();
if (!hdr) return std::unexpected(hdr.error()); if (!hdr) return std::unexpected(hdr.error());
@@ -467,11 +453,9 @@ public:
auto cur = advance(info->header); auto cur = advance(info->header);
if (!cur) return std::unexpected(cur.error()); if (!cur) return std::unexpected(cur.error());
for (uint32_t i = 0; i < *cnt; ++i) { for (uint32_t i = 0; i < *cnt; ++i) {
// key
auto k = cur->next(); auto k = cur->next();
if (!k) return std::unexpected(k.error()); if (!k) return std::unexpected(k.error());
cur = *k; cur = *k;
// value
auto v = cur->next(); auto v = cur->next();
if (!v) return std::unexpected(v.error()); if (!v) return std::unexpected(v.error());
cur = *v; cur = *v;
@@ -485,7 +469,6 @@ public:
} }
} }
// Type checks
bool is_nil() const { bool is_nil() const {
auto h = header_byte(); auto h = header_byte();
return h && *h == format::NIL; return h && *h == format::NIL;
@@ -544,7 +527,6 @@ public:
return b == format::MAP16 || b == format::MAP32; return b == format::MAP16 || b == format::MAP32;
} }
// Value accessors
result<bool> get_bool() const { result<bool> get_bool() const {
auto h = header_byte(); auto h = header_byte();
@@ -610,7 +592,6 @@ public:
return std::string_view(reinterpret_cast<const char *>(m_p + offset), len); return std::string_view(reinterpret_cast<const char *>(m_p + offset), len);
} }
// Returns {ext_type, data_view}.
result<std::tuple<int8_t, std::string_view>> get_ext() const { result<std::tuple<int8_t, std::string_view>> get_ext() const {
auto h = header_byte(); auto h = header_byte();
if (!h) return std::unexpected(h.error()); if (!h) return std::unexpected(h.error());
@@ -714,4 +695,79 @@ public:
} }
}; };
template <typename T>
requires std::is_integral_v<T> && (!std::is_same_v<T, bool>)
result<parser> unpack(const parser &p, T &out) {
auto v = p.get_number<T>();
if (!v) return std::unexpected(v.error());
out = *v;
return p.next();
}
inline result<parser> unpack(const parser &p, bool &out) {
auto v = p.get_bool();
if (!v) return std::unexpected(v.error());
out = *v;
return p.next();
}
inline result<parser> unpack(const parser &p, std::string_view &out) {
auto v = p.get_string();
if (!v) return std::unexpected(v.error());
out = *v;
return p.next();
}
inline result<parser> unpack(const parser &p, std::vector<uint8_t> &out) {
auto v = p.get_binary_view();
if (!v) return std::unexpected(v.error());
out.assign(v->begin(), v->end());
return p.next();
}
template <typename... Ts, size_t... Is>
result<parser> unpack_tuple_elements(const parser &p, std::tuple<Ts...> &t, std::index_sequence<Is...>) {
result<parser> cur = p.first_item();
if (!cur) return cur;
((cur = cur ? unpack(*cur, std::get<Is>(t)) : cur), ...);
return cur;
}
template <typename... Ts>
result<parser> unpack(const parser &p, std::tuple<Ts...> &t) {
auto cnt = p.count();
if (!cnt) return std::unexpected(cnt.error());
if (*cnt != sizeof...(Ts)) return std::unexpected(error_code::type_error);
auto r = unpack_tuple_elements(p, t, std::index_sequence_for<Ts...>{});
if (!r) return r;
return p.next();
}
template <typename T>
requires (requires(T &v) { v.as_tuple(); } && !requires { { T::ext_id } -> std::convertible_to<int8_t>; })
result<parser> unpack(const parser &p, T &out) {
auto tup = out.as_tuple();
auto cnt = p.count();
if (!cnt) return std::unexpected(cnt.error());
if (*cnt != std::tuple_size_v<decltype(tup)>) return std::unexpected(error_code::type_error);
auto r = unpack_tuple_elements(p, tup, std::make_index_sequence<std::tuple_size_v<decltype(tup)>>{});
if (!r) return r;
return p.next();
}
template <typename T>
requires requires(T &v) { { T::ext_id } -> std::convertible_to<int8_t>; v.as_tuple(); }
result<parser> unpack(const parser &p, T &out) {
auto ext = p.get_ext();
if (!ext) return std::unexpected(ext.error());
auto [ext_type, ext_data] = *ext;
if (ext_type != T::ext_id) return std::unexpected(error_code::type_error);
parser inner(reinterpret_cast<const uint8_t *>(ext_data.data()),
static_cast<int>(ext_data.size()));
auto tup = out.as_tuple();
auto r = unpack_tuple_elements(inner, tup, std::make_index_sequence<std::tuple_size_v<decltype(tup)>>{});
if (!r) return r;
return p.next();
}
} // namespace msgpackpp } // namespace msgpackpp

33
include/static_vector.h Normal file
View File

@@ -0,0 +1,33 @@
#pragma once
#include <cstddef>
#include <cstdint>
#include <cstring>
template <typename T, size_t Capacity>
class static_vector {
T m_data[Capacity];
size_t m_size = 0;
public:
void push_back(const T &v) {
if (m_size < Capacity) m_data[m_size++] = v;
}
void clear() { m_size = 0; }
size_t size() const { return m_size; }
size_t capacity() const { return Capacity; }
bool full() const { return m_size >= Capacity; }
bool empty() const { return m_size == 0; }
T *data() { return m_data; }
const T *data() const { return m_data; }
T &operator[](size_t i) { return m_data[i]; }
const T &operator[](size_t i) const { return m_data[i]; }
T *begin() { return m_data; }
T *end() { return m_data + m_size; }
const T *begin() const { return m_data; }
const T *end() const { return m_data + m_size; }
};

View File

@@ -33,11 +33,20 @@ func RegisterExt(extID int8, value interface{}) {
return v.Interface().(Marshaler).MarshalMsgpack() return v.Interface().(Marshaler).MarshalMsgpack()
}) })
} else { } else {
encTyp := typ
if encTyp.Kind() == reflect.Ptr {
encTyp = encTyp.Elem()
}
structEncoder := _getEncoder(encTyp)
RegisterExtEncoder(extID, value, func(e *Encoder, v reflect.Value) ([]byte, error) { RegisterExtEncoder(extID, value, func(e *Encoder, v reflect.Value) ([]byte, error) {
var buf bytes.Buffer var buf bytes.Buffer
enc := NewEncoder(&buf) enc := NewEncoder(&buf)
enc.UseArrayEncodedStructs(true) enc.UseArrayEncodedStructs(true)
if err := enc.Encode(v.Interface()); err != nil { val := v
if val.Kind() == reflect.Ptr {
val = val.Elem()
}
if err := structEncoder(enc, val); err != nil {
return nil, err return nil, err
} }
return buf.Bytes(), nil return buf.Bytes(), nil

View File

@@ -25,22 +25,8 @@ func Open(portName string) (serial.Port, error) {
return serial.Open(portName, &serial.Mode{BaudRate: 115200}) return serial.Open(portName, &serial.Mode{BaudRate: 115200})
} }
func SendByte(portName string, b byte) error { // SendAndReceive sends data and reads the response with a timeout.
port, err := Open(portName) func SendAndReceive(portName string, data []byte, timeout time.Duration) ([]byte, error) {
if err != nil {
return fmt.Errorf("opening %s: %w", portName, err)
}
defer port.Close()
_, err = port.Write([]byte{b})
if err != nil {
return fmt.Errorf("writing to %s: %w", portName, err)
}
return nil
}
// SendByteAndRead sends a byte and reads the response with a timeout.
func SendByteAndRead(portName string, b byte, timeout time.Duration) ([]byte, error) {
port, err := Open(portName) port, err := Open(portName)
if err != nil { if err != nil {
return nil, fmt.Errorf("opening %s: %w", portName, err) return nil, fmt.Errorf("opening %s: %w", portName, err)
@@ -49,7 +35,7 @@ func SendByteAndRead(portName string, b byte, timeout time.Duration) ([]byte, er
port.SetReadTimeout(timeout) port.SetReadTimeout(timeout)
_, err = port.Write([]byte{b}) _, err = port.Write(data)
if err != nil { if err != nil {
return nil, fmt.Errorf("writing to %s: %w", portName, err) return nil, fmt.Errorf("writing to %s: %w", portName, err)
} }

View File

@@ -9,7 +9,8 @@ import (
var HashKey = [8]byte{} var HashKey = [8]byte{}
type RebootingBootsel struct{} type ResponseBOOTSEL struct{}
type RequestBOOTSEL struct{}
type Envelope struct { type Envelope struct {
Checksum uint32 Checksum uint32
@@ -18,7 +19,23 @@ type Envelope struct {
func init() { func init() {
msgpack.RegisterExt(0, (*Envelope)(nil)) msgpack.RegisterExt(0, (*Envelope)(nil))
msgpack.RegisterExt(1, (*RebootingBootsel)(nil)) msgpack.RegisterExt(1, (*ResponseBOOTSEL)(nil))
msgpack.RegisterExt(2, (*RequestBOOTSEL)(nil))
}
func EncodeMessage(msg any) ([]byte, error) {
payload, err := msgpack.Marshal(msg)
if err != nil {
return nil, fmt.Errorf("encode inner: %w", err)
}
checksum := halfsiphash.Sum32(payload, HashKey)
env := Envelope{Checksum: checksum, Payload: payload}
data, err := msgpack.Marshal(&env)
if err != nil {
return nil, fmt.Errorf("encode envelope: %w", err)
}
return data, nil
} }
func DecodeMessage(data []byte) (any, error) { func DecodeMessage(data []byte) (any, error) {

View File

@@ -5,12 +5,20 @@
#include "pico/bootrom.h" #include "pico/bootrom.h"
#include "msgpackpp.h" #include "msgpackpp.h"
#include "halfsiphash.h" #include "halfsiphash.h"
#include "static_vector.h"
static constexpr uint8_t hash_key[8] = {}; static constexpr uint8_t hash_key[8] = {};
struct RebootingBootsel { struct ResponseBOOTSEL {
static constexpr int8_t ext_id = 1; static constexpr int8_t ext_id = 1;
auto as_tuple() const { return std::tie(); } auto as_tuple() const { return std::tie(); }
auto as_tuple() { return std::tie(); }
};
struct RequestBOOTSEL {
static constexpr int8_t ext_id = 2;
auto as_tuple() const { return std::tie(); }
auto as_tuple() { return std::tie(); }
}; };
struct Envelope { struct Envelope {
@@ -18,6 +26,7 @@ struct Envelope {
uint32_t checksum; uint32_t checksum;
std::vector<uint8_t> payload; std::vector<uint8_t> payload;
auto as_tuple() const { return std::tie(checksum, payload); } auto as_tuple() const { return std::tie(checksum, payload); }
auto as_tuple() { return std::tie(checksum, payload); }
}; };
static std::vector<uint8_t> pack_envelope(const std::vector<uint8_t> &payload) { static std::vector<uint8_t> pack_envelope(const std::vector<uint8_t> &payload) {
@@ -34,20 +43,56 @@ static void send_bytes(const std::vector<uint8_t> &data) {
stdio_flush(); stdio_flush();
} }
template <typename T>
static void send_message(const T &msg) {
msgpackpp::packer inner;
inner.pack(msg);
auto envelope = pack_envelope(inner.get_payload());
send_bytes(envelope);
}
static int8_t try_decode(const static_vector<uint8_t, 256> &buf) {
msgpackpp::parser p(buf.data(), static_cast<int>(buf.size()));
Envelope env;
if (!msgpackpp::unpack(p, env)) return -1;
uint32_t expected = halfsiphash::hash32(env.payload.data(), env.payload.size(), hash_key);
if (env.checksum != expected) return -1;
msgpackpp::parser inner(env.payload.data(), static_cast<int>(env.payload.size()));
if (!inner.is_ext()) return -1;
auto ext = inner.get_ext();
if (!ext) return -1;
return std::get<0>(*ext);
}
int main() { int main() {
stdio_init_all(); stdio_init_all();
static static_vector<uint8_t, 256> rx_buf;
while (true) { while (true) {
int c = getchar_timeout_us(100000); int c = getchar_timeout_us(100000);
if (c == 'p') { if (c == PICO_ERROR_TIMEOUT) continue;
printf("p");
} else if (c == 'b') { rx_buf.push_back(static_cast<uint8_t>(c));
msgpackpp::packer inner;
inner.pack(RebootingBootsel{}); int8_t msg_type = try_decode(rx_buf);
auto msg = pack_envelope(inner.get_payload()); if (msg_type < 0) {
send_bytes(msg); if (rx_buf.full()) rx_buf.clear();
continue;
}
rx_buf.clear();
switch (msg_type) {
case RequestBOOTSEL::ext_id:
send_message(ResponseBOOTSEL{});
sleep_ms(100); sleep_ms(100);
reset_usb_boot(0, 0); reset_usb_boot(0, 0);
break;
} }
} }
} }