diff --git a/firmware/lib/net.cpp b/firmware/lib/net.cpp index 3facccf..73a65b1 100644 --- a/firmware/lib/net.cpp +++ b/firmware/lib/net.cpp @@ -30,6 +30,7 @@ struct __attribute__((packed)) arp_packet { static_assert(sizeof(arp_packet) == 42); struct __attribute__((packed)) ipv4_header { + eth_header eth; uint8_t ver_ihl; uint8_t dscp_ecn; uint16_t total_len; @@ -41,19 +42,21 @@ struct __attribute__((packed)) ipv4_header { ip4_addr src; ip4_addr dst; - size_t header_len() const { return (ver_ihl & 0x0F) * 4; } - size_t payload_len() const { return __builtin_bswap16(total_len) - header_len(); } + size_t ip_header_len() const { return (ver_ihl & 0x0F) * 4; } + size_t ip_total_len() const { return __builtin_bswap16(total_len); } + const uint8_t* ip_start() const { return reinterpret_cast(&ver_ihl); } + uint8_t* ip_start() { return reinterpret_cast(&ver_ihl); } }; -static_assert(sizeof(ipv4_header) == 20); +static_assert(sizeof(ipv4_header) == 34); -struct __attribute__((packed)) icmp_header { +struct __attribute__((packed)) icmp_echo { uint8_t type; uint8_t code; uint16_t checksum; uint16_t id; uint16_t seq; }; -static_assert(sizeof(icmp_header) == 8); +static_assert(sizeof(icmp_echo) == 8); static constexpr uint16_t ETH_ARP = __builtin_bswap16(0x0806); static constexpr uint16_t ETH_IPV4 = __builtin_bswap16(0x0800); @@ -128,54 +131,51 @@ static void handle_arp(const uint8_t* frame, size_t len) { } static void handle_icmp(const uint8_t* frame, size_t len) { - auto& eth = *reinterpret_cast(frame); - auto& ip = *reinterpret_cast(frame + sizeof(eth_header)); - size_t ip_hdr_len = ip.header_len(); - size_t ip_total = __builtin_bswap16(ip.total_len); + auto& ip = *reinterpret_cast(frame); + size_t ip_hdr_len = ip.ip_header_len(); + size_t ip_total = ip.ip_total_len(); if (sizeof(eth_header) + ip_total > len) return; if (ip.protocol != 1) return; if (!ip_match_or_broadcast(ip.dst)) return; - auto* icmp = reinterpret_cast(frame + sizeof(eth_header) + ip_hdr_len); + auto& icmp = *reinterpret_cast(frame + sizeof(eth_header) + ip_hdr_len); size_t icmp_len = ip_total - ip_hdr_len; - if (icmp_len < sizeof(icmp_header)) return; - if (icmp->type != 8) return; + if (icmp_len < sizeof(icmp_echo)) return; + if (icmp.type != 8) return; uint8_t reply_buf[1514]; size_t reply_len = sizeof(eth_header) + ip_total; if (reply_len > sizeof(reply_buf)) return; - auto& reth = *reinterpret_cast(reply_buf); - reth.dst = eth.src; - reth.src = state.mac; - reth.ethertype = ETH_IPV4; + memcpy(reply_buf, frame, reply_len); + auto& rip = *reinterpret_cast(reply_buf); + rip.eth.dst = ip.eth.src; + rip.eth.src = state.mac; + rip.src = state.ip; + rip.dst = ip.src; + rip.ttl = 64; + rip.checksum = 0; + rip.checksum = ip_checksum(rip.ip_start(), ip_hdr_len); - auto* rip = reply_buf + sizeof(eth_header); - memcpy(rip, &ip, ip_hdr_len); - auto& rip_hdr = *reinterpret_cast(rip); - rip_hdr.src = state.ip; - rip_hdr.dst = ip.src; - rip_hdr.ttl = 64; - rip_hdr.checksum = 0; - rip_hdr.checksum = ip_checksum(rip, ip_hdr_len); - - auto* ricmp = rip + ip_hdr_len; - memcpy(ricmp, icmp, icmp_len); - auto& ricmp_hdr = *reinterpret_cast(ricmp); - ricmp_hdr.type = 0; - ricmp_hdr.checksum = 0; - ricmp_hdr.checksum = ip_checksum(ricmp, icmp_len); + auto& ricmp = *reinterpret_cast(reply_buf + sizeof(eth_header) + ip_hdr_len); + ricmp.type = 0; + ricmp.checksum = 0; + ricmp.checksum = ip_checksum(&ricmp, icmp_len); send_raw(reply_buf, reply_len); } static void handle_ipv4(const uint8_t* frame, size_t len) { - if (len < sizeof(eth_header) + sizeof(ipv4_header)) return; - auto& ip = *reinterpret_cast(frame + sizeof(eth_header)); + if (len < sizeof(ipv4_header)) return; + auto& ip = *reinterpret_cast(frame); if ((ip.ver_ihl >> 4) != 4) return; - handle_icmp(frame, len); + switch (ip.protocol) { + case 1: + handle_icmp(frame, len); + break; + } } static void process_frame(const uint8_t* frame, size_t len) {