From f161dda60a69eceb3c412973543005fd51c176b3 Mon Sep 17 00:00:00 2001 From: Ian Gulliver Date: Tue, 7 Apr 2026 12:21:41 +0900 Subject: [PATCH] Rewrite net.cpp to use packed structs for frame encoding/decoding --- firmware/lib/net.cpp | 234 +++++++++++++++++++++++++------------------ 1 file changed, 136 insertions(+), 98 deletions(-) diff --git a/firmware/lib/net.cpp b/firmware/lib/net.cpp index 7ae6412..3facccf 100644 --- a/firmware/lib/net.cpp +++ b/firmware/lib/net.cpp @@ -5,152 +5,190 @@ #include "w6300.h" #include "debug_log.h" +using mac_addr = std::array; +using ip4_addr = std::array; + +struct __attribute__((packed)) eth_header { + mac_addr dst; + mac_addr src; + uint16_t ethertype; +}; +static_assert(sizeof(eth_header) == 14); + +struct __attribute__((packed)) arp_packet { + eth_header eth; + uint16_t htype; + uint16_t ptype; + uint8_t hlen; + uint8_t plen; + uint16_t oper; + mac_addr sha; + ip4_addr spa; + mac_addr tha; + ip4_addr tpa; +}; +static_assert(sizeof(arp_packet) == 42); + +struct __attribute__((packed)) ipv4_header { + uint8_t ver_ihl; + uint8_t dscp_ecn; + uint16_t total_len; + uint16_t identification; + uint16_t flags_frag; + uint8_t ttl; + uint8_t protocol; + uint16_t checksum; + ip4_addr src; + ip4_addr dst; + + size_t header_len() const { return (ver_ihl & 0x0F) * 4; } + size_t payload_len() const { return __builtin_bswap16(total_len) - header_len(); } +}; +static_assert(sizeof(ipv4_header) == 20); + +struct __attribute__((packed)) icmp_header { + uint8_t type; + uint8_t code; + uint16_t checksum; + uint16_t id; + uint16_t seq; +}; +static_assert(sizeof(icmp_header) == 8); + +static constexpr uint16_t ETH_ARP = __builtin_bswap16(0x0806); +static constexpr uint16_t ETH_IPV4 = __builtin_bswap16(0x0800); +static constexpr uint16_t ARP_HTYPE_ETH = __builtin_bswap16(1); +static constexpr uint16_t ARP_PTYPE_IPV4 = __builtin_bswap16(0x0800); +static constexpr uint16_t ARP_OP_REQUEST = __builtin_bswap16(1); +static constexpr uint16_t ARP_OP_REPLY = __builtin_bswap16(2); +static constexpr mac_addr MAC_BROADCAST = {0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF}; +static constexpr ip4_addr IP_BROADCAST_ALL = {255, 255, 255, 255}; +static constexpr ip4_addr IP_BROADCAST_SUBNET = {169, 254, 255, 255}; + static net_state state; static w6300::socket_id raw_socket{0}; -static constexpr uint16_t ETHERTYPE_ARP = 0x0806; -static constexpr uint16_t ETHERTYPE_IPV4 = 0x0800; -static constexpr uint8_t IP_PROTO_ICMP = 1; -static constexpr uint8_t ICMP_ECHO_REQUEST = 8; -static constexpr uint8_t ICMP_ECHO_REPLY = 0; -static constexpr uint16_t ARP_OP_REQUEST = 1; -static constexpr uint16_t ARP_OP_REPLY = 2; - -static uint16_t read_u16(const uint8_t* p) { return (p[0] << 8) | p[1]; } - -static void write_u16(uint8_t* p, uint16_t v) { - p[0] = v >> 8; - p[1] = v & 0xFF; -} - -static uint16_t ip_checksum(const uint8_t* data, size_t len) { +static uint16_t ip_checksum(const void* data, size_t len) { + auto p = static_cast(data); uint32_t sum = 0; for (size_t i = 0; i < len - 1; i += 2) - sum += read_u16(data + i); + sum += (p[i] << 8) | p[i + 1]; if (len & 1) - sum += data[len - 1] << 8; + sum += p[len - 1] << 8; while (sum >> 16) sum = (sum & 0xFFFF) + (sum >> 16); - return ~sum; + return __builtin_bswap16(~sum); } -static bool mac_match(const uint8_t* dst) { - static constexpr uint8_t broadcast[6] = {0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF}; - return memcmp(dst, state.mac.data(), 6) == 0 || - memcmp(dst, broadcast, 6) == 0; +static bool mac_match(const mac_addr& dst) { + return dst == state.mac || dst == MAC_BROADCAST; } -static bool ip_match(const uint8_t* dst) { - return memcmp(dst, state.ip.data(), 4) == 0; +static bool ip_match(const ip4_addr& dst) { + return dst == state.ip; } -static bool ip_match_or_broadcast(const uint8_t* dst) { - static constexpr uint8_t bcast_all[4] = {255, 255, 255, 255}; - static constexpr uint8_t bcast_subnet[4] = {169, 254, 255, 255}; - return ip_match(dst) || - memcmp(dst, bcast_all, 4) == 0 || - memcmp(dst, bcast_subnet, 4) == 0; +static bool ip_match_or_broadcast(const ip4_addr& dst) { + return ip_match(dst) || dst == IP_BROADCAST_ALL || dst == IP_BROADCAST_SUBNET; } -static void send_raw(const uint8_t* data, size_t len) { +static void send_raw(const void* data, size_t len) { dlog_if_slow("send_raw", 1000, [&]{ w6300::ip_address dummy = {}; - w6300::sendto(raw_socket, std::span{data, len}, dummy, w6300::port_num{0}); + w6300::sendto(raw_socket, std::span{static_cast(data), len}, + dummy, w6300::port_num{0}); }); } static void handle_arp(const uint8_t* frame, size_t len) { - if (len < 42) return; - const uint8_t* arp = frame + 14; + if (len < sizeof(arp_packet)) return; + auto& pkt = *reinterpret_cast(frame); - if (read_u16(arp) != 1) return; - if (read_u16(arp + 2) != ETHERTYPE_IPV4) return; - if (arp[4] != 6 || arp[5] != 4) return; + if (pkt.htype != ARP_HTYPE_ETH) return; + if (pkt.ptype != ARP_PTYPE_IPV4) return; + if (pkt.hlen != 6 || pkt.plen != 4) return; + if (pkt.oper != ARP_OP_REQUEST) return; + if (!ip_match(pkt.tpa)) return; - if (read_u16(arp + 6) != ARP_OP_REQUEST) return; - if (!ip_match(arp + 24)) return; + arp_packet reply = {}; + reply.eth.dst = pkt.eth.src; + reply.eth.src = state.mac; + reply.eth.ethertype = ETH_ARP; + reply.htype = ARP_HTYPE_ETH; + reply.ptype = ARP_PTYPE_IPV4; + reply.hlen = 6; + reply.plen = 4; + reply.oper = ARP_OP_REPLY; + reply.sha = state.mac; + reply.spa = state.ip; + reply.tha = pkt.sha; + reply.tpa = pkt.spa; - uint8_t reply[42]; - memcpy(reply, frame + 6, 6); - memcpy(reply + 6, state.mac.data(), 6); - write_u16(reply + 12, ETHERTYPE_ARP); - - uint8_t* rarp = reply + 14; - write_u16(rarp, 1); - write_u16(rarp + 2, ETHERTYPE_IPV4); - rarp[4] = 6; - rarp[5] = 4; - write_u16(rarp + 6, ARP_OP_REPLY); - memcpy(rarp + 8, state.mac.data(), 6); - memcpy(rarp + 14, state.ip.data(), 4); - memcpy(rarp + 18, arp + 8, 6); - memcpy(rarp + 24, arp + 14, 4); - - send_raw(reply, 42); + send_raw(&reply, sizeof(reply)); } static void handle_icmp(const uint8_t* frame, size_t len) { - const uint8_t* ip = frame + 14; - size_t ip_hdr_len = (ip[0] & 0x0F) * 4; - size_t ip_total_len = read_u16(ip + 2); + auto& eth = *reinterpret_cast(frame); + auto& ip = *reinterpret_cast(frame + sizeof(eth_header)); + size_t ip_hdr_len = ip.header_len(); + size_t ip_total = __builtin_bswap16(ip.total_len); - if (14 + ip_total_len > len) return; - if (ip[9] != IP_PROTO_ICMP) return; - if (!ip_match_or_broadcast(ip + 16)) return; + if (sizeof(eth_header) + ip_total > len) return; + if (ip.protocol != 1) return; + if (!ip_match_or_broadcast(ip.dst)) return; - const uint8_t* icmp = ip + ip_hdr_len; - size_t icmp_len = ip_total_len - ip_hdr_len; - if (icmp_len < 8) return; + auto* icmp = reinterpret_cast(frame + sizeof(eth_header) + ip_hdr_len); + size_t icmp_len = ip_total - ip_hdr_len; + if (icmp_len < sizeof(icmp_header)) return; + if (icmp->type != 8) return; - if (icmp[0] != ICMP_ECHO_REQUEST) return; + uint8_t reply_buf[1514]; + size_t reply_len = sizeof(eth_header) + ip_total; + if (reply_len > sizeof(reply_buf)) return; - uint8_t reply[1514]; - size_t reply_len = 14 + ip_total_len; - if (reply_len > sizeof(reply)) return; + auto& reth = *reinterpret_cast(reply_buf); + reth.dst = eth.src; + reth.src = state.mac; + reth.ethertype = ETH_IPV4; - memcpy(reply, frame + 6, 6); - memcpy(reply + 6, state.mac.data(), 6); - write_u16(reply + 12, ETHERTYPE_IPV4); + auto* rip = reply_buf + sizeof(eth_header); + memcpy(rip, &ip, ip_hdr_len); + auto& rip_hdr = *reinterpret_cast(rip); + rip_hdr.src = state.ip; + rip_hdr.dst = ip.src; + rip_hdr.ttl = 64; + rip_hdr.checksum = 0; + rip_hdr.checksum = ip_checksum(rip, ip_hdr_len); - uint8_t* rip = reply + 14; - memcpy(rip, ip, ip_hdr_len); - memcpy(rip + 12, state.ip.data(), 4); - memcpy(rip + 16, ip + 12, 4); - rip[8] = 64; - memset(rip + 10, 0, 2); - uint16_t ip_cksum = ip_checksum(rip, ip_hdr_len); - write_u16(rip + 10, ip_cksum); - - uint8_t* ricmp = rip + ip_hdr_len; + auto* ricmp = rip + ip_hdr_len; memcpy(ricmp, icmp, icmp_len); - ricmp[0] = ICMP_ECHO_REPLY; - memset(ricmp + 2, 0, 2); - uint16_t icmp_cksum = ip_checksum(ricmp, icmp_len); - write_u16(ricmp + 2, icmp_cksum); + auto& ricmp_hdr = *reinterpret_cast(ricmp); + ricmp_hdr.type = 0; + ricmp_hdr.checksum = 0; + ricmp_hdr.checksum = ip_checksum(ricmp, icmp_len); - send_raw(reply, reply_len); + send_raw(reply_buf, reply_len); } static void handle_ipv4(const uint8_t* frame, size_t len) { - if (len < 34) return; - const uint8_t* ip = frame + 14; - if ((ip[0] >> 4) != 4) return; + if (len < sizeof(eth_header) + sizeof(ipv4_header)) return; + auto& ip = *reinterpret_cast(frame + sizeof(eth_header)); + if ((ip.ver_ihl >> 4) != 4) return; handle_icmp(frame, len); } static void process_frame(const uint8_t* frame, size_t len) { - if (len < 14) return; + if (len < sizeof(eth_header)) return; + auto& eth = *reinterpret_cast(frame); - if (!mac_match(frame)) return; + if (!mac_match(eth.dst)) return; - uint16_t ethertype = read_u16(frame + 12); - - switch (ethertype) { - case ETHERTYPE_ARP: + switch (eth.ethertype) { + case ETH_ARP: handle_arp(frame, len); break; - case ETHERTYPE_IPV4: + case ETH_IPV4: handle_ipv4(frame, len); break; }