From c35c1de76aa2c6b8bf32d77198ba3ac568093a05 Mon Sep 17 00:00:00 2001 From: Ian Gulliver Date: Sat, 11 Apr 2026 08:15:41 +0900 Subject: [PATCH] Split net stack into eth/arp/ipv4/icmp, deferred handler responses, ping tests --- firmware/CMakeLists.txt | 3 + firmware/include/arp.h | 28 +++ firmware/include/dispatch.h | 25 ++- firmware/include/eth.h | 20 +++ firmware/include/handlers.h | 8 +- firmware/include/icmp.h | 31 ++++ firmware/include/ipv4.h | 49 ++++++ firmware/include/net.h | 16 +- firmware/include/test_handlers.h | 6 +- firmware/lib/arp.cpp | 41 +++++ firmware/lib/dispatch.cpp | 41 ++--- firmware/lib/handlers.cpp | 8 +- firmware/lib/icmp.cpp | 97 +++++++++++ firmware/lib/ipv4.cpp | 46 +++++ firmware/lib/net.cpp | 282 +++++++------------------------ firmware/lib/test_handlers.cpp | 151 ++++++++--------- 16 files changed, 522 insertions(+), 330 deletions(-) create mode 100644 firmware/include/arp.h create mode 100644 firmware/include/eth.h create mode 100644 firmware/include/icmp.h create mode 100644 firmware/include/ipv4.h create mode 100644 firmware/lib/arp.cpp create mode 100644 firmware/lib/icmp.cpp create mode 100644 firmware/lib/ipv4.cpp diff --git a/firmware/CMakeLists.txt b/firmware/CMakeLists.txt index a5f708b..872e5cb 100644 --- a/firmware/CMakeLists.txt +++ b/firmware/CMakeLists.txt @@ -10,8 +10,11 @@ set(CMAKE_CXX_STANDARD 23) pico_sdk_init() set(LIB_SOURCES + lib/arp.cpp lib/dispatch.cpp lib/handlers.cpp + lib/icmp.cpp + lib/ipv4.cpp lib/net.cpp lib/tusb_config.cpp w6300/w6300.cpp diff --git a/firmware/include/arp.h b/firmware/include/arp.h new file mode 100644 index 0000000..6318c07 --- /dev/null +++ b/firmware/include/arp.h @@ -0,0 +1,28 @@ +#pragma once +#include +#include +#include "eth.h" +#include "ipv4.h" +#include "span_writer.h" + +namespace arp { + +struct __attribute__((packed)) packet { + eth::header eth; + uint16_t htype; + uint16_t ptype; + uint8_t hlen; + uint8_t plen; + uint16_t oper; + eth::mac_addr sha; + ipv4::ip4_addr spa; + eth::mac_addr tha; + ipv4::ip4_addr tpa; +}; +static_assert(sizeof(packet) == 42); + +void handle(std::span frame, span_writer& tx, + eth::mac_addr our_mac, ipv4::ip4_addr our_ip, + std::function)> send_raw); + +} // namespace arp diff --git a/firmware/include/dispatch.h b/firmware/include/dispatch.h index 9a9c374..db83a62 100644 --- a/firmware/include/dispatch.h +++ b/firmware/include/dispatch.h @@ -1,10 +1,24 @@ #pragma once #include #include +#include #include #include "wire.h" -using handler_fn = msgpack::result (*)(uint32_t message_id, std::span payload, span_writer &out); +struct responder { + uint32_t message_id; + std::function)> send; + + template + void respond(const T& msg) const { + uint8_t buf[1024]; + span_writer out(buf, sizeof(buf)); + auto r = encode_response_into(out, message_id, msg); + if (r) send({buf, *r}); + } +}; + +using handler_fn = void (*)(responder resp, std::span payload); struct handler_entry { int8_t type_id; @@ -12,16 +26,19 @@ struct handler_entry { }; template -msgpack::result typed_handler(uint32_t message_id, std::span payload, span_writer &out) { +void typed_handler(responder resp, std::span payload) { msgpack::parser p(payload.data(), static_cast(payload.size())); Req req; auto tup = req.as_tuple(); auto r = msgpack::unpack(p, tup); if (!r) { - return encode_response_into(out, message_id, DeviceError{1, "decode request ext_id=" + + resp.respond(DeviceError{1, "decode request ext_id=" + std::to_string(Req::ext_id) + ": msgpack error " + std::to_string(static_cast(r.error()))}); + return; } - return encode_response_into(out, message_id, Fn(req)); + auto result = Fn(resp, req); + if (result) + resp.respond(*result); } void dispatch_init(); diff --git a/firmware/include/eth.h b/firmware/include/eth.h new file mode 100644 index 0000000..75b56be --- /dev/null +++ b/firmware/include/eth.h @@ -0,0 +1,20 @@ +#pragma once +#include +#include + +namespace eth { + +using mac_addr = std::array; + +static constexpr mac_addr MAC_BROADCAST = {0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF}; +static constexpr uint16_t ETH_ARP = __builtin_bswap16(0x0806); +static constexpr uint16_t ETH_IPV4 = __builtin_bswap16(0x0800); + +struct __attribute__((packed)) header { + mac_addr dst; + mac_addr src; + uint16_t ethertype; +}; +static_assert(sizeof(header) == 14); + +} // namespace eth diff --git a/firmware/include/handlers.h b/firmware/include/handlers.h index 63d5b61..7fc69f5 100644 --- a/firmware/include/handlers.h +++ b/firmware/include/handlers.h @@ -1,9 +1,11 @@ #pragma once +#include #include +#include "dispatch.h" #include "wire.h" extern std::string_view firmware_name; -ResponsePICOBOOT handle_picoboot(const RequestPICOBOOT&); -ResponseInfo handle_info(const RequestInfo&); -ResponseLog handle_log(const RequestLog&); +std::optional handle_picoboot(const responder& resp, const RequestPICOBOOT&); +std::optional handle_info(const responder& resp, const RequestInfo&); +std::optional handle_log(const responder& resp, const RequestLog&); diff --git a/firmware/include/icmp.h b/firmware/include/icmp.h new file mode 100644 index 0000000..b93c879 --- /dev/null +++ b/firmware/include/icmp.h @@ -0,0 +1,31 @@ +#pragma once +#include +#include +#include +#include "eth.h" +#include "ipv4.h" +#include "span_writer.h" + +namespace icmp { + +struct __attribute__((packed)) echo { + uint8_t type; + uint8_t code; + uint16_t checksum; + uint16_t id; + uint16_t seq; +}; +static_assert(sizeof(echo) == 8); + +void handle(std::span frame, span_writer& tx, + eth::mac_addr our_mac, ipv4::ip4_addr our_ip, + std::function)> send_raw); + +size_t build_echo_request(std::span buf, + eth::mac_addr src_mac, ipv4::ip4_addr src_ip, + eth::mac_addr dst_mac, ipv4::ip4_addr dst_ip, + uint16_t id, uint16_t seq); + +bool parse_echo_reply(std::span frame, ipv4::ip4_addr& src_ip, uint16_t expected_id); + +} // namespace icmp diff --git a/firmware/include/ipv4.h b/firmware/include/ipv4.h new file mode 100644 index 0000000..4d875a0 --- /dev/null +++ b/firmware/include/ipv4.h @@ -0,0 +1,49 @@ +#pragma once +#include +#include +#include +#include +#include "eth.h" +#include "span_writer.h" + +namespace ipv4 { + +using ip4_addr = std::array; + +struct __attribute__((packed)) header { + eth::header eth; + uint8_t ver_ihl; + uint8_t dscp_ecn; + uint16_t total_len; + uint16_t identification; + uint16_t flags_frag; + uint8_t ttl; + uint8_t protocol; + uint16_t checksum; + ip4_addr src; + ip4_addr dst; + + size_t ip_header_len() const { return (ver_ihl & 0x0F) * 4; } + size_t ip_total_len() const { return __builtin_bswap16(total_len); } + const uint8_t* ip_start() const { return reinterpret_cast(&ver_ihl); } + uint8_t* ip_start() { return reinterpret_cast(&ver_ihl); } +}; +static_assert(sizeof(header) == 34); + +struct __attribute__((packed)) udp_header { + header ip; + uint16_t src_port; + uint16_t dst_port; + uint16_t length; + uint16_t checksum; +}; +static_assert(sizeof(udp_header) == 42); + +uint16_t checksum(const void* data, size_t len); + +void handle(std::span frame, span_writer& tx, + eth::mac_addr our_mac, ip4_addr our_ip, ip4_addr subnet_broadcast, + std::function)> send_raw, + std::function, span_writer&)> handle_udp); + +} // namespace ipv4 diff --git a/firmware/include/net.h b/firmware/include/net.h index b6356e4..0b856fc 100644 --- a/firmware/include/net.h +++ b/firmware/include/net.h @@ -1,19 +1,25 @@ #pragma once -#include -#include +#include #include #include +#include "eth.h" +#include "ipv4.h" #include "span_writer.h" #include "msgpack.h" struct net_state { - std::array mac; - std::array ip; + eth::mac_addr mac; + ipv4::ip4_addr ip; }; -using net_handler = std::function(std::span payload, span_writer &out)>; +using net_handler = std::function payload, + std::function)> send)>; + +using net_frame_callback = std::function frame)>; bool net_init(); const net_state& net_get_state(); void net_set_handler(net_handler handler); +void net_add_frame_callback(net_frame_callback cb); void net_poll(std::span tx); +void net_send_raw(std::span data); diff --git a/firmware/include/test_handlers.h b/firmware/include/test_handlers.h index fdb0f79..87a98fa 100644 --- a/firmware/include/test_handlers.h +++ b/firmware/include/test_handlers.h @@ -1,5 +1,7 @@ #pragma once +#include +#include "dispatch.h" #include "wire.h" -ResponseListTests handle_list_tests(const RequestListTests&); -ResponseTest handle_test(const RequestTest&); +std::optional handle_list_tests(const responder& resp, const RequestListTests&); +std::optional handle_test(const responder& resp, const RequestTest&); diff --git a/firmware/lib/arp.cpp b/firmware/lib/arp.cpp new file mode 100644 index 0000000..953ab85 --- /dev/null +++ b/firmware/lib/arp.cpp @@ -0,0 +1,41 @@ +#include "arp.h" + +namespace arp { + +static constexpr uint16_t ARP_HTYPE_ETH = __builtin_bswap16(1); +static constexpr uint16_t ARP_PTYPE_IPV4 = __builtin_bswap16(0x0800); +static constexpr uint16_t ARP_OP_REQUEST = __builtin_bswap16(1); +static constexpr uint16_t ARP_OP_REPLY = __builtin_bswap16(2); + +void handle(std::span frame, span_writer& tx, + eth::mac_addr our_mac, ipv4::ip4_addr our_ip, + std::function)> send_raw) { + if (frame.size() < sizeof(packet)) return; + auto& pkt = *reinterpret_cast(frame.data()); + + if (pkt.htype != ARP_HTYPE_ETH) return; + if (pkt.ptype != ARP_PTYPE_IPV4) return; + if (pkt.hlen != 6 || pkt.plen != 4) return; + if (pkt.oper != ARP_OP_REQUEST) return; + if (pkt.tpa != our_ip) return; + if (sizeof(packet) > tx.capacity()) return; + + auto& reply = *reinterpret_cast(tx.data()); + reply = {}; + reply.eth.dst = pkt.eth.src; + reply.eth.src = our_mac; + reply.eth.ethertype = eth::ETH_ARP; + reply.htype = ARP_HTYPE_ETH; + reply.ptype = ARP_PTYPE_IPV4; + reply.hlen = 6; + reply.plen = 4; + reply.oper = ARP_OP_REPLY; + reply.sha = our_mac; + reply.spa = our_ip; + reply.tha = pkt.sha; + reply.tpa = pkt.spa; + + send_raw({tx.data(), sizeof(packet)}); +} + +} // namespace arp diff --git a/firmware/lib/dispatch.cpp b/firmware/lib/dispatch.cpp index f9f301a..2efb799 100644 --- a/firmware/lib/dispatch.cpp +++ b/firmware/lib/dispatch.cpp @@ -31,12 +31,18 @@ void dispatch_schedule_ms(uint32_t ms, std::function fn) { static static_vector usb_rx_buf; static std::array tx_buf; - net_set_handler([&](std::span payload, span_writer &out) -> msgpack::result { + auto dispatch_msg = [&](const DecodedMessage& msg, std::function)> send) { + auto it = handler_map.find(msg.type_id); + if (it == handler_map.end()) return; + responder resp{msg.message_id, std::move(send)}; + it->second(resp, msg.payload); + }; + + net_set_handler([&](std::span payload, + std::function)> send) { auto msg = try_decode(payload.data(), payload.size()); - if (!msg) return 0; - auto it = handler_map.find(msg->type_id); - if (it == handler_map.end()) return 0; - return it->second(msg->message_id, msg->payload, out); + if (!msg) return; + dispatch_msg(*msg, std::move(send)); }); while (true) { @@ -59,21 +65,18 @@ void dispatch_schedule_ms(uint32_t ms, std::function fn) { continue; } - auto it = handler_map.find(msg->type_id); - if (it == handler_map.end()) { usb_rx_buf.clear(); continue; } - span_writer out(tx_buf); - auto resp = it->second(msg->message_id, msg->payload, out); + dispatch_msg(*msg, [&](std::span data) { + if (data.size() <= usb.tx.free()) { + usb.send(data); + } else { + uint8_t err_buf[256]; + span_writer err_out(err_buf, sizeof(err_buf)); + auto err = encode_response_into(err_out, msg->message_id, + DeviceError{2, "response too large: " + std::to_string(data.size())}); + if (err) usb.send(std::span{err_buf, *err}); + } + }); usb_rx_buf.clear(); - if (!resp || *resp == 0) continue; - size_t resp_len = *resp; - if (resp_len <= usb.tx.free()) { - usb.send(std::span{tx_buf.data(), resp_len}); - continue; - } - span_writer err_out(tx_buf); - auto err = encode_response_into(err_out, msg->message_id, - DeviceError{2, "response too large: " + std::to_string(resp_len)}); - if (err) usb.send(std::span{tx_buf.data(), *err}); } __wfi(); diff --git a/firmware/lib/handlers.cpp b/firmware/lib/handlers.cpp index f916a8c..af432f8 100644 --- a/firmware/lib/handlers.cpp +++ b/firmware/lib/handlers.cpp @@ -5,12 +5,12 @@ #include "net.h" #include "debug_log.h" -ResponsePICOBOOT handle_picoboot(const RequestPICOBOOT&) { +std::optional handle_picoboot(const responder&, const RequestPICOBOOT&) { dispatch_schedule_ms(100, []{ reset_usb_boot(0, 1); }); - return {}; + return ResponsePICOBOOT{}; } -ResponseInfo handle_info(const RequestInfo&) { +std::optional handle_info(const responder&, const RequestInfo&) { ResponseInfo resp; pico_unique_board_id_t uid; pico_get_unique_board_id(&uid); @@ -22,7 +22,7 @@ ResponseInfo handle_info(const RequestInfo&) { return resp; } -ResponseLog handle_log(const RequestLog&) { +std::optional handle_log(const responder&, const RequestLog&) { ResponseLog resp; for (auto& e : dlog_drain()) resp.entries.push_back(LogEntry{e.timestamp_us, std::move(e.message)}); diff --git a/firmware/lib/icmp.cpp b/firmware/lib/icmp.cpp new file mode 100644 index 0000000..eb9acbf --- /dev/null +++ b/firmware/lib/icmp.cpp @@ -0,0 +1,97 @@ +#include "icmp.h" +#include +#include "ipv4.h" + +namespace icmp { + +void handle(std::span frame, span_writer& tx, + eth::mac_addr our_mac, ipv4::ip4_addr our_ip, + std::function)> send_raw) { + auto& ip = *reinterpret_cast(frame.data()); + size_t ip_hdr_len = ip.ip_header_len(); + size_t ip_total = ip.ip_total_len(); + + if (sizeof(eth::header) + ip_total > frame.size()) return; + if (ip.protocol != 1) return; + + auto& icmp_pkt = *reinterpret_cast(frame.data() + sizeof(eth::header) + ip_hdr_len); + size_t icmp_len = ip_total - ip_hdr_len; + if (icmp_len < sizeof(echo)) return; + if (icmp_pkt.type != 8) return; + + size_t reply_len = sizeof(eth::header) + ip_total; + if (reply_len > tx.capacity()) return; + + memcpy(tx.data(), frame.data(), reply_len); + auto& rip = *reinterpret_cast(tx.data()); + rip.eth.dst = ip.eth.src; + rip.eth.src = our_mac; + rip.src = our_ip; + rip.dst = ip.src; + rip.ttl = 64; + rip.checksum = 0; + rip.checksum = ipv4::checksum(rip.ip_start(), ip_hdr_len); + + auto& ricmp = *reinterpret_cast(tx.data() + sizeof(eth::header) + ip_hdr_len); + ricmp.type = 0; + ricmp.checksum = 0; + ricmp.checksum = ipv4::checksum(&ricmp, icmp_len); + + send_raw({tx.data(), reply_len}); +} + +size_t build_echo_request(std::span buf, + eth::mac_addr src_mac, ipv4::ip4_addr src_ip, + eth::mac_addr dst_mac, ipv4::ip4_addr dst_ip, + uint16_t id, uint16_t seq) { + size_t total = sizeof(ipv4::header) + sizeof(echo); + if (buf.size() < total) return 0; + + memset(buf.data(), 0, total); + + auto& ip = *reinterpret_cast(buf.data()); + ip.eth.dst = dst_mac; + ip.eth.src = src_mac; + ip.eth.ethertype = eth::ETH_IPV4; + ip.ver_ihl = 0x45; + ip.dscp_ecn = 0; + ip.total_len = __builtin_bswap16(20 + sizeof(echo)); + ip.identification = 0; + ip.flags_frag = 0; + ip.ttl = 64; + ip.protocol = 1; + ip.checksum = 0; + ip.src = src_ip; + ip.dst = dst_ip; + ip.checksum = ipv4::checksum(ip.ip_start(), 20); + + auto& icmp_pkt = *reinterpret_cast(buf.data() + sizeof(ipv4::header)); + icmp_pkt.type = 8; + icmp_pkt.code = 0; + icmp_pkt.checksum = 0; + icmp_pkt.id = id; + icmp_pkt.seq = seq; + icmp_pkt.checksum = ipv4::checksum(&icmp_pkt, sizeof(echo)); + + return total; +} + +bool parse_echo_reply(std::span frame, ipv4::ip4_addr& src_ip, uint16_t expected_id) { + if (frame.size() < sizeof(ipv4::header)) return false; + auto& ip = *reinterpret_cast(frame.data()); + if ((ip.ver_ihl >> 4) != 4) return false; + if (ip.eth.ethertype != eth::ETH_IPV4) return false; + if (ip.protocol != 1) return false; + + size_t ip_hdr_len = ip.ip_header_len(); + if (sizeof(eth::header) + ip_hdr_len + sizeof(echo) > frame.size()) return false; + + auto& icmp_pkt = *reinterpret_cast(frame.data() + sizeof(eth::header) + ip_hdr_len); + if (icmp_pkt.type != 0) return false; + if (icmp_pkt.id != expected_id) return false; + + src_ip = ip.src; + return true; +} + +} // namespace icmp diff --git a/firmware/lib/ipv4.cpp b/firmware/lib/ipv4.cpp new file mode 100644 index 0000000..3f0a247 --- /dev/null +++ b/firmware/lib/ipv4.cpp @@ -0,0 +1,46 @@ +#include "ipv4.h" +#include "icmp.h" + +namespace ipv4 { + +static constexpr ip4_addr IP_BROADCAST_ALL = {255, 255, 255, 255}; + +uint16_t checksum(const void* data, size_t len) { + auto p = static_cast(data); + uint32_t sum = 0; + for (size_t i = 0; i < len - 1; i += 2) + sum += (p[i] << 8) | p[i + 1]; + if (len & 1) + sum += p[len - 1] << 8; + while (sum >> 16) + sum = (sum & 0xFFFF) + (sum >> 16); + return __builtin_bswap16(~sum); +} + +static bool ip_match_or_broadcast(const ip4_addr& dst, const ip4_addr& our_ip, const ip4_addr& subnet_broadcast) { + return dst == our_ip || dst == IP_BROADCAST_ALL || dst == subnet_broadcast; +} + +void handle(std::span frame, span_writer& tx, + eth::mac_addr our_mac, ip4_addr our_ip, ip4_addr subnet_broadcast, + std::function)> send_raw, + std::function, span_writer&)> handle_udp) { + if (frame.size() < sizeof(header)) return; + auto& ip = *reinterpret_cast(frame.data()); + if ((ip.ver_ihl >> 4) != 4) return; + + switch (ip.protocol) { + case 1: + if (!ip_match_or_broadcast(ip.dst, our_ip, subnet_broadcast)) + return; + icmp::handle(frame, tx, our_mac, our_ip, send_raw); + break; + case 17: + if (!ip_match_or_broadcast(ip.dst, our_ip, subnet_broadcast)) + return; + handle_udp(frame, tx); + break; + } +} + +} // namespace ipv4 diff --git a/firmware/lib/net.cpp b/firmware/lib/net.cpp index 4532876..f997f42 100644 --- a/firmware/lib/net.cpp +++ b/firmware/lib/net.cpp @@ -1,255 +1,99 @@ #include "net.h" -#include +#include #include "pico/unique_id.h" #include "pico/time.h" +#include "eth.h" +#include "arp.h" +#include "ipv4.h" #include "w6300.h" #include "debug_log.h" -using mac_addr = std::array; -using ip4_addr = std::array; - -struct __attribute__((packed)) eth_header { - mac_addr dst; - mac_addr src; - uint16_t ethertype; -}; -static_assert(sizeof(eth_header) == 14); - -struct __attribute__((packed)) arp_packet { - eth_header eth; - uint16_t htype; - uint16_t ptype; - uint8_t hlen; - uint8_t plen; - uint16_t oper; - mac_addr sha; - ip4_addr spa; - mac_addr tha; - ip4_addr tpa; -}; -static_assert(sizeof(arp_packet) == 42); - -struct __attribute__((packed)) ipv4_header { - eth_header eth; - uint8_t ver_ihl; - uint8_t dscp_ecn; - uint16_t total_len; - uint16_t identification; - uint16_t flags_frag; - uint8_t ttl; - uint8_t protocol; - uint16_t checksum; - ip4_addr src; - ip4_addr dst; - - size_t ip_header_len() const { return (ver_ihl & 0x0F) * 4; } - size_t ip_total_len() const { return __builtin_bswap16(total_len); } - const uint8_t* ip_start() const { return reinterpret_cast(&ver_ihl); } - uint8_t* ip_start() { return reinterpret_cast(&ver_ihl); } -}; -static_assert(sizeof(ipv4_header) == 34); - -struct __attribute__((packed)) udp_header { - ipv4_header ip; - uint16_t src_port; - uint16_t dst_port; - uint16_t length; - uint16_t checksum; -}; -static_assert(sizeof(udp_header) == 42); - -struct __attribute__((packed)) icmp_echo { - uint8_t type; - uint8_t code; - uint16_t checksum; - uint16_t id; - uint16_t seq; -}; -static_assert(sizeof(icmp_echo) == 8); - -static constexpr uint16_t ETH_ARP = __builtin_bswap16(0x0806); -static constexpr uint16_t ETH_IPV4 = __builtin_bswap16(0x0800); -static constexpr uint16_t ARP_HTYPE_ETH = __builtin_bswap16(1); -static constexpr uint16_t ARP_PTYPE_IPV4 = __builtin_bswap16(0x0800); -static constexpr uint16_t ARP_OP_REQUEST = __builtin_bswap16(1); -static constexpr uint16_t ARP_OP_REPLY = __builtin_bswap16(2); +static constexpr ipv4::ip4_addr IP_BROADCAST_SUBNET = {169, 254, 255, 255}; static constexpr uint16_t PICOMAP_PORT = __builtin_bswap16(28781); -static constexpr mac_addr MAC_BROADCAST = {0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF}; -static constexpr ip4_addr IP_BROADCAST_ALL = {255, 255, 255, 255}; -static constexpr ip4_addr IP_BROADCAST_SUBNET = {169, 254, 255, 255}; static net_state state; static w6300::socket_id raw_socket{0}; static net_handler msg_handler; +static std::vector frame_callbacks; -static uint16_t ip_checksum(const void* data, size_t len) { - auto p = static_cast(data); - uint32_t sum = 0; - for (size_t i = 0; i < len - 1; i += 2) - sum += (p[i] << 8) | p[i + 1]; - if (len & 1) - sum += p[len - 1] << 8; - while (sum >> 16) - sum = (sum & 0xFFFF) + (sum >> 16); - return __builtin_bswap16(~sum); -} - -static bool mac_match(const mac_addr& dst) { - return dst == state.mac || dst == MAC_BROADCAST; -} - -static bool ip_match(const ip4_addr& dst) { - return dst == state.ip; -} - -static bool ip_match_or_broadcast(const ip4_addr& dst) { - return ip_match(dst) || dst == IP_BROADCAST_ALL || dst == IP_BROADCAST_SUBNET; -} - -static void send_raw(std::span data) { - dlog_if_slow("send_raw", 1000, [&]{ +void net_send_raw(std::span data) { + dlog_if_slow("net_send_raw", 1000, [&]{ w6300::send(raw_socket, data); }); } -static void handle_arp(std::span frame, span_writer &tx) { - if (frame.size() < sizeof(arp_packet)) return; - auto& pkt = *reinterpret_cast(frame.data()); - - if (pkt.htype != ARP_HTYPE_ETH) return; - if (pkt.ptype != ARP_PTYPE_IPV4) return; - if (pkt.hlen != 6 || pkt.plen != 4) return; - if (pkt.oper != ARP_OP_REQUEST) return; - if (!ip_match(pkt.tpa)) return; - if (sizeof(arp_packet) > tx.capacity()) return; - - auto& reply = *reinterpret_cast(tx.data()); - reply = {}; - reply.eth.dst = pkt.eth.src; - reply.eth.src = state.mac; - reply.eth.ethertype = ETH_ARP; - reply.htype = ARP_HTYPE_ETH; - reply.ptype = ARP_PTYPE_IPV4; - reply.hlen = 6; - reply.plen = 4; - reply.oper = ARP_OP_REPLY; - reply.sha = state.mac; - reply.spa = state.ip; - reply.tha = pkt.sha; - reply.tpa = pkt.spa; - - send_raw({tx.data(), sizeof(arp_packet)}); -} - -static void handle_udp(std::span frame, span_writer &tx) { - if (frame.size() < sizeof(udp_header)) return; - auto& pkt = *reinterpret_cast(frame.data()); +static void handle_udp(std::span frame, span_writer& tx) { + if (frame.size() < sizeof(ipv4::udp_header)) return; + auto& pkt = *reinterpret_cast(frame.data()); if (pkt.dst_port != PICOMAP_PORT) return; - if (!ip_match_or_broadcast(pkt.ip.dst)) return; if (!msg_handler) return; size_t udp_len = __builtin_bswap16(pkt.length); if (udp_len < 8) return; - if (sizeof(eth_header) + pkt.ip.ip_total_len() < sizeof(udp_header) + udp_len - 8) return; + if (sizeof(eth::header) + pkt.ip.ip_total_len() < sizeof(ipv4::udp_header) + udp_len - 8) return; size_t payload_len = udp_len - 8; - auto resp = tx.subspan(sizeof(udp_header)); - auto result = msg_handler(frame.subspan(sizeof(udp_header), payload_len), resp); - if (!result || *result == 0) return; - size_t resp_len = *result; + eth::mac_addr dst_mac = pkt.ip.eth.src; + ipv4::ip4_addr dst_ip = pkt.ip.src; + uint16_t dst_port = pkt.src_port; - size_t ip_total = 20 + 8 + resp_len; - size_t reply_len = sizeof(eth_header) + ip_total; + msg_handler(frame.subspan(sizeof(ipv4::udp_header), payload_len), + [dst_mac, dst_ip, dst_port](std::span resp_data) { + size_t ip_total = 20 + 8 + resp_data.size(); + size_t reply_len = sizeof(eth::header) + ip_total; + uint8_t reply_buf[1514]; + if (reply_len > sizeof(reply_buf)) return; - auto& rip = *reinterpret_cast(tx.data()); - rip.eth.dst = pkt.ip.eth.src; - rip.eth.src = state.mac; - rip.eth.ethertype = ETH_IPV4; - rip.ver_ihl = 0x45; - rip.dscp_ecn = 0; - rip.total_len = __builtin_bswap16(ip_total); - rip.identification = 0; - rip.flags_frag = 0; - rip.ttl = 64; - rip.protocol = 17; - rip.checksum = 0; - rip.src = state.ip; - rip.dst = pkt.ip.src; - rip.checksum = ip_checksum(rip.ip_start(), 20); + auto& rip = *reinterpret_cast(reply_buf); + rip.eth.dst = dst_mac; + rip.eth.src = state.mac; + rip.eth.ethertype = eth::ETH_IPV4; + rip.ver_ihl = 0x45; + rip.dscp_ecn = 0; + rip.total_len = __builtin_bswap16(ip_total); + rip.identification = 0; + rip.flags_frag = 0; + rip.ttl = 64; + rip.protocol = 17; + rip.checksum = 0; + rip.src = state.ip; + rip.dst = dst_ip; + rip.checksum = ipv4::checksum(rip.ip_start(), 20); - auto& rudp = *reinterpret_cast(tx.data()); - rudp.src_port = PICOMAP_PORT; - rudp.dst_port = pkt.src_port; - rudp.length = __builtin_bswap16(8 + resp_len); - rudp.checksum = 0; + auto& rudp = *reinterpret_cast(reply_buf); + rudp.src_port = PICOMAP_PORT; + rudp.dst_port = dst_port; + rudp.length = __builtin_bswap16(8 + resp_data.size()); + rudp.checksum = 0; - send_raw({tx.data(), reply_len}); + memcpy(reply_buf + sizeof(ipv4::udp_header), resp_data.data(), resp_data.size()); + net_send_raw({reply_buf, reply_len}); + }); } -static void handle_icmp(std::span frame, span_writer &tx) { - auto& ip = *reinterpret_cast(frame.data()); - size_t ip_hdr_len = ip.ip_header_len(); - size_t ip_total = ip.ip_total_len(); - - if (sizeof(eth_header) + ip_total > frame.size()) return; - if (ip.protocol != 1) return; - if (!ip_match_or_broadcast(ip.dst)) return; - - auto& icmp = *reinterpret_cast(frame.data() + sizeof(eth_header) + ip_hdr_len); - size_t icmp_len = ip_total - ip_hdr_len; - if (icmp_len < sizeof(icmp_echo)) return; - if (icmp.type != 8) return; - - size_t reply_len = sizeof(eth_header) + ip_total; - if (reply_len > tx.capacity()) return; - - memcpy(tx.data(), frame.data(), reply_len); - auto& rip = *reinterpret_cast(tx.data()); - rip.eth.dst = ip.eth.src; - rip.eth.src = state.mac; - rip.src = state.ip; - rip.dst = ip.src; - rip.ttl = 64; - rip.checksum = 0; - rip.checksum = ip_checksum(rip.ip_start(), ip_hdr_len); - - auto& ricmp = *reinterpret_cast(tx.data() + sizeof(eth_header) + ip_hdr_len); - ricmp.type = 0; - ricmp.checksum = 0; - ricmp.checksum = ip_checksum(&ricmp, icmp_len); - - send_raw({tx.data(), reply_len}); +static bool mac_match(const eth::mac_addr& dst) { + return dst == state.mac || dst == eth::MAC_BROADCAST; } -static void handle_ipv4(std::span frame, span_writer &tx) { - if (frame.size() < sizeof(ipv4_header)) return; - auto& ip = *reinterpret_cast(frame.data()); - if ((ip.ver_ihl >> 4) != 4) return; +static void process_frame(std::span frame, span_writer& tx) { + if (frame.size() < sizeof(eth::header)) return; + auto& eth_hdr = *reinterpret_cast(frame.data()); - switch (ip.protocol) { - case 1: - handle_icmp(frame, tx); + if (!mac_match(eth_hdr.dst)) return; + + auto cbs = std::move(frame_callbacks); + frame_callbacks.clear(); + for (auto& cb : cbs) + cb(frame); + + switch (eth_hdr.ethertype) { + case eth::ETH_ARP: + arp::handle(frame, tx, state.mac, state.ip, net_send_raw); break; - case 17: - handle_udp(frame, tx); - break; - } -} - -static void process_frame(std::span frame, span_writer &tx) { - if (frame.size() < sizeof(eth_header)) return; - auto& eth = *reinterpret_cast(frame.data()); - - if (!mac_match(eth.dst)) return; - - switch (eth.ethertype) { - case ETH_ARP: - handle_arp(frame, tx); - break; - case ETH_IPV4: - handle_ipv4(frame, tx); + case eth::ETH_IPV4: + ipv4::handle(frame, tx, state.mac, state.ip, IP_BROADCAST_SUBNET, net_send_raw, handle_udp); break; } } @@ -288,6 +132,10 @@ void net_set_handler(net_handler handler) { msg_handler = std::move(handler); } +void net_add_frame_callback(net_frame_callback cb) { + frame_callbacks.push_back(std::move(cb)); +} + void net_poll(std::span tx) { if (!w6300::irq_pending) return; w6300::irq_pending = false; diff --git a/firmware/lib/test_handlers.cpp b/firmware/lib/test_handlers.cpp index 1113898..e10295b 100644 --- a/firmware/lib/test_handlers.cpp +++ b/firmware/lib/test_handlers.cpp @@ -1,98 +1,97 @@ #include "test_handlers.h" +#include #include #include "pico/stdlib.h" #include "pico/time.h" -#include "w6300.h" +#include "net.h" +#include "icmp.h" -static w6300::socket_id test_socket{1}; - -static ResponseTest test_discovery() { +static ResponseTest test_discovery(const responder&) { ResponseTest resp; resp.pass = true; - - uint8_t req_buf[1514]; - span_writer req_out(req_buf, sizeof(req_buf)); - auto req_len = encode_response_into(req_out, 0, RequestInfo{}); - if (!req_len) { - resp.pass = false; - resp.messages.push_back("encode: overflow"); - return resp; - } - auto send_result = w6300::send(test_socket, std::span{req_buf, *req_len}); - if (!send_result) { - resp.pass = false; - resp.messages.push_back("send: error " + std::to_string(static_cast(send_result.error()))); - return resp; - } - - uint8_t rx_buf[512]; - - auto deadline = make_timeout_time_ms(5000); - std::expected recv_result = std::unexpected(w6300::sock_error::busy); - while (get_absolute_time() < deadline) { - recv_result = w6300::recv(test_socket, std::span{rx_buf}); - if (recv_result || recv_result.error() != w6300::sock_error::busy) break; - } - - if (!recv_result) { - resp.pass = false; - if (recv_result.error() == w6300::sock_error::busy) { - resp.messages.push_back("recv: timed out after 5s"); - } else { - resp.messages.push_back("recv: error " + std::to_string(static_cast(recv_result.error()))); - } - return resp; - } - - resp.messages.push_back("received " + std::to_string(*recv_result) + " bytes"); - - auto info = decode_response(rx_buf, *recv_result); - if (!info) { - resp.pass = false; - resp.messages.push_back("decode: msgpack error " + std::to_string(static_cast(info.error()))); - return resp; - } - - if (info->firmware_name.empty()) { - resp.pass = false; - resp.messages.push_back("firmware_name is empty"); - } else { - resp.messages.push_back("firmware_name: " + info->firmware_name); - } - - bool mac_zero = true; - for (auto b : info->mac) { if (b != 0) { mac_zero = false; break; } } - if (mac_zero) { - resp.pass = false; - resp.messages.push_back("mac is all zeros"); - } - - bool ip_zero = true; - for (auto b : info->ip) { if (b != 0) { ip_zero = false; break; } } - if (ip_zero) { - resp.pass = false; - resp.messages.push_back("ip is all zeros"); - } - + resp.messages.push_back("TODO: rewrite as deferred test"); return resp; } -using test_fn = ResponseTest (*)(); +static void test_ping(const responder& resp, ipv4::ip4_addr dst_ip) { + auto& ns = net_get_state(); + uint16_t ping_id = 0x1234; -static const std::unordered_map tests = { - {"discovery", test_discovery}, + uint8_t tx_buf[128]; + size_t len = icmp::build_echo_request( + std::span{tx_buf}, ns.mac, ns.ip, + eth::MAC_BROADCAST, dst_ip, ping_id, 1); + if (len == 0) { + resp.respond(ResponseTest{false, {"build_echo_request failed"}}); + return; + } + + net_send_raw(std::span{tx_buf, len}); + + ipv4::ip4_addr our_ip = ns.ip; + + auto done = std::make_shared(false); + auto cb = std::make_shared)>>(); + *cb = [resp, ping_id, our_ip, done, cb](std::span frame) { + if (*done) return; + ipv4::ip4_addr src_ip; + if (!icmp::parse_echo_reply(frame, src_ip, ping_id)) { + net_add_frame_callback(*cb); + return; + } + if (src_ip == our_ip) { + net_add_frame_callback(*cb); + return; + } + *done = true; + std::string ip_str = std::to_string(src_ip[0]) + "." + std::to_string(src_ip[1]) + "." + + std::to_string(src_ip[2]) + "." + std::to_string(src_ip[3]); + resp.respond(ResponseTest{true, {"reply from " + ip_str}}); + }; + net_add_frame_callback(*cb); + + dispatch_schedule_ms(5000, [resp, done]() { + if (*done) return; + *done = true; + resp.respond(ResponseTest{false, {"no reply from non-self host within 5s"}}); + }); +} + +static void test_ping_subnet(const responder& resp) { + test_ping(resp, {169, 254, 255, 255}); +} + +static void test_ping_global(const responder& resp) { + test_ping(resp, {255, 255, 255, 255}); +} + +using sync_test_fn = ResponseTest (*)(const responder&); +using async_test_fn = void (*)(const responder&); + +struct test_entry { + sync_test_fn sync; + async_test_fn async; }; -ResponseListTests handle_list_tests(const RequestListTests&) { +static const std::unordered_map tests = { + {"discovery", {test_discovery, nullptr}}, + {"ping_subnet", {nullptr, test_ping_subnet}}, + {"ping_global", {nullptr, test_ping_global}}, +}; + +std::optional handle_list_tests(const responder&, const RequestListTests&) { ResponseListTests resp; for (const auto& [name, _] : tests) resp.names.emplace_back(name); return resp; } -ResponseTest handle_test(const RequestTest& req) { +std::optional handle_test(const responder& resp, const RequestTest& req) { auto it = tests.find(req.name); if (it == tests.end()) - return {false, {"unknown test: " + req.name}}; - return it->second(); + return ResponseTest{false, {"unknown test: " + req.name}}; + if (it->second.sync) + return it->second.sync(resp); + it->second.async(resp); + return std::nullopt; }