diff --git a/firmware/include/callback_list.h b/firmware/include/callback_list.h new file mode 100644 index 0000000..e6758fa --- /dev/null +++ b/firmware/include/callback_list.h @@ -0,0 +1,87 @@ +#pragma once +#include +#include +#include + +template +struct callback_list { + struct node { + alignas(T) uint8_t storage[sizeof(T)]; + node* prev = nullptr; + node* next = nullptr; + bool active = false; + + T& value() { return *reinterpret_cast(storage); } + const T& value() const { return *reinterpret_cast(storage); } + }; + + node nodes[N]; + node* free_head = &nodes[0]; + node sentinel; + + callback_list() { + for (int i = 0; i < N - 1; i++) nodes[i].next = &nodes[i + 1]; + nodes[N - 1].next = nullptr; + sentinel.prev = &sentinel; + sentinel.next = &sentinel; + } + + bool empty() const { return sentinel.next == &sentinel; } + + node* insert(T value) { + if (!free_head) return nullptr; + node* n = alloc(std::move(value)); + link_before(&sentinel, n); + return n; + } + + template + node* insert_sorted(T value, Less&& less) { + if (!free_head) return nullptr; + node* n = alloc(std::move(value)); + node* pos = sentinel.next; + while (pos != &sentinel && !less(n->value(), pos->value())) + pos = pos->next; + link_before(pos, n); + return n; + } + + void remove(node* n) { + if (!n || !n->active) return; + n->prev->next = n->next; + n->next->prev = n->prev; + n->active = false; + n->value().~T(); + n->next = free_head; + n->prev = nullptr; + free_head = n; + } + + node* front() { return sentinel.next != &sentinel ? sentinel.next : nullptr; } + + template + void for_each(Fn&& fn) { + node* cur = sentinel.next; + while (cur != &sentinel) { + node* next = cur->next; + fn(cur); + cur = next; + } + } + +private: + node* alloc(T value) { + node* n = free_head; + free_head = n->next; + new (n->storage) T(std::move(value)); + n->active = true; + return n; + } + + void link_before(node* pos, node* n) { + n->prev = pos->prev; + n->next = pos; + pos->prev->next = n; + pos->prev = n; + } +}; diff --git a/firmware/include/net.h b/firmware/include/net.h index 0b856fc..ecc8caa 100644 --- a/firmware/include/net.h +++ b/firmware/include/net.h @@ -6,6 +6,7 @@ #include "ipv4.h" #include "span_writer.h" #include "msgpack.h" +#include "callback_list.h" struct net_state { eth::mac_addr mac; @@ -15,11 +16,19 @@ struct net_state { using net_handler = std::function payload, std::function)> send)>; -using net_frame_callback = std::function frame)>; +using net_frame_callback = std::function frame)>; + +struct frame_callback_entry { + net_frame_callback fn; +}; + +using frame_cb_list = callback_list; +using frame_cb_handle = frame_cb_list::node*; bool net_init(); const net_state& net_get_state(); void net_set_handler(net_handler handler); -void net_add_frame_callback(net_frame_callback cb); +frame_cb_handle net_add_frame_callback(net_frame_callback cb); +void net_remove_frame_callback(frame_cb_handle h); void net_poll(std::span tx); void net_send_raw(std::span data); diff --git a/firmware/include/timer_queue.h b/firmware/include/timer_queue.h index 2d07e13..4733622 100644 --- a/firmware/include/timer_queue.h +++ b/firmware/include/timer_queue.h @@ -1,26 +1,25 @@ #pragma once #include #include "pico/time.h" -#include "sorted_list.h" +#include "callback_list.h" struct timer_entry { absolute_time_t when; std::function fn; }; -inline bool operator<(const timer_entry& a, const timer_entry& b) { - return absolute_time_diff_us(b.when, a.when) < 0; -} - -using timer_handle = sorted_list::node*; +using timer_handle = callback_list::node*; struct timer_queue { - sorted_list queue; + callback_list list; alarm_id_t alarm = -1; volatile bool irq_pending = false; timer_handle schedule(absolute_time_t when, std::function fn) { - auto* n = queue.insert({when, std::move(fn)}); + auto* n = list.insert_sorted({when, std::move(fn)}, + [](const timer_entry& a, const timer_entry& b) { + return absolute_time_diff_us(b.when, a.when) < 0; + }); arm(); return n; } @@ -30,25 +29,25 @@ struct timer_queue { } bool cancel(timer_handle h) { - bool removed = queue.remove(h); - if (removed) arm(); - return removed; + if (!h || !h->active) return false; + list.remove(h); + arm(); + return true; } void run() { if (!irq_pending) return; irq_pending = false; - while (!queue.empty()) { - auto& front = queue.front(); - if (absolute_time_diff_us(get_absolute_time(), front.when) > 0) break; - auto fn = std::move(front.fn); - queue.pop_front(); + while (auto* n = list.front()) { + if (absolute_time_diff_us(get_absolute_time(), n->value().when) > 0) break; + auto fn = std::move(n->value().fn); + list.remove(n); fn(); } arm(); } - bool empty() const { return queue.empty(); } + bool empty() const { return list.empty(); } private: static int64_t alarm_cb(alarm_id_t, void* user_data) { @@ -59,7 +58,7 @@ private: void arm() { if (alarm >= 0) cancel_alarm(alarm); alarm = -1; - if (!queue.empty()) - alarm = add_alarm_at(queue.front().when, alarm_cb, this, false); + if (auto* n = list.front()) + alarm = add_alarm_at(n->value().when, alarm_cb, this, false); } }; diff --git a/firmware/lib/net.cpp b/firmware/lib/net.cpp index 2583c0e..da6aedf 100644 --- a/firmware/lib/net.cpp +++ b/firmware/lib/net.cpp @@ -1,5 +1,4 @@ #include "net.h" -#include #include "pico/unique_id.h" #include "pico/time.h" #include "eth.h" @@ -18,7 +17,7 @@ static constexpr uint16_t PICOMAP_PORT = __builtin_bswap16(28781); static net_state state; static w6300::socket_id raw_socket{0}; static net_handler msg_handler; -static std::vector frame_callbacks; +static frame_cb_list frame_callbacks; void net_send_raw(std::span data) { dlog_if_slow("net_send_raw", 1000, [&]{ @@ -72,10 +71,10 @@ static void process_frame(std::span frame, span_writer& tx) { if (!mac_match(eth_hdr.dst)) return; - auto cbs = std::move(frame_callbacks); - frame_callbacks.clear(); - for (auto& cb : cbs) - cb(frame); + frame_callbacks.for_each([&](frame_cb_list::node* n) { + if (n->value().fn(frame)) + frame_callbacks.remove(n); + }); switch (eth_hdr.ethertype) { case eth::ETH_ARP: @@ -123,8 +122,12 @@ void net_set_handler(net_handler handler) { msg_handler = std::move(handler); } -void net_add_frame_callback(net_frame_callback cb) { - frame_callbacks.push_back(std::move(cb)); +frame_cb_handle net_add_frame_callback(net_frame_callback cb) { + return frame_callbacks.insert({std::move(cb)}); +} + +void net_remove_frame_callback(frame_cb_handle h) { + frame_callbacks.remove(h); } void net_poll(std::span tx) { diff --git a/firmware/lib/test_handlers.cpp b/firmware/lib/test_handlers.cpp index 625db72..a14654a 100644 --- a/firmware/lib/test_handlers.cpp +++ b/firmware/lib/test_handlers.cpp @@ -43,39 +43,25 @@ static void discover_peer(peer_callback on_found, fail_callback on_timeout) { ipv4::ip4_addr our_ip = ns.ip; auto timer = std::make_shared(nullptr); - auto cb = std::make_shared)>>(); - *cb = [our_ip, timer, cb, on_found = std::move(on_found)](std::span frame) { + auto cb_id = std::make_shared(nullptr); + *cb_id = net_add_frame_callback([our_ip, timer, cb_id, on_found = std::move(on_found)](std::span frame) -> bool { parse_buffer pb(frame); auto* eth_hdr = pb.consume(); - if (!eth_hdr || eth_hdr->ethertype != eth::ETH_IPV4) { - net_add_frame_callback(*cb); - return; - } + if (!eth_hdr || eth_hdr->ethertype != eth::ETH_IPV4) return false; auto* ip = pb.consume(); - if (!ip || ip->protocol != 17) { - net_add_frame_callback(*cb); - return; - } + if (!ip || ip->protocol != 17) return false; size_t options_len = ip->header_len() - sizeof(ipv4::header); - if (options_len > 0 && !pb.skip(options_len)) { - net_add_frame_callback(*cb); - return; - } + if (options_len > 0 && !pb.skip(options_len)) return false; auto* uhdr = pb.consume(); - if (!uhdr || uhdr->src_port != PICOMAP_PORT) { - net_add_frame_callback(*cb); - return; - } - if (ip->src == our_ip) { - net_add_frame_callback(*cb); - return; - } + if (!uhdr || uhdr->src_port != PICOMAP_PORT) return false; + if (ip->src == our_ip) return false; dispatch_cancel_timer(*timer); on_found({eth_hdr->src, ip->src}); - }; - net_add_frame_callback(*cb); + return true; + }); - *timer = dispatch_schedule_ms(5000, [on_timeout = std::move(on_timeout)]() { + *timer = dispatch_schedule_ms(5000, [cb_id, on_timeout = std::move(on_timeout)]() { + net_remove_frame_callback(*cb_id); on_timeout(); }); } @@ -88,23 +74,18 @@ static void test_discovery_igmp(const responder& resp) { net_send_raw(buf.span()); auto timer = std::make_shared(nullptr); - auto cb = std::make_shared)>>(); - *cb = [resp, timer, cb](std::span frame) { + auto cb_id = std::make_shared(nullptr); + *cb_id = net_add_frame_callback([resp, timer](std::span frame) -> bool { ipv4::ip4_addr group; - if (!igmp::parse_report(frame, group)) { - net_add_frame_callback(*cb); - return; - } - if (group != igmp::PICOMAP_DISCOVERY_GROUP) { - net_add_frame_callback(*cb); - return; - } + if (!igmp::parse_report(frame, group)) return false; + if (group != igmp::PICOMAP_DISCOVERY_GROUP) return false; dispatch_cancel_timer(*timer); resp.respond(ResponseTest{true, {"got IGMP report for " + ipv4::to_string(group)}}); - }; - net_add_frame_callback(*cb); + return true; + }); - *timer = dispatch_schedule_ms(5000, [resp]() { + *timer = dispatch_schedule_ms(5000, [cb_id, resp]() { + net_remove_frame_callback(*cb_id); resp.respond(ResponseTest{false, {"no IGMP report within 5s"}}); }); } @@ -131,23 +112,21 @@ static void test_ping(const responder& resp, ipv4::ip4_addr dst_ip) { ipv4::ip4_addr our_ip = ns.ip; auto timer = std::make_shared(nullptr); - auto cb = std::make_shared)>>(); - *cb = [resp, ping_id, our_ip, timer, cb](std::span frame) { + auto cb_id = std::make_shared(nullptr); + *cb_id = net_add_frame_callback([resp, ping_id, our_ip, timer](std::span frame) -> bool { ipv4::ip4_addr src_ip; - if (!icmp::parse_echo_reply(frame, src_ip, ping_id)) { - net_add_frame_callback(*cb); - return; - } + if (!icmp::parse_echo_reply(frame, src_ip, ping_id)) return false; dispatch_cancel_timer(*timer); if (src_ip == our_ip) { resp.respond(ResponseTest{false, {"got reply from self: " + ipv4::to_string(src_ip)}}); - return; + return true; } resp.respond(ResponseTest{true, {"reply from " + ipv4::to_string(src_ip)}}); - }; - net_add_frame_callback(*cb); + return true; + }); - *timer = dispatch_schedule_ms(5000, [resp]() { + *timer = dispatch_schedule_ms(5000, [cb_id, resp]() { + net_remove_frame_callback(*cb_id); resp.respond(ResponseTest{false, {"no reply from non-self host within 5s"}}); }); } @@ -163,6 +142,7 @@ static void test_ping_global(const responder& resp) { struct ping_rate_state { responder resp; timer_handle timer = nullptr; + frame_cb_handle cb_handle = nullptr; uint16_t ping_id; uint16_t sent = 0; uint16_t received = 0; @@ -190,44 +170,6 @@ static void ping_rate_send_one(std::shared_ptr st) { st->sent++; } -static void ping_rate_recv(std::shared_ptr st, - std::shared_ptr)>> cb, - std::span frame) { - ipv4::ip4_addr src_ip; - if (!icmp::parse_echo_reply(frame, src_ip, st->ping_id)) { - net_add_frame_callback(*cb); - return; - } - if (src_ip == st->our_ip) { - net_add_frame_callback(*cb); - return; - } - - st->received++; - if (st->received >= st->target) { - dispatch_cancel_timer(st->timer); - uint32_t elapsed_us = time_us_32() - st->start_us; - uint32_t elapsed_ms = elapsed_us / 1000; - uint32_t pps = static_cast( - static_cast(st->received) * 1000000 / elapsed_us); - uint64_t total_bytes = static_cast(st->received) * 2 * st->frame_size(); - uint32_t kbps = static_cast(total_bytes * 1000 / elapsed_us); - char msg[128]; - snprintf(msg, sizeof(msg), - "%u rt in %lu ms, %lu pps, %lu bytes, %lu KB/s", - st->received, static_cast(elapsed_ms), - static_cast(pps), - static_cast(total_bytes), - static_cast(kbps)); - st->resp.respond(ResponseTest{true, {msg}}); - return; - } - - if (st->sent < st->target) - ping_rate_send_one(st); - net_add_frame_callback(*cb); -} - static void start_ping_rate(const responder& resp, uint16_t target, uint16_t payload_len, uint16_t pipeline) { auto& ns = net_get_state(); @@ -245,16 +187,41 @@ static void start_ping_rate(const responder& resp, uint16_t target, st->peer = peer; st->start_us = time_us_32(); - auto cb = std::make_shared)>>(); - *cb = [st, cb](std::span frame) { - ping_rate_recv(st, cb, frame); - }; - for (uint16_t i = 0; i < st->pipeline && st->sent < st->target; i++) ping_rate_send_one(st); - net_add_frame_callback(*cb); + + st->cb_handle = net_add_frame_callback([st](std::span frame) -> bool { + ipv4::ip4_addr src_ip; + if (!icmp::parse_echo_reply(frame, src_ip, st->ping_id)) return false; + if (src_ip == st->our_ip) return false; + + st->received++; + if (st->received >= st->target) { + dispatch_cancel_timer(st->timer); + uint32_t elapsed_us = time_us_32() - st->start_us; + uint32_t elapsed_ms = elapsed_us / 1000; + uint32_t pps = static_cast( + static_cast(st->received) * 1000000 / elapsed_us); + uint64_t total_bytes = static_cast(st->received) * 2 * st->frame_size(); + uint32_t kbps = static_cast(total_bytes * 1000 / elapsed_us); + char msg[128]; + snprintf(msg, sizeof(msg), + "%u rt in %lu ms, %lu pps, %lu bytes, %lu KB/s", + st->received, static_cast(elapsed_ms), + static_cast(pps), + static_cast(total_bytes), + static_cast(kbps)); + st->resp.respond(ResponseTest{true, {msg}}); + return true; + } + + if (st->sent < st->target) + ping_rate_send_one(st); + return false; + }); st->timer = dispatch_schedule_ms(10000, [st]() { + net_remove_frame_callback(st->cb_handle); uint32_t elapsed_us = time_us_32() - st->start_us; char msg[64]; snprintf(msg, sizeof(msg), "timeout after %u/%u rt in %lu ms", @@ -269,7 +236,7 @@ static void start_ping_rate(const responder& resp, uint16_t target, } static void test_ping_rate(const responder& resp) { - start_ping_rate(resp, 8192, 0, 1); + start_ping_rate(resp, 8192, 0, 2); } static void test_ping_rate_1k(const responder& resp) {