diff --git a/.claude/skills/load/SKILL.md b/.claude/skills/load/SKILL.md index 9033c8f..05843cd 100644 --- a/.claude/skills/load/SKILL.md +++ b/.claude/skills/load/SKILL.md @@ -7,3 +7,5 @@ user-invocable: true Run `go run ./cmd/load/` from the project root. This builds the firmware, loads it onto the Pico, and reboots. If cmake needs reconfiguring (e.g. after CMakeLists.txt changes), run `cmake -B build` first. + +After modifying the load command itself (cmd/load/, lib/wire/, lib/picoserial/), run it twice: once to load the firmware, once to verify the load process still works end-to-end. diff --git a/cmd/load/main.go b/cmd/load/main.go index 5db160d..17a589d 100644 --- a/cmd/load/main.go +++ b/cmd/load/main.go @@ -40,8 +40,12 @@ func run(buildDir string) error { return err } if dev != "" { - fmt.Printf("Sending 'b' to %s to enter BOOTSEL mode...\n", dev) - resp, err := picoserial.SendByteAndRead(dev, 'b', 2*time.Second) + fmt.Printf("Sending bootsel request to %s...\n", dev) + req, err := wire.EncodeMessage(&wire.RequestBOOTSEL{}) + if err != nil { + return fmt.Errorf("encoding request: %w", err) + } + resp, err := picoserial.SendAndReceive(dev, req, 2*time.Second) if err != nil { return err } @@ -51,7 +55,7 @@ func run(buildDir string) error { fmt.Fprintf(os.Stderr, "warning: failed to decode response: %v\n", err) } else { switch msg.(type) { - case *wire.RebootingBootsel: + case *wire.ResponseBOOTSEL: fmt.Println("Device confirmed reboot into BOOTSEL mode.") default: fmt.Printf("Unexpected response type: %T\n", msg) diff --git a/include/msgpackpp.h b/include/msgpackpp.h index c56109b..6c960f0 100644 --- a/include/msgpackpp.h +++ b/include/msgpackpp.h @@ -21,10 +21,7 @@ enum class error_code { type_error, }; -// MessagePack format byte constants. Ranges are handled via helper functions -// rather than enumerating every value. namespace format { - // Fixed ranges (use is_* helpers below) constexpr uint8_t POSITIVE_FIXINT_MIN = 0x00; constexpr uint8_t POSITIVE_FIXINT_MAX = 0x7F; constexpr uint8_t FIXMAP_MIN = 0x80; @@ -36,7 +33,6 @@ namespace format { constexpr uint8_t NEGATIVE_FIXINT_MIN = 0xE0; constexpr uint8_t NEGATIVE_FIXINT_MAX = 0xFF; - // Specific type bytes constexpr uint8_t NIL = 0xC0; constexpr uint8_t NEVER_USED = 0xC1; constexpr uint8_t FALSE = 0xC2; @@ -80,7 +76,6 @@ namespace format { template using result = std::expected; -// Read a big-endian number from position m_p+1. template result body_number(const uint8_t *p, int size) { if (size < 1 + static_cast(sizeof(T))) { @@ -105,8 +100,6 @@ result body_number(const uint8_t *p, int size) { } } -// Returns {header_bytes, body_bytes} for a given format byte. -// For container types (array/map), body_bytes is not meaningful (returns 0). struct body_info { int header; // bytes before the body (includes format byte + length fields + ext type byte) uint32_t body; // body size in bytes (0 for containers, computed for variable-length) @@ -360,8 +353,6 @@ public: return *this; } - // Generic pack: dispatches based on type. - // Integers template requires std::is_integral_v && (!std::is_same_v) pack_result pack(T n) { return pack_integer(n); } @@ -373,10 +364,8 @@ public: pack_result pack(std::string_view v) { return pack_str(v); } pack_result pack(const std::string &v) { return pack_str(v); } - // Binary (vector) pack_result pack(const std::vector &v) { return pack_bin(v); } - // Tuples → msgpack array template pack_result pack(const std::tuple &t) { auto r = pack_array(sizeof...(Ts)); @@ -384,7 +373,6 @@ public: return pack_tuple_elements(t, std::index_sequence_for{}); } - // Structs with ext_id and as_tuple() → ext wrapping msgpack array template requires requires(const T &v) { { T::ext_id } -> std::convertible_to; v.as_tuple(); } pack_result pack(const T &v) { @@ -394,7 +382,6 @@ public: return pack_ext(T::ext_id, inner.get_payload()); } - // Structs with as_tuple() but no ext_id → plain msgpack array template requires (requires(const T &v) { v.as_tuple(); } && !requires { { T::ext_id } -> std::convertible_to; }) pack_result pack(const T &v) { @@ -440,7 +427,6 @@ public: return parser(m_p + n, m_size - n); } - // Navigate to the next value in a sequence. result next() const { auto hdr = header_byte(); if (!hdr) return std::unexpected(hdr.error()); @@ -467,11 +453,9 @@ public: auto cur = advance(info->header); if (!cur) return std::unexpected(cur.error()); for (uint32_t i = 0; i < *cnt; ++i) { - // key auto k = cur->next(); if (!k) return std::unexpected(k.error()); cur = *k; - // value auto v = cur->next(); if (!v) return std::unexpected(v.error()); cur = *v; @@ -485,7 +469,6 @@ public: } } - // Type checks bool is_nil() const { auto h = header_byte(); return h && *h == format::NIL; @@ -544,7 +527,6 @@ public: return b == format::MAP16 || b == format::MAP32; } - // Value accessors result get_bool() const { auto h = header_byte(); @@ -610,7 +592,6 @@ public: return std::string_view(reinterpret_cast(m_p + offset), len); } - // Returns {ext_type, data_view}. result> get_ext() const { auto h = header_byte(); if (!h) return std::unexpected(h.error()); @@ -714,4 +695,79 @@ public: } }; +template + requires std::is_integral_v && (!std::is_same_v) +result unpack(const parser &p, T &out) { + auto v = p.get_number(); + if (!v) return std::unexpected(v.error()); + out = *v; + return p.next(); +} + +inline result unpack(const parser &p, bool &out) { + auto v = p.get_bool(); + if (!v) return std::unexpected(v.error()); + out = *v; + return p.next(); +} + +inline result unpack(const parser &p, std::string_view &out) { + auto v = p.get_string(); + if (!v) return std::unexpected(v.error()); + out = *v; + return p.next(); +} + +inline result unpack(const parser &p, std::vector &out) { + auto v = p.get_binary_view(); + if (!v) return std::unexpected(v.error()); + out.assign(v->begin(), v->end()); + return p.next(); +} + +template +result unpack_tuple_elements(const parser &p, std::tuple &t, std::index_sequence) { + result cur = p.first_item(); + if (!cur) return cur; + ((cur = cur ? unpack(*cur, std::get(t)) : cur), ...); + return cur; +} + +template +result unpack(const parser &p, std::tuple &t) { + auto cnt = p.count(); + if (!cnt) return std::unexpected(cnt.error()); + if (*cnt != sizeof...(Ts)) return std::unexpected(error_code::type_error); + auto r = unpack_tuple_elements(p, t, std::index_sequence_for{}); + if (!r) return r; + return p.next(); +} + +template + requires (requires(T &v) { v.as_tuple(); } && !requires { { T::ext_id } -> std::convertible_to; }) +result unpack(const parser &p, T &out) { + auto tup = out.as_tuple(); + auto cnt = p.count(); + if (!cnt) return std::unexpected(cnt.error()); + if (*cnt != std::tuple_size_v) return std::unexpected(error_code::type_error); + auto r = unpack_tuple_elements(p, tup, std::make_index_sequence>{}); + if (!r) return r; + return p.next(); +} + +template + requires requires(T &v) { { T::ext_id } -> std::convertible_to; v.as_tuple(); } +result unpack(const parser &p, T &out) { + auto ext = p.get_ext(); + if (!ext) return std::unexpected(ext.error()); + auto [ext_type, ext_data] = *ext; + if (ext_type != T::ext_id) return std::unexpected(error_code::type_error); + parser inner(reinterpret_cast(ext_data.data()), + static_cast(ext_data.size())); + auto tup = out.as_tuple(); + auto r = unpack_tuple_elements(inner, tup, std::make_index_sequence>{}); + if (!r) return r; + return p.next(); +} + } // namespace msgpackpp diff --git a/include/static_vector.h b/include/static_vector.h new file mode 100644 index 0000000..df3934c --- /dev/null +++ b/include/static_vector.h @@ -0,0 +1,33 @@ +#pragma once +#include +#include +#include + +template +class static_vector { + T m_data[Capacity]; + size_t m_size = 0; + +public: + void push_back(const T &v) { + if (m_size < Capacity) m_data[m_size++] = v; + } + + void clear() { m_size = 0; } + + size_t size() const { return m_size; } + size_t capacity() const { return Capacity; } + bool full() const { return m_size >= Capacity; } + bool empty() const { return m_size == 0; } + + T *data() { return m_data; } + const T *data() const { return m_data; } + + T &operator[](size_t i) { return m_data[i]; } + const T &operator[](size_t i) const { return m_data[i]; } + + T *begin() { return m_data; } + T *end() { return m_data + m_size; } + const T *begin() const { return m_data; } + const T *end() const { return m_data + m_size; } +}; diff --git a/lib/msgpack/ext.go b/lib/msgpack/ext.go index 42e0fb5..c5d1824 100644 --- a/lib/msgpack/ext.go +++ b/lib/msgpack/ext.go @@ -33,11 +33,20 @@ func RegisterExt(extID int8, value interface{}) { return v.Interface().(Marshaler).MarshalMsgpack() }) } else { + encTyp := typ + if encTyp.Kind() == reflect.Ptr { + encTyp = encTyp.Elem() + } + structEncoder := _getEncoder(encTyp) RegisterExtEncoder(extID, value, func(e *Encoder, v reflect.Value) ([]byte, error) { var buf bytes.Buffer enc := NewEncoder(&buf) enc.UseArrayEncodedStructs(true) - if err := enc.Encode(v.Interface()); err != nil { + val := v + if val.Kind() == reflect.Ptr { + val = val.Elem() + } + if err := structEncoder(enc, val); err != nil { return nil, err } return buf.Bytes(), nil diff --git a/lib/picoserial/picoserial.go b/lib/picoserial/picoserial.go index 633a7d5..de8ec27 100644 --- a/lib/picoserial/picoserial.go +++ b/lib/picoserial/picoserial.go @@ -25,22 +25,8 @@ func Open(portName string) (serial.Port, error) { return serial.Open(portName, &serial.Mode{BaudRate: 115200}) } -func SendByte(portName string, b byte) error { - port, err := Open(portName) - if err != nil { - return fmt.Errorf("opening %s: %w", portName, err) - } - defer port.Close() - - _, err = port.Write([]byte{b}) - if err != nil { - return fmt.Errorf("writing to %s: %w", portName, err) - } - return nil -} - -// SendByteAndRead sends a byte and reads the response with a timeout. -func SendByteAndRead(portName string, b byte, timeout time.Duration) ([]byte, error) { +// SendAndReceive sends data and reads the response with a timeout. +func SendAndReceive(portName string, data []byte, timeout time.Duration) ([]byte, error) { port, err := Open(portName) if err != nil { return nil, fmt.Errorf("opening %s: %w", portName, err) @@ -49,7 +35,7 @@ func SendByteAndRead(portName string, b byte, timeout time.Duration) ([]byte, er port.SetReadTimeout(timeout) - _, err = port.Write([]byte{b}) + _, err = port.Write(data) if err != nil { return nil, fmt.Errorf("writing to %s: %w", portName, err) } diff --git a/lib/wire/wire.go b/lib/wire/wire.go index fead255..6c36f0f 100644 --- a/lib/wire/wire.go +++ b/lib/wire/wire.go @@ -9,7 +9,8 @@ import ( var HashKey = [8]byte{} -type RebootingBootsel struct{} +type ResponseBOOTSEL struct{} +type RequestBOOTSEL struct{} type Envelope struct { Checksum uint32 @@ -18,7 +19,23 @@ type Envelope struct { func init() { msgpack.RegisterExt(0, (*Envelope)(nil)) - msgpack.RegisterExt(1, (*RebootingBootsel)(nil)) + msgpack.RegisterExt(1, (*ResponseBOOTSEL)(nil)) + msgpack.RegisterExt(2, (*RequestBOOTSEL)(nil)) +} + +func EncodeMessage(msg any) ([]byte, error) { + payload, err := msgpack.Marshal(msg) + if err != nil { + return nil, fmt.Errorf("encode inner: %w", err) + } + + checksum := halfsiphash.Sum32(payload, HashKey) + env := Envelope{Checksum: checksum, Payload: payload} + data, err := msgpack.Marshal(&env) + if err != nil { + return nil, fmt.Errorf("encode envelope: %w", err) + } + return data, nil } func DecodeMessage(data []byte) (any, error) { diff --git a/picomap.cpp b/picomap.cpp index d58dbbf..fe99ee8 100644 --- a/picomap.cpp +++ b/picomap.cpp @@ -5,12 +5,20 @@ #include "pico/bootrom.h" #include "msgpackpp.h" #include "halfsiphash.h" +#include "static_vector.h" static constexpr uint8_t hash_key[8] = {}; -struct RebootingBootsel { +struct ResponseBOOTSEL { static constexpr int8_t ext_id = 1; auto as_tuple() const { return std::tie(); } + auto as_tuple() { return std::tie(); } +}; + +struct RequestBOOTSEL { + static constexpr int8_t ext_id = 2; + auto as_tuple() const { return std::tie(); } + auto as_tuple() { return std::tie(); } }; struct Envelope { @@ -18,6 +26,7 @@ struct Envelope { uint32_t checksum; std::vector payload; auto as_tuple() const { return std::tie(checksum, payload); } + auto as_tuple() { return std::tie(checksum, payload); } }; static std::vector pack_envelope(const std::vector &payload) { @@ -34,20 +43,56 @@ static void send_bytes(const std::vector &data) { stdio_flush(); } +template +static void send_message(const T &msg) { + msgpackpp::packer inner; + inner.pack(msg); + auto envelope = pack_envelope(inner.get_payload()); + send_bytes(envelope); +} + +static int8_t try_decode(const static_vector &buf) { + msgpackpp::parser p(buf.data(), static_cast(buf.size())); + + Envelope env; + if (!msgpackpp::unpack(p, env)) return -1; + + uint32_t expected = halfsiphash::hash32(env.payload.data(), env.payload.size(), hash_key); + if (env.checksum != expected) return -1; + + msgpackpp::parser inner(env.payload.data(), static_cast(env.payload.size())); + if (!inner.is_ext()) return -1; + auto ext = inner.get_ext(); + if (!ext) return -1; + + return std::get<0>(*ext); +} + int main() { stdio_init_all(); + static static_vector rx_buf; + while (true) { int c = getchar_timeout_us(100000); - if (c == 'p') { - printf("p"); - } else if (c == 'b') { - msgpackpp::packer inner; - inner.pack(RebootingBootsel{}); - auto msg = pack_envelope(inner.get_payload()); - send_bytes(msg); + if (c == PICO_ERROR_TIMEOUT) continue; + + rx_buf.push_back(static_cast(c)); + + int8_t msg_type = try_decode(rx_buf); + if (msg_type < 0) { + if (rx_buf.full()) rx_buf.clear(); + continue; + } + + rx_buf.clear(); + + switch (msg_type) { + case RequestBOOTSEL::ext_id: + send_message(ResponseBOOTSEL{}); sleep_ms(100); reset_usb_boot(0, 0); + break; } } }