From e2a5d97daef3c599391204a6d92ea5a8939c4c9c Mon Sep 17 00:00:00 2001 From: Ian Gulliver Date: Fri, 10 Apr 2026 22:18:44 +0900 Subject: [PATCH] Zero-copy TX: span_writer packer, static buffers, no vector returns --- firmware/include/dispatch.h | 11 ++-- firmware/include/handlers.h | 7 +-- firmware/include/msgpack.h | 104 ++++++++++++++++----------------- firmware/include/net.h | 4 +- firmware/include/span_writer.h | 35 +++++++++++ firmware/include/wire.h | 32 ++++++---- firmware/lib/dispatch.cpp | 26 +++++---- firmware/lib/handlers.cpp | 12 ++-- firmware/lib/net.cpp | 59 +++++++++---------- firmware/test.cpp | 16 ++--- 10 files changed, 173 insertions(+), 133 deletions(-) create mode 100644 firmware/include/span_writer.h diff --git a/firmware/include/dispatch.h b/firmware/include/dispatch.h index 099f93a..e185ca9 100644 --- a/firmware/include/dispatch.h +++ b/firmware/include/dispatch.h @@ -2,10 +2,9 @@ #include #include #include -#include #include "wire.h" -using handler_fn = std::vector> (*)(uint32_t message_id, std::span payload); +using handler_fn = size_t (*)(uint32_t message_id, std::span payload, span_writer &out); struct handler_entry { int8_t type_id; @@ -13,16 +12,16 @@ struct handler_entry { }; template -std::vector> typed_handler(uint32_t message_id, std::span payload) { +size_t typed_handler(uint32_t message_id, std::span payload, span_writer &out) { msgpack::parser p(payload.data(), static_cast(payload.size())); Req req; auto tup = req.as_tuple(); auto r = msgpack::unpack(p, tup); if (!r) { - return {encode_response(message_id, DeviceError{1, "decode request ext_id=" + - std::to_string(Req::ext_id) + ": msgpack error " + std::to_string(static_cast(r.error()))})}; + return encode_response_into(out, message_id, DeviceError{1, "decode request ext_id=" + + std::to_string(Req::ext_id) + ": msgpack error " + std::to_string(static_cast(r.error()))}); } - return Fn(message_id, req); + return Fn(message_id, req, out); } void dispatch_init(); diff --git a/firmware/include/handlers.h b/firmware/include/handlers.h index 328a03f..f18368f 100644 --- a/firmware/include/handlers.h +++ b/firmware/include/handlers.h @@ -2,11 +2,10 @@ #include #include #include -#include #include "wire.h" extern std::string_view firmware_name; -std::vector> handle_picoboot(uint32_t message_id, std::span payload); -std::vector> handle_info(uint32_t message_id, std::span payload); -std::vector> handle_log(uint32_t message_id, std::span payload); +size_t handle_picoboot(uint32_t message_id, std::span payload, span_writer &out); +size_t handle_info(uint32_t message_id, std::span payload, span_writer &out); +size_t handle_log(uint32_t message_id, std::span payload, span_writer &out); diff --git a/firmware/include/msgpack.h b/firmware/include/msgpack.h index c217b9b..85c0eec 100644 --- a/firmware/include/msgpack.h +++ b/firmware/include/msgpack.h @@ -6,11 +6,11 @@ #include #include #include -#include #include #include #include #include +#include "span_writer.h" namespace msgpack { @@ -163,26 +163,23 @@ inline result get_body_info(const uint8_t *p, int size) { } class packer { -public: - using buffer = std::vector; - private: - std::shared_ptr m_buffer; + span_writer m_buf; template void push_big_endian(T n) { auto p = reinterpret_cast(&n) + (sizeof(T) - 1); for (size_t i = 0; i < sizeof(T); ++i, --p) { - m_buffer->push_back(*p); + m_buf.push_back(*p); } } template void push(const Range &r) { - m_buffer->insert(m_buffer->end(), std::begin(r), std::end(r)); + m_buf.insert(m_buf.end(), std::begin(r), std::end(r)); } public: - packer() : m_buffer(std::make_shared()) {} - packer(const std::shared_ptr &buf) : m_buffer(buf) {} + packer(uint8_t *data, size_t capacity) : m_buf(data, capacity) {} + packer(span_writer buf) : m_buf(buf) {} packer(const packer &) = delete; packer &operator=(const packer &) = delete; @@ -190,12 +187,12 @@ public: using pack_result = result>; pack_result pack_nil() { - m_buffer->push_back(format::NIL); + m_buf.push_back(format::NIL); return *this; } pack_result pack_bool(bool v) { - m_buffer->push_back(v ? format::TRUE : format::FALSE); + m_buf.push_back(v ? format::TRUE : format::FALSE); return *this; } @@ -203,36 +200,36 @@ public: pack_result pack_integer(T n) { if constexpr (std::is_signed_v) { if (n >= 0 && n <= 0x7F) { - m_buffer->push_back(static_cast(n)); + m_buf.push_back(static_cast(n)); } else if (n >= -32 && n < 0) { - m_buffer->push_back(static_cast(n)); // negative fixint + m_buf.push_back(static_cast(n)); // negative fixint } else if (n >= std::numeric_limits::min() && n <= std::numeric_limits::max()) { - m_buffer->push_back(format::INT8); - m_buffer->push_back(static_cast(n)); + m_buf.push_back(format::INT8); + m_buf.push_back(static_cast(n)); } else if (n >= std::numeric_limits::min() && n <= std::numeric_limits::max()) { - m_buffer->push_back(format::INT16); + m_buf.push_back(format::INT16); push_big_endian(static_cast(n)); } else if (n >= std::numeric_limits::min() && n <= std::numeric_limits::max()) { - m_buffer->push_back(format::INT32); + m_buf.push_back(format::INT32); push_big_endian(static_cast(n)); } else { - m_buffer->push_back(format::INT64); + m_buf.push_back(format::INT64); push_big_endian(static_cast(n)); } } else { if (n <= 0x7F) { - m_buffer->push_back(static_cast(n)); + m_buf.push_back(static_cast(n)); } else if (n <= std::numeric_limits::max()) { - m_buffer->push_back(format::UINT8); - m_buffer->push_back(static_cast(n)); + m_buf.push_back(format::UINT8); + m_buf.push_back(static_cast(n)); } else if (n <= std::numeric_limits::max()) { - m_buffer->push_back(format::UINT16); + m_buf.push_back(format::UINT16); push_big_endian(static_cast(n)); } else if (n <= std::numeric_limits::max()) { - m_buffer->push_back(format::UINT32); + m_buf.push_back(format::UINT32); push_big_endian(static_cast(n)); } else { - m_buffer->push_back(format::UINT64); + m_buf.push_back(format::UINT64); push_big_endian(static_cast(n)); } } @@ -240,13 +237,13 @@ public: } pack_result pack_float(float n) { - m_buffer->push_back(format::FLOAT32); + m_buf.push_back(format::FLOAT32); push_big_endian(n); return *this; } pack_result pack_double(double n) { - m_buffer->push_back(format::FLOAT64); + m_buf.push_back(format::FLOAT64); push_big_endian(n); return *this; } @@ -255,15 +252,15 @@ public: pack_result pack_str(const Range &r) { auto sz = static_cast(std::distance(std::begin(r), std::end(r))); if (sz < 32) { - m_buffer->push_back(format::FIXSTR_MIN | static_cast(sz)); + m_buf.push_back(format::FIXSTR_MIN | static_cast(sz)); } else if (sz <= std::numeric_limits::max()) { - m_buffer->push_back(format::STR8); - m_buffer->push_back(static_cast(sz)); + m_buf.push_back(format::STR8); + m_buf.push_back(static_cast(sz)); } else if (sz <= std::numeric_limits::max()) { - m_buffer->push_back(format::STR16); + m_buf.push_back(format::STR16); push_big_endian(static_cast(sz)); } else if (sz <= std::numeric_limits::max()) { - m_buffer->push_back(format::STR32); + m_buf.push_back(format::STR32); push_big_endian(static_cast(sz)); } else { return std::unexpected(error_code::overflow); @@ -280,13 +277,13 @@ public: pack_result pack_bin(const Range &r) { auto sz = static_cast(std::distance(std::begin(r), std::end(r))); if (sz <= std::numeric_limits::max()) { - m_buffer->push_back(format::BIN8); - m_buffer->push_back(static_cast(sz)); + m_buf.push_back(format::BIN8); + m_buf.push_back(static_cast(sz)); } else if (sz <= std::numeric_limits::max()) { - m_buffer->push_back(format::BIN16); + m_buf.push_back(format::BIN16); push_big_endian(static_cast(sz)); } else if (sz <= std::numeric_limits::max()) { - m_buffer->push_back(format::BIN32); + m_buf.push_back(format::BIN32); push_big_endian(static_cast(sz)); } else { return std::unexpected(error_code::overflow); @@ -297,12 +294,12 @@ public: pack_result pack_array(size_t n) { if (n <= 15) { - m_buffer->push_back(format::FIXARRAY_MIN | static_cast(n)); + m_buf.push_back(format::FIXARRAY_MIN | static_cast(n)); } else if (n <= std::numeric_limits::max()) { - m_buffer->push_back(format::ARRAY16); + m_buf.push_back(format::ARRAY16); push_big_endian(static_cast(n)); } else if (n <= std::numeric_limits::max()) { - m_buffer->push_back(format::ARRAY32); + m_buf.push_back(format::ARRAY32); push_big_endian(static_cast(n)); } else { return std::unexpected(error_code::overflow); @@ -312,12 +309,12 @@ public: pack_result pack_map(size_t n) { if (n <= 15) { - m_buffer->push_back(format::FIXMAP_MIN | static_cast(n)); + m_buf.push_back(format::FIXMAP_MIN | static_cast(n)); } else if (n <= std::numeric_limits::max()) { - m_buffer->push_back(format::MAP16); + m_buf.push_back(format::MAP16); push_big_endian(static_cast(n)); } else if (n <= std::numeric_limits::max()) { - m_buffer->push_back(format::MAP32); + m_buf.push_back(format::MAP32); push_big_endian(static_cast(n)); } else { return std::unexpected(error_code::overflow); @@ -330,26 +327,26 @@ public: auto sz = static_cast(std::distance(std::begin(r), std::end(r))); switch (sz) { - case 1: m_buffer->push_back(format::FIXEXT1); break; - case 2: m_buffer->push_back(format::FIXEXT2); break; - case 4: m_buffer->push_back(format::FIXEXT4); break; - case 8: m_buffer->push_back(format::FIXEXT8); break; - case 16: m_buffer->push_back(format::FIXEXT16); break; + case 1: m_buf.push_back(format::FIXEXT1); break; + case 2: m_buf.push_back(format::FIXEXT2); break; + case 4: m_buf.push_back(format::FIXEXT4); break; + case 8: m_buf.push_back(format::FIXEXT8); break; + case 16: m_buf.push_back(format::FIXEXT16); break; default: if (sz <= std::numeric_limits::max()) { - m_buffer->push_back(format::EXT8); - m_buffer->push_back(static_cast(sz)); + m_buf.push_back(format::EXT8); + m_buf.push_back(static_cast(sz)); } else if (sz <= std::numeric_limits::max()) { - m_buffer->push_back(format::EXT16); + m_buf.push_back(format::EXT16); push_big_endian(static_cast(sz)); } else if (sz <= std::numeric_limits::max()) { - m_buffer->push_back(format::EXT32); + m_buf.push_back(format::EXT32); push_big_endian(static_cast(sz)); } else { return std::unexpected(error_code::overflow); } } - m_buffer->push_back(static_cast(type)); + m_buf.push_back(static_cast(type)); push(r); return *this; } @@ -392,7 +389,8 @@ public: template requires requires(const T &v) { { T::ext_id } -> std::convertible_to; v.as_tuple(); } pack_result pack(const T &v) { - packer inner; + uint8_t ext_buf[256]; + packer inner(ext_buf, sizeof(ext_buf)); auto r = inner.pack(v.as_tuple()); if (!r) return r; return pack_ext(T::ext_id, inner.get_payload()); @@ -413,7 +411,7 @@ private: } public: - const buffer &get_payload() const { return *m_buffer; } + const span_writer &get_payload() const { return m_buf; } }; class parser { diff --git a/firmware/include/net.h b/firmware/include/net.h index 15551f0..3e02741 100644 --- a/firmware/include/net.h +++ b/firmware/include/net.h @@ -3,14 +3,14 @@ #include #include #include -#include +#include "span_writer.h" struct net_state { std::array mac; std::array ip; }; -using net_handler = std::function>(std::span payload)>; +using net_handler = std::function payload, span_writer &out)>; bool net_init(); const net_state& net_get_state(); diff --git a/firmware/include/span_writer.h b/firmware/include/span_writer.h new file mode 100644 index 0000000..c3aec3f --- /dev/null +++ b/firmware/include/span_writer.h @@ -0,0 +1,35 @@ +#pragma once +#include +#include +#include + +class span_writer { + uint8_t *m_data; + size_t m_capacity; + size_t m_size = 0; + +public: + span_writer(uint8_t *data, size_t capacity) : m_data(data), m_capacity(capacity) {} + + void push_back(uint8_t v) { + if (m_size < m_capacity) m_data[m_size++] = v; + } + + template + void insert(uint8_t *, It first, It last) { + while (first != last && m_size < m_capacity) + m_data[m_size++] = *first++; + } + + size_t size() const { return m_size; } + size_t capacity() const { return m_capacity; } + bool full() const { return m_size >= m_capacity; } + + uint8_t *data() { return m_data; } + const uint8_t *data() const { return m_data; } + + uint8_t *begin() { return m_data; } + uint8_t *end() { return m_data + m_size; } + const uint8_t *begin() const { return m_data; } + const uint8_t *end() const { return m_data + m_size; } +}; diff --git a/firmware/include/wire.h b/firmware/include/wire.h index 1b1b85a..eadc300 100644 --- a/firmware/include/wire.h +++ b/firmware/include/wire.h @@ -1,6 +1,7 @@ #pragma once #include #include +#include #include #include #include @@ -96,18 +97,27 @@ struct DecodedMessage { std::vector payload; }; -inline std::vector pack_envelope(uint32_t message_id, const std::vector &payload) { - uint32_t checksum = halfsiphash::hash32(payload.data(), payload.size(), hash_key); - msgpack::packer p; - p.pack(Envelope{message_id, checksum, payload}); - return p.get_payload(); +inline size_t pack_envelope_into(span_writer &out, uint32_t message_id, const uint8_t *payload, size_t payload_len) { + uint32_t checksum = halfsiphash::hash32(payload, payload_len, hash_key); + uint8_t env_buf[512]; + span_writer env_body(env_buf, sizeof(env_buf)); + msgpack::packer env_p(env_body); + env_p.pack_array(3); + env_p.pack(message_id); + env_p.pack(checksum); + env_p.pack_bin(std::span{payload, payload_len}); + msgpack::packer outer(out); + outer.pack_ext(Envelope::ext_id, env_body); + return out.size(); } template -inline std::vector encode_response(uint32_t message_id, const T &msg) { - msgpack::packer inner; +inline size_t encode_response_into(span_writer &out, uint32_t message_id, const T &msg) { + uint8_t inner_buf[256]; + msgpack::packer inner(inner_buf, sizeof(inner_buf)); inner.pack(msg); - return pack_envelope(message_id, inner.get_payload()); + auto &pl = inner.get_payload(); + return pack_envelope_into(out, message_id, pl.data(), pl.size()); } inline msgpack::result try_decode(const uint8_t *data, size_t len) { @@ -154,8 +164,6 @@ inline msgpack::result decode_response(const uint8_t *data, size_t len) { return out; } -inline std::vector encode_request(uint32_t message_id, const auto &msg) { - msgpack::packer inner; - inner.pack(msg); - return pack_envelope(message_id, inner.get_payload()); +inline size_t encode_request_into(span_writer &out, uint32_t message_id, const auto &msg) { + return encode_response_into(out, message_id, msg); } diff --git a/firmware/lib/dispatch.cpp b/firmware/lib/dispatch.cpp index b9fbc51..4988aee 100644 --- a/firmware/lib/dispatch.cpp +++ b/firmware/lib/dispatch.cpp @@ -22,20 +22,21 @@ void dispatch_schedule_ms(uint32_t ms, std::function fn) { } [[noreturn]] void dispatch_run(std::span handlers) { - std::unordered_map> (*)(uint32_t, std::span)> handler_map; + std::unordered_map handler_map; for (auto& entry : handlers) { handler_map[entry.type_id] = entry.handle; } static usb_cdc usb; static static_vector usb_rx_buf; + static uint8_t tx_buf[1514]; - net_set_handler([&](std::span payload) -> std::vector> { + net_set_handler([&](std::span payload, span_writer &out) -> size_t { auto msg = try_decode(payload.data(), payload.size()); - if (!msg) return {}; + if (!msg) return 0; auto it = handler_map.find(msg->type_id); - if (it == handler_map.end()) return {}; - return it->second(msg->message_id, msg->payload); + if (it == handler_map.end()) return 0; + return it->second(msg->message_id, msg->payload, out); }); while (true) { @@ -62,13 +63,16 @@ void dispatch_schedule_ms(uint32_t ms, std::function fn) { auto it = handler_map.find(msg->type_id); if (it != handler_map.end()) { - for (auto& response : it->second(msg->message_id, msg->payload)) { - if (response.size() > usb.tx.free()) { - auto err = encode_response(msg->message_id, - DeviceError{2, "response too large: " + std::to_string(response.size())}); - usb.send(err); + span_writer out(tx_buf, sizeof(tx_buf)); + size_t resp_len = it->second(msg->message_id, msg->payload, out); + if (resp_len > 0) { + if (resp_len > usb.tx.free()) { + span_writer err_out(tx_buf, sizeof(tx_buf)); + size_t err_len = encode_response_into(err_out, msg->message_id, + DeviceError{2, "response too large: " + std::to_string(resp_len)}); + usb.send(std::span{tx_buf, err_len}); } else { - usb.send(response); + usb.send(std::span{tx_buf, resp_len}); } } } diff --git a/firmware/lib/handlers.cpp b/firmware/lib/handlers.cpp index 2b792f4..707239a 100644 --- a/firmware/lib/handlers.cpp +++ b/firmware/lib/handlers.cpp @@ -5,12 +5,12 @@ #include "net.h" #include "debug_log.h" -std::vector> handle_picoboot(uint32_t message_id, std::span) { +size_t handle_picoboot(uint32_t message_id, std::span, span_writer &out) { dispatch_schedule_ms(100, []{ reset_usb_boot(0, 1); }); - return {encode_response(message_id, ResponsePICOBOOT{})}; + return encode_response_into(out, message_id, ResponsePICOBOOT{}); } -std::vector> handle_info(uint32_t message_id, std::span) { +size_t handle_info(uint32_t message_id, std::span, span_writer &out) { ResponseInfo resp; pico_unique_board_id_t uid; pico_get_unique_board_id(&uid); @@ -19,12 +19,12 @@ std::vector> handle_info(uint32_t message_id, std::span> handle_log(uint32_t message_id, std::span) { +size_t handle_log(uint32_t message_id, std::span, span_writer &out) { ResponseLog resp; for (auto& e : dlog_drain()) resp.entries.push_back(LogEntry{e.timestamp_us, std::move(e.message)}); - return {encode_response(message_id, resp)}; + return encode_response_into(out, message_id, resp); } diff --git a/firmware/lib/net.cpp b/firmware/lib/net.cpp index 65fca8a..66d9952 100644 --- a/firmware/lib/net.cpp +++ b/firmware/lib/net.cpp @@ -139,6 +139,8 @@ static void handle_arp(const uint8_t* frame, size_t len) { send_raw(&reply, sizeof(reply)); } +static uint8_t tx_buf[1514]; + static void handle_udp(const uint8_t* frame, size_t len) { if (len < sizeof(udp_header)) return; auto& pkt = *reinterpret_cast(frame); @@ -154,41 +156,36 @@ static void handle_udp(const uint8_t* frame, size_t len) { auto* payload = frame + sizeof(udp_header); size_t payload_len = udp_len - 8; - auto responses = msg_handler(std::span{payload, payload_len}); + span_writer resp(tx_buf + sizeof(udp_header), sizeof(tx_buf) - sizeof(udp_header)); + size_t resp_len = msg_handler(std::span{payload, payload_len}, resp); + if (resp_len == 0) return; - for (auto& resp : responses) { - uint8_t reply_buf[1514]; - size_t udp_data_len = resp.size(); - size_t ip_total = 20 + 8 + udp_data_len; - size_t reply_len = sizeof(eth_header) + ip_total; - if (reply_len > sizeof(reply_buf)) continue; + size_t ip_total = 20 + 8 + resp_len; + size_t reply_len = sizeof(eth_header) + ip_total; - auto& rip = *reinterpret_cast(reply_buf); - rip.eth.dst = pkt.ip.eth.src; - rip.eth.src = state.mac; - rip.eth.ethertype = ETH_IPV4; - rip.ver_ihl = 0x45; - rip.dscp_ecn = 0; - rip.total_len = __builtin_bswap16(ip_total); - rip.identification = 0; - rip.flags_frag = 0; - rip.ttl = 64; - rip.protocol = 17; - rip.checksum = 0; - rip.src = state.ip; - rip.dst = pkt.ip.src; - rip.checksum = ip_checksum(rip.ip_start(), 20); + auto& rip = *reinterpret_cast(tx_buf); + rip.eth.dst = pkt.ip.eth.src; + rip.eth.src = state.mac; + rip.eth.ethertype = ETH_IPV4; + rip.ver_ihl = 0x45; + rip.dscp_ecn = 0; + rip.total_len = __builtin_bswap16(ip_total); + rip.identification = 0; + rip.flags_frag = 0; + rip.ttl = 64; + rip.protocol = 17; + rip.checksum = 0; + rip.src = state.ip; + rip.dst = pkt.ip.src; + rip.checksum = ip_checksum(rip.ip_start(), 20); - auto& rudp = *reinterpret_cast(reply_buf); - rudp.src_port = PICOMAP_PORT; - rudp.dst_port = pkt.src_port; - rudp.length = __builtin_bswap16(8 + udp_data_len); - rudp.checksum = 0; + auto& rudp = *reinterpret_cast(tx_buf); + rudp.src_port = PICOMAP_PORT; + rudp.dst_port = pkt.src_port; + rudp.length = __builtin_bswap16(8 + resp_len); + rudp.checksum = 0; - memcpy(reply_buf + sizeof(udp_header), resp.data(), udp_data_len); - - send_raw(reply_buf, reply_len); - } + send_raw(tx_buf, reply_len); } static void handle_icmp(const uint8_t* frame, size_t len) { diff --git a/firmware/test.cpp b/firmware/test.cpp index b95fee2..f698eab 100644 --- a/firmware/test.cpp +++ b/firmware/test.cpp @@ -29,8 +29,10 @@ static ResponseTest test_discovery() { ResponseTest resp; resp.pass = true; - auto req = encode_request(0, RequestInfo{}); - auto send_result = w6300::send(test_socket, std::span{req}); + uint8_t req_buf[256]; + span_writer req_out(req_buf, sizeof(req_buf)); + size_t req_len = encode_request_into(req_out, 0, RequestInfo{}); + auto send_result = w6300::send(test_socket, std::span{req_buf, req_len}); if (!send_result) { resp.pass = false; resp.messages.push_back("send: error " + std::to_string(static_cast(send_result.error()))); @@ -95,13 +97,11 @@ static const std::unordered_map tests = { {"discovery", test_discovery}, }; -static std::vector> handle_test(uint32_t message_id, const RequestTest& req) { +static size_t handle_test(uint32_t message_id, const RequestTest& req, span_writer &out) { auto it = tests.find(req.name); - if (it == tests.end()) { - return {encode_response(message_id, ResponseTest{false, {"unknown test: " + req.name}})}; - } - - return {encode_response(message_id, it->second())}; + if (it == tests.end()) + return encode_response_into(out, message_id, ResponseTest{false, {"unknown test: " + req.name}}); + return encode_response_into(out, message_id, it->second()); } static constexpr handler_entry handlers[] = {