Rewrite net.cpp to use packed structs for frame encoding/decoding

This commit is contained in:
Ian Gulliver
2026-04-07 12:21:41 +09:00
parent b0294fada3
commit f161dda60a

View File

@@ -5,152 +5,190 @@
#include "w6300.h"
#include "debug_log.h"
using mac_addr = std::array<uint8_t, 6>;
using ip4_addr = std::array<uint8_t, 4>;
struct __attribute__((packed)) eth_header {
mac_addr dst;
mac_addr src;
uint16_t ethertype;
};
static_assert(sizeof(eth_header) == 14);
struct __attribute__((packed)) arp_packet {
eth_header eth;
uint16_t htype;
uint16_t ptype;
uint8_t hlen;
uint8_t plen;
uint16_t oper;
mac_addr sha;
ip4_addr spa;
mac_addr tha;
ip4_addr tpa;
};
static_assert(sizeof(arp_packet) == 42);
struct __attribute__((packed)) ipv4_header {
uint8_t ver_ihl;
uint8_t dscp_ecn;
uint16_t total_len;
uint16_t identification;
uint16_t flags_frag;
uint8_t ttl;
uint8_t protocol;
uint16_t checksum;
ip4_addr src;
ip4_addr dst;
size_t header_len() const { return (ver_ihl & 0x0F) * 4; }
size_t payload_len() const { return __builtin_bswap16(total_len) - header_len(); }
};
static_assert(sizeof(ipv4_header) == 20);
struct __attribute__((packed)) icmp_header {
uint8_t type;
uint8_t code;
uint16_t checksum;
uint16_t id;
uint16_t seq;
};
static_assert(sizeof(icmp_header) == 8);
static constexpr uint16_t ETH_ARP = __builtin_bswap16(0x0806);
static constexpr uint16_t ETH_IPV4 = __builtin_bswap16(0x0800);
static constexpr uint16_t ARP_HTYPE_ETH = __builtin_bswap16(1);
static constexpr uint16_t ARP_PTYPE_IPV4 = __builtin_bswap16(0x0800);
static constexpr uint16_t ARP_OP_REQUEST = __builtin_bswap16(1);
static constexpr uint16_t ARP_OP_REPLY = __builtin_bswap16(2);
static constexpr mac_addr MAC_BROADCAST = {0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF};
static constexpr ip4_addr IP_BROADCAST_ALL = {255, 255, 255, 255};
static constexpr ip4_addr IP_BROADCAST_SUBNET = {169, 254, 255, 255};
static net_state state;
static w6300::socket_id raw_socket{0};
static constexpr uint16_t ETHERTYPE_ARP = 0x0806;
static constexpr uint16_t ETHERTYPE_IPV4 = 0x0800;
static constexpr uint8_t IP_PROTO_ICMP = 1;
static constexpr uint8_t ICMP_ECHO_REQUEST = 8;
static constexpr uint8_t ICMP_ECHO_REPLY = 0;
static constexpr uint16_t ARP_OP_REQUEST = 1;
static constexpr uint16_t ARP_OP_REPLY = 2;
static uint16_t read_u16(const uint8_t* p) { return (p[0] << 8) | p[1]; }
static void write_u16(uint8_t* p, uint16_t v) {
p[0] = v >> 8;
p[1] = v & 0xFF;
}
static uint16_t ip_checksum(const uint8_t* data, size_t len) {
static uint16_t ip_checksum(const void* data, size_t len) {
auto p = static_cast<const uint8_t*>(data);
uint32_t sum = 0;
for (size_t i = 0; i < len - 1; i += 2)
sum += read_u16(data + i);
sum += (p[i] << 8) | p[i + 1];
if (len & 1)
sum += data[len - 1] << 8;
sum += p[len - 1] << 8;
while (sum >> 16)
sum = (sum & 0xFFFF) + (sum >> 16);
return ~sum;
return __builtin_bswap16(~sum);
}
static bool mac_match(const uint8_t* dst) {
static constexpr uint8_t broadcast[6] = {0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF};
return memcmp(dst, state.mac.data(), 6) == 0 ||
memcmp(dst, broadcast, 6) == 0;
static bool mac_match(const mac_addr& dst) {
return dst == state.mac || dst == MAC_BROADCAST;
}
static bool ip_match(const uint8_t* dst) {
return memcmp(dst, state.ip.data(), 4) == 0;
static bool ip_match(const ip4_addr& dst) {
return dst == state.ip;
}
static bool ip_match_or_broadcast(const uint8_t* dst) {
static constexpr uint8_t bcast_all[4] = {255, 255, 255, 255};
static constexpr uint8_t bcast_subnet[4] = {169, 254, 255, 255};
return ip_match(dst) ||
memcmp(dst, bcast_all, 4) == 0 ||
memcmp(dst, bcast_subnet, 4) == 0;
static bool ip_match_or_broadcast(const ip4_addr& dst) {
return ip_match(dst) || dst == IP_BROADCAST_ALL || dst == IP_BROADCAST_SUBNET;
}
static void send_raw(const uint8_t* data, size_t len) {
static void send_raw(const void* data, size_t len) {
dlog_if_slow("send_raw", 1000, [&]{
w6300::ip_address dummy = {};
w6300::sendto(raw_socket, std::span<const uint8_t>{data, len}, dummy, w6300::port_num{0});
w6300::sendto(raw_socket, std::span<const uint8_t>{static_cast<const uint8_t*>(data), len},
dummy, w6300::port_num{0});
});
}
static void handle_arp(const uint8_t* frame, size_t len) {
if (len < 42) return;
const uint8_t* arp = frame + 14;
if (len < sizeof(arp_packet)) return;
auto& pkt = *reinterpret_cast<const arp_packet*>(frame);
if (read_u16(arp) != 1) return;
if (read_u16(arp + 2) != ETHERTYPE_IPV4) return;
if (arp[4] != 6 || arp[5] != 4) return;
if (pkt.htype != ARP_HTYPE_ETH) return;
if (pkt.ptype != ARP_PTYPE_IPV4) return;
if (pkt.hlen != 6 || pkt.plen != 4) return;
if (pkt.oper != ARP_OP_REQUEST) return;
if (!ip_match(pkt.tpa)) return;
if (read_u16(arp + 6) != ARP_OP_REQUEST) return;
if (!ip_match(arp + 24)) return;
arp_packet reply = {};
reply.eth.dst = pkt.eth.src;
reply.eth.src = state.mac;
reply.eth.ethertype = ETH_ARP;
reply.htype = ARP_HTYPE_ETH;
reply.ptype = ARP_PTYPE_IPV4;
reply.hlen = 6;
reply.plen = 4;
reply.oper = ARP_OP_REPLY;
reply.sha = state.mac;
reply.spa = state.ip;
reply.tha = pkt.sha;
reply.tpa = pkt.spa;
uint8_t reply[42];
memcpy(reply, frame + 6, 6);
memcpy(reply + 6, state.mac.data(), 6);
write_u16(reply + 12, ETHERTYPE_ARP);
uint8_t* rarp = reply + 14;
write_u16(rarp, 1);
write_u16(rarp + 2, ETHERTYPE_IPV4);
rarp[4] = 6;
rarp[5] = 4;
write_u16(rarp + 6, ARP_OP_REPLY);
memcpy(rarp + 8, state.mac.data(), 6);
memcpy(rarp + 14, state.ip.data(), 4);
memcpy(rarp + 18, arp + 8, 6);
memcpy(rarp + 24, arp + 14, 4);
send_raw(reply, 42);
send_raw(&reply, sizeof(reply));
}
static void handle_icmp(const uint8_t* frame, size_t len) {
const uint8_t* ip = frame + 14;
size_t ip_hdr_len = (ip[0] & 0x0F) * 4;
size_t ip_total_len = read_u16(ip + 2);
auto& eth = *reinterpret_cast<const eth_header*>(frame);
auto& ip = *reinterpret_cast<const ipv4_header*>(frame + sizeof(eth_header));
size_t ip_hdr_len = ip.header_len();
size_t ip_total = __builtin_bswap16(ip.total_len);
if (14 + ip_total_len > len) return;
if (ip[9] != IP_PROTO_ICMP) return;
if (!ip_match_or_broadcast(ip + 16)) return;
if (sizeof(eth_header) + ip_total > len) return;
if (ip.protocol != 1) return;
if (!ip_match_or_broadcast(ip.dst)) return;
const uint8_t* icmp = ip + ip_hdr_len;
size_t icmp_len = ip_total_len - ip_hdr_len;
if (icmp_len < 8) return;
auto* icmp = reinterpret_cast<const icmp_header*>(frame + sizeof(eth_header) + ip_hdr_len);
size_t icmp_len = ip_total - ip_hdr_len;
if (icmp_len < sizeof(icmp_header)) return;
if (icmp->type != 8) return;
if (icmp[0] != ICMP_ECHO_REQUEST) return;
uint8_t reply_buf[1514];
size_t reply_len = sizeof(eth_header) + ip_total;
if (reply_len > sizeof(reply_buf)) return;
uint8_t reply[1514];
size_t reply_len = 14 + ip_total_len;
if (reply_len > sizeof(reply)) return;
auto& reth = *reinterpret_cast<eth_header*>(reply_buf);
reth.dst = eth.src;
reth.src = state.mac;
reth.ethertype = ETH_IPV4;
memcpy(reply, frame + 6, 6);
memcpy(reply + 6, state.mac.data(), 6);
write_u16(reply + 12, ETHERTYPE_IPV4);
auto* rip = reply_buf + sizeof(eth_header);
memcpy(rip, &ip, ip_hdr_len);
auto& rip_hdr = *reinterpret_cast<ipv4_header*>(rip);
rip_hdr.src = state.ip;
rip_hdr.dst = ip.src;
rip_hdr.ttl = 64;
rip_hdr.checksum = 0;
rip_hdr.checksum = ip_checksum(rip, ip_hdr_len);
uint8_t* rip = reply + 14;
memcpy(rip, ip, ip_hdr_len);
memcpy(rip + 12, state.ip.data(), 4);
memcpy(rip + 16, ip + 12, 4);
rip[8] = 64;
memset(rip + 10, 0, 2);
uint16_t ip_cksum = ip_checksum(rip, ip_hdr_len);
write_u16(rip + 10, ip_cksum);
uint8_t* ricmp = rip + ip_hdr_len;
auto* ricmp = rip + ip_hdr_len;
memcpy(ricmp, icmp, icmp_len);
ricmp[0] = ICMP_ECHO_REPLY;
memset(ricmp + 2, 0, 2);
uint16_t icmp_cksum = ip_checksum(ricmp, icmp_len);
write_u16(ricmp + 2, icmp_cksum);
auto& ricmp_hdr = *reinterpret_cast<icmp_header*>(ricmp);
ricmp_hdr.type = 0;
ricmp_hdr.checksum = 0;
ricmp_hdr.checksum = ip_checksum(ricmp, icmp_len);
send_raw(reply, reply_len);
send_raw(reply_buf, reply_len);
}
static void handle_ipv4(const uint8_t* frame, size_t len) {
if (len < 34) return;
const uint8_t* ip = frame + 14;
if ((ip[0] >> 4) != 4) return;
if (len < sizeof(eth_header) + sizeof(ipv4_header)) return;
auto& ip = *reinterpret_cast<const ipv4_header*>(frame + sizeof(eth_header));
if ((ip.ver_ihl >> 4) != 4) return;
handle_icmp(frame, len);
}
static void process_frame(const uint8_t* frame, size_t len) {
if (len < 14) return;
if (len < sizeof(eth_header)) return;
auto& eth = *reinterpret_cast<const eth_header*>(frame);
if (!mac_match(frame)) return;
if (!mac_match(eth.dst)) return;
uint16_t ethertype = read_u16(frame + 12);
switch (ethertype) {
case ETHERTYPE_ARP:
switch (eth.ethertype) {
case ETH_ARP:
handle_arp(frame, len);
break;
case ETHERTYPE_IPV4:
case ETH_IPV4:
handle_ipv4(frame, len);
break;
}