diff --git a/cmd/picomap/main.go b/cmd/picomap/main.go index ec21e0c..3c70e1b 100644 --- a/cmd/picomap/main.go +++ b/cmd/picomap/main.go @@ -356,15 +356,17 @@ func findTestDevice() (string, error) { func cmdTestGroup(args []string) error { if len(args) < 1 { - return fmt.Errorf("usage: picomap test [args...]") + return fmt.Errorf("usage: picomap test [args...]") } switch args[0] { case "list": return cmdTestList(args[1:]) case "run": return cmdTestRun(args[1:]) + case "all": + return cmdTestAll(args[1:]) default: - return fmt.Errorf("usage: picomap test [args...]") + return fmt.Errorf("usage: picomap test [args...]") } } @@ -389,6 +391,51 @@ func cmdTestList(_ []string) error { return nil } +func cmdTestAll(_ []string) error { + dev, err := findTestDevice() + if err != nil { + return err + } + + c, err := client.NewSerial(dev, 10*time.Second) + if err != nil { + return err + } + defer c.Close() + + list, err := c.ListTests() + if err != nil { + return fmt.Errorf("remote: %w", err) + } + + log := slog.With("dev", dev) + failed := 0 + for _, name := range list.Names { + log.Info("running test", "name", name) + result, err := c.Test(name) + if err != nil { + log.Error("error", "name", name, "err", err) + failed++ + continue + } + for _, msg := range result.Messages { + log.Info("remote", "name", name, "msg", msg) + } + if result.Pass { + log.Info("PASS", "name", name) + } else { + log.Error("FAIL", "name", name) + failed++ + } + } + if failed > 0 { + log.Error("tests failed", "count", failed) + os.Exit(1) + } + log.Info("all tests passed", "count", len(list.Names)) + return nil +} + func cmdTestRun(args []string) error { if len(args) < 1 { return fmt.Errorf("usage: picomap test run ") diff --git a/firmware/CMakeLists.txt b/firmware/CMakeLists.txt index 872e5cb..0e74a9a 100644 --- a/firmware/CMakeLists.txt +++ b/firmware/CMakeLists.txt @@ -14,6 +14,7 @@ set(LIB_SOURCES lib/dispatch.cpp lib/handlers.cpp lib/icmp.cpp + lib/igmp.cpp lib/ipv4.cpp lib/net.cpp lib/tusb_config.cpp diff --git a/firmware/include/arp.h b/firmware/include/arp.h index 6318c07..a27b221 100644 --- a/firmware/include/arp.h +++ b/firmware/include/arp.h @@ -7,8 +7,7 @@ namespace arp { -struct __attribute__((packed)) packet { - eth::header eth; +struct __attribute__((packed)) header { uint16_t htype; uint16_t ptype; uint8_t hlen; @@ -19,7 +18,7 @@ struct __attribute__((packed)) packet { eth::mac_addr tha; ipv4::ip4_addr tpa; }; -static_assert(sizeof(packet) == 42); +static_assert(sizeof(header) == 28); void handle(std::span frame, span_writer& tx, eth::mac_addr our_mac, ipv4::ip4_addr our_ip, diff --git a/firmware/include/eth.h b/firmware/include/eth.h index 75b56be..df02e3f 100644 --- a/firmware/include/eth.h +++ b/firmware/include/eth.h @@ -17,4 +17,12 @@ struct __attribute__((packed)) header { }; static_assert(sizeof(header) == 14); +template +void prepend(Buf& buf, const mac_addr& dst, const mac_addr& src, uint16_t ethertype) { + auto* h = buf.template prepend
(); + h->dst = dst; + h->src = src; + h->ethertype = ethertype; +} + } // namespace eth diff --git a/firmware/include/icmp.h b/firmware/include/icmp.h index b93c879..18e1a01 100644 --- a/firmware/include/icmp.h +++ b/firmware/include/icmp.h @@ -21,10 +21,20 @@ void handle(std::span frame, span_writer& tx, eth::mac_addr our_mac, ipv4::ip4_addr our_ip, std::function)> send_raw); -size_t build_echo_request(std::span buf, +template +void prepend_echo_request(Buf& buf, eth::mac_addr src_mac, ipv4::ip4_addr src_ip, eth::mac_addr dst_mac, ipv4::ip4_addr dst_ip, - uint16_t id, uint16_t seq); + uint16_t id, uint16_t seq) { + auto* e = buf.template prepend(); + e->type = 8; + e->code = 0; + e->checksum = 0; + e->id = id; + e->seq = seq; + e->checksum = ipv4::checksum(e, sizeof(echo)); + ipv4::prepend(buf, dst_mac, src_mac, src_ip, dst_ip, 1, sizeof(echo)); +} bool parse_echo_reply(std::span frame, ipv4::ip4_addr& src_ip, uint16_t expected_id); diff --git a/firmware/include/igmp.h b/firmware/include/igmp.h new file mode 100644 index 0000000..fa8529f --- /dev/null +++ b/firmware/include/igmp.h @@ -0,0 +1,64 @@ +#pragma once +#include +#include +#include +#include "eth.h" +#include "ipv4.h" +#include "span_writer.h" + +namespace igmp { + +static constexpr ipv4::ip4_addr PICOMAP_DISCOVERY_GROUP = {239, 112, 77, 1}; +static constexpr ipv4::ip4_addr ALL_HOSTS = {224, 0, 0, 1}; + +struct __attribute__((packed)) message { + uint8_t type; + uint8_t max_resp_time; + uint16_t checksum; + ipv4::ip4_addr group; +}; +static_assert(sizeof(message) == 8); + +eth::mac_addr mac_for_ip(const ipv4::ip4_addr& group); +bool is_member(const ipv4::ip4_addr& ip); +bool is_member_mac(const eth::mac_addr& mac); + +void join(const ipv4::ip4_addr& group, + eth::mac_addr our_mac, ipv4::ip4_addr our_ip, + std::function)> send_raw); + +void send_all_reports(eth::mac_addr our_mac, ipv4::ip4_addr our_ip, + std::function)> send_raw); + +void handle(std::span frame, span_writer& tx, + eth::mac_addr our_mac, ipv4::ip4_addr our_ip, + std::function)> send_raw); + +template +void prepend_report(Buf& buf, const eth::mac_addr& src_mac, ipv4::ip4_addr src_ip, + const ipv4::ip4_addr& group) { + auto* m = buf.template prepend(); + m->type = 0x16; + m->max_resp_time = 0; + m->checksum = 0; + m->group = group; + m->checksum = ipv4::checksum(m, sizeof(message)); + ipv4::prepend(buf, mac_for_ip(group), src_mac, src_ip, group, 2, sizeof(message), 1); +} + +template +void prepend_query(Buf& buf, const eth::mac_addr& src_mac, ipv4::ip4_addr src_ip, + const ipv4::ip4_addr& group) { + ipv4::ip4_addr dst_ip = (group == ipv4::ip4_addr{0, 0, 0, 0}) ? ALL_HOSTS : group; + auto* m = buf.template prepend(); + m->type = 0x11; + m->max_resp_time = 100; + m->checksum = 0; + m->group = group; + m->checksum = ipv4::checksum(m, sizeof(message)); + ipv4::prepend(buf, mac_for_ip(dst_ip), src_mac, src_ip, dst_ip, 2, sizeof(message), 1); +} + +bool parse_report(std::span frame, ipv4::ip4_addr& group); + +} // namespace igmp diff --git a/firmware/include/ipv4.h b/firmware/include/ipv4.h index c2a7b70..602b017 100644 --- a/firmware/include/ipv4.h +++ b/firmware/include/ipv4.h @@ -19,7 +19,6 @@ inline std::string to_string(const ip4_addr& ip) { } struct __attribute__((packed)) header { - eth::header eth; uint8_t ver_ihl; uint8_t dscp_ecn; uint16_t total_len; @@ -31,24 +30,32 @@ struct __attribute__((packed)) header { ip4_addr src; ip4_addr dst; - size_t ip_header_len() const { return (ver_ihl & 0x0F) * 4; } - size_t ip_total_len() const { return __builtin_bswap16(total_len); } - const uint8_t* ip_start() const { return reinterpret_cast(&ver_ihl); } - uint8_t* ip_start() { return reinterpret_cast(&ver_ihl); } + size_t header_len() const { return (ver_ihl & 0x0F) * 4; } + size_t total() const { return __builtin_bswap16(total_len); } }; -static_assert(sizeof(header) == 34); - -struct __attribute__((packed)) udp_header { - header ip; - uint16_t src_port; - uint16_t dst_port; - uint16_t length; - uint16_t checksum; -}; -static_assert(sizeof(udp_header) == 42); +static_assert(sizeof(header) == 20); uint16_t checksum(const void* data, size_t len); +template +void prepend(Buf& buf, const eth::mac_addr& dst_mac, const eth::mac_addr& src_mac, + ip4_addr src_ip, ip4_addr dst_ip, uint8_t protocol, + size_t payload_len, uint8_t ttl = 64) { + auto* h = buf.template prepend
(); + h->ver_ihl = 0x45; + h->dscp_ecn = 0; + h->total_len = __builtin_bswap16(sizeof(header) + payload_len); + h->identification = 0; + h->flags_frag = 0; + h->ttl = ttl; + h->protocol = protocol; + h->checksum = 0; + h->src = src_ip; + h->dst = dst_ip; + h->checksum = checksum(h, sizeof(header)); + eth::prepend(buf, dst_mac, src_mac, eth::ETH_IPV4); +} + void handle(std::span frame, span_writer& tx, eth::mac_addr our_mac, ip4_addr our_ip, ip4_addr subnet_broadcast, std::function)> send_raw, diff --git a/firmware/include/parse_buffer.h b/firmware/include/parse_buffer.h new file mode 100644 index 0000000..9ee7014 --- /dev/null +++ b/firmware/include/parse_buffer.h @@ -0,0 +1,32 @@ +#pragma once +#include +#include +#include + +class parse_buffer { + const uint8_t* m_data; + size_t m_remaining; + +public: + parse_buffer(std::span data) + : m_data(data.data()), m_remaining(data.size()) {} + + template + const T* consume() { + if (m_remaining < sizeof(T)) return nullptr; + auto* p = reinterpret_cast(m_data); + m_data += sizeof(T); + m_remaining -= sizeof(T); + return p; + } + + bool skip(size_t len) { + if (m_remaining < len) return false; + m_data += len; + m_remaining -= len; + return true; + } + + std::span remaining() const { return {m_data, m_remaining}; } + size_t remaining_size() const { return m_remaining; } +}; diff --git a/firmware/include/prepend_buffer.h b/firmware/include/prepend_buffer.h new file mode 100644 index 0000000..9dc8155 --- /dev/null +++ b/firmware/include/prepend_buffer.h @@ -0,0 +1,37 @@ +#pragma once +#include +#include +#include +#include + +template +class prepend_buffer { + uint8_t m_buf[N]; + size_t m_start = N / 2; + size_t m_end = N / 2; + +public: + template + T* prepend() { + m_start -= sizeof(T); + return reinterpret_cast(m_buf + m_start); + } + + uint8_t* append(size_t len) { + uint8_t* p = m_buf + m_end; + m_end += len; + return p; + } + + void append_copy(std::span data) { + memcpy(append(data.size()), data.data(), data.size()); + } + + uint8_t* payload_ptr() { return m_buf + m_end; } + + uint8_t* data() { return m_buf + m_start; } + const uint8_t* data() const { return m_buf + m_start; } + size_t size() const { return m_end - m_start; } + + std::span span() const { return {data(), size()}; } +}; diff --git a/firmware/include/udp.h b/firmware/include/udp.h new file mode 100644 index 0000000..d5b87bd --- /dev/null +++ b/firmware/include/udp.h @@ -0,0 +1,29 @@ +#pragma once +#include +#include "eth.h" +#include "ipv4.h" + +namespace udp { + +struct __attribute__((packed)) header { + uint16_t src_port; + uint16_t dst_port; + uint16_t length; + uint16_t checksum; +}; +static_assert(sizeof(header) == 8); + +template +void prepend(Buf& buf, const eth::mac_addr& dst_mac, const eth::mac_addr& src_mac, + ipv4::ip4_addr src_ip, ipv4::ip4_addr dst_ip, + uint16_t src_port, uint16_t dst_port, + size_t payload_len, uint8_t ttl = 64) { + auto* u = buf.template prepend
(); + u->src_port = src_port; + u->dst_port = dst_port; + u->length = __builtin_bswap16(sizeof(header) + payload_len); + u->checksum = 0; + ipv4::prepend(buf, dst_mac, src_mac, src_ip, dst_ip, 17, sizeof(header) + payload_len, ttl); +} + +} // namespace udp diff --git a/firmware/lib/arp.cpp b/firmware/lib/arp.cpp index 953ab85..d44e578 100644 --- a/firmware/lib/arp.cpp +++ b/firmware/lib/arp.cpp @@ -1,4 +1,6 @@ #include "arp.h" +#include "parse_buffer.h" +#include "prepend_buffer.h" namespace arp { @@ -10,32 +12,31 @@ static constexpr uint16_t ARP_OP_REPLY = __builtin_bswap16(2); void handle(std::span frame, span_writer& tx, eth::mac_addr our_mac, ipv4::ip4_addr our_ip, std::function)> send_raw) { - if (frame.size() < sizeof(packet)) return; - auto& pkt = *reinterpret_cast(frame.data()); + parse_buffer pb(frame); + pb.consume(); + auto* arp_hdr = pb.consume
(); + if (!arp_hdr) return; - if (pkt.htype != ARP_HTYPE_ETH) return; - if (pkt.ptype != ARP_PTYPE_IPV4) return; - if (pkt.hlen != 6 || pkt.plen != 4) return; - if (pkt.oper != ARP_OP_REQUEST) return; - if (pkt.tpa != our_ip) return; - if (sizeof(packet) > tx.capacity()) return; + if (arp_hdr->htype != ARP_HTYPE_ETH) return; + if (arp_hdr->ptype != ARP_PTYPE_IPV4) return; + if (arp_hdr->hlen != 6 || arp_hdr->plen != 4) return; + if (arp_hdr->oper != ARP_OP_REQUEST) return; + if (arp_hdr->tpa != our_ip) return; - auto& reply = *reinterpret_cast(tx.data()); - reply = {}; - reply.eth.dst = pkt.eth.src; - reply.eth.src = our_mac; - reply.eth.ethertype = eth::ETH_ARP; - reply.htype = ARP_HTYPE_ETH; - reply.ptype = ARP_PTYPE_IPV4; - reply.hlen = 6; - reply.plen = 4; - reply.oper = ARP_OP_REPLY; - reply.sha = our_mac; - reply.spa = our_ip; - reply.tha = pkt.sha; - reply.tpa = pkt.spa; + prepend_buffer<128> buf; + auto* reply = buf.template prepend
(); + reply->htype = ARP_HTYPE_ETH; + reply->ptype = ARP_PTYPE_IPV4; + reply->hlen = 6; + reply->plen = 4; + reply->oper = ARP_OP_REPLY; + reply->sha = our_mac; + reply->spa = our_ip; + reply->tha = arp_hdr->sha; + reply->tpa = arp_hdr->spa; + eth::prepend(buf, arp_hdr->sha, our_mac, eth::ETH_ARP); - send_raw({tx.data(), sizeof(packet)}); + send_raw(buf.span()); } } // namespace arp diff --git a/firmware/lib/dispatch.cpp b/firmware/lib/dispatch.cpp index 9198a97..3a11861 100644 --- a/firmware/lib/dispatch.cpp +++ b/firmware/lib/dispatch.cpp @@ -6,14 +6,22 @@ #include "usb_cdc.h" #include "timer_queue.h" #include "net.h" +#include "igmp.h" #include "debug_log.h" #include "hardware/sync.h" static timer_queue timers; +static void igmp_reannounce() { + auto& ns = net_get_state(); + igmp::send_all_reports(ns.mac, ns.ip, net_send_raw); + dispatch_schedule_ms(60000, igmp_reannounce); +} + void dispatch_init() { tusb_init(); net_init(); + dispatch_schedule_ms(60000, igmp_reannounce); dlog("dispatch_init complete"); } diff --git a/firmware/lib/icmp.cpp b/firmware/lib/icmp.cpp index eb9acbf..95f1765 100644 --- a/firmware/lib/icmp.cpp +++ b/firmware/lib/icmp.cpp @@ -1,96 +1,62 @@ #include "icmp.h" #include #include "ipv4.h" +#include "parse_buffer.h" +#include "prepend_buffer.h" namespace icmp { void handle(std::span frame, span_writer& tx, eth::mac_addr our_mac, ipv4::ip4_addr our_ip, std::function)> send_raw) { - auto& ip = *reinterpret_cast(frame.data()); - size_t ip_hdr_len = ip.ip_header_len(); - size_t ip_total = ip.ip_total_len(); + parse_buffer pb(frame); + auto* eth_hdr = pb.consume(); + auto* ip = pb.consume(); + if (!ip) return; + if (ip->protocol != 1) return; - if (sizeof(eth::header) + ip_total > frame.size()) return; - if (ip.protocol != 1) return; + size_t options_len = ip->header_len() - sizeof(ipv4::header); + if (options_len > 0 && !pb.skip(options_len)) return; - auto& icmp_pkt = *reinterpret_cast(frame.data() + sizeof(eth::header) + ip_hdr_len); - size_t icmp_len = ip_total - ip_hdr_len; - if (icmp_len < sizeof(echo)) return; - if (icmp_pkt.type != 8) return; + size_t icmp_len = ip->total() - ip->header_len(); + if (pb.remaining_size() < icmp_len) return; - size_t reply_len = sizeof(eth::header) + ip_total; - if (reply_len > tx.capacity()) return; + auto* icmp_pkt = pb.consume(); + if (!icmp_pkt) return; + if (icmp_pkt->type != 8) return; - memcpy(tx.data(), frame.data(), reply_len); - auto& rip = *reinterpret_cast(tx.data()); - rip.eth.dst = ip.eth.src; - rip.eth.src = our_mac; - rip.src = our_ip; - rip.dst = ip.src; - rip.ttl = 64; - rip.checksum = 0; - rip.checksum = ipv4::checksum(rip.ip_start(), ip_hdr_len); + prepend_buffer<1514> buf; + memcpy(buf.append(icmp_len), pb.remaining().data() - sizeof(echo), icmp_len); - auto& ricmp = *reinterpret_cast(tx.data() + sizeof(eth::header) + ip_hdr_len); - ricmp.type = 0; - ricmp.checksum = 0; - ricmp.checksum = ipv4::checksum(&ricmp, icmp_len); + auto* reply = reinterpret_cast(buf.data()); + reply->type = 0; + reply->checksum = 0; + reply->checksum = ipv4::checksum(reply, icmp_len); - send_raw({tx.data(), reply_len}); -} - -size_t build_echo_request(std::span buf, - eth::mac_addr src_mac, ipv4::ip4_addr src_ip, - eth::mac_addr dst_mac, ipv4::ip4_addr dst_ip, - uint16_t id, uint16_t seq) { - size_t total = sizeof(ipv4::header) + sizeof(echo); - if (buf.size() < total) return 0; - - memset(buf.data(), 0, total); - - auto& ip = *reinterpret_cast(buf.data()); - ip.eth.dst = dst_mac; - ip.eth.src = src_mac; - ip.eth.ethertype = eth::ETH_IPV4; - ip.ver_ihl = 0x45; - ip.dscp_ecn = 0; - ip.total_len = __builtin_bswap16(20 + sizeof(echo)); - ip.identification = 0; - ip.flags_frag = 0; - ip.ttl = 64; - ip.protocol = 1; - ip.checksum = 0; - ip.src = src_ip; - ip.dst = dst_ip; - ip.checksum = ipv4::checksum(ip.ip_start(), 20); - - auto& icmp_pkt = *reinterpret_cast(buf.data() + sizeof(ipv4::header)); - icmp_pkt.type = 8; - icmp_pkt.code = 0; - icmp_pkt.checksum = 0; - icmp_pkt.id = id; - icmp_pkt.seq = seq; - icmp_pkt.checksum = ipv4::checksum(&icmp_pkt, sizeof(echo)); - - return total; + ipv4::prepend(buf, eth_hdr->src, our_mac, our_ip, ip->src, 1, icmp_len); + send_raw(buf.span()); } bool parse_echo_reply(std::span frame, ipv4::ip4_addr& src_ip, uint16_t expected_id) { - if (frame.size() < sizeof(ipv4::header)) return false; - auto& ip = *reinterpret_cast(frame.data()); - if ((ip.ver_ihl >> 4) != 4) return false; - if (ip.eth.ethertype != eth::ETH_IPV4) return false; - if (ip.protocol != 1) return false; + parse_buffer pb(frame); + auto* eth_hdr = pb.consume(); + if (!eth_hdr) return false; + if (eth_hdr->ethertype != eth::ETH_IPV4) return false; - size_t ip_hdr_len = ip.ip_header_len(); - if (sizeof(eth::header) + ip_hdr_len + sizeof(echo) > frame.size()) return false; + auto* ip = pb.consume(); + if (!ip) return false; + if ((ip->ver_ihl >> 4) != 4) return false; + if (ip->protocol != 1) return false; - auto& icmp_pkt = *reinterpret_cast(frame.data() + sizeof(eth::header) + ip_hdr_len); - if (icmp_pkt.type != 0) return false; - if (icmp_pkt.id != expected_id) return false; + size_t options_len = ip->header_len() - sizeof(ipv4::header); + if (options_len > 0 && !pb.skip(options_len)) return false; - src_ip = ip.src; + auto* icmp_pkt = pb.consume(); + if (!icmp_pkt) return false; + if (icmp_pkt->type != 0) return false; + if (icmp_pkt->id != expected_id) return false; + + src_ip = ip->src; return true; } diff --git a/firmware/lib/igmp.cpp b/firmware/lib/igmp.cpp new file mode 100644 index 0000000..4514ca4 --- /dev/null +++ b/firmware/lib/igmp.cpp @@ -0,0 +1,109 @@ +#include "igmp.h" +#include +#include "ipv4.h" +#include "parse_buffer.h" +#include "prepend_buffer.h" + +namespace igmp { + +struct group_entry { + ipv4::ip4_addr ip; + eth::mac_addr mac; +}; + +static std::vector groups; + +eth::mac_addr mac_for_ip(const ipv4::ip4_addr& group) { + return {0x01, 0x00, 0x5E, + static_cast(group[1] & 0x7F), group[2], group[3]}; +} + +bool is_member(const ipv4::ip4_addr& ip) { + if (ip == ALL_HOSTS) return true; + for (auto& g : groups) + if (g.ip == ip) return true; + return false; +} + +bool is_member_mac(const eth::mac_addr& mac) { + static constexpr eth::mac_addr ALL_HOSTS_MAC = {0x01, 0x00, 0x5E, 0x00, 0x00, 0x01}; + if (mac == ALL_HOSTS_MAC) return true; + for (auto& g : groups) + if (g.mac == mac) return true; + return false; +} + +static void send_report(const ipv4::ip4_addr& group, + eth::mac_addr our_mac, ipv4::ip4_addr our_ip, + std::function)> send_raw) { + prepend_buffer<128> buf; + prepend_report(buf, our_mac, our_ip, group); + send_raw(buf.span()); +} + +void join(const ipv4::ip4_addr& group, + eth::mac_addr our_mac, ipv4::ip4_addr our_ip, + std::function)> send_raw) { + for (auto& g : groups) + if (g.ip == group) return; + groups.push_back({group, mac_for_ip(group)}); + send_report(group, our_mac, our_ip, send_raw); +} + +void send_all_reports(eth::mac_addr our_mac, ipv4::ip4_addr our_ip, + std::function)> send_raw) { + for (auto& g : groups) + send_report(g.ip, our_mac, our_ip, send_raw); +} + +void handle(std::span frame, span_writer& tx, + eth::mac_addr our_mac, ipv4::ip4_addr our_ip, + std::function)> send_raw) { + parse_buffer pb(frame); + pb.consume(); + auto* ip = pb.consume(); + if (!ip) return; + + size_t options_len = ip->header_len() - sizeof(ipv4::header); + if (options_len > 0 && !pb.skip(options_len)) return; + + auto* msg = pb.consume(); + if (!msg) return; + if (msg->type != 0x11) return; + + if (msg->group == ipv4::ip4_addr{0, 0, 0, 0}) { + for (auto& g : groups) + send_report(g.ip, our_mac, our_ip, send_raw); + } else { + for (auto& g : groups) { + if (g.ip == msg->group) { + send_report(g.ip, our_mac, our_ip, send_raw); + break; + } + } + } +} + +bool parse_report(std::span frame, ipv4::ip4_addr& group) { + parse_buffer pb(frame); + auto* eth_hdr = pb.consume(); + if (!eth_hdr) return false; + if (eth_hdr->ethertype != eth::ETH_IPV4) return false; + + auto* ip = pb.consume(); + if (!ip) return false; + if ((ip->ver_ihl >> 4) != 4) return false; + if (ip->protocol != 2) return false; + + size_t options_len = ip->header_len() - sizeof(ipv4::header); + if (options_len > 0 && !pb.skip(options_len)) return false; + + auto* msg = pb.consume(); + if (!msg) return false; + if (msg->type != 0x16) return false; + + group = msg->group; + return true; +} + +} // namespace igmp diff --git a/firmware/lib/ipv4.cpp b/firmware/lib/ipv4.cpp index 3f0a247..96ee2e0 100644 --- a/firmware/lib/ipv4.cpp +++ b/firmware/lib/ipv4.cpp @@ -1,5 +1,7 @@ #include "ipv4.h" #include "icmp.h" +#include "igmp.h" +#include "parse_buffer.h" namespace ipv4 { @@ -17,26 +19,34 @@ uint16_t checksum(const void* data, size_t len) { return __builtin_bswap16(~sum); } -static bool ip_match_or_broadcast(const ip4_addr& dst, const ip4_addr& our_ip, const ip4_addr& subnet_broadcast) { - return dst == our_ip || dst == IP_BROADCAST_ALL || dst == subnet_broadcast; +static bool ip_match(const ip4_addr& dst, const ip4_addr& our_ip, const ip4_addr& subnet_broadcast) { + return dst == our_ip || dst == IP_BROADCAST_ALL || dst == subnet_broadcast || igmp::is_member(dst); } void handle(std::span frame, span_writer& tx, eth::mac_addr our_mac, ip4_addr our_ip, ip4_addr subnet_broadcast, std::function)> send_raw, std::function, span_writer&)> handle_udp) { - if (frame.size() < sizeof(header)) return; - auto& ip = *reinterpret_cast(frame.data()); - if ((ip.ver_ihl >> 4) != 4) return; + parse_buffer pb(frame); + pb.consume(); + auto* ip = pb.consume
(); + if (!ip) return; + if ((ip->ver_ihl >> 4) != 4) return; - switch (ip.protocol) { + size_t options_len = ip->header_len() - sizeof(header); + if (options_len > 0 && !pb.skip(options_len)) return; + + switch (ip->protocol) { case 1: - if (!ip_match_or_broadcast(ip.dst, our_ip, subnet_broadcast)) + if (!ip_match(ip->dst, our_ip, subnet_broadcast)) return; icmp::handle(frame, tx, our_mac, our_ip, send_raw); break; + case 2: + igmp::handle(frame, tx, our_mac, our_ip, send_raw); + break; case 17: - if (!ip_match_or_broadcast(ip.dst, our_ip, subnet_broadcast)) + if (!ip_match(ip->dst, our_ip, subnet_broadcast)) return; handle_udp(frame, tx); break; diff --git a/firmware/lib/net.cpp b/firmware/lib/net.cpp index f997f42..29f06f2 100644 --- a/firmware/lib/net.cpp +++ b/firmware/lib/net.cpp @@ -5,6 +5,10 @@ #include "eth.h" #include "arp.h" #include "ipv4.h" +#include "udp.h" +#include "igmp.h" +#include "parse_buffer.h" +#include "prepend_buffer.h" #include "w6300.h" #include "debug_log.h" @@ -23,58 +27,40 @@ void net_send_raw(std::span data) { } static void handle_udp(std::span frame, span_writer& tx) { - if (frame.size() < sizeof(ipv4::udp_header)) return; - auto& pkt = *reinterpret_cast(frame.data()); + parse_buffer pb(frame); + auto* eth_hdr = pb.consume(); + auto* ip = pb.consume(); + if (!ip) return; - if (pkt.dst_port != PICOMAP_PORT) return; + size_t options_len = ip->header_len() - sizeof(ipv4::header); + if (options_len > 0 && !pb.skip(options_len)) return; + + auto* uhdr = pb.consume(); + if (!uhdr) return; + if (uhdr->dst_port != PICOMAP_PORT) return; if (!msg_handler) return; - size_t udp_len = __builtin_bswap16(pkt.length); - if (udp_len < 8) return; - if (sizeof(eth::header) + pkt.ip.ip_total_len() < sizeof(ipv4::udp_header) + udp_len - 8) return; + size_t udp_len = __builtin_bswap16(uhdr->length); + if (udp_len < sizeof(udp::header)) return; + size_t payload_len = udp_len - sizeof(udp::header); + if (pb.remaining_size() < payload_len) return; - size_t payload_len = udp_len - 8; + eth::mac_addr dst_mac = eth_hdr->src; + ipv4::ip4_addr dst_ip = ip->src; + uint16_t dst_port = uhdr->src_port; - eth::mac_addr dst_mac = pkt.ip.eth.src; - ipv4::ip4_addr dst_ip = pkt.ip.src; - uint16_t dst_port = pkt.src_port; - - msg_handler(frame.subspan(sizeof(ipv4::udp_header), payload_len), + msg_handler(pb.remaining().subspan(0, payload_len), [dst_mac, dst_ip, dst_port](std::span resp_data) { - size_t ip_total = 20 + 8 + resp_data.size(); - size_t reply_len = sizeof(eth::header) + ip_total; - uint8_t reply_buf[1514]; - if (reply_len > sizeof(reply_buf)) return; - - auto& rip = *reinterpret_cast(reply_buf); - rip.eth.dst = dst_mac; - rip.eth.src = state.mac; - rip.eth.ethertype = eth::ETH_IPV4; - rip.ver_ihl = 0x45; - rip.dscp_ecn = 0; - rip.total_len = __builtin_bswap16(ip_total); - rip.identification = 0; - rip.flags_frag = 0; - rip.ttl = 64; - rip.protocol = 17; - rip.checksum = 0; - rip.src = state.ip; - rip.dst = dst_ip; - rip.checksum = ipv4::checksum(rip.ip_start(), 20); - - auto& rudp = *reinterpret_cast(reply_buf); - rudp.src_port = PICOMAP_PORT; - rudp.dst_port = dst_port; - rudp.length = __builtin_bswap16(8 + resp_data.size()); - rudp.checksum = 0; - - memcpy(reply_buf + sizeof(ipv4::udp_header), resp_data.data(), resp_data.size()); - net_send_raw({reply_buf, reply_len}); + prepend_buffer<1514> buf; + buf.append_copy(resp_data); + udp::prepend(buf, dst_mac, state.mac, state.ip, dst_ip, + PICOMAP_PORT, dst_port, resp_data.size()); + net_send_raw(buf.span()); }); } static bool mac_match(const eth::mac_addr& dst) { - return dst == state.mac || dst == eth::MAC_BROADCAST; + return dst == state.mac || dst == eth::MAC_BROADCAST || igmp::is_member_mac(dst); } static void process_frame(std::span frame, span_writer& tx) { @@ -121,6 +107,8 @@ bool net_init() { w6300::open_socket(raw_socket, w6300::protocol::macraw, w6300::sock_flag::none); w6300::set_interrupt_mask(w6300::ik_sock_0); + igmp::join(igmp::PICOMAP_DISCOVERY_GROUP, state.mac, state.ip, net_send_raw); + return true; } diff --git a/firmware/lib/test_handlers.cpp b/firmware/lib/test_handlers.cpp index 4232013..ba041f5 100644 --- a/firmware/lib/test_handlers.cpp +++ b/firmware/lib/test_handlers.cpp @@ -5,28 +5,110 @@ #include "pico/time.h" #include "net.h" #include "icmp.h" +#include "igmp.h" +#include "udp.h" +#include "parse_buffer.h" +#include "prepend_buffer.h" -static ResponseTest test_discovery(const responder&) { - ResponseTest resp; - resp.pass = true; - resp.messages.push_back("TODO: rewrite as deferred test"); - return resp; +static void test_discovery_igmp(const responder& resp) { + auto& ns = net_get_state(); + + prepend_buffer<128> buf; + igmp::prepend_query(buf, ns.mac, ns.ip, igmp::PICOMAP_DISCOVERY_GROUP); + net_send_raw(buf.span()); + + auto timer = std::make_shared(nullptr); + auto cb = std::make_shared)>>(); + *cb = [resp, timer, cb](std::span frame) { + ipv4::ip4_addr group; + if (!igmp::parse_report(frame, group)) { + net_add_frame_callback(*cb); + return; + } + if (group != igmp::PICOMAP_DISCOVERY_GROUP) { + net_add_frame_callback(*cb); + return; + } + dispatch_cancel_timer(*timer); + resp.respond(ResponseTest{true, {"got IGMP report for " + ipv4::to_string(group)}}); + }; + net_add_frame_callback(*cb); + + *timer = dispatch_schedule_ms(5000, [resp]() { + resp.respond(ResponseTest{false, {"no IGMP report within 5s"}}); + }); +} + +static void test_discovery_info(const responder& resp) { + auto& ns = net_get_state(); + + eth::mac_addr mcast_mac = igmp::mac_for_ip(igmp::PICOMAP_DISCOVERY_GROUP); + static constexpr uint16_t PICOMAP_PORT = __builtin_bswap16(28781); + + prepend_buffer<1514> buf; + uint8_t* payload = buf.payload_ptr(); + span_writer out(payload, 1024); + RequestInfo req_msg; + auto encoded = encode_response_into(out, 0xFFFF, req_msg); + if (!encoded) { + resp.respond(ResponseTest{false, {"encode RequestInfo failed"}}); + return; + } + buf.append(*encoded); + size_t payload_len = *encoded; + + udp::prepend(buf, mcast_mac, ns.mac, ns.ip, igmp::PICOMAP_DISCOVERY_GROUP, + PICOMAP_PORT, PICOMAP_PORT, payload_len, 1); + net_send_raw(buf.span()); + + ipv4::ip4_addr our_ip = ns.ip; + + auto timer = std::make_shared(nullptr); + auto cb = std::make_shared)>>(); + *cb = [resp, our_ip, timer, cb](std::span frame) { + parse_buffer pb(frame); + auto* eth_hdr = pb.consume(); + if (!eth_hdr || eth_hdr->ethertype != eth::ETH_IPV4) { + net_add_frame_callback(*cb); + return; + } + auto* ip = pb.consume(); + if (!ip || ip->protocol != 17) { + net_add_frame_callback(*cb); + return; + } + size_t options_len = ip->header_len() - sizeof(ipv4::header); + if (options_len > 0 && !pb.skip(options_len)) { + net_add_frame_callback(*cb); + return; + } + auto* uhdr = pb.consume(); + if (!uhdr || uhdr->src_port != __builtin_bswap16(28781)) { + net_add_frame_callback(*cb); + return; + } + if (ip->src == our_ip) { + net_add_frame_callback(*cb); + return; + } + dispatch_cancel_timer(*timer); + resp.respond(ResponseTest{true, {"got info response from " + ipv4::to_string(ip->src)}}); + }; + net_add_frame_callback(*cb); + + *timer = dispatch_schedule_ms(5000, [resp]() { + resp.respond(ResponseTest{false, {"no info response within 5s"}}); + }); } static void test_ping(const responder& resp, ipv4::ip4_addr dst_ip) { auto& ns = net_get_state(); uint16_t ping_id = 0x1234; - uint8_t tx_buf[128]; - size_t len = icmp::build_echo_request( - std::span{tx_buf}, ns.mac, ns.ip, - eth::MAC_BROADCAST, dst_ip, ping_id, 1); - if (len == 0) { - resp.respond(ResponseTest{false, {"build_echo_request failed"}}); - return; - } - - net_send_raw(std::span{tx_buf, len}); + prepend_buffer<128> buf; + icmp::prepend_echo_request(buf, ns.mac, ns.ip, + eth::MAC_BROADCAST, dst_ip, ping_id, 1); + net_send_raw(buf.span()); ipv4::ip4_addr our_ip = ns.ip; @@ -69,7 +151,8 @@ struct test_entry { }; static const std::unordered_map tests = { - {"discovery", {test_discovery, nullptr}}, + {"discovery_igmp", {nullptr, test_discovery_igmp}}, + {"discovery_info", {nullptr, test_discovery_info}}, {"ping_subnet", {nullptr, test_ping_subnet}}, {"ping_global", {nullptr, test_ping_global}}, };