From 76c519c17af385e1929962546f372d92e8583bc6 Mon Sep 17 00:00:00 2001 From: Ian Gulliver Date: Fri, 10 Apr 2026 23:02:07 +0900 Subject: [PATCH] Overflow detection, span-based signatures, flatten control flow --- firmware/include/dispatch.h | 4 +-- firmware/include/net.h | 3 +- firmware/include/span_writer.h | 9 +++-- firmware/include/wire.h | 7 ++-- firmware/lib/dispatch.cpp | 27 +++++++------- firmware/lib/net.cpp | 64 ++++++++++++++++++---------------- firmware/test.cpp | 9 +++-- 7 files changed, 67 insertions(+), 56 deletions(-) diff --git a/firmware/include/dispatch.h b/firmware/include/dispatch.h index 34a8dc5..9a9c374 100644 --- a/firmware/include/dispatch.h +++ b/firmware/include/dispatch.h @@ -4,7 +4,7 @@ #include #include "wire.h" -using handler_fn = size_t (*)(uint32_t message_id, std::span payload, span_writer &out); +using handler_fn = msgpack::result (*)(uint32_t message_id, std::span payload, span_writer &out); struct handler_entry { int8_t type_id; @@ -12,7 +12,7 @@ struct handler_entry { }; template -size_t typed_handler(uint32_t message_id, std::span payload, span_writer &out) { +msgpack::result typed_handler(uint32_t message_id, std::span payload, span_writer &out) { msgpack::parser p(payload.data(), static_cast(payload.size())); Req req; auto tup = req.as_tuple(); diff --git a/firmware/include/net.h b/firmware/include/net.h index f05c327..b6356e4 100644 --- a/firmware/include/net.h +++ b/firmware/include/net.h @@ -4,13 +4,14 @@ #include #include #include "span_writer.h" +#include "msgpack.h" struct net_state { std::array mac; std::array ip; }; -using net_handler = std::function payload, span_writer &out)>; +using net_handler = std::function(std::span payload, span_writer &out)>; bool net_init(); const net_state& net_get_state(); diff --git a/firmware/include/span_writer.h b/firmware/include/span_writer.h index 513f384..47506d0 100644 --- a/firmware/include/span_writer.h +++ b/firmware/include/span_writer.h @@ -8,6 +8,7 @@ class span_writer { uint8_t *m_data; size_t m_capacity; size_t m_size = 0; + bool m_overflow = false; public: span_writer(uint8_t *data, size_t capacity) : m_data(data), m_capacity(capacity) {} @@ -15,17 +16,21 @@ public: void push_back(uint8_t v) { if (m_size < m_capacity) m_data[m_size++] = v; + else m_overflow = true; } template void insert(uint8_t *, It first, It last) { - while (first != last && m_size < m_capacity) - m_data[m_size++] = *first++; + while (first != last) { + if (m_size < m_capacity) m_data[m_size++] = *first++; + else { m_overflow = true; return; } + } } size_t size() const { return m_size; } size_t capacity() const { return m_capacity; } bool full() const { return m_size >= m_capacity; } + bool overflow() const { return m_overflow; } uint8_t *data() { return m_data; } const uint8_t *data() const { return m_data; } diff --git a/firmware/include/wire.h b/firmware/include/wire.h index 78313cc..d76f230 100644 --- a/firmware/include/wire.h +++ b/firmware/include/wire.h @@ -109,7 +109,7 @@ static constexpr size_t envelope_hdr_len = static constexpr size_t response_prefix_len = envelope_hdr_len + ext16_header_len; template -inline size_t encode_response_into(span_writer &out, uint32_t message_id, const T &msg) { +inline msgpack::result encode_response_into(span_writer &out, uint32_t message_id, const T &msg) { auto body = out.subspan(response_prefix_len); msgpack::packer body_p(body); body_p.pack(msg.as_tuple()); @@ -129,6 +129,8 @@ inline size_t encode_response_into(span_writer &out, uint32_t message_id, const hdr.pack_uint32_fixed(checksum); hdr.pack_bin16_header(static_cast(bin_len)); + if (body.overflow() || inner_ext.overflow() || env_hdr.overflow()) + return std::unexpected(msgpack::error_code::overflow); return response_prefix_len + body.size(); } @@ -176,6 +178,3 @@ inline msgpack::result decode_response(const uint8_t *data, size_t len) { return out; } -inline size_t encode_request_into(span_writer &out, uint32_t message_id, const auto &msg) { - return encode_response_into(out, message_id, msg); -} diff --git a/firmware/lib/dispatch.cpp b/firmware/lib/dispatch.cpp index c3d7a33..5b6ecef 100644 --- a/firmware/lib/dispatch.cpp +++ b/firmware/lib/dispatch.cpp @@ -31,7 +31,7 @@ void dispatch_schedule_ms(uint32_t ms, std::function fn) { static static_vector usb_rx_buf; static std::array tx_buf; - net_set_handler([&](std::span payload, span_writer &out) -> size_t { + net_set_handler([&](std::span payload, span_writer &out) -> msgpack::result { auto msg = try_decode(payload.data(), payload.size()); if (!msg) return 0; auto it = handler_map.find(msg->type_id); @@ -62,20 +62,19 @@ void dispatch_schedule_ms(uint32_t ms, std::function fn) { usb_rx_buf.clear(); auto it = handler_map.find(msg->type_id); - if (it != handler_map.end()) { - span_writer out(tx_buf); - size_t resp_len = it->second(msg->message_id, msg->payload, out); - if (resp_len > 0) { - if (resp_len > usb.tx.free()) { - span_writer err_out(tx_buf); - size_t err_len = encode_response_into(err_out, msg->message_id, - DeviceError{2, "response too large: " + std::to_string(resp_len)}); - usb.send(std::span{tx_buf.data(), err_len}); - } else { - usb.send(std::span{tx_buf.data(), resp_len}); - } - } + if (it == handler_map.end()) continue; + span_writer out(tx_buf); + auto resp = it->second(msg->message_id, msg->payload, out); + if (!resp || *resp == 0) continue; + size_t resp_len = *resp; + if (resp_len <= usb.tx.free()) { + usb.send(std::span{tx_buf.data(), resp_len}); + continue; } + span_writer err_out(tx_buf); + auto err = encode_response_into(err_out, msg->message_id, + DeviceError{2, "response too large: " + std::to_string(resp_len)}); + if (err) usb.send(std::span{tx_buf.data(), *err}); } __wfi(); diff --git a/firmware/lib/net.cpp b/firmware/lib/net.cpp index d34b0b6..4532876 100644 --- a/firmware/lib/net.cpp +++ b/firmware/lib/net.cpp @@ -106,23 +106,25 @@ static bool ip_match_or_broadcast(const ip4_addr& dst) { return ip_match(dst) || dst == IP_BROADCAST_ALL || dst == IP_BROADCAST_SUBNET; } -static void send_raw(const void* data, size_t len) { +static void send_raw(std::span data) { dlog_if_slow("send_raw", 1000, [&]{ - w6300::send(raw_socket, std::span{static_cast(data), len}); + w6300::send(raw_socket, data); }); } -static void handle_arp(const uint8_t* frame, size_t len) { - if (len < sizeof(arp_packet)) return; - auto& pkt = *reinterpret_cast(frame); +static void handle_arp(std::span frame, span_writer &tx) { + if (frame.size() < sizeof(arp_packet)) return; + auto& pkt = *reinterpret_cast(frame.data()); if (pkt.htype != ARP_HTYPE_ETH) return; if (pkt.ptype != ARP_PTYPE_IPV4) return; if (pkt.hlen != 6 || pkt.plen != 4) return; if (pkt.oper != ARP_OP_REQUEST) return; if (!ip_match(pkt.tpa)) return; + if (sizeof(arp_packet) > tx.capacity()) return; - arp_packet reply = {}; + auto& reply = *reinterpret_cast(tx.data()); + reply = {}; reply.eth.dst = pkt.eth.src; reply.eth.src = state.mac; reply.eth.ethertype = ETH_ARP; @@ -136,12 +138,12 @@ static void handle_arp(const uint8_t* frame, size_t len) { reply.tha = pkt.sha; reply.tpa = pkt.spa; - send_raw(&reply, sizeof(reply)); + send_raw({tx.data(), sizeof(arp_packet)}); } -static void handle_udp(const uint8_t* frame, size_t len, span_writer &tx) { - if (len < sizeof(udp_header)) return; - auto& pkt = *reinterpret_cast(frame); +static void handle_udp(std::span frame, span_writer &tx) { + if (frame.size() < sizeof(udp_header)) return; + auto& pkt = *reinterpret_cast(frame.data()); if (pkt.dst_port != PICOMAP_PORT) return; if (!ip_match_or_broadcast(pkt.ip.dst)) return; @@ -151,12 +153,12 @@ static void handle_udp(const uint8_t* frame, size_t len, span_writer &tx) { if (udp_len < 8) return; if (sizeof(eth_header) + pkt.ip.ip_total_len() < sizeof(udp_header) + udp_len - 8) return; - auto* payload = frame + sizeof(udp_header); size_t payload_len = udp_len - 8; auto resp = tx.subspan(sizeof(udp_header)); - size_t resp_len = msg_handler(std::span{payload, payload_len}, resp); - if (resp_len == 0) return; + auto result = msg_handler(frame.subspan(sizeof(udp_header), payload_len), resp); + if (!result || *result == 0) return; + size_t resp_len = *result; size_t ip_total = 20 + 8 + resp_len; size_t reply_len = sizeof(eth_header) + ip_total; @@ -183,19 +185,19 @@ static void handle_udp(const uint8_t* frame, size_t len, span_writer &tx) { rudp.length = __builtin_bswap16(8 + resp_len); rudp.checksum = 0; - send_raw(tx.data(), reply_len); + send_raw({tx.data(), reply_len}); } -static void handle_icmp(const uint8_t* frame, size_t len, span_writer &tx) { - auto& ip = *reinterpret_cast(frame); +static void handle_icmp(std::span frame, span_writer &tx) { + auto& ip = *reinterpret_cast(frame.data()); size_t ip_hdr_len = ip.ip_header_len(); size_t ip_total = ip.ip_total_len(); - if (sizeof(eth_header) + ip_total > len) return; + if (sizeof(eth_header) + ip_total > frame.size()) return; if (ip.protocol != 1) return; if (!ip_match_or_broadcast(ip.dst)) return; - auto& icmp = *reinterpret_cast(frame + sizeof(eth_header) + ip_hdr_len); + auto& icmp = *reinterpret_cast(frame.data() + sizeof(eth_header) + ip_hdr_len); size_t icmp_len = ip_total - ip_hdr_len; if (icmp_len < sizeof(icmp_echo)) return; if (icmp.type != 8) return; @@ -203,7 +205,7 @@ static void handle_icmp(const uint8_t* frame, size_t len, span_writer &tx) { size_t reply_len = sizeof(eth_header) + ip_total; if (reply_len > tx.capacity()) return; - memcpy(tx.data(), frame, reply_len); + memcpy(tx.data(), frame.data(), reply_len); auto& rip = *reinterpret_cast(tx.data()); rip.eth.dst = ip.eth.src; rip.eth.src = state.mac; @@ -218,36 +220,36 @@ static void handle_icmp(const uint8_t* frame, size_t len, span_writer &tx) { ricmp.checksum = 0; ricmp.checksum = ip_checksum(&ricmp, icmp_len); - send_raw(tx.data(), reply_len); + send_raw({tx.data(), reply_len}); } -static void handle_ipv4(const uint8_t* frame, size_t len, span_writer &tx) { - if (len < sizeof(ipv4_header)) return; - auto& ip = *reinterpret_cast(frame); +static void handle_ipv4(std::span frame, span_writer &tx) { + if (frame.size() < sizeof(ipv4_header)) return; + auto& ip = *reinterpret_cast(frame.data()); if ((ip.ver_ihl >> 4) != 4) return; switch (ip.protocol) { case 1: - handle_icmp(frame, len, tx); + handle_icmp(frame, tx); break; case 17: - handle_udp(frame, len, tx); + handle_udp(frame, tx); break; } } -static void process_frame(const uint8_t* frame, size_t len, span_writer &tx) { - if (len < sizeof(eth_header)) return; - auto& eth = *reinterpret_cast(frame); +static void process_frame(std::span frame, span_writer &tx) { + if (frame.size() < sizeof(eth_header)) return; + auto& eth = *reinterpret_cast(frame.data()); if (!mac_match(eth.dst)) return; switch (eth.ethertype) { case ETH_ARP: - handle_arp(frame, len); + handle_arp(frame, tx); break; case ETH_IPV4: - handle_ipv4(frame, len, tx); + handle_ipv4(frame, tx); break; } } @@ -295,6 +297,6 @@ void net_poll(std::span tx) { auto result = w6300::recv(raw_socket, std::span{rx_buf}); if (!result) break; span_writer tx_writer(tx); - process_frame(rx_buf, *result, tx_writer); + process_frame({rx_buf, *result}, tx_writer); } } diff --git a/firmware/test.cpp b/firmware/test.cpp index 9c200ac..a22d2fc 100644 --- a/firmware/test.cpp +++ b/firmware/test.cpp @@ -31,8 +31,13 @@ static ResponseTest test_discovery() { uint8_t req_buf[1514]; span_writer req_out(req_buf, sizeof(req_buf)); - size_t req_len = encode_request_into(req_out, 0, RequestInfo{}); - auto send_result = w6300::send(test_socket, std::span{req_buf, req_len}); + auto req_len = encode_response_into(req_out, 0, RequestInfo{}); + if (!req_len) { + resp.pass = false; + resp.messages.push_back("encode: overflow"); + return resp; + } + auto send_result = w6300::send(test_socket, std::span{req_buf, *req_len}); if (!send_result) { resp.pass = false; resp.messages.push_back("send: error " + std::to_string(static_cast(send_result.error())));