diff --git a/firmware/include/halfsiphash.h b/firmware/include/halfsiphash.h index 5a79b03..998a8a8 100644 --- a/firmware/include/halfsiphash.h +++ b/firmware/include/halfsiphash.h @@ -1,6 +1,7 @@ #pragma once #include #include +#include namespace halfsiphash { @@ -34,7 +35,7 @@ inline void sipround(uint32_t &v0, uint32_t &v1, uint32_t &v2, uint32_t &v3) { } // namespace detail // Compute HalfSipHash-2-4 with an 8-byte key, returning a 32-bit hash. -inline uint32_t hash32(const uint8_t *data, size_t len, const uint8_t key[8]) { +inline uint32_t hash32(std::span data, const uint8_t key[8]) { using namespace detail; uint32_t k0 = load_le32(key); @@ -45,8 +46,8 @@ inline uint32_t hash32(const uint8_t *data, size_t len, const uint8_t key[8]) { uint32_t v2 = UINT32_C(0x6c796765) ^ k0; uint32_t v3 = UINT32_C(0x74656462) ^ k1; - const uint8_t *end = data + len - (len % 4); - for (const uint8_t *p = data; p != end; p += 4) { + const uint8_t *end = data.data() + data.size() - (data.size() % 4); + for (const uint8_t *p = data.data(); p != end; p += 4) { uint32_t m = load_le32(p); v3 ^= m; sipround(v0, v1, v2, v3); @@ -54,8 +55,8 @@ inline uint32_t hash32(const uint8_t *data, size_t len, const uint8_t key[8]) { v0 ^= m; } - uint32_t b = static_cast(len) << 24; - switch (len & 3) { + uint32_t b = static_cast(data.size()) << 24; + switch (data.size() & 3) { case 3: b |= static_cast(end[2]) << 16; [[fallthrough]]; case 2: b |= static_cast(end[1]) << 8; [[fallthrough]]; case 1: b |= static_cast(end[0]); break; diff --git a/firmware/include/msgpack.h b/firmware/include/msgpack.h index 85c0eec..424d360 100644 --- a/firmware/include/msgpack.h +++ b/firmware/include/msgpack.h @@ -236,6 +236,12 @@ public: return *this; } + pack_result pack_uint32_fixed(uint32_t n) { + m_buf.push_back(format::UINT32); + push_big_endian(n); + return *this; + } + pack_result pack_float(float n) { m_buf.push_back(format::FLOAT32); push_big_endian(n); @@ -322,6 +328,19 @@ public: return *this; } + pack_result pack_ext16_header(char type, uint16_t len) { + m_buf.push_back(format::EXT16); + push_big_endian(len); + m_buf.push_back(static_cast(type)); + return *this; + } + + pack_result pack_bin16_header(uint16_t len) { + m_buf.push_back(format::BIN16); + push_big_endian(len); + return *this; + } + template pack_result pack_ext(char type, const Range &r) { auto sz = static_cast(std::distance(std::begin(r), std::end(r))); diff --git a/firmware/include/net.h b/firmware/include/net.h index 3e02741..f05c327 100644 --- a/firmware/include/net.h +++ b/firmware/include/net.h @@ -15,4 +15,4 @@ using net_handler = std::function payload, span_ bool net_init(); const net_state& net_get_state(); void net_set_handler(net_handler handler); -void net_poll(); +void net_poll(std::span tx); diff --git a/firmware/include/span_writer.h b/firmware/include/span_writer.h index c3aec3f..513f384 100644 --- a/firmware/include/span_writer.h +++ b/firmware/include/span_writer.h @@ -2,6 +2,7 @@ #include #include #include +#include class span_writer { uint8_t *m_data; @@ -10,6 +11,7 @@ class span_writer { public: span_writer(uint8_t *data, size_t capacity) : m_data(data), m_capacity(capacity) {} + span_writer(std::span buf) : m_data(buf.data()), m_capacity(buf.size()) {} void push_back(uint8_t v) { if (m_size < m_capacity) m_data[m_size++] = v; @@ -32,4 +34,12 @@ public: uint8_t *end() { return m_data + m_size; } const uint8_t *begin() const { return m_data; } const uint8_t *end() const { return m_data + m_size; } + + span_writer subspan(size_t offset) { + return span_writer(m_data + offset, m_capacity - offset); + } + + span_writer subspan(size_t offset, size_t len) { + return span_writer(m_data + offset, len); + } }; diff --git a/firmware/include/wire.h b/firmware/include/wire.h index eadc300..78313cc 100644 --- a/firmware/include/wire.h +++ b/firmware/include/wire.h @@ -97,27 +97,39 @@ struct DecodedMessage { std::vector payload; }; -inline size_t pack_envelope_into(span_writer &out, uint32_t message_id, const uint8_t *payload, size_t payload_len) { - uint32_t checksum = halfsiphash::hash32(payload, payload_len, hash_key); - uint8_t env_buf[512]; - span_writer env_body(env_buf, sizeof(env_buf)); - msgpack::packer env_p(env_body); - env_p.pack_array(3); - env_p.pack(message_id); - env_p.pack(checksum); - env_p.pack_bin(std::span{payload, payload_len}); - msgpack::packer outer(out); - outer.pack_ext(Envelope::ext_id, env_body); - return out.size(); -} +static constexpr size_t ext16_header_len = 4; +static constexpr size_t array3_header_len = 1; +static constexpr size_t uint32_fixed_len = 5; +static constexpr size_t bin16_header_len = 3; + +static constexpr size_t envelope_hdr_len = + ext16_header_len + array3_header_len + + uint32_fixed_len + uint32_fixed_len + bin16_header_len; + +static constexpr size_t response_prefix_len = envelope_hdr_len + ext16_header_len; template inline size_t encode_response_into(span_writer &out, uint32_t message_id, const T &msg) { - uint8_t inner_buf[256]; - msgpack::packer inner(inner_buf, sizeof(inner_buf)); - inner.pack(msg); - auto &pl = inner.get_payload(); - return pack_envelope_into(out, message_id, pl.data(), pl.size()); + auto body = out.subspan(response_prefix_len); + msgpack::packer body_p(body); + body_p.pack(msg.as_tuple()); + + auto inner_ext = out.subspan(envelope_hdr_len, ext16_header_len); + msgpack::packer(inner_ext).pack_ext16_header(T::ext_id, static_cast(body.size())); + + size_t bin_len = inner_ext.size() + body.size(); + uint32_t checksum = halfsiphash::hash32({inner_ext.data(), bin_len}, hash_key); + + auto env_hdr = out.subspan(0, envelope_hdr_len); + size_t env_body_len = array3_header_len + uint32_fixed_len + uint32_fixed_len + bin16_header_len + bin_len; + msgpack::packer hdr(env_hdr); + hdr.pack_ext16_header(Envelope::ext_id, static_cast(env_body_len)); + hdr.pack_array(3); + hdr.pack_uint32_fixed(message_id); + hdr.pack_uint32_fixed(checksum); + hdr.pack_bin16_header(static_cast(bin_len)); + + return response_prefix_len + body.size(); } inline msgpack::result try_decode(const uint8_t *data, size_t len) { @@ -127,7 +139,7 @@ inline msgpack::result try_decode(const uint8_t *data, size_t le auto r = msgpack::unpack(p, env); if (!r) return std::unexpected(r.error()); - uint32_t expected = halfsiphash::hash32(env.payload.data(), env.payload.size(), hash_key); + uint32_t expected = halfsiphash::hash32(env.payload, hash_key); if (env.checksum != expected) return std::unexpected(msgpack::error_code::invalid); msgpack::parser inner(env.payload.data(), static_cast(env.payload.size())); @@ -154,7 +166,7 @@ inline msgpack::result decode_response(const uint8_t *data, size_t len) { auto r = msgpack::unpack(p, env); if (!r) return std::unexpected(r.error()); - uint32_t expected = halfsiphash::hash32(env.payload.data(), env.payload.size(), hash_key); + uint32_t expected = halfsiphash::hash32(env.payload, hash_key); if (env.checksum != expected) return std::unexpected(msgpack::error_code::invalid); msgpack::parser inner(env.payload.data(), static_cast(env.payload.size())); diff --git a/firmware/lib/dispatch.cpp b/firmware/lib/dispatch.cpp index 4988aee..c3d7a33 100644 --- a/firmware/lib/dispatch.cpp +++ b/firmware/lib/dispatch.cpp @@ -29,7 +29,7 @@ void dispatch_schedule_ms(uint32_t ms, std::function fn) { static usb_cdc usb; static static_vector usb_rx_buf; - static uint8_t tx_buf[1514]; + static std::array tx_buf; net_set_handler([&](std::span payload, span_writer &out) -> size_t { auto msg = try_decode(payload.data(), payload.size()); @@ -45,7 +45,7 @@ void dispatch_schedule_ms(uint32_t ms, std::function fn) { dlog_if_slow("tud_task", 1000, [&]{ tud_task(); }); dlog_if_slow("drain", 1000, [&]{ usb.drain(); }); dlog_if_slow("timers", 1000, [&]{ timers.run(); }); - dlog_if_slow("net_poll", 1000, [&]{ net_poll(); }); + dlog_if_slow("net_poll", 1000, [&]{ net_poll(std::span{tx_buf}); }); while (tud_cdc_available()) { uint8_t byte; @@ -63,16 +63,16 @@ void dispatch_schedule_ms(uint32_t ms, std::function fn) { auto it = handler_map.find(msg->type_id); if (it != handler_map.end()) { - span_writer out(tx_buf, sizeof(tx_buf)); + span_writer out(tx_buf); size_t resp_len = it->second(msg->message_id, msg->payload, out); if (resp_len > 0) { if (resp_len > usb.tx.free()) { - span_writer err_out(tx_buf, sizeof(tx_buf)); + span_writer err_out(tx_buf); size_t err_len = encode_response_into(err_out, msg->message_id, DeviceError{2, "response too large: " + std::to_string(resp_len)}); - usb.send(std::span{tx_buf, err_len}); + usb.send(std::span{tx_buf.data(), err_len}); } else { - usb.send(std::span{tx_buf, resp_len}); + usb.send(std::span{tx_buf.data(), resp_len}); } } } diff --git a/firmware/lib/net.cpp b/firmware/lib/net.cpp index 66d9952..d34b0b6 100644 --- a/firmware/lib/net.cpp +++ b/firmware/lib/net.cpp @@ -139,9 +139,7 @@ static void handle_arp(const uint8_t* frame, size_t len) { send_raw(&reply, sizeof(reply)); } -static uint8_t tx_buf[1514]; - -static void handle_udp(const uint8_t* frame, size_t len) { +static void handle_udp(const uint8_t* frame, size_t len, span_writer &tx) { if (len < sizeof(udp_header)) return; auto& pkt = *reinterpret_cast(frame); @@ -156,14 +154,14 @@ static void handle_udp(const uint8_t* frame, size_t len) { auto* payload = frame + sizeof(udp_header); size_t payload_len = udp_len - 8; - span_writer resp(tx_buf + sizeof(udp_header), sizeof(tx_buf) - sizeof(udp_header)); + auto resp = tx.subspan(sizeof(udp_header)); size_t resp_len = msg_handler(std::span{payload, payload_len}, resp); if (resp_len == 0) return; size_t ip_total = 20 + 8 + resp_len; size_t reply_len = sizeof(eth_header) + ip_total; - auto& rip = *reinterpret_cast(tx_buf); + auto& rip = *reinterpret_cast(tx.data()); rip.eth.dst = pkt.ip.eth.src; rip.eth.src = state.mac; rip.eth.ethertype = ETH_IPV4; @@ -179,16 +177,16 @@ static void handle_udp(const uint8_t* frame, size_t len) { rip.dst = pkt.ip.src; rip.checksum = ip_checksum(rip.ip_start(), 20); - auto& rudp = *reinterpret_cast(tx_buf); + auto& rudp = *reinterpret_cast(tx.data()); rudp.src_port = PICOMAP_PORT; rudp.dst_port = pkt.src_port; rudp.length = __builtin_bswap16(8 + resp_len); rudp.checksum = 0; - send_raw(tx_buf, reply_len); + send_raw(tx.data(), reply_len); } -static void handle_icmp(const uint8_t* frame, size_t len) { +static void handle_icmp(const uint8_t* frame, size_t len, span_writer &tx) { auto& ip = *reinterpret_cast(frame); size_t ip_hdr_len = ip.ip_header_len(); size_t ip_total = ip.ip_total_len(); @@ -202,12 +200,11 @@ static void handle_icmp(const uint8_t* frame, size_t len) { if (icmp_len < sizeof(icmp_echo)) return; if (icmp.type != 8) return; - uint8_t reply_buf[1514]; size_t reply_len = sizeof(eth_header) + ip_total; - if (reply_len > sizeof(reply_buf)) return; + if (reply_len > tx.capacity()) return; - memcpy(reply_buf, frame, reply_len); - auto& rip = *reinterpret_cast(reply_buf); + memcpy(tx.data(), frame, reply_len); + auto& rip = *reinterpret_cast(tx.data()); rip.eth.dst = ip.eth.src; rip.eth.src = state.mac; rip.src = state.ip; @@ -216,30 +213,30 @@ static void handle_icmp(const uint8_t* frame, size_t len) { rip.checksum = 0; rip.checksum = ip_checksum(rip.ip_start(), ip_hdr_len); - auto& ricmp = *reinterpret_cast(reply_buf + sizeof(eth_header) + ip_hdr_len); + auto& ricmp = *reinterpret_cast(tx.data() + sizeof(eth_header) + ip_hdr_len); ricmp.type = 0; ricmp.checksum = 0; ricmp.checksum = ip_checksum(&ricmp, icmp_len); - send_raw(reply_buf, reply_len); + send_raw(tx.data(), reply_len); } -static void handle_ipv4(const uint8_t* frame, size_t len) { +static void handle_ipv4(const uint8_t* frame, size_t len, span_writer &tx) { if (len < sizeof(ipv4_header)) return; auto& ip = *reinterpret_cast(frame); if ((ip.ver_ihl >> 4) != 4) return; switch (ip.protocol) { case 1: - handle_icmp(frame, len); + handle_icmp(frame, len, tx); break; case 17: - handle_udp(frame, len); + handle_udp(frame, len, tx); break; } } -static void process_frame(const uint8_t* frame, size_t len) { +static void process_frame(const uint8_t* frame, size_t len, span_writer &tx) { if (len < sizeof(eth_header)) return; auto& eth = *reinterpret_cast(frame); @@ -250,7 +247,7 @@ static void process_frame(const uint8_t* frame, size_t len) { handle_arp(frame, len); break; case ETH_IPV4: - handle_ipv4(frame, len); + handle_ipv4(frame, len, tx); break; } } @@ -289,13 +286,15 @@ void net_set_handler(net_handler handler) { msg_handler = std::move(handler); } -void net_poll() { +void net_poll(std::span tx) { if (!w6300::irq_pending) return; w6300::irq_pending = false; w6300::clear_interrupt(w6300::ik_int_all); - if (w6300::get_socket_recv_buf(raw_socket) == 0) return; static uint8_t rx_buf[1518]; - auto result = w6300::recv(raw_socket, std::span{rx_buf}); - if (!result) return; - process_frame(rx_buf, *result); + while (w6300::get_socket_recv_buf(raw_socket) > 0) { + auto result = w6300::recv(raw_socket, std::span{rx_buf}); + if (!result) break; + span_writer tx_writer(tx); + process_frame(rx_buf, *result, tx_writer); + } } diff --git a/firmware/test.cpp b/firmware/test.cpp index 4e6005e..9c200ac 100644 --- a/firmware/test.cpp +++ b/firmware/test.cpp @@ -29,7 +29,7 @@ static ResponseTest test_discovery() { ResponseTest resp; resp.pass = true; - uint8_t req_buf[256]; + uint8_t req_buf[1514]; span_writer req_out(req_buf, sizeof(req_buf)); size_t req_len = encode_request_into(req_out, 0, RequestInfo{}); auto send_result = w6300::send(test_socket, std::span{req_buf, req_len});