commit 92d2ce81813b957b5e8c22a6e5fbf283f52e17ba Author: Ian Gulliver Date: Sun Apr 19 17:28:44 2026 -0700 initial import from picomap firmware diff --git a/CMakeLists.txt b/CMakeLists.txt new file mode 100644 index 0000000..94f2187 --- /dev/null +++ b/CMakeLists.txt @@ -0,0 +1,48 @@ +cmake_minimum_required(VERSION 3.13) + +add_library(limen STATIC + src/arp.cpp + src/dispatch.cpp + src/flash.cpp + src/handlers.cpp + src/icmp.cpp + src/igmp.cpp + src/ipv4.cpp + src/net.cpp + src/test_handlers.cpp + src/udp.cpp + w6300/w6300.cpp +) + +target_include_directories(limen PUBLIC + ${CMAKE_CURRENT_SOURCE_DIR}/include + ${CMAKE_CURRENT_SOURCE_DIR}/w6300 +) + +target_compile_options(limen PRIVATE -Wall -Wextra -Wno-unused-parameter) + +target_link_libraries(limen PUBLIC + pico_stdlib + pico_sha256 + pico_unique_id + hardware_pio + hardware_spi + hardware_dma + hardware_clocks +) + +pico_generate_pio_header(limen ${CMAKE_CURRENT_SOURCE_DIR}/w6300/qspi.pio) + +set(LIMEN_PARTITION_TABLE ${CMAKE_CURRENT_SOURCE_DIR}/partition_table.json CACHE INTERNAL "") + +# Apply per-executable limen setup: pt embed, binary hash, extras, copy-to-ram, +# version, disable stdio. Callers still do pico_add_executable + link-libs. +function(limen_configure_executable target) + pico_enable_stdio_usb(${target} 0) + pico_enable_stdio_uart(${target} 0) + pico_set_binary_type(${target} copy_to_ram) + pico_hash_binary(${target}) + pico_embed_pt_in_binary(${target} ${LIMEN_PARTITION_TABLE}) + pico_add_extra_outputs(${target}) + target_compile_options(${target} PRIVATE -Wall -Wextra -Wno-unused-parameter) +endfunction() diff --git a/include/arp.h b/include/arp.h new file mode 100644 index 0000000..0fff18c --- /dev/null +++ b/include/arp.h @@ -0,0 +1,24 @@ +#pragma once +#include +#include "eth.h" +#include "ipv4.h" +#include "span_writer.h" + +namespace arp { + +struct __attribute__((packed)) header { + uint16_t htype; + uint16_t ptype; + uint8_t hlen; + uint8_t plen; + uint16_t oper; + eth::mac_addr sha; + ipv4::ip4_addr spa; + eth::mac_addr tha; + ipv4::ip4_addr tpa; +}; +static_assert(sizeof(header) == 28); + +void handle(std::span frame, span_writer& tx); + +} // namespace arp diff --git a/include/callback_list.h b/include/callback_list.h new file mode 100644 index 0000000..acd6af6 --- /dev/null +++ b/include/callback_list.h @@ -0,0 +1,80 @@ +#pragma once +#include + +template +struct callback_list { + struct node { + T value; + node* prev = nullptr; + node* next = nullptr; + }; + + node nodes[N]; + node* free_head = &nodes[0]; + node* head = nullptr; + + callback_list() { + for (int i = 0; i < N - 1; i++) nodes[i].next = &nodes[i + 1]; + nodes[N - 1].next = nullptr; + } + + bool empty() const { return head == nullptr; } + + node* insert(T value) { + if (!free_head) return nullptr; + node* n = free_head; + free_head = n->next; + n->value = std::move(value); + n->prev = nullptr; + n->next = head; + if (head) head->prev = n; + head = n; + return n; + } + + template + node* insert_sorted(T value, Less&& less) { + if (!free_head) return nullptr; + node* n = free_head; + free_head = n->next; + n->value = std::move(value); + if (!head || less(n->value, head->value)) { + n->prev = nullptr; + n->next = head; + if (head) head->prev = n; + head = n; + return n; + } + node* cur = head; + while (cur->next && !less(n->value, cur->next->value)) + cur = cur->next; + n->prev = cur; + n->next = cur->next; + if (cur->next) cur->next->prev = n; + cur->next = n; + return n; + } + + void remove(node* n) { + if (!n) return; + if (n->prev) n->prev->next = n->next; + else head = n->next; + if (n->next) n->next->prev = n->prev; + n->value = T{}; + n->next = free_head; + n->prev = nullptr; + free_head = n; + } + + node* front() { return head; } + + template + void for_each(Fn&& fn) { + node* cur = head; + while (cur) { + node* next = cur->next; + fn(cur); + cur = next; + } + } +}; diff --git a/include/debug_log.h b/include/debug_log.h new file mode 100644 index 0000000..ce241a9 --- /dev/null +++ b/include/debug_log.h @@ -0,0 +1,39 @@ +#pragma once +#include +#include +#include +#include +#include +#include "pico/time.h" +#include "ring_buffer.h" + +struct log_entry { + uint32_t timestamp_us; + std::string message; +}; + +inline ring_buffer g_debug_log; + +inline void dlog(std::string_view msg) { + g_debug_log.push_overwrite(log_entry{static_cast(time_us_32()), std::string(msg)}); +} + +__attribute__((format(printf, 1, 2))) +inline void dlogf(const char* fmt, ...) { + char buf[128]; + va_list args; + va_start(args, fmt); + vsnprintf(buf, sizeof(buf), fmt, args); + va_end(args); + dlog(buf); +} + +template +inline void dlog_if_slow(std::string_view label, uint32_t threshold_us, F&& fn) { + uint32_t t0 = time_us_32(); + fn(); + uint32_t elapsed = time_us_32() - t0; + if (elapsed > threshold_us) + dlogf("%.*s %luus", static_cast(label.size()), label.data(), static_cast(elapsed)); +} + diff --git a/include/dispatch.h b/include/dispatch.h new file mode 100644 index 0000000..1137c11 --- /dev/null +++ b/include/dispatch.h @@ -0,0 +1,60 @@ +#pragma once +#include +#include +#include +#include +#include "wire.h" +#include "timer_queue.h" +#include "net.h" +#include "prepend_buffer.h" +#include "udp.h" + +uint16_t dispatch_listen_port_be(); + +struct responder { + uint32_t message_id; + udp::address reply_to; + + template + void respond(const T& msg) const { + const auto& ns = net_get_state(); + prepend_buffer<4096> buf; + span_writer out(buf.payload_ptr(), 2048); + auto r = encode_response_into(out, message_id, msg); + if (!r) return; + buf.append(*r); + udp::prepend(buf, reply_to.mac, ns.mac, ns.ip, reply_to.ip, + dispatch_listen_port_be(), reply_to.port, *r); + net_send_raw(buf.span()); + } +}; + +using handler_fn = void (*)(const responder& resp, std::span payload); + +struct handler_entry { + int8_t type_id; + handler_fn handle; +}; + +template +void typed_handler(const responder& resp, std::span payload) { + msgpack::parser p(payload.data(), static_cast(payload.size())); + Req req; + auto tup = req.as_tuple(); + auto r = msgpack::unpack(p, tup); + if (!r) { + char err[64]; + snprintf(err, sizeof(err), "decode request ext_id=%d: msgpack error %d", + Req::ext_id, static_cast(r.error())); + resp.respond(DeviceError{1, err}); + return; + } + auto result = Fn(resp, req); + if (result) + resp.respond(*result); +} + +void dispatch_init(uint16_t listen_port_be); +timer_handle dispatch_schedule_ms(uint32_t ms, void (*fn)()); +bool dispatch_cancel_timer(timer_handle h); +[[noreturn]] void dispatch_run(std::span handlers); diff --git a/include/eth.h b/include/eth.h new file mode 100644 index 0000000..df02e3f --- /dev/null +++ b/include/eth.h @@ -0,0 +1,28 @@ +#pragma once +#include +#include + +namespace eth { + +using mac_addr = std::array; + +static constexpr mac_addr MAC_BROADCAST = {0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF}; +static constexpr uint16_t ETH_ARP = __builtin_bswap16(0x0806); +static constexpr uint16_t ETH_IPV4 = __builtin_bswap16(0x0800); + +struct __attribute__((packed)) header { + mac_addr dst; + mac_addr src; + uint16_t ethertype; +}; +static_assert(sizeof(header) == 14); + +template +void prepend(Buf& buf, const mac_addr& dst, const mac_addr& src, uint16_t ethertype) { + auto* h = buf.template prepend
(); + h->dst = dst; + h->src = src; + h->ethertype = ethertype; +} + +} // namespace eth diff --git a/include/flash.h b/include/flash.h new file mode 100644 index 0000000..f9a166e --- /dev/null +++ b/include/flash.h @@ -0,0 +1,20 @@ +#pragma once +#include +#include + +namespace flash { + +constexpr uint32_t FLASH_BASE = 0x10000000; +constexpr uint32_t FLASH_SIZE = 2 * 1024 * 1024; + +struct slot { + bool valid; + uint32_t version; + bool hash_ok; + auto as_tuple() const { return std::tie(valid, version, hash_ok); } + auto as_tuple() { return std::tie(valid, version, hash_ok); } +}; + +slot scan(uint32_t flash_offset); + +} diff --git a/include/halfsiphash.h b/include/halfsiphash.h new file mode 100644 index 0000000..998a8a8 --- /dev/null +++ b/include/halfsiphash.h @@ -0,0 +1,80 @@ +#pragma once +#include +#include +#include + +namespace halfsiphash { + +namespace detail { + +constexpr uint32_t rotl(uint32_t x, int b) { + return (x << b) | (x >> (32 - b)); +} + +constexpr uint32_t load_le32(const uint8_t *p) { + return static_cast(p[0]) + | (static_cast(p[1]) << 8) + | (static_cast(p[2]) << 16) + | (static_cast(p[3]) << 24); +} + +inline void store_le32(uint8_t *p, uint32_t v) { + p[0] = static_cast(v); + p[1] = static_cast(v >> 8); + p[2] = static_cast(v >> 16); + p[3] = static_cast(v >> 24); +} + +inline void sipround(uint32_t &v0, uint32_t &v1, uint32_t &v2, uint32_t &v3) { + v0 += v1; v1 = rotl(v1, 5); v1 ^= v0; v0 = rotl(v0, 16); + v2 += v3; v3 = rotl(v3, 8); v3 ^= v2; + v0 += v3; v3 = rotl(v3, 7); v3 ^= v0; + v2 += v1; v1 = rotl(v1, 13); v1 ^= v2; v2 = rotl(v2, 16); +} + +} // namespace detail + +// Compute HalfSipHash-2-4 with an 8-byte key, returning a 32-bit hash. +inline uint32_t hash32(std::span data, const uint8_t key[8]) { + using namespace detail; + + uint32_t k0 = load_le32(key); + uint32_t k1 = load_le32(key + 4); + + uint32_t v0 = 0 ^ k0; + uint32_t v1 = 0 ^ k1; + uint32_t v2 = UINT32_C(0x6c796765) ^ k0; + uint32_t v3 = UINT32_C(0x74656462) ^ k1; + + const uint8_t *end = data.data() + data.size() - (data.size() % 4); + for (const uint8_t *p = data.data(); p != end; p += 4) { + uint32_t m = load_le32(p); + v3 ^= m; + sipround(v0, v1, v2, v3); + sipround(v0, v1, v2, v3); + v0 ^= m; + } + + uint32_t b = static_cast(data.size()) << 24; + switch (data.size() & 3) { + case 3: b |= static_cast(end[2]) << 16; [[fallthrough]]; + case 2: b |= static_cast(end[1]) << 8; [[fallthrough]]; + case 1: b |= static_cast(end[0]); break; + case 0: break; + } + + v3 ^= b; + sipround(v0, v1, v2, v3); + sipround(v0, v1, v2, v3); + v0 ^= b; + + v2 ^= 0xff; + sipround(v0, v1, v2, v3); + sipround(v0, v1, v2, v3); + sipround(v0, v1, v2, v3); + sipround(v0, v1, v2, v3); + + return v1 ^ v3; +} + +} // namespace halfsiphash diff --git a/include/handlers.h b/include/handlers.h new file mode 100644 index 0000000..6c740f9 --- /dev/null +++ b/include/handlers.h @@ -0,0 +1,21 @@ +#pragma once +#include +#include +#include +#include "dispatch.h" +#include "ipv4.h" +#include "wire.h" + +inline constexpr uint16_t PICOMAP_PORT_BE = __builtin_bswap16(28781); +inline constexpr ipv4::ip4_addr PICOMAP_DISCOVERY_GROUP = {239, 112, 77, 1}; + +extern std::string_view firmware_name; + +void handlers_init(); +void handlers_start(); +std::optional handle_info(const responder& resp, const RequestInfo&); +std::optional handle_log(const responder& resp, const RequestLog&); +std::optional handle_flash_erase(const responder& resp, const RequestFlashErase&); +std::optional handle_flash_write(const responder& resp, const RequestFlashWrite&); +std::optional handle_reboot(const responder& resp, const RequestReboot&); +std::optional handle_flash_status(const responder& resp, const RequestFlashStatus&); diff --git a/include/icmp.h b/include/icmp.h new file mode 100644 index 0000000..b4f67b9 --- /dev/null +++ b/include/icmp.h @@ -0,0 +1,40 @@ +#pragma once +#include +#include +#include "eth.h" +#include "ipv4.h" +#include "span_writer.h" + +namespace icmp { + +struct __attribute__((packed)) echo { + uint8_t type; + uint8_t code; + uint16_t checksum; + uint16_t id; + uint16_t seq; +}; +static_assert(sizeof(echo) == 8); + +void handle(std::span frame, span_writer& tx); + +template +void prepend_echo_request(Buf& buf, + eth::mac_addr src_mac, ipv4::ip4_addr src_ip, + eth::mac_addr dst_mac, ipv4::ip4_addr dst_ip, + uint16_t id, uint16_t seq, + size_t payload_len = 0) { + auto* e = buf.template prepend(); + e->type = 8; + e->code = 0; + e->checksum = 0; + e->id = id; + e->seq = seq; + size_t icmp_len = sizeof(echo) + payload_len; + e->checksum = ipv4::checksum(e, icmp_len); + ipv4::prepend(buf, dst_mac, src_mac, src_ip, dst_ip, 1, icmp_len); +} + +bool parse_echo_reply(std::span frame, ipv4::ip4_addr& src_ip, uint16_t expected_id); + +} // namespace icmp diff --git a/include/igmp.h b/include/igmp.h new file mode 100644 index 0000000..fc6bf10 --- /dev/null +++ b/include/igmp.h @@ -0,0 +1,57 @@ +#pragma once +#include +#include +#include "eth.h" +#include "ipv4.h" +#include "span_writer.h" + +namespace igmp { + +static constexpr ipv4::ip4_addr ALL_HOSTS = {224, 0, 0, 1}; + +struct __attribute__((packed)) message { + uint8_t type; + uint8_t max_resp_time; + uint16_t checksum; + ipv4::ip4_addr group; +}; +static_assert(sizeof(message) == 8); + +eth::mac_addr mac_for_ip(const ipv4::ip4_addr& group); +bool is_member(const ipv4::ip4_addr& ip); +bool is_member_mac(const eth::mac_addr& mac); + +void join(const ipv4::ip4_addr& group); + +void send_all_reports(); + +void handle(std::span frame, span_writer& tx); + +template +void prepend_report(Buf& buf, const eth::mac_addr& src_mac, ipv4::ip4_addr src_ip, + const ipv4::ip4_addr& group) { + auto* m = buf.template prepend(); + m->type = 0x16; + m->max_resp_time = 0; + m->checksum = 0; + m->group = group; + m->checksum = ipv4::checksum(m, sizeof(message)); + ipv4::prepend(buf, mac_for_ip(group), src_mac, src_ip, group, 2, sizeof(message), 1); +} + +template +void prepend_query(Buf& buf, const eth::mac_addr& src_mac, ipv4::ip4_addr src_ip, + const ipv4::ip4_addr& group) { + ipv4::ip4_addr dst_ip = (group == ipv4::ip4_addr{0, 0, 0, 0}) ? ALL_HOSTS : group; + auto* m = buf.template prepend(); + m->type = 0x11; + m->max_resp_time = 100; + m->checksum = 0; + m->group = group; + m->checksum = ipv4::checksum(m, sizeof(message)); + ipv4::prepend(buf, mac_for_ip(dst_ip), src_mac, src_ip, dst_ip, 2, sizeof(message), 1); +} + +bool parse_report(std::span frame, ipv4::ip4_addr& group); + +} // namespace igmp diff --git a/include/ipv4.h b/include/ipv4.h new file mode 100644 index 0000000..2f4ad92 --- /dev/null +++ b/include/ipv4.h @@ -0,0 +1,67 @@ +#pragma once +#include +#include +#include +#include +#include +#include "eth.h" +#include "span_writer.h" + +namespace ipv4 { + +using ip4_addr = std::array; + +inline std::string to_string(const ip4_addr& ip) { + char buf[16]; + snprintf(buf, sizeof(buf), "%u.%u.%u.%u", ip[0], ip[1], ip[2], ip[3]); + return buf; +} + +struct __attribute__((packed)) header { + uint8_t ver_ihl; + uint8_t dscp_ecn; + uint16_t total_len; + uint16_t identification; + uint16_t flags_frag; + uint8_t ttl; + uint8_t protocol; + uint16_t checksum; + ip4_addr src; + ip4_addr dst; + + size_t header_len() const { return (ver_ihl & 0x0F) * 4; } + size_t total() const { return __builtin_bswap16(total_len); } +}; +static_assert(sizeof(header) == 20); + +uint16_t checksum(const void* data, size_t len); + +static constexpr ip4_addr SUBNET_BROADCAST = {169, 254, 255, 255}; + +template +void prepend(Buf& buf, const eth::mac_addr& dst_mac, const eth::mac_addr& src_mac, + ip4_addr src_ip, ip4_addr dst_ip, uint8_t protocol, + size_t payload_len, uint8_t ttl = 64) { + auto* h = buf.template prepend
(); + h->ver_ihl = 0x45; + h->dscp_ecn = 0; + h->total_len = __builtin_bswap16(sizeof(header) + payload_len); + h->identification = 0; + h->flags_frag = 0; + h->ttl = ttl; + h->protocol = protocol; + h->checksum = 0; + h->src = src_ip; + h->dst = dst_ip; + h->checksum = checksum(h, sizeof(header)); + eth::prepend(buf, dst_mac, src_mac, eth::ETH_IPV4); +} + +void handle(std::span frame, span_writer& tx); + +bool addressed_to_us(ip4_addr dst); + +using protocol_handler = void (*)(std::span frame, span_writer& tx); +void register_protocol(uint8_t protocol, protocol_handler fn); + +} // namespace ipv4 diff --git a/include/msgpack.h b/include/msgpack.h new file mode 100644 index 0000000..ae3e331 --- /dev/null +++ b/include/msgpack.h @@ -0,0 +1,857 @@ +#pragma once +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "span_writer.h" + +namespace msgpack { + +enum class error_code { + overflow, + empty, + lack, + invalid, + type_error, +}; + +namespace format { + constexpr uint8_t POSITIVE_FIXINT_MIN = 0x00; + constexpr uint8_t POSITIVE_FIXINT_MAX = 0x7F; + constexpr uint8_t FIXMAP_MIN = 0x80; + constexpr uint8_t FIXMAP_MAX = 0x8F; + constexpr uint8_t FIXARRAY_MIN = 0x90; + constexpr uint8_t FIXARRAY_MAX = 0x9F; + constexpr uint8_t FIXSTR_MIN = 0xA0; + constexpr uint8_t FIXSTR_MAX = 0xBF; + constexpr uint8_t NEGATIVE_FIXINT_MIN = 0xE0; + constexpr uint8_t NEGATIVE_FIXINT_MAX = 0xFF; + + constexpr uint8_t NIL = 0xC0; + constexpr uint8_t NEVER_USED = 0xC1; + constexpr uint8_t FALSE = 0xC2; + constexpr uint8_t TRUE = 0xC3; + constexpr uint8_t BIN8 = 0xC4; + constexpr uint8_t BIN16 = 0xC5; + constexpr uint8_t BIN32 = 0xC6; + constexpr uint8_t EXT8 = 0xC7; + constexpr uint8_t EXT16 = 0xC8; + constexpr uint8_t EXT32 = 0xC9; + constexpr uint8_t FLOAT32 = 0xCA; + constexpr uint8_t FLOAT64 = 0xCB; + constexpr uint8_t UINT8 = 0xCC; + constexpr uint8_t UINT16 = 0xCD; + constexpr uint8_t UINT32 = 0xCE; + constexpr uint8_t UINT64 = 0xCF; + constexpr uint8_t INT8 = 0xD0; + constexpr uint8_t INT16 = 0xD1; + constexpr uint8_t INT32 = 0xD2; + constexpr uint8_t INT64 = 0xD3; + constexpr uint8_t FIXEXT1 = 0xD4; + constexpr uint8_t FIXEXT2 = 0xD5; + constexpr uint8_t FIXEXT4 = 0xD6; + constexpr uint8_t FIXEXT8 = 0xD7; + constexpr uint8_t FIXEXT16 = 0xD8; + constexpr uint8_t STR8 = 0xD9; + constexpr uint8_t STR16 = 0xDA; + constexpr uint8_t STR32 = 0xDB; + constexpr uint8_t ARRAY16 = 0xDC; + constexpr uint8_t ARRAY32 = 0xDD; + constexpr uint8_t MAP16 = 0xDE; + constexpr uint8_t MAP32 = 0xDF; + + constexpr bool is_positive_fixint(uint8_t b) { return b <= POSITIVE_FIXINT_MAX; } + constexpr bool is_fixmap(uint8_t b) { return b >= FIXMAP_MIN && b <= FIXMAP_MAX; } + constexpr bool is_fixarray(uint8_t b) { return b >= FIXARRAY_MIN && b <= FIXARRAY_MAX; } + constexpr bool is_fixstr(uint8_t b) { return b >= FIXSTR_MIN && b <= FIXSTR_MAX; } + constexpr bool is_negative_fixint(uint8_t b) { return b >= NEGATIVE_FIXINT_MIN; } +} // namespace format + +template +using result = std::expected; + +template +result body_number(const uint8_t *p, int size) { + if (size < 1 + static_cast(sizeof(T))) { + return std::unexpected(error_code::lack); + } + if constexpr (sizeof(T) == 1) { + return static_cast(p[1]); + } else if constexpr (sizeof(T) == 2) { + return static_cast((p[1] << 8) | p[2]); + } else if constexpr (sizeof(T) == 4) { + uint8_t buf[] = {p[4], p[3], p[2], p[1]}; + T val; + __builtin_memcpy(&val, buf, sizeof(T)); + return val; + } else if constexpr (sizeof(T) == 8) { + uint8_t buf[] = {p[8], p[7], p[6], p[5], p[4], p[3], p[2], p[1]}; + T val; + __builtin_memcpy(&val, buf, sizeof(T)); + return val; + } else { + return std::unexpected(error_code::invalid); + } +} + +struct body_info { + int header; // bytes before the body (includes format byte + length fields + ext type byte) + uint32_t body; // body size in bytes (0 for containers, computed for variable-length) +}; + +inline result get_body_info(const uint8_t *p, int size) { + if (size < 1) return std::unexpected(error_code::empty); + uint8_t b = p[0]; + + using namespace format; + + if (is_positive_fixint(b)) return body_info{1, 0}; + if (is_negative_fixint(b)) return body_info{1, 0}; + if (is_fixmap(b)) return body_info{1, 0}; // container + if (is_fixarray(b)) return body_info{1, 0}; // container + if (is_fixstr(b)) return body_info{1, static_cast(b & 0x1F)}; + + switch (b) { + case NIL: case FALSE: case TRUE: + return body_info{1, 0}; + case NEVER_USED: + return std::unexpected(error_code::invalid); + + case BIN8: { auto n = body_number(p, size); if (!n) return std::unexpected(n.error()); return body_info{1+1, *n}; } + case BIN16: { auto n = body_number(p, size); if (!n) return std::unexpected(n.error()); return body_info{1+2, *n}; } + case BIN32: { auto n = body_number(p, size); if (!n) return std::unexpected(n.error()); return body_info{1+4, *n}; } + + case EXT8: { auto n = body_number(p, size); if (!n) return std::unexpected(n.error()); return body_info{1+1+1, *n}; } + case EXT16: { auto n = body_number(p, size); if (!n) return std::unexpected(n.error()); return body_info{1+2+1, *n}; } + case EXT32: { auto n = body_number(p, size); if (!n) return std::unexpected(n.error()); return body_info{1+4+1, *n}; } + + case FLOAT32: return body_info{1, 4}; + case FLOAT64: return body_info{1, 8}; + case UINT8: return body_info{1, 1}; + case UINT16: return body_info{1, 2}; + case UINT32: return body_info{1, 4}; + case UINT64: return body_info{1, 8}; + case INT8: return body_info{1, 1}; + case INT16: return body_info{1, 2}; + case INT32: return body_info{1, 4}; + case INT64: return body_info{1, 8}; + + case FIXEXT1: return body_info{1+1, 1}; + case FIXEXT2: return body_info{1+1, 2}; + case FIXEXT4: return body_info{1+1, 4}; + case FIXEXT8: return body_info{1+1, 8}; + case FIXEXT16: return body_info{1+1, 16}; + + case STR8: { auto n = body_number(p, size); if (!n) return std::unexpected(n.error()); return body_info{1+1, *n}; } + case STR16: { auto n = body_number(p, size); if (!n) return std::unexpected(n.error()); return body_info{1+2, *n}; } + case STR32: { auto n = body_number(p, size); if (!n) return std::unexpected(n.error()); return body_info{1+4, *n}; } + + case ARRAY16: case ARRAY32: + case MAP16: case MAP32: + return body_info{1 + (b == ARRAY16 || b == MAP16 ? 2 : 4), 0}; // container + + default: + return std::unexpected(error_code::invalid); + } +} + +class packer { +private: + span_writer &m_buf; + + template void push_big_endian(T n) { + auto p = reinterpret_cast(&n) + (sizeof(T) - 1); + for (size_t i = 0; i < sizeof(T); ++i, --p) { + m_buf.push_back(*p); + } + } + + template void push(const Range &r) { + m_buf.insert(m_buf.end(), std::begin(r), std::end(r)); + } + +public: + packer(span_writer &buf) : m_buf(buf) {} + + packer(const packer &) = delete; + packer &operator=(const packer &) = delete; + + using pack_result = result>; + + pack_result pack_nil() { + m_buf.push_back(format::NIL); + return *this; + } + + pack_result pack_bool(bool v) { + m_buf.push_back(v ? format::TRUE : format::FALSE); + return *this; + } + + template + pack_result pack_integer(T n) { + if constexpr (std::is_signed_v) { + if (n >= 0 && n <= 0x7F) { + m_buf.push_back(static_cast(n)); + } else if (n >= -32 && n < 0) { + m_buf.push_back(static_cast(n)); // negative fixint + } else if (n >= std::numeric_limits::min() && n <= std::numeric_limits::max()) { + m_buf.push_back(format::INT8); + m_buf.push_back(static_cast(n)); + } else if (n >= std::numeric_limits::min() && n <= std::numeric_limits::max()) { + m_buf.push_back(format::INT16); + push_big_endian(static_cast(n)); + } else if (n >= std::numeric_limits::min() && n <= std::numeric_limits::max()) { + m_buf.push_back(format::INT32); + push_big_endian(static_cast(n)); + } else { + m_buf.push_back(format::INT64); + push_big_endian(static_cast(n)); + } + } else { + if (n <= 0x7F) { + m_buf.push_back(static_cast(n)); + } else if (n <= std::numeric_limits::max()) { + m_buf.push_back(format::UINT8); + m_buf.push_back(static_cast(n)); + } else if (n <= std::numeric_limits::max()) { + m_buf.push_back(format::UINT16); + push_big_endian(static_cast(n)); + } else if (n <= std::numeric_limits::max()) { + m_buf.push_back(format::UINT32); + push_big_endian(static_cast(n)); + } else { + m_buf.push_back(format::UINT64); + push_big_endian(static_cast(n)); + } + } + return *this; + } + + pack_result pack_uint32_fixed(uint32_t n) { + m_buf.push_back(format::UINT32); + push_big_endian(n); + return *this; + } + + pack_result pack_float(float n) { + m_buf.push_back(format::FLOAT32); + push_big_endian(n); + return *this; + } + + pack_result pack_double(double n) { + m_buf.push_back(format::FLOAT64); + push_big_endian(n); + return *this; + } + + template + pack_result pack_str(const Range &r) { + auto sz = static_cast(std::distance(std::begin(r), std::end(r))); + if (sz < 32) { + m_buf.push_back(format::FIXSTR_MIN | static_cast(sz)); + } else if (sz <= std::numeric_limits::max()) { + m_buf.push_back(format::STR8); + m_buf.push_back(static_cast(sz)); + } else if (sz <= std::numeric_limits::max()) { + m_buf.push_back(format::STR16); + push_big_endian(static_cast(sz)); + } else if (sz <= std::numeric_limits::max()) { + m_buf.push_back(format::STR32); + push_big_endian(static_cast(sz)); + } else { + return std::unexpected(error_code::overflow); + } + push(r); + return *this; + } + + pack_result pack_str(const char *s) { + return pack_str(std::string_view(s)); + } + + template + pack_result pack_bin(const Range &r) { + auto sz = static_cast(std::distance(std::begin(r), std::end(r))); + if (sz <= std::numeric_limits::max()) { + m_buf.push_back(format::BIN8); + m_buf.push_back(static_cast(sz)); + } else if (sz <= std::numeric_limits::max()) { + m_buf.push_back(format::BIN16); + push_big_endian(static_cast(sz)); + } else if (sz <= std::numeric_limits::max()) { + m_buf.push_back(format::BIN32); + push_big_endian(static_cast(sz)); + } else { + return std::unexpected(error_code::overflow); + } + push(r); + return *this; + } + + pack_result pack_array(size_t n) { + if (n <= 15) { + m_buf.push_back(format::FIXARRAY_MIN | static_cast(n)); + } else if (n <= std::numeric_limits::max()) { + m_buf.push_back(format::ARRAY16); + push_big_endian(static_cast(n)); + } else if (n <= std::numeric_limits::max()) { + m_buf.push_back(format::ARRAY32); + push_big_endian(static_cast(n)); + } else { + return std::unexpected(error_code::overflow); + } + return *this; + } + + pack_result pack_map(size_t n) { + if (n <= 15) { + m_buf.push_back(format::FIXMAP_MIN | static_cast(n)); + } else if (n <= std::numeric_limits::max()) { + m_buf.push_back(format::MAP16); + push_big_endian(static_cast(n)); + } else if (n <= std::numeric_limits::max()) { + m_buf.push_back(format::MAP32); + push_big_endian(static_cast(n)); + } else { + return std::unexpected(error_code::overflow); + } + return *this; + } + + pack_result pack_ext16_header(char type, uint16_t len) { + m_buf.push_back(format::EXT16); + push_big_endian(len); + m_buf.push_back(static_cast(type)); + return *this; + } + + pack_result pack_bin16_header(uint16_t len) { + m_buf.push_back(format::BIN16); + push_big_endian(len); + return *this; + } + + template + pack_result pack_ext(char type, const Range &r) { + auto sz = static_cast(std::distance(std::begin(r), std::end(r))); + + switch (sz) { + case 1: m_buf.push_back(format::FIXEXT1); break; + case 2: m_buf.push_back(format::FIXEXT2); break; + case 4: m_buf.push_back(format::FIXEXT4); break; + case 8: m_buf.push_back(format::FIXEXT8); break; + case 16: m_buf.push_back(format::FIXEXT16); break; + default: + if (sz <= std::numeric_limits::max()) { + m_buf.push_back(format::EXT8); + m_buf.push_back(static_cast(sz)); + } else if (sz <= std::numeric_limits::max()) { + m_buf.push_back(format::EXT16); + push_big_endian(static_cast(sz)); + } else if (sz <= std::numeric_limits::max()) { + m_buf.push_back(format::EXT32); + push_big_endian(static_cast(sz)); + } else { + return std::unexpected(error_code::overflow); + } + } + m_buf.push_back(static_cast(type)); + push(r); + return *this; + } + + template + requires std::is_integral_v && (!std::is_same_v) + pack_result pack(T n) { return pack_integer(n); } + + template + requires std::is_enum_v + pack_result pack(T v) { return pack_integer(static_cast>(v)); } + + pack_result pack(bool v) { return pack_bool(v); } + pack_result pack(float v) { return pack_float(v); } + pack_result pack(double v) { return pack_double(v); } + pack_result pack(const char *v) { return pack_str(v); } + pack_result pack(std::string_view v) { return pack_str(v); } + pack_result pack(const std::string &v) { return pack_str(v); } + + pack_result pack(const std::vector &v) { return pack_bin(v); } + + template + requires (!std::is_same_v) + pack_result pack(const std::vector &v) { + auto r = pack_array(v.size()); + if (!r) return r; + for (auto& elem : v) { + r = r->get().pack(elem); + if (!r) return r; + } + return r; + } + + template + pack_result pack(const std::array &v) { return pack_bin(v); } + + template + pack_result pack(const std::tuple &t) { + auto r = pack_array(sizeof...(Ts)); + if (!r) return r; + return pack_tuple_elements(t, std::index_sequence_for{}); + } + + template + requires requires(const T &v) { { T::ext_id } -> std::convertible_to; v.as_tuple(); } + pack_result pack(const T &v) { + uint8_t ext_buf[256]; + span_writer ext_writer(ext_buf, sizeof(ext_buf)); + packer inner(ext_writer); + auto r = inner.pack(v.as_tuple()); + if (!r) return r; + return pack_ext(T::ext_id, inner.get_payload()); + } + + template + requires (requires(const T &v) { v.as_tuple(); } && !requires { { T::ext_id } -> std::convertible_to; }) + pack_result pack(const T &v) { + return pack(v.as_tuple()); + } + +private: + template + pack_result pack_tuple_elements(const Tuple &t, std::index_sequence) { + pack_result r = *this; + ((r = r ? r->get().pack(std::get(t)) : r), ...); + return r; + } + +public: + const span_writer &get_payload() const { return m_buf; } +}; + +class parser { + const uint8_t *m_p = nullptr; + int m_size = 0; + + result header_byte() const { + if (m_size < 1) return std::unexpected(error_code::empty); + return m_p[0]; + } + +public: + parser() = default; + + parser(const std::vector &v) + : m_p(v.data()), m_size(static_cast(v.size())) {} + + parser(const uint8_t *p, int size) + : m_p(p), m_size(size < 0 ? 0 : size) {} + + bool is_empty() const { return m_size == 0; } + const uint8_t *data() const { return m_p; } + int size() const { return m_size; } + + result advance(int n) const { + if (n > m_size) return std::unexpected(error_code::lack); + return parser(m_p + n, m_size - n); + } + + result next() const { + auto hdr = header_byte(); + if (!hdr) return std::unexpected(hdr.error()); + + if (is_array()) { + auto info = get_body_info(m_p, m_size); + if (!info) return std::unexpected(info.error()); + auto cnt = count(); + if (!cnt) return std::unexpected(cnt.error()); + auto cur = advance(info->header); + if (!cur) return std::unexpected(cur.error()); + for (uint32_t i = 0; i < *cnt; ++i) { + auto n = cur->next(); + if (!n) return std::unexpected(n.error()); + cur = *n; + } + return *cur; + } else if (is_map()) { + auto info = get_body_info(m_p, m_size); + if (!info) return std::unexpected(info.error()); + auto cnt = count(); + if (!cnt) return std::unexpected(cnt.error()); + auto cur = advance(info->header); + if (!cur) return std::unexpected(cur.error()); + for (uint32_t i = 0; i < *cnt; ++i) { + auto k = cur->next(); + if (!k) return std::unexpected(k.error()); + cur = *k; + auto v = cur->next(); + if (!v) return std::unexpected(v.error()); + cur = *v; + } + return *cur; + } else { + auto info = get_body_info(m_p, m_size); + if (!info) return std::unexpected(info.error()); + auto total = info->header + static_cast(info->body); + return advance(total); + } + } + + bool is_nil() const { + auto h = header_byte(); + return h && *h == format::NIL; + } + + bool is_bool() const { + auto h = header_byte(); + return h && (*h == format::TRUE || *h == format::FALSE); + } + + bool is_number() const { + auto h = header_byte(); + if (!h) return false; + uint8_t b = *h; + if (format::is_positive_fixint(b)) return true; + if (format::is_negative_fixint(b)) return true; + return b >= format::FLOAT32 && b <= format::INT64; + } + + bool is_string() const { + auto h = header_byte(); + if (!h) return false; + uint8_t b = *h; + if (format::is_fixstr(b)) return true; + return b == format::STR8 || b == format::STR16 || b == format::STR32; + } + + bool is_binary() const { + auto h = header_byte(); + if (!h) return false; + uint8_t b = *h; + return b == format::BIN8 || b == format::BIN16 || b == format::BIN32; + } + + bool is_ext() const { + auto h = header_byte(); + if (!h) return false; + uint8_t b = *h; + return (b >= format::FIXEXT1 && b <= format::FIXEXT16) || + b == format::EXT8 || b == format::EXT16 || b == format::EXT32; + } + + bool is_array() const { + auto h = header_byte(); + if (!h) return false; + uint8_t b = *h; + if (format::is_fixarray(b)) return true; + return b == format::ARRAY16 || b == format::ARRAY32; + } + + bool is_map() const { + auto h = header_byte(); + if (!h) return false; + uint8_t b = *h; + if (format::is_fixmap(b)) return true; + return b == format::MAP16 || b == format::MAP32; + } + + + result get_bool() const { + auto h = header_byte(); + if (!h) return std::unexpected(h.error()); + if (*h == format::TRUE) return true; + if (*h == format::FALSE) return false; + return std::unexpected(error_code::type_error); + } + + result get_string() const { + auto h = header_byte(); + if (!h) return std::unexpected(h.error()); + uint8_t b = *h; + size_t offset, len; + if (format::is_fixstr(b)) { + len = b & 0x1F; + offset = 1; + } else if (b == format::STR8) { + auto n = body_number(m_p, m_size); + if (!n) return std::unexpected(n.error()); + len = *n; offset = 1 + 1; + } else if (b == format::STR16) { + auto n = body_number(m_p, m_size); + if (!n) return std::unexpected(n.error()); + len = *n; offset = 1 + 2; + } else if (b == format::STR32) { + auto n = body_number(m_p, m_size); + if (!n) return std::unexpected(n.error()); + len = *n; offset = 1 + 4; + } else { + return std::unexpected(error_code::type_error); + } + if (static_cast(offset + len) > m_size) { + return std::unexpected(error_code::lack); + } + return std::string_view(reinterpret_cast(m_p + offset), len); + } + + result get_binary_view() const { + auto h = header_byte(); + if (!h) return std::unexpected(h.error()); + uint8_t b = *h; + size_t offset, len; + + if (b == format::BIN8) { + auto n = body_number(m_p, m_size); + if (!n) return std::unexpected(n.error()); + len = *n; offset = 1 + 1; + } else if (b == format::BIN16) { + auto n = body_number(m_p, m_size); + if (!n) return std::unexpected(n.error()); + len = *n; offset = 1 + 2; + } else if (b == format::BIN32) { + auto n = body_number(m_p, m_size); + if (!n) return std::unexpected(n.error()); + len = *n; offset = 1 + 4; + } else { + return std::unexpected(error_code::type_error); + } + if (static_cast(offset + len) > m_size) { + return std::unexpected(error_code::lack); + } + return std::string_view(reinterpret_cast(m_p + offset), len); + } + + result> get_ext() const { + auto h = header_byte(); + if (!h) return std::unexpected(h.error()); + uint8_t b = *h; + int8_t ext_type; + size_t data_offset, data_len; + + switch (b) { + case format::FIXEXT1: ext_type = m_p[1]; data_offset = 2; data_len = 1; break; + case format::FIXEXT2: ext_type = m_p[1]; data_offset = 2; data_len = 2; break; + case format::FIXEXT4: ext_type = m_p[1]; data_offset = 2; data_len = 4; break; + case format::FIXEXT8: ext_type = m_p[1]; data_offset = 2; data_len = 8; break; + case format::FIXEXT16: ext_type = m_p[1]; data_offset = 2; data_len = 16; break; + case format::EXT8: { + auto n = body_number(m_p, m_size); + if (!n) return std::unexpected(n.error()); + ext_type = m_p[2]; data_offset = 3; data_len = *n; + break; + } + case format::EXT16: { + auto n = body_number(m_p, m_size); + if (!n) return std::unexpected(n.error()); + ext_type = m_p[3]; data_offset = 4; data_len = *n; + break; + } + case format::EXT32: { + auto n = body_number(m_p, m_size); + if (!n) return std::unexpected(n.error()); + ext_type = m_p[5]; data_offset = 6; data_len = *n; + break; + } + default: + return std::unexpected(error_code::type_error); + } + if (static_cast(data_offset + data_len) > m_size) { + return std::unexpected(error_code::lack); + } + return std::tuple{ext_type, + std::string_view(reinterpret_cast(m_p + data_offset), data_len)}; + } + + template + result get_number() const { + auto h = header_byte(); + if (!h) return std::unexpected(h.error()); + uint8_t b = *h; + + if (format::is_positive_fixint(b)) return static_cast(b); + if (format::is_negative_fixint(b)) return static_cast(static_cast(b)); + + switch (b) { + case format::UINT8: { auto n = body_number(m_p, m_size); if (!n) return std::unexpected(n.error()); return static_cast(*n); } + case format::UINT16: { auto n = body_number(m_p, m_size); if (!n) return std::unexpected(n.error()); return static_cast(*n); } + case format::UINT32: { auto n = body_number(m_p, m_size); if (!n) return std::unexpected(n.error()); return static_cast(*n); } + case format::UINT64: { auto n = body_number(m_p, m_size); if (!n) return std::unexpected(n.error()); return static_cast(*n); } + case format::INT8: { auto n = body_number(m_p, m_size); if (!n) return std::unexpected(n.error()); return static_cast(*n); } + case format::INT16: { auto n = body_number(m_p, m_size); if (!n) return std::unexpected(n.error()); return static_cast(*n); } + case format::INT32: { auto n = body_number(m_p, m_size); if (!n) return std::unexpected(n.error()); return static_cast(*n); } + case format::INT64: { auto n = body_number(m_p, m_size); if (!n) return std::unexpected(n.error()); return static_cast(*n); } + case format::FLOAT32: { auto n = body_number(m_p, m_size); if (!n) return std::unexpected(n.error()); return static_cast(*n); } + case format::FLOAT64: { auto n = body_number(m_p, m_size); if (!n) return std::unexpected(n.error()); return static_cast(*n); } + default: + return std::unexpected(error_code::type_error); + } + } + + result count() const { + auto h = header_byte(); + if (!h) return std::unexpected(h.error()); + uint8_t b = *h; + + if (format::is_fixarray(b)) return static_cast(b & 0x0F); + if (format::is_fixmap(b)) return static_cast(b & 0x0F); + + switch (b) { + case format::ARRAY16: { auto n = body_number(m_p, m_size); if (!n) return std::unexpected(n.error()); return static_cast(*n); } + case format::ARRAY32: { auto n = body_number(m_p, m_size); if (!n) return std::unexpected(n.error()); return *n; } + case format::MAP16: { auto n = body_number(m_p, m_size); if (!n) return std::unexpected(n.error()); return static_cast(*n); } + case format::MAP32: { auto n = body_number(m_p, m_size); if (!n) return std::unexpected(n.error()); return *n; } + default: + return std::unexpected(error_code::type_error); + } + } + + result first_item() const { + if (!is_array() && !is_map()) return std::unexpected(error_code::type_error); + auto info = get_body_info(m_p, m_size); + if (!info) return std::unexpected(info.error()); + return advance(info->header); + } + + parser operator[](int index) const { + auto cur = first_item(); + if (!cur) return {}; + for (int i = 0; i < index; ++i) { + auto n = cur->next(); + if (!n) return {}; + cur = *n; + } + return *cur; + } +}; + +template + requires std::is_enum_v +result unpack(const parser &p, T &out) { + std::underlying_type_t v; + auto r = unpack(p, v); + if (!r) return r; + out = static_cast(v); + return r; +} + +template + requires std::is_integral_v && (!std::is_same_v) +result unpack(const parser &p, T &out) { + auto v = p.get_number(); + if (!v) return std::unexpected(v.error()); + out = *v; + return p.next(); +} + +inline result unpack(const parser &p, bool &out) { + auto v = p.get_bool(); + if (!v) return std::unexpected(v.error()); + out = *v; + return p.next(); +} + +inline result unpack(const parser &p, std::string_view &out) { + auto v = p.get_string(); + if (!v) return std::unexpected(v.error()); + out = *v; + return p.next(); +} + +inline result unpack(const parser &p, std::string &out) { + auto v = p.get_string(); + if (!v) return std::unexpected(v.error()); + out = std::string(v->data(), v->size()); + return p.next(); +} + +template +result unpack(const parser &p, std::array &out) { + auto v = p.get_binary_view(); + if (!v) return std::unexpected(v.error()); + if (v->size() != N) return std::unexpected(error_code::type_error); + std::copy(v->begin(), v->end(), out.begin()); + return p.next(); +} + +inline result unpack(const parser &p, std::vector &out) { + auto v = p.get_binary_view(); + if (!v) return std::unexpected(v.error()); + out.assign(v->begin(), v->end()); + return p.next(); +} + +inline result unpack(const parser &p, std::span &out) { + auto v = p.get_binary_view(); + if (!v) return std::unexpected(v.error()); + out = std::span(reinterpret_cast(v->data()), v->size()); + return p.next(); +} + +template + requires (!std::is_same_v) +result unpack(const parser &p, std::vector &out) { + auto cnt = p.count(); + if (!cnt) return std::unexpected(cnt.error()); + out.resize(*cnt); + result cur = p.first_item(); + for (size_t i = 0; i < *cnt; i++) { + if (!cur) return cur; + cur = unpack(*cur, out[i]); + } + if (!cur) return cur; + return p.next(); +} + +template +result unpack_tuple_elements(const parser &p, std::tuple &t, std::index_sequence) { + result cur = p.first_item(); + if (!cur) return cur; + ((cur = cur ? unpack(*cur, std::get(t)) : cur), ...); + return cur; +} + +template +result unpack(const parser &p, std::tuple &t) { + auto cnt = p.count(); + if (!cnt) return std::unexpected(cnt.error()); + if (*cnt != sizeof...(Ts)) return std::unexpected(error_code::type_error); + auto r = unpack_tuple_elements(p, t, std::index_sequence_for{}); + if (!r) return r; + return p.next(); +} + +template + requires (requires(T &v) { v.as_tuple(); } && !requires { { T::ext_id } -> std::convertible_to; }) +result unpack(const parser &p, T &out) { + auto tup = out.as_tuple(); + auto cnt = p.count(); + if (!cnt) return std::unexpected(cnt.error()); + if (*cnt != std::tuple_size_v) return std::unexpected(error_code::type_error); + auto r = unpack_tuple_elements(p, tup, std::make_index_sequence>{}); + if (!r) return r; + return p.next(); +} + +template + requires requires(T &v) { { T::ext_id } -> std::convertible_to; v.as_tuple(); } +result unpack(const parser &p, T &out) { + auto ext = p.get_ext(); + if (!ext) return std::unexpected(ext.error()); + auto [ext_type, ext_data] = *ext; + if (ext_type != T::ext_id) return std::unexpected(error_code::type_error); + parser inner(reinterpret_cast(ext_data.data()), + static_cast(ext_data.size())); + auto tup = out.as_tuple(); + auto r = unpack_tuple_elements(inner, tup, std::make_index_sequence>{}); + if (!r) return r; + return p.next(); +} + +} // namespace msgpack diff --git a/include/net.h b/include/net.h new file mode 100644 index 0000000..813204f --- /dev/null +++ b/include/net.h @@ -0,0 +1,27 @@ +#pragma once +#include +#include +#include "eth.h" +#include "ipv4.h" +#include "span_writer.h" +#include "callback_list.h" + +struct net_state { + eth::mac_addr mac; + ipv4::ip4_addr ip; +}; + +using net_frame_callback = bool (*)(std::span frame); + +using frame_cb_list = callback_list; +using frame_cb_handle = frame_cb_list::node*; + +using ethertype_handler = void (*)(std::span frame, span_writer& tx); + +bool net_init(); +const net_state& net_get_state(); +frame_cb_handle net_add_frame_callback(net_frame_callback cb); +void net_remove_frame_callback(frame_cb_handle h); +void net_poll(std::span tx); +void net_send_raw(std::span data); +void net_register_ethertype(uint16_t ethertype_be, ethertype_handler fn); diff --git a/include/parse_buffer.h b/include/parse_buffer.h new file mode 100644 index 0000000..9ee7014 --- /dev/null +++ b/include/parse_buffer.h @@ -0,0 +1,32 @@ +#pragma once +#include +#include +#include + +class parse_buffer { + const uint8_t* m_data; + size_t m_remaining; + +public: + parse_buffer(std::span data) + : m_data(data.data()), m_remaining(data.size()) {} + + template + const T* consume() { + if (m_remaining < sizeof(T)) return nullptr; + auto* p = reinterpret_cast(m_data); + m_data += sizeof(T); + m_remaining -= sizeof(T); + return p; + } + + bool skip(size_t len) { + if (m_remaining < len) return false; + m_data += len; + m_remaining -= len; + return true; + } + + std::span remaining() const { return {m_data, m_remaining}; } + size_t remaining_size() const { return m_remaining; } +}; diff --git a/include/prepend_buffer.h b/include/prepend_buffer.h new file mode 100644 index 0000000..9dc8155 --- /dev/null +++ b/include/prepend_buffer.h @@ -0,0 +1,37 @@ +#pragma once +#include +#include +#include +#include + +template +class prepend_buffer { + uint8_t m_buf[N]; + size_t m_start = N / 2; + size_t m_end = N / 2; + +public: + template + T* prepend() { + m_start -= sizeof(T); + return reinterpret_cast(m_buf + m_start); + } + + uint8_t* append(size_t len) { + uint8_t* p = m_buf + m_end; + m_end += len; + return p; + } + + void append_copy(std::span data) { + memcpy(append(data.size()), data.data(), data.size()); + } + + uint8_t* payload_ptr() { return m_buf + m_end; } + + uint8_t* data() { return m_buf + m_start; } + const uint8_t* data() const { return m_buf + m_start; } + size_t size() const { return m_end - m_start; } + + std::span span() const { return {data(), size()}; } +}; diff --git a/include/ring_buffer.h b/include/ring_buffer.h new file mode 100644 index 0000000..5220335 --- /dev/null +++ b/include/ring_buffer.h @@ -0,0 +1,67 @@ +#pragma once +#include +#include +#include + +template +struct ring_buffer { + std::array data = {}; + uint16_t head = 0; + uint16_t tail = 0; + + uint16_t used() const { return tail - head; } + uint16_t free() const { return N - used(); } + bool empty() const { return head == tail; } + + bool push(std::span src) { + if (src.size() > free()) return false; + for (auto& v : src) + data[(tail++) % N] = v; + return true; + } + + bool push(const T& v) { + if (free() == 0) return false; + data[(tail++) % N] = v; + return true; + } + + void push_overwrite(const T& v) { + if (free() == 0) head++; + data[(tail++) % N] = v; + } + + uint16_t peek(std::span dst) const { + uint16_t len = dst.size() < used() ? dst.size() : used(); + for (uint16_t i = 0; i < len; i++) + dst[i] = data[(head + i) % N]; + return len; + } + + void consume(uint16_t len) { + head += len; + if (head >= N) { + head -= N; + tail -= N; + } + } + + std::span read_contiguous() const { + uint16_t offset = head % N; + uint16_t contig = N - offset; + uint16_t pending = used(); + uint16_t len = pending < contig ? pending : contig; + return {data.data() + offset, len}; + } + + struct iterator { + const ring_buffer* rb; + uint16_t index; + const T& operator*() const { return rb->data[(rb->head + index) % N]; } + iterator& operator++() { index++; return *this; } + bool operator!=(const iterator& o) const { return index != o.index; } + }; + + iterator begin() const { return {this, 0}; } + iterator end() const { return {this, used()}; } +}; diff --git a/include/span_writer.h b/include/span_writer.h new file mode 100644 index 0000000..47506d0 --- /dev/null +++ b/include/span_writer.h @@ -0,0 +1,50 @@ +#pragma once +#include +#include +#include +#include + +class span_writer { + uint8_t *m_data; + size_t m_capacity; + size_t m_size = 0; + bool m_overflow = false; + +public: + span_writer(uint8_t *data, size_t capacity) : m_data(data), m_capacity(capacity) {} + span_writer(std::span buf) : m_data(buf.data()), m_capacity(buf.size()) {} + + void push_back(uint8_t v) { + if (m_size < m_capacity) m_data[m_size++] = v; + else m_overflow = true; + } + + template + void insert(uint8_t *, It first, It last) { + while (first != last) { + if (m_size < m_capacity) m_data[m_size++] = *first++; + else { m_overflow = true; return; } + } + } + + size_t size() const { return m_size; } + size_t capacity() const { return m_capacity; } + bool full() const { return m_size >= m_capacity; } + bool overflow() const { return m_overflow; } + + uint8_t *data() { return m_data; } + const uint8_t *data() const { return m_data; } + + uint8_t *begin() { return m_data; } + uint8_t *end() { return m_data + m_size; } + const uint8_t *begin() const { return m_data; } + const uint8_t *end() const { return m_data + m_size; } + + span_writer subspan(size_t offset) { + return span_writer(m_data + offset, m_capacity - offset); + } + + span_writer subspan(size_t offset, size_t len) { + return span_writer(m_data + offset, len); + } +}; diff --git a/include/static_vector.h b/include/static_vector.h new file mode 100644 index 0000000..df3934c --- /dev/null +++ b/include/static_vector.h @@ -0,0 +1,33 @@ +#pragma once +#include +#include +#include + +template +class static_vector { + T m_data[Capacity]; + size_t m_size = 0; + +public: + void push_back(const T &v) { + if (m_size < Capacity) m_data[m_size++] = v; + } + + void clear() { m_size = 0; } + + size_t size() const { return m_size; } + size_t capacity() const { return Capacity; } + bool full() const { return m_size >= Capacity; } + bool empty() const { return m_size == 0; } + + T *data() { return m_data; } + const T *data() const { return m_data; } + + T &operator[](size_t i) { return m_data[i]; } + const T &operator[](size_t i) const { return m_data[i]; } + + T *begin() { return m_data; } + T *end() { return m_data + m_size; } + const T *begin() const { return m_data; } + const T *end() const { return m_data + m_size; } +}; diff --git a/include/test_handlers.h b/include/test_handlers.h new file mode 100644 index 0000000..87a98fa --- /dev/null +++ b/include/test_handlers.h @@ -0,0 +1,7 @@ +#pragma once +#include +#include "dispatch.h" +#include "wire.h" + +std::optional handle_list_tests(const responder& resp, const RequestListTests&); +std::optional handle_test(const responder& resp, const RequestTest&); diff --git a/include/timer_queue.h b/include/timer_queue.h new file mode 100644 index 0000000..c8c304d --- /dev/null +++ b/include/timer_queue.h @@ -0,0 +1,63 @@ +#pragma once +#include "pico/time.h" +#include "callback_list.h" + +struct timer_entry { + absolute_time_t when; + void (*fn)() = nullptr; +}; + +using timer_handle = callback_list::node*; + +struct timer_queue { + callback_list list; + alarm_id_t alarm = -1; + volatile bool irq_pending = false; + + timer_handle schedule(absolute_time_t when, void (*fn)()) { + auto* n = list.insert_sorted({when, fn}, + [](const timer_entry& a, const timer_entry& b) { + return absolute_time_diff_us(b.when, a.when) < 0; + }); + arm(); + return n; + } + + timer_handle schedule_ms(uint32_t ms, void (*fn)()) { + return schedule(make_timeout_time_ms(ms), fn); + } + + bool cancel(timer_handle h) { + if (!h) return false; + list.remove(h); + arm(); + return true; + } + + void run() { + if (!irq_pending) return; + irq_pending = false; + while (auto* n = list.front()) { + if (absolute_time_diff_us(get_absolute_time(), n->value.when) > 0) break; + auto fn = n->value.fn; + list.remove(n); + fn(); + } + arm(); + } + + bool empty() const { return list.empty(); } + +private: + static int64_t alarm_cb(alarm_id_t, void* user_data) { + static_cast(user_data)->irq_pending = true; + return 0; + } + + void arm() { + if (alarm >= 0) cancel_alarm(alarm); + alarm = -1; + if (auto* n = list.front()) + alarm = add_alarm_at(n->value.when, alarm_cb, this, false); + } +}; diff --git a/include/udp.h b/include/udp.h new file mode 100644 index 0000000..71c6e9e --- /dev/null +++ b/include/udp.h @@ -0,0 +1,45 @@ +#pragma once +#include +#include +#include "eth.h" +#include "ipv4.h" +#include "net.h" +#include "span_writer.h" + +namespace udp { + +struct __attribute__((packed)) header { + uint16_t src_port; + uint16_t dst_port; + uint16_t length; + uint16_t checksum; +}; +static_assert(sizeof(header) == 8); + +struct address { + // mac is carried here until we grow an ARP cache; once we do, a + // destination mac comes from resolving ip and this field goes away. + eth::mac_addr mac; + ipv4::ip4_addr ip; + uint16_t port; +}; + +template +void prepend(Buf& buf, const eth::mac_addr& dst_mac, const eth::mac_addr& src_mac, + ipv4::ip4_addr src_ip, ipv4::ip4_addr dst_ip, + uint16_t src_port, uint16_t dst_port, + size_t payload_len, uint8_t ttl = 64) { + auto* u = buf.template prepend
(); + u->src_port = src_port; + u->dst_port = dst_port; + u->length = __builtin_bswap16(sizeof(header) + payload_len); + u->checksum = 0; + ipv4::prepend(buf, dst_mac, src_mac, src_ip, dst_ip, 17, sizeof(header) + payload_len, ttl); +} + +void handle(std::span frame, span_writer& tx); + +using port_handler = void (*)(std::span payload, const address& from); +void register_port(uint16_t port_be, port_handler fn); + +} // namespace udp diff --git a/include/wire.h b/include/wire.h new file mode 100644 index 0000000..b0199b9 --- /dev/null +++ b/include/wire.h @@ -0,0 +1,245 @@ +#pragma once +#include +#include +#include +#include +#include +#include +#include "msgpack.h" +#include "halfsiphash.h" +#include "static_vector.h" +#include "flash.h" + +struct Envelope { + static constexpr int8_t ext_id = 0; + uint32_t message_id; + uint32_t checksum; + std::span payload; + auto as_tuple() const { return std::tie(message_id, checksum, payload); } + auto as_tuple() { return std::tie(message_id, checksum, payload); } +}; + +struct DeviceError { + static constexpr int8_t ext_id = 1; + uint32_t code; + std::string message; + auto as_tuple() const { return std::tie(code, message); } + auto as_tuple() { return std::tie(code, message); } +}; + +struct RequestInfo { + static constexpr int8_t ext_id = 4; + auto as_tuple() const { return std::tie(); } + auto as_tuple() { return std::tie(); } +}; + +enum class boot_reason : uint8_t { + cold_boot = 0, + request_reboot = 1, + watchdog = 2, +}; + +struct ResponseInfo { + static constexpr int8_t ext_id = 5; + std::array board_id; + std::array mac; + std::array ip; + std::string firmware_name; + boot_reason boot; + uint32_t build_epoch; + auto as_tuple() const { return std::tie(board_id, mac, ip, firmware_name, boot, build_epoch); } + auto as_tuple() { return std::tie(board_id, mac, ip, firmware_name, boot, build_epoch); } +}; + +struct RequestLog { + static constexpr int8_t ext_id = 6; + auto as_tuple() const { return std::tie(); } + auto as_tuple() { return std::tie(); } +}; + +struct LogEntry { + uint32_t timestamp_us; + std::string message; + auto as_tuple() const { return std::tie(timestamp_us, message); } + auto as_tuple() { return std::tie(timestamp_us, message); } +}; + +struct ResponseLog { + static constexpr int8_t ext_id = 7; + std::vector entries; + auto as_tuple() const { return std::tie(entries); } + auto as_tuple() { return std::tie(entries); } +}; + +struct RequestFlashErase { + static constexpr int8_t ext_id = 8; + uint32_t addr; + uint32_t len; + auto as_tuple() const { return std::tie(addr, len); } + auto as_tuple() { return std::tie(addr, len); } +}; + +struct ResponseFlashErase { + static constexpr int8_t ext_id = 9; + auto as_tuple() const { return std::tie(); } + auto as_tuple() { return std::tie(); } +}; + +struct RequestFlashWrite { + static constexpr int8_t ext_id = 10; + uint32_t addr; + std::span data; + auto as_tuple() const { return std::tie(addr, data); } + auto as_tuple() { return std::tie(addr, data); } +}; + +struct ResponseFlashWrite { + static constexpr int8_t ext_id = 11; + auto as_tuple() const { return std::tie(); } + auto as_tuple() { return std::tie(); } +}; + +struct RequestReboot { + static constexpr int8_t ext_id = 12; + auto as_tuple() const { return std::tie(); } + auto as_tuple() { return std::tie(); } +}; + +struct ResponseReboot { + static constexpr int8_t ext_id = 13; + auto as_tuple() const { return std::tie(); } + auto as_tuple() { return std::tie(); } +}; + +struct RequestFlashStatus { + static constexpr int8_t ext_id = 14; + auto as_tuple() const { return std::tie(); } + auto as_tuple() { return std::tie(); } +}; + +struct ResponseFlashStatus { + static constexpr int8_t ext_id = 15; + int8_t boot_partition; + flash::slot slot_a; + flash::slot slot_b; + auto as_tuple() const { return std::tie(boot_partition, slot_a, slot_b); } + auto as_tuple() { return std::tie(boot_partition, slot_a, slot_b); } +}; + +struct RequestListTests { + static constexpr int8_t ext_id = 125; + auto as_tuple() const { return std::tie(); } + auto as_tuple() { return std::tie(); } +}; + +struct ResponseListTests { + static constexpr int8_t ext_id = 124; + std::vector names; + auto as_tuple() const { return std::tie(names); } + auto as_tuple() { return std::tie(names); } +}; + +struct RequestTest { + static constexpr int8_t ext_id = 127; + std::string name; + auto as_tuple() const { return std::tie(name); } + auto as_tuple() { return std::tie(name); } +}; + +struct ResponseTest { + static constexpr int8_t ext_id = 126; + bool pass; + std::vector messages; + auto as_tuple() const { return std::tie(pass, messages); } + auto as_tuple() { return std::tie(pass, messages); } +}; + +static constexpr uint8_t hash_key[8] = {}; + +struct DecodedMessage { + uint32_t message_id; + int8_t type_id; + std::span payload; +}; + +static constexpr size_t ext16_header_len = 4; +static constexpr size_t array3_header_len = 1; +static constexpr size_t uint32_fixed_len = 5; +static constexpr size_t bin16_header_len = 3; + +static constexpr size_t envelope_hdr_len = + ext16_header_len + array3_header_len + + uint32_fixed_len + uint32_fixed_len + bin16_header_len; + +static constexpr size_t response_prefix_len = envelope_hdr_len + ext16_header_len; + +template +inline msgpack::result encode_response_into(span_writer &out, uint32_t message_id, const T &msg) { + auto body = out.subspan(response_prefix_len); + msgpack::packer body_p(body); + body_p.pack(msg.as_tuple()); + + auto inner_ext = out.subspan(envelope_hdr_len, ext16_header_len); + msgpack::packer inner_ext_p(inner_ext); + inner_ext_p.pack_ext16_header(T::ext_id, static_cast(body.size())); + + size_t bin_len = inner_ext.size() + body.size(); + uint32_t checksum = halfsiphash::hash32({inner_ext.data(), bin_len}, hash_key); + + auto env_hdr = out.subspan(0, envelope_hdr_len); + size_t env_body_len = array3_header_len + uint32_fixed_len + uint32_fixed_len + bin16_header_len + bin_len; + msgpack::packer hdr(env_hdr); + hdr.pack_ext16_header(Envelope::ext_id, static_cast(env_body_len)); + hdr.pack_array(3); + hdr.pack_uint32_fixed(message_id); + hdr.pack_uint32_fixed(checksum); + hdr.pack_bin16_header(static_cast(bin_len)); + + if (body.overflow() || inner_ext.overflow() || env_hdr.overflow()) + return std::unexpected(msgpack::error_code::overflow); + return response_prefix_len + body.size(); +} + +inline msgpack::result try_decode(const uint8_t *data, size_t len) { + msgpack::parser p(data, static_cast(len)); + + Envelope env; + auto r = msgpack::unpack(p, env); + if (!r) return std::unexpected(r.error()); + + uint32_t expected = halfsiphash::hash32(env.payload, hash_key); + if (env.checksum != expected) return std::unexpected(msgpack::error_code::invalid); + + msgpack::parser inner(env.payload.data(), static_cast(env.payload.size())); + if (!inner.is_ext()) return std::unexpected(msgpack::error_code::type_error); + auto ext = inner.get_ext(); + if (!ext) return std::unexpected(ext.error()); + + auto& [type_id, ext_data] = *ext; + return DecodedMessage{env.message_id, type_id, + std::span(reinterpret_cast(ext_data.data()), ext_data.size())}; +} + +template +inline msgpack::result try_decode(const static_vector &buf) { + return try_decode(buf.data(), buf.size()); +} + +template +inline msgpack::result decode_response(const uint8_t *data, size_t len) { + msgpack::parser p(data, static_cast(len)); + + Envelope env; + auto r = msgpack::unpack(p, env); + if (!r) return std::unexpected(r.error()); + + uint32_t expected = halfsiphash::hash32(env.payload, hash_key); + if (env.checksum != expected) return std::unexpected(msgpack::error_code::invalid); + + msgpack::parser inner(env.payload.data(), static_cast(env.payload.size())); + T out; + auto r2 = msgpack::unpack(inner, out); + if (!r2) return std::unexpected(r2.error()); + return out; +} + diff --git a/partition_table.json b/partition_table.json new file mode 100644 index 0000000..4d19598 --- /dev/null +++ b/partition_table.json @@ -0,0 +1,37 @@ +{ + "version": [1, 0], + "unpartitioned": { + "families": ["absolute"], + "permissions": { + "secure": "rw", + "nonsecure": "rw", + "bootloader": "rw" + } + }, + "partitions": [ + { + "name": "A", + "id": 0, + "start": "0K", + "size": "512K", + "families": ["rp2350-arm-s"], + "permissions": { + "secure": "rw", + "nonsecure": "rw", + "bootloader": "rw" + } + }, + { + "name": "B", + "start": "512K", + "size": "512K", + "families": ["rp2350-arm-s"], + "permissions": { + "secure": "rw", + "nonsecure": "rw", + "bootloader": "rw" + }, + "link": ["a", 0] + } + ] +} diff --git a/src/arp.cpp b/src/arp.cpp new file mode 100644 index 0000000..0fce56b --- /dev/null +++ b/src/arp.cpp @@ -0,0 +1,47 @@ +#include "arp.h" +#include "net.h" +#include "parse_buffer.h" +#include "prepend_buffer.h" + +namespace arp { + +static constexpr uint16_t ARP_HTYPE_ETH = __builtin_bswap16(1); +static constexpr uint16_t ARP_PTYPE_IPV4 = __builtin_bswap16(0x0800); +static constexpr uint16_t ARP_OP_REQUEST = __builtin_bswap16(1); +static constexpr uint16_t ARP_OP_REPLY = __builtin_bswap16(2); + +void handle(std::span frame, span_writer& tx) { + const auto& ns = net_get_state(); + parse_buffer pb(frame); + pb.consume(); + auto* arp_hdr = pb.consume
(); + if (!arp_hdr) return; + + if (arp_hdr->htype != ARP_HTYPE_ETH) return; + if (arp_hdr->ptype != ARP_PTYPE_IPV4) return; + if (arp_hdr->hlen != 6 || arp_hdr->plen != 4) return; + if (arp_hdr->oper != ARP_OP_REQUEST) return; + if (arp_hdr->tpa != ns.ip) return; + + prepend_buffer<4096> buf; + auto* reply = buf.template prepend
(); + reply->htype = ARP_HTYPE_ETH; + reply->ptype = ARP_PTYPE_IPV4; + reply->hlen = 6; + reply->plen = 4; + reply->oper = ARP_OP_REPLY; + reply->sha = ns.mac; + reply->spa = ns.ip; + reply->tha = arp_hdr->sha; + reply->tpa = arp_hdr->spa; + eth::prepend(buf, arp_hdr->sha, ns.mac, eth::ETH_ARP); + + net_send_raw(buf.span()); +} + +__attribute__((constructor)) +static void register_ethertype() { + net_register_ethertype(eth::ETH_ARP, handle); +} + +} // namespace arp diff --git a/src/dispatch.cpp b/src/dispatch.cpp new file mode 100644 index 0000000..a4b9cc0 --- /dev/null +++ b/src/dispatch.cpp @@ -0,0 +1,67 @@ +#include "dispatch.h" +#include +#include "pico/stdlib.h" +#include "wire.h" +#include "timer_queue.h" +#include "net.h" +#include "igmp.h" +#include "udp.h" +#include "debug_log.h" +#include "hardware/sync.h" + +static timer_queue timers; +static std::array handler_map{}; +static uint16_t listen_port_be = 0; + +uint16_t dispatch_listen_port_be() { return listen_port_be; } + +static void igmp_reannounce() { + igmp::send_all_reports(); + dispatch_schedule_ms(60000, igmp_reannounce); +} + +static void on_udp_message(std::span payload, const udp::address& from) { + auto msg = try_decode(payload.data(), payload.size()); + if (!msg) return; + if (msg->type_id < 0 || !handler_map[msg->type_id]) { + dlogf("dispatch: unknown type_id %d", msg->type_id); + return; + } + responder resp{msg->message_id, from}; + handler_map[msg->type_id](resp, msg->payload); +} + +void dispatch_init(uint16_t port_be) { + listen_port_be = port_be; + udp::register_port(port_be, on_udp_message); + net_init(); + dispatch_schedule_ms(60000, igmp_reannounce); + dlog("dispatch_init complete"); +} + +timer_handle dispatch_schedule_ms(uint32_t ms, void (*fn)()) { + auto h = timers.schedule_ms(ms, fn); + if (!h) dlogf("timer alloc failed: %lu ms", static_cast(ms)); + return h; +} + +bool dispatch_cancel_timer(timer_handle h) { + return timers.cancel(h); +} + +[[noreturn]] void dispatch_run(std::span handlers) { + for (auto& entry : handlers) + handler_map[entry.type_id] = entry.handle; + + static std::array tx_buf; + + while (true) { + uint32_t save = save_and_disable_interrupts(); + + dlog_if_slow("timers", 1000, [&]{ timers.run(); }); + dlog_if_slow("net_poll", 1000, [&]{ net_poll(std::span{tx_buf}); }); + + __wfi(); + restore_interrupts(save); + } +} diff --git a/src/flash.cpp b/src/flash.cpp new file mode 100644 index 0000000..fb90b01 --- /dev/null +++ b/src/flash.cpp @@ -0,0 +1,163 @@ +#include "flash.h" +#include +#include "pico/sha256.h" +#include "boot/picobin.h" + +namespace flash { +namespace { + +constexpr uint32_t PICOBIN_MARKER_END = 0xab123579; + +struct __attribute__((packed)) last_item { + uint8_t type; + uint16_t block_item_words; + uint8_t pad; + int32_t next_block_offset; + uint32_t marker_end; +}; + +struct __attribute__((packed)) hash_def_header { + uint8_t type; + uint8_t size_words; + uint8_t reserved; + uint8_t hash_type; + uint16_t block_words_to_hash; + uint16_t pad; +}; + +struct __attribute__((packed)) load_map_header { + uint8_t type; + uint8_t size_words; + uint8_t reserved; + uint8_t flags_and_count; + + uint8_t count() const { return flags_and_count & 0x7f; } + bool absolute() const { return flags_and_count & 0x80; } +}; + +struct load_map_entry { + uint32_t storage_addr; + uint32_t runtime_addr; + uint32_t size; +}; + +struct parsed_item { + uint8_t type; + uint8_t size_words; + const uint32_t* words; +}; + +struct parsed_block { + const uint32_t* base; + parsed_item items[16]; + uint8_t item_count; + int32_t next_block_offset; +}; + +bool parse_block(const uint32_t* start, const uint32_t* limit, parsed_block& out) { + if (start >= limit || *start != PICOBIN_BLOCK_MARKER_START) return false; + out.base = start; + out.item_count = 0; + out.next_block_offset = 0; + + auto* w = start + 1; + while (w + 3 < limit && out.item_count < 16) { + uint8_t type = *w & 0xff; + if (type == PICOBIN_BLOCK_ITEM_2BS_LAST) { + auto* last = reinterpret_cast(w); + out.next_block_offset = last->next_block_offset; + return last->marker_end == PICOBIN_MARKER_END; + } + uint8_t size_words = (*w >> 8) & 0xff; + if (w + size_words > limit) return false; + out.items[out.item_count++] = {type, size_words, w}; + w += size_words; + } + return false; +} + +const parsed_item* find_item(const parsed_block& blk, uint8_t type) { + for (uint8_t i = 0; i < blk.item_count; i++) + if (blk.items[i].type == type) return &blk.items[i]; + return nullptr; +} + +bool verify_hash(const parsed_block& last) { + auto* lm_item = find_item(last, PICOBIN_BLOCK_ITEM_LOAD_MAP); + auto* hd_item = find_item(last, PICOBIN_BLOCK_ITEM_1BS_HASH_DEF); + auto* hv_item = find_item(last, PICOBIN_BLOCK_ITEM_HASH_VALUE); + if (!lm_item || !hd_item || !hv_item) return false; + if (hd_item->size_words < 2 || hv_item->size_words < 9) return false; + + auto* hd = reinterpret_cast(hd_item->words); + if (hd->hash_type != PICOBIN_HASH_SHA256) return false; + + auto* lm = reinterpret_cast(lm_item->words); + auto* entries = reinterpret_cast(lm_item->words + 1); + uint32_t lm_xip_addr = reinterpret_cast(lm_item->words); + + pico_sha256_state_t sha; + if (pico_sha256_try_start(&sha, SHA256_BIG_ENDIAN, false) != PICO_OK) return false; + + for (uint8_t i = 0; i < lm->count(); i++) { + uint32_t storage_addr = entries[i].storage_addr; + uint32_t size = entries[i].size; + if (lm->absolute()) size -= entries[i].runtime_addr; + if (storage_addr == 0) { + pico_sha256_update_blocking(&sha, reinterpret_cast(&size), 4); + } else { + if (!lm->absolute()) storage_addr += lm_xip_addr; + pico_sha256_update_blocking(&sha, reinterpret_cast(storage_addr), size); + } + } + + pico_sha256_update_blocking(&sha, + reinterpret_cast(last.base), + static_cast(hd->block_words_to_hash) * 4); + + sha256_result_t result; + pico_sha256_finish(&sha, &result); + + auto* expected = reinterpret_cast(hv_item->words + 1); + return memcmp(result.bytes, expected, 32) == 0; +} + +} + +slot scan(uint32_t flash_offset) { + slot info{}; + constexpr uint32_t scan_limit = 4096; + constexpr uint32_t slot_size = 512 * 1024; + auto* slot_base = reinterpret_cast(FLASH_BASE + flash_offset); + auto* slot_end = reinterpret_cast(FLASH_BASE + flash_offset + slot_size); + + auto* s = slot_base; + auto* s_end = slot_base + scan_limit / 4; + while (s < s_end && *s != PICOBIN_BLOCK_MARKER_START) s++; + if (s >= s_end) return info; + + parsed_block start; + if (!parse_block(s, slot_end, start)) return info; + + for (uint8_t i = 0; i < start.item_count; i++) { + if (start.items[i].type == PICOBIN_BLOCK_ITEM_1BS_VERSION && start.items[i].size_words >= 2) + info.version = start.items[i].words[1]; + } + + auto* cur = &start; + parsed_block next; + while (cur->next_block_offset != 0) { + auto* np = reinterpret_cast( + reinterpret_cast(cur->base) + cur->next_block_offset); + if (np <= slot_base || np >= slot_end) return info; + if (np == start.base) break; + if (!parse_block(np, slot_end, next)) return info; + cur = &next; + } + + info.valid = true; + info.hash_ok = verify_hash(*cur); + return info; +} + +} diff --git a/src/handlers.cpp b/src/handlers.cpp new file mode 100644 index 0000000..c1b9b84 --- /dev/null +++ b/src/handlers.cpp @@ -0,0 +1,104 @@ +#include "handlers.h" +#include "pico/unique_id.h" +#include "pico/bootrom.h" +#include "hardware/flash.h" +#include "hardware/watchdog.h" +#include "flash.h" +#include "dispatch.h" +#include "net.h" +#include "debug_log.h" + +static boot_reason detected_boot_reason; + +static void poke_watchdog() { + watchdog_update(); + dispatch_schedule_ms(500, poke_watchdog); +} + +void handlers_init() { + auto val = static_cast(watchdog_hw->scratch[0]); + if (val == boot_reason::request_reboot || val == boot_reason::watchdog) + detected_boot_reason = val; + else + detected_boot_reason = boot_reason::cold_boot; + watchdog_hw->scratch[0] = static_cast(boot_reason::watchdog); + watchdog_enable(1000, true); +} + +void handlers_start() { + poke_watchdog(); +} + +std::optional handle_info(const responder&, const RequestInfo&) { + ResponseInfo resp; + pico_unique_board_id_t uid; + pico_get_unique_board_id(&uid); + std::copy(uid.id, uid.id + 8, resp.board_id.begin()); + auto& ns = net_get_state(); + resp.mac = ns.mac; + resp.ip = ns.ip; + resp.firmware_name = firmware_name; + resp.boot = detected_boot_reason; + resp.build_epoch = BUILD_EPOCH; + return resp; +} + +std::optional handle_log(const responder&, const RequestLog&) { + ResponseLog resp; + for (auto& e : g_debug_log) + resp.entries.push_back(LogEntry{e.timestamp_us, e.message}); + return resp; +} + +std::optional handle_flash_erase(const responder&, const RequestFlashErase& req) { + if (req.addr < flash::FLASH_BASE || req.addr + req.len > flash::FLASH_BASE + flash::FLASH_SIZE) { + dlogf("flash erase: out of range %08lx+%lu", + static_cast(req.addr), static_cast(req.len)); + return std::nullopt; + } + uint32_t offset = req.addr - flash::FLASH_BASE; + if (offset % FLASH_SECTOR_SIZE != 0 || req.len % FLASH_SECTOR_SIZE != 0 || req.len == 0) { + dlogf("flash erase: bad alignment %08lx+%lu", + static_cast(req.addr), static_cast(req.len)); + return std::nullopt; + } + flash_range_erase(offset, req.len); + return ResponseFlashErase{}; +} + +std::optional handle_flash_write(const responder&, const RequestFlashWrite& req) { + if (req.addr < flash::FLASH_BASE || req.addr + req.data.size() > flash::FLASH_BASE + flash::FLASH_SIZE) { + dlogf("flash write: out of range %08lx+%zu", + static_cast(req.addr), req.data.size()); + return std::nullopt; + } + uint32_t offset = req.addr - flash::FLASH_BASE; + if (offset % FLASH_PAGE_SIZE != 0 || req.data.size() % FLASH_PAGE_SIZE != 0 || req.data.empty()) { + dlogf("flash write: bad alignment %08lx+%zu", + static_cast(req.addr), req.data.size()); + return std::nullopt; + } + flash_range_program(offset, req.data.data(), req.data.size()); + return ResponseFlashWrite{}; +} + +std::optional handle_flash_status(const responder&, const RequestFlashStatus&) { + ResponseFlashStatus resp; + boot_info_t bi; + if (rom_get_boot_info(&bi)) + resp.boot_partition = bi.partition; + else + resp.boot_partition = -1; + resp.slot_a = flash::scan(0x00000); + resp.slot_b = flash::scan(0x80000); + return resp; +} + +std::optional handle_reboot(const responder&, const RequestReboot&) { + dispatch_schedule_ms(100, []{ + watchdog_hw->scratch[0] = static_cast(boot_reason::request_reboot); + watchdog_reboot(0, 0, 0); + }); + return ResponseReboot{}; +} + diff --git a/src/icmp.cpp b/src/icmp.cpp new file mode 100644 index 0000000..9b717f5 --- /dev/null +++ b/src/icmp.cpp @@ -0,0 +1,68 @@ +#include "icmp.h" +#include +#include "ipv4.h" +#include "net.h" +#include "parse_buffer.h" +#include "prepend_buffer.h" + +namespace icmp { + +void handle(std::span frame, span_writer& tx) { + parse_buffer pb(frame); + auto* eth_hdr = pb.consume(); + auto* ip = pb.consume(); + if (!ip) return; + if (!ipv4::addressed_to_us(ip->dst)) return; + + size_t options_len = ip->header_len() - sizeof(ipv4::header); + if (options_len > 0 && !pb.skip(options_len)) return; + + size_t icmp_len = ip->total() - ip->header_len(); + if (pb.remaining_size() < icmp_len) return; + + auto* icmp_pkt = pb.consume(); + if (!icmp_pkt) return; + if (icmp_pkt->type != 8) return; + + const auto& ns = net_get_state(); + prepend_buffer<4096> buf; + memcpy(buf.append(icmp_len), pb.remaining().data() - sizeof(echo), icmp_len); + + auto* reply = reinterpret_cast(buf.data()); + reply->type = 0; + reply->checksum = 0; + reply->checksum = ipv4::checksum(reply, icmp_len); + + ipv4::prepend(buf, eth_hdr->src, ns.mac, ns.ip, ip->src, 1, icmp_len); + net_send_raw(buf.span()); +} + +__attribute__((constructor)) +static void register_protocol() { + ipv4::register_protocol(1, handle); +} + +bool parse_echo_reply(std::span frame, ipv4::ip4_addr& src_ip, uint16_t expected_id) { + parse_buffer pb(frame); + auto* eth_hdr = pb.consume(); + if (!eth_hdr) return false; + if (eth_hdr->ethertype != eth::ETH_IPV4) return false; + + auto* ip = pb.consume(); + if (!ip) return false; + if ((ip->ver_ihl >> 4) != 4) return false; + if (ip->protocol != 1) return false; + + size_t options_len = ip->header_len() - sizeof(ipv4::header); + if (options_len > 0 && !pb.skip(options_len)) return false; + + auto* icmp_pkt = pb.consume(); + if (!icmp_pkt) return false; + if (icmp_pkt->type != 0) return false; + if (icmp_pkt->id != expected_id) return false; + + src_ip = ip->src; + return true; +} + +} // namespace icmp diff --git a/src/igmp.cpp b/src/igmp.cpp new file mode 100644 index 0000000..43aaa2e --- /dev/null +++ b/src/igmp.cpp @@ -0,0 +1,109 @@ +#include "igmp.h" +#include +#include "ipv4.h" +#include "net.h" +#include "parse_buffer.h" +#include "prepend_buffer.h" + +namespace igmp { + +struct group_entry { + ipv4::ip4_addr ip; + eth::mac_addr mac; +}; + +static std::vector groups; + +eth::mac_addr mac_for_ip(const ipv4::ip4_addr& group) { + return {0x01, 0x00, 0x5E, + static_cast(group[1] & 0x7F), group[2], group[3]}; +} + +bool is_member(const ipv4::ip4_addr& ip) { + if (ip == ALL_HOSTS) return true; + for (auto& g : groups) + if (g.ip == ip) return true; + return false; +} + +bool is_member_mac(const eth::mac_addr& mac) { + static constexpr eth::mac_addr ALL_HOSTS_MAC = {0x01, 0x00, 0x5E, 0x00, 0x00, 0x01}; + if (mac == ALL_HOSTS_MAC) return true; + for (auto& g : groups) + if (g.mac == mac) return true; + return false; +} + +static void send_report(const ipv4::ip4_addr& group) { + const auto& ns = net_get_state(); + prepend_buffer<4096> buf; + prepend_report(buf, ns.mac, ns.ip, group); + net_send_raw(buf.span()); +} + +void join(const ipv4::ip4_addr& group) { + for (auto& g : groups) + if (g.ip == group) return; + groups.push_back({group, mac_for_ip(group)}); + send_report(group); +} + +void send_all_reports() { + for (auto& g : groups) + send_report(g.ip); +} + +void handle(std::span frame, span_writer& tx) { + parse_buffer pb(frame); + pb.consume(); + auto* ip = pb.consume(); + if (!ip) return; + + size_t options_len = ip->header_len() - sizeof(ipv4::header); + if (options_len > 0 && !pb.skip(options_len)) return; + + auto* msg = pb.consume(); + if (!msg) return; + if (msg->type != 0x11) return; + + if (msg->group == ipv4::ip4_addr{0, 0, 0, 0}) { + for (auto& g : groups) + send_report(g.ip); + } else { + for (auto& g : groups) { + if (g.ip == msg->group) { + send_report(g.ip); + break; + } + } + } +} + +__attribute__((constructor)) +static void register_protocol() { + ipv4::register_protocol(2, handle); +} + +bool parse_report(std::span frame, ipv4::ip4_addr& group) { + parse_buffer pb(frame); + auto* eth_hdr = pb.consume(); + if (!eth_hdr) return false; + if (eth_hdr->ethertype != eth::ETH_IPV4) return false; + + auto* ip = pb.consume(); + if (!ip) return false; + if ((ip->ver_ihl >> 4) != 4) return false; + if (ip->protocol != 2) return false; + + size_t options_len = ip->header_len() - sizeof(ipv4::header); + if (options_len > 0 && !pb.skip(options_len)) return false; + + auto* msg = pb.consume(); + if (!msg) return false; + if (msg->type != 0x16) return false; + + group = msg->group; + return true; +} + +} // namespace igmp diff --git a/src/ipv4.cpp b/src/ipv4.cpp new file mode 100644 index 0000000..99fd7b4 --- /dev/null +++ b/src/ipv4.cpp @@ -0,0 +1,63 @@ +#include "ipv4.h" +#include +#include "igmp.h" +#include "net.h" +#include "parse_buffer.h" + +namespace ipv4 { + +static constexpr ip4_addr IP_BROADCAST_ALL = {255, 255, 255, 255}; + +uint16_t checksum(const void* data, size_t len) { + auto p = static_cast(data); + uint32_t sum = 0; + for (size_t i = 0; i < len - 1; i += 2) + sum += (p[i] << 8) | p[i + 1]; + if (len & 1) + sum += p[len - 1] << 8; + while (sum >> 16) + sum = (sum & 0xFFFF) + (sum >> 16); + return __builtin_bswap16(~sum); +} + +bool addressed_to_us(ip4_addr dst) { + const auto& ns = net_get_state(); + return dst == ns.ip || dst == IP_BROADCAST_ALL || dst == SUBNET_BROADCAST || igmp::is_member(dst); +} + +struct protocol_entry { + uint8_t protocol; + protocol_handler fn; +}; +static std::array protocol_handlers; +static size_t protocol_handler_count = 0; + +void register_protocol(uint8_t protocol, protocol_handler fn) { + if (protocol_handler_count < protocol_handlers.size()) + protocol_handlers[protocol_handler_count++] = {protocol, fn}; +} + +void handle(std::span frame, span_writer& tx) { + parse_buffer pb(frame); + pb.consume(); + auto* ip = pb.consume
(); + if (!ip) return; + if ((ip->ver_ihl >> 4) != 4) return; + + size_t options_len = ip->header_len() - sizeof(header); + if (options_len > 0 && !pb.skip(options_len)) return; + + for (size_t i = 0; i < protocol_handler_count; i++) { + if (protocol_handlers[i].protocol == ip->protocol) { + protocol_handlers[i].fn(frame, tx); + break; + } + } +} + +__attribute__((constructor)) +static void register_ethertype() { + net_register_ethertype(eth::ETH_IPV4, handle); +} + +} // namespace ipv4 diff --git a/src/net.cpp b/src/net.cpp new file mode 100644 index 0000000..1a4378c --- /dev/null +++ b/src/net.cpp @@ -0,0 +1,112 @@ +#include "net.h" +#include +#include "pico/unique_id.h" +#include "pico/time.h" +#include "eth.h" +#include "igmp.h" +#include "parse_buffer.h" +#include "prepend_buffer.h" +#include "w6300.h" +#include "debug_log.h" + +static net_state state; +static w6300::socket_id raw_socket{0}; +static frame_cb_list frame_callbacks; + +struct ethertype_entry { + uint16_t ethertype_be; + ethertype_handler fn; +}; +static std::array eth_handlers; +static size_t eth_handler_count = 0; + +void net_register_ethertype(uint16_t ethertype_be, ethertype_handler fn) { + if (eth_handler_count < eth_handlers.size()) + eth_handlers[eth_handler_count++] = {ethertype_be, fn}; +} + +void net_send_raw(std::span data) { + dlog_if_slow("net_send_raw", 1000, [&]{ + auto result = w6300::send(raw_socket, data); + if (!result) + dlogf("w6300 send failed: %zu bytes, err %d", + data.size(), static_cast(result.error())); + }); +} + +static bool mac_match(const eth::mac_addr& dst) { + return dst == state.mac || dst == eth::MAC_BROADCAST || igmp::is_member_mac(dst); +} + +static void process_frame(std::span frame, span_writer& tx) { + if (frame.size() < sizeof(eth::header)) return; + auto& eth_hdr = *reinterpret_cast(frame.data()); + + if (!mac_match(eth_hdr.dst)) return; + + frame_callbacks.for_each([&](frame_cb_list::node* n) { + if (n->value(frame)) + frame_callbacks.remove(n); + }); + + for (size_t i = 0; i < eth_handler_count; i++) { + if (eth_handlers[i].ethertype_be == eth_hdr.ethertype) { + eth_handlers[i].fn(frame, tx); + break; + } + } +} + +bool net_init() { + w6300::init_spi(); + w6300::reset(); + w6300::init(); + if (!w6300::check()) return false; + + pico_unique_board_id_t uid; + pico_get_unique_board_id(&uid); + state.mac[0] = (uid.id[0] & 0xFC) | 0x02; + state.mac[1] = uid.id[1]; + state.mac[2] = uid.id[2]; + state.mac[3] = uid.id[3]; + state.mac[4] = uid.id[4]; + state.mac[5] = uid.id[5]; + + state.ip[0] = 169; + state.ip[1] = 254; + state.ip[2] = state.mac[4]; + state.ip[3] = state.mac[5]; + + w6300::open_socket(raw_socket, w6300::protocol::macraw, w6300::sock_flag::none); + w6300::set_interrupt_mask(w6300::ik_sock_0); + + return true; +} + +const net_state& net_get_state() { + return state; +} + +frame_cb_handle net_add_frame_callback(net_frame_callback cb) { + auto h = frame_callbacks.insert(cb); + if (!h) dlog("frame callback alloc failed"); + return h; +} + +void net_remove_frame_callback(frame_cb_handle h) { + frame_callbacks.remove(h); +} + +void net_poll(std::span tx) { + if (!w6300::irq_pending) return; + w6300::irq_pending = false; + w6300::clear_interrupt(w6300::ik_int_all); + static uint8_t rx_buf[1518]; + for (int i = 0; i < 16 && w6300::get_socket_recv_buf(raw_socket) > 0; i++) { + auto result = w6300::recv(raw_socket, std::span{rx_buf}); + if (!result) break; + span_writer tx_writer(tx); + process_frame({rx_buf, *result}, tx_writer); + } + w6300::rearm_gpio_irq(); +} diff --git a/src/test_handlers.cpp b/src/test_handlers.cpp new file mode 100644 index 0000000..7fdadee --- /dev/null +++ b/src/test_handlers.cpp @@ -0,0 +1,343 @@ +#include "test_handlers.h" +#include +#include +#include "pico/stdlib.h" +#include "pico/time.h" +#include "handlers.h" +#include "net.h" +#include "icmp.h" +#include "igmp.h" +#include "udp.h" +#include "parse_buffer.h" +#include "prepend_buffer.h" + +static constexpr uint16_t PING_ECHO_ID = 0x1234; +static constexpr uint16_t PING_RATE_ECHO_ID = 0x5678; + +struct peer_info { + eth::mac_addr mac; + ipv4::ip4_addr ip; +}; + +struct discovery_data { + void (*on_found)(const peer_info&) = nullptr; + void (*on_timeout)() = nullptr; +}; + +struct ping_rate_data { + peer_info peer; + uint16_t target; + uint16_t pipeline; + uint16_t payload_len; + uint16_t sent; + uint16_t received; + uint32_t start_us; +}; + +struct discovery_igmp_test {}; +struct discovery_info_test { + discovery_data discovery; +}; +struct ping_subnet_test {}; +struct ping_global_test {}; +struct packet_rate_test { + discovery_data discovery; + ping_rate_data rate; +}; +struct byte_rate_test { + discovery_data discovery; + ping_rate_data rate; +}; + +// One test runs at a time; in_flight gates that. active_* let shared +// primitive callbacks find the running test's sub-state. +struct test_state { + bool in_flight = false; + responder resp; + timer_handle timer = nullptr; + frame_cb_handle frame_cb = nullptr; + + discovery_data* active_discovery = nullptr; + ping_rate_data* active_rate = nullptr; + + discovery_igmp_test discovery_igmp; + discovery_info_test discovery_info; + ping_subnet_test ping_subnet; + ping_global_test ping_global; + packet_rate_test packet_rate; + byte_rate_test byte_rate; +}; + +static test_state ts; + +static void test_end(const ResponseTest& result) { + if (ts.timer) { dispatch_cancel_timer(ts.timer); ts.timer = nullptr; } + if (ts.frame_cb) { net_remove_frame_callback(ts.frame_cb); ts.frame_cb = nullptr; } + ts.active_discovery = nullptr; + ts.active_rate = nullptr; + ts.resp.respond(result); + ts.in_flight = false; +} + +// When a callback fires, its dispatcher (net or timer_queue) has already +// removed the node; the matching ts.timer / ts.frame_cb is stale. Callbacks +// that self-consume must null that handle before calling test_end. + +static bool discover_reply_cb(std::span frame) { + parse_buffer pb(frame); + auto* eth_hdr = pb.consume(); + if (!eth_hdr || eth_hdr->ethertype != eth::ETH_IPV4) return false; + auto* ip = pb.consume(); + if (!ip || ip->protocol != 17) return false; + size_t options_len = ip->header_len() - sizeof(ipv4::header); + if (options_len > 0 && !pb.skip(options_len)) return false; + auto* uhdr = pb.consume(); + if (!uhdr || uhdr->src_port != PICOMAP_PORT_BE) return false; + if (ip->src == net_get_state().ip) return false; + dispatch_cancel_timer(ts.timer); + ts.timer = nullptr; + ts.frame_cb = nullptr; + auto cont = ts.active_discovery ? ts.active_discovery->on_found : nullptr; + ts.active_discovery = nullptr; + peer_info peer{eth_hdr->src, ip->src}; + if (cont) cont(peer); + return true; +} + +static void discover_timeout_cb() { + net_remove_frame_callback(ts.frame_cb); + ts.frame_cb = nullptr; + ts.timer = nullptr; + auto cont = ts.active_discovery ? ts.active_discovery->on_timeout : nullptr; + ts.active_discovery = nullptr; + if (cont) cont(); +} + +static void discover_peer(discovery_data& d, + void (*found)(const peer_info&), void (*timeout)()) { + d.on_found = found; + d.on_timeout = timeout; + ts.active_discovery = &d; + + const auto& ns = net_get_state(); + eth::mac_addr mcast_mac = igmp::mac_for_ip(PICOMAP_DISCOVERY_GROUP); + + prepend_buffer<4096> buf; + uint8_t* payload = buf.payload_ptr(); + span_writer out(payload, 1024); + RequestInfo req_msg; + auto encoded = encode_response_into(out, 0xFFFF, req_msg); + if (!encoded) { + ts.active_discovery = nullptr; + timeout(); + return; + } + buf.append(*encoded); + + udp::prepend(buf, mcast_mac, ns.mac, ns.ip, PICOMAP_DISCOVERY_GROUP, + PICOMAP_PORT_BE, PICOMAP_PORT_BE, *encoded, 1); + + ts.frame_cb = net_add_frame_callback(discover_reply_cb); + ts.timer = dispatch_schedule_ms(5000, discover_timeout_cb); + + net_send_raw(buf.span()); +} + +static bool igmp_report_cb(std::span frame) { + ipv4::ip4_addr group; + if (!igmp::parse_report(frame, group)) return false; + if (group != PICOMAP_DISCOVERY_GROUP) return false; + ts.frame_cb = nullptr; + test_end({true, {"got IGMP report for " + ipv4::to_string(group)}}); + return true; +} + +static void igmp_timeout_cb() { + ts.timer = nullptr; + test_end({false, {"no IGMP report within 5s"}}); +} + +static void test_discovery_igmp() { + const auto& ns = net_get_state(); + prepend_buffer<4096> buf; + igmp::prepend_query(buf, ns.mac, ns.ip, PICOMAP_DISCOVERY_GROUP); + + ts.frame_cb = net_add_frame_callback(igmp_report_cb); + ts.timer = dispatch_schedule_ms(5000, igmp_timeout_cb); + + net_send_raw(buf.span()); +} + +static void info_found(const peer_info& peer) { + test_end({true, {"got info response from " + ipv4::to_string(peer.ip)}}); +} + +static void info_timeout() { + test_end({false, {"no info response within 5s"}}); +} + +static void test_discovery_info() { + discover_peer(ts.discovery_info.discovery, info_found, info_timeout); +} + +static bool ping_reply_cb(std::span frame) { + ipv4::ip4_addr src_ip; + if (!icmp::parse_echo_reply(frame, src_ip, PING_ECHO_ID)) return false; + ts.frame_cb = nullptr; + if (src_ip == net_get_state().ip) + test_end({false, {"got reply from self: " + ipv4::to_string(src_ip)}}); + else + test_end({true, {"reply from " + ipv4::to_string(src_ip)}}); + return true; +} + +static void ping_timeout_cb() { + ts.timer = nullptr; + test_end({false, {"no reply from non-self host within 5s"}}); +} + +static void start_ping(ipv4::ip4_addr dst_ip) { + const auto& ns = net_get_state(); + prepend_buffer<4096> buf; + icmp::prepend_echo_request(buf, ns.mac, ns.ip, + eth::MAC_BROADCAST, dst_ip, PING_ECHO_ID, 1); + ts.frame_cb = net_add_frame_callback(ping_reply_cb); + ts.timer = dispatch_schedule_ms(5000, ping_timeout_cb); + net_send_raw(buf.span()); +} + +static void test_ping_subnet() { start_ping({169, 254, 255, 255}); } +static void test_ping_global() { start_ping({255, 255, 255, 255}); } + +static size_t ping_rate_frame_size() { + return sizeof(eth::header) + sizeof(ipv4::header) + sizeof(icmp::echo) + + ts.active_rate->payload_len; +} + +static void ping_rate_send_one() { + const auto& ns = net_get_state(); + auto& r = *ts.active_rate; + prepend_buffer<4096> buf; + if (r.payload_len > 0) + memset(buf.append(r.payload_len), 0xAA, r.payload_len); + icmp::prepend_echo_request(buf, ns.mac, ns.ip, + r.peer.mac, r.peer.ip, PING_RATE_ECHO_ID, + r.sent + 1, r.payload_len); + net_send_raw(buf.span()); + r.sent++; +} + +static bool ping_rate_reply_cb(std::span frame) { + ipv4::ip4_addr src_ip; + if (!icmp::parse_echo_reply(frame, src_ip, PING_RATE_ECHO_ID)) return false; + if (src_ip == net_get_state().ip) return false; + + auto& r = *ts.active_rate; + r.received++; + if (r.received >= r.target) { + uint32_t elapsed_us = time_us_32() - r.start_us; + uint32_t elapsed_ms = elapsed_us / 1000; + uint32_t pps = static_cast( + static_cast(r.received) * 1000000 / elapsed_us); + uint64_t total_bytes = static_cast(r.received) * 2 * ping_rate_frame_size(); + uint32_t kbps = static_cast(total_bytes * 1000 / elapsed_us); + char msg[128]; + snprintf(msg, sizeof(msg), + "%u rt in %lu ms, %lu pps, %lu bytes, %lu KB/s", + r.received, static_cast(elapsed_ms), + static_cast(pps), + static_cast(total_bytes), + static_cast(kbps)); + ts.frame_cb = nullptr; + test_end({true, {msg}}); + return true; + } + + if (r.sent < r.target) + ping_rate_send_one(); + return false; +} + +static void ping_rate_timeout_cb() { + ts.timer = nullptr; + auto& r = *ts.active_rate; + uint32_t elapsed_us = time_us_32() - r.start_us; + char msg[64]; + snprintf(msg, sizeof(msg), "timeout after %u/%u rt in %lu ms", + r.received, r.sent, + static_cast(elapsed_us / 1000)); + test_end({false, {msg}}); +} + +static void ping_rate_found(const peer_info& peer) { + auto& r = *ts.active_rate; + r.peer = peer; + r.sent = 0; + r.received = 0; + r.start_us = time_us_32(); + + ts.frame_cb = net_add_frame_callback(ping_rate_reply_cb); + ts.timer = dispatch_schedule_ms(10000, ping_rate_timeout_cb); + + for (uint16_t i = 0; i < r.pipeline && r.sent < r.target; i++) + ping_rate_send_one(); +} + +static void ping_rate_no_peer() { + test_end({false, {"no peer found"}}); +} + +static void start_ping_rate(discovery_data& d, ping_rate_data& r, + uint16_t target, uint16_t payload_len, uint16_t pipeline) { + r.target = target; + r.payload_len = payload_len; + r.pipeline = pipeline; + ts.active_rate = &r; + discover_peer(d, ping_rate_found, ping_rate_no_peer); +} + +static void test_packet_rate() { + start_ping_rate(ts.packet_rate.discovery, ts.packet_rate.rate, 8192, 0, 8); +} +static void test_byte_rate() { + start_ping_rate(ts.byte_rate.discovery, ts.byte_rate.rate, 2048, 1400, 8); +} + +using sync_test_fn = ResponseTest (*)(); +using async_test_fn = void (*)(); + +struct test_entry { + sync_test_fn sync; + async_test_fn async; +}; + +static const std::unordered_map tests = { + {"discovery_igmp", {nullptr, test_discovery_igmp}}, + {"discovery_info", {nullptr, test_discovery_info}}, + {"ping_subnet", {nullptr, test_ping_subnet}}, + {"ping_global", {nullptr, test_ping_global}}, + {"packet_rate", {nullptr, test_packet_rate}}, + {"byte_rate", {nullptr, test_byte_rate}}, +}; + +std::optional handle_list_tests(const responder&, const RequestListTests&) { + ResponseListTests resp; + for (const auto& [name, _] : tests) + resp.names.emplace_back(name); + return resp; +} + +std::optional handle_test(const responder& resp, const RequestTest& req) { + if (ts.in_flight) + return ResponseTest{false, {"test already running"}}; + auto it = tests.find(req.name); + if (it == tests.end()) + return ResponseTest{false, {"unknown test: " + req.name}}; + if (it->second.sync) + return it->second.sync(); + + ts.in_flight = true; + ts.resp = resp; + it->second.async(); + return std::nullopt; +} diff --git a/src/udp.cpp b/src/udp.cpp new file mode 100644 index 0000000..52af99d --- /dev/null +++ b/src/udp.cpp @@ -0,0 +1,54 @@ +#include "udp.h" +#include +#include "eth.h" +#include "ipv4.h" +#include "net.h" +#include "parse_buffer.h" + +namespace udp { + +struct port_entry { + uint16_t port_be; + port_handler fn; +}; +static std::array port_handlers; +static size_t port_handler_count = 0; + +void register_port(uint16_t port_be, port_handler fn) { + if (port_handler_count < port_handlers.size()) + port_handlers[port_handler_count++] = {port_be, fn}; +} + +void handle(std::span frame, span_writer& tx) { + parse_buffer pb(frame); + auto* eth_hdr = pb.consume(); + auto* ip = pb.consume(); + if (!ip) return; + if (!ipv4::addressed_to_us(ip->dst)) return; + + size_t options_len = ip->header_len() - sizeof(ipv4::header); + if (options_len > 0 && !pb.skip(options_len)) return; + + auto* uhdr = pb.consume
(); + if (!uhdr) return; + + size_t udp_len = __builtin_bswap16(uhdr->length); + if (udp_len < sizeof(header)) return; + size_t payload_len = udp_len - sizeof(header); + if (pb.remaining_size() < payload_len) return; + + for (size_t i = 0; i < port_handler_count; i++) { + if (port_handlers[i].port_be == uhdr->dst_port) { + address from{eth_hdr->src, ip->src, uhdr->src_port}; + port_handlers[i].fn(pb.remaining().subspan(0, payload_len), from); + break; + } + } +} + +__attribute__((constructor)) +static void register_protocol() { + ipv4::register_protocol(17, handle); +} + +} // namespace udp diff --git a/w6300/qspi.pio b/w6300/qspi.pio new file mode 100644 index 0000000..fac05ee --- /dev/null +++ b/w6300/qspi.pio @@ -0,0 +1,18 @@ +.program qspi +.side_set 1 + +write_bits: + out pins, 4 side 0 + jmp x-- write_bits side 1 + set pins 0 side 0 +public write_bits_end: +read_byte_delay: + set pindirs 0 side 0 +read_byte: + set x 0 side 1 +read_bits: + in pins, 4 side 0 + jmp x-- read_bits side 1 + in pins, 4 side 0 + jmp y-- read_byte side 0 +public read_bits_end: diff --git a/w6300/w6300.cpp b/w6300/w6300.cpp new file mode 100644 index 0000000..e11c9c4 --- /dev/null +++ b/w6300/w6300.cpp @@ -0,0 +1,650 @@ +#include +#include +#include "pico/stdlib.h" +#include "pico/error.h" +#include "hardware/dma.h" +#include "hardware/clocks.h" +#include "w6300.h" +#include "qspi.pio.h" + +namespace w6300 { + +constexpr int sock_count = 8; + +namespace { + +#define PIO_PROGRAM_NAME qspi +#define PIO_PROGRAM_FUNC __CONCAT(PIO_PROGRAM_NAME, _program) +#define PIO_PROGRAM_GET_DEFAULT_CONFIG_FUNC __CONCAT(PIO_PROGRAM_NAME, _program_get_default_config) +#define PIO_OFFSET_WRITE_BITS_END __CONCAT(PIO_PROGRAM_NAME, _offset_write_bits_end) +#define PIO_OFFSET_READ_BITS_END __CONCAT(PIO_PROGRAM_NAME, _offset_read_bits_end) + +constexpr uint8_t PIN_INT = 15; +constexpr uint8_t PIN_CS = 16; +constexpr uint8_t PIO_SPI_SCK_PIN = 17; +constexpr uint8_t PIO_SPI_DATA_IO0_PIN = 18; +constexpr uint8_t PIO_SPI_DATA_IO1_PIN = 19; +constexpr uint8_t PIO_SPI_DATA_IO2_PIN = 20; +constexpr uint8_t PIO_SPI_DATA_IO3_PIN = 21; +constexpr uint8_t PIN_RST = 22; + +constexpr uint16_t SPI_CLKDIV_MAJOR = 2; +constexpr uint8_t SPI_CLKDIV_MINOR = 0; + +constexpr uint32_t PADS_DRIVE = PADS_BANK0_GPIO0_DRIVE_VALUE_12MA; +constexpr uint32_t IRQ_DELAY_NS = 100; +constexpr uint32_t QSPI_LOOP_CNT = 2; + +struct { + pio_hw_t *pio; + uint8_t pio_func_sel; + int8_t pio_offset; + int8_t pio_sm; + int8_t dma_out; + int8_t dma_in; +} state; + +uint16_t mk_cmd_buf(uint8_t *pdst, uint8_t opcode, uint16_t addr) { + pdst[0] = ((opcode >> 7 & 0x01) << 4) | ((opcode >> 6 & 0x01) << 0); + pdst[1] = ((opcode >> 5 & 0x01) << 4) | ((opcode >> 4 & 0x01) << 0); + pdst[2] = ((opcode >> 3 & 0x01) << 4) | ((opcode >> 2 & 0x01) << 0); + pdst[3] = ((opcode >> 1 & 0x01) << 4) | ((opcode >> 0 & 0x01) << 0); + pdst[4] = (uint8_t)(addr >> 8); + pdst[5] = (uint8_t)(addr); + pdst[6] = 0; + return 7; +} + +uint32_t data_pin_mask() { + return (1u << PIO_SPI_DATA_IO0_PIN) | (1u << PIO_SPI_DATA_IO1_PIN) | + (1u << PIO_SPI_DATA_IO2_PIN) | (1u << PIO_SPI_DATA_IO3_PIN); +} + +__noinline void ns_delay(uint32_t ns) { + uint32_t cycles = ns * (clock_get_hz(clk_sys) >> 16u) / (1000000000u >> 16u); + busy_wait_at_least_cycles(cycles); +} + +void pio_init() { + for (auto pin : {PIO_SPI_DATA_IO0_PIN, PIO_SPI_DATA_IO1_PIN, PIO_SPI_DATA_IO2_PIN, PIO_SPI_DATA_IO3_PIN}) { + gpio_init(pin); + gpio_set_dir(pin, GPIO_OUT); + gpio_put(pin, false); + } + gpio_init(PIN_CS); + gpio_set_dir(PIN_CS, GPIO_OUT); + gpio_put(PIN_CS, true); + gpio_init(PIN_INT); + gpio_set_dir(PIN_INT, GPIO_IN); + gpio_pull_up(PIN_INT); + gpio_set_irq_enabled_with_callback(PIN_INT, GPIO_IRQ_LEVEL_LOW, true, + [](uint gpio, uint32_t){ + irq_pending = true; + gpio_set_irq_enabled(gpio, GPIO_IRQ_LEVEL_LOW, false); + }); + + pio_hw_t *pios[2] = {pio0, pio1}; + uint pio_index = 1; + if (!pio_can_add_program(pios[pio_index], &PIO_PROGRAM_FUNC)) { + pio_index ^= 1; + assert(pio_can_add_program(pios[pio_index], &PIO_PROGRAM_FUNC)); + } + + state.pio = pios[pio_index]; + state.dma_in = -1; + state.dma_out = -1; + + static_assert(GPIO_FUNC_PIO1 == GPIO_FUNC_PIO0 + 1); + state.pio_func_sel = GPIO_FUNC_PIO0 + pio_index; + state.pio_sm = (int8_t)pio_claim_unused_sm(state.pio, true); + state.pio_offset = pio_add_program(state.pio, &PIO_PROGRAM_FUNC); + + pio_sm_config sm_config = PIO_PROGRAM_GET_DEFAULT_CONFIG_FUNC(state.pio_offset); + sm_config_set_clkdiv_int_frac(&sm_config, SPI_CLKDIV_MAJOR, SPI_CLKDIV_MINOR); + + hw_write_masked(&pads_bank0_hw->io[PIO_SPI_SCK_PIN], + (uint)PADS_DRIVE << PADS_BANK0_GPIO0_DRIVE_LSB, + PADS_BANK0_GPIO0_DRIVE_BITS); + hw_write_masked(&pads_bank0_hw->io[PIO_SPI_SCK_PIN], + 1u << PADS_BANK0_GPIO0_SLEWFAST_LSB, + PADS_BANK0_GPIO0_SLEWFAST_BITS); + + sm_config_set_out_pins(&sm_config, PIO_SPI_DATA_IO0_PIN, 4); + sm_config_set_in_pins(&sm_config, PIO_SPI_DATA_IO0_PIN); + sm_config_set_set_pins(&sm_config, PIO_SPI_DATA_IO0_PIN, 4); + sm_config_set_sideset(&sm_config, 1, false, false); + sm_config_set_sideset_pins(&sm_config, PIO_SPI_SCK_PIN); + sm_config_set_in_shift(&sm_config, false, true, 8); + sm_config_set_out_shift(&sm_config, false, true, 8); + hw_set_bits(&state.pio->input_sync_bypass, data_pin_mask()); + pio_sm_set_config(state.pio, state.pio_sm, &sm_config); + pio_sm_set_consecutive_pindirs(state.pio, state.pio_sm, PIO_SPI_SCK_PIN, 1, true); + + for (auto pin : {PIO_SPI_DATA_IO0_PIN, PIO_SPI_DATA_IO1_PIN, PIO_SPI_DATA_IO2_PIN, PIO_SPI_DATA_IO3_PIN}) { + gpio_set_function(pin, (gpio_function_t)state.pio_func_sel); + gpio_set_pulls(pin, false, true); + gpio_set_input_hysteresis_enabled(pin, true); + } + + pio_sm_exec(state.pio, state.pio_sm, pio_encode_set(pio_pins, 1)); + + state.dma_out = (int8_t)dma_claim_unused_channel(true); + state.dma_in = (int8_t)dma_claim_unused_channel(true); +} + + +void pio_frame_start() { + for (auto pin : {PIO_SPI_DATA_IO0_PIN, PIO_SPI_DATA_IO1_PIN, PIO_SPI_DATA_IO2_PIN, PIO_SPI_DATA_IO3_PIN}) + gpio_set_function(pin, (gpio_function_t)state.pio_func_sel); + gpio_set_function(PIO_SPI_SCK_PIN, (gpio_function_t)state.pio_func_sel); + gpio_pull_down(PIO_SPI_SCK_PIN); + gpio_put(PIN_CS, false); +} + +void pio_frame_end() { + gpio_put(PIN_CS, true); + ns_delay(IRQ_DELAY_NS); +} + +void pio_read(uint8_t opcode, uint16_t addr, uint8_t* buf, uint16_t len) { + uint8_t cmd[8] = {}; + uint16_t cmd_len = mk_cmd_buf(cmd, opcode, addr); + + pio_sm_set_enabled(state.pio, state.pio_sm, false); + pio_sm_set_wrap(state.pio, state.pio_sm, state.pio_offset, state.pio_offset + PIO_OFFSET_READ_BITS_END - 1); + pio_sm_clear_fifos(state.pio, state.pio_sm); + pio_sm_set_pindirs_with_mask(state.pio, state.pio_sm, data_pin_mask(), data_pin_mask()); + pio_sm_restart(state.pio, state.pio_sm); + pio_sm_clkdiv_restart(state.pio, state.pio_sm); + + pio_sm_put(state.pio, state.pio_sm, cmd_len * QSPI_LOOP_CNT - 1); + pio_sm_exec(state.pio, state.pio_sm, pio_encode_out(pio_x, 32)); + pio_sm_put(state.pio, state.pio_sm, len - 1); + pio_sm_exec(state.pio, state.pio_sm, pio_encode_out(pio_y, 32)); + pio_sm_exec(state.pio, state.pio_sm, pio_encode_jmp(state.pio_offset)); + + dma_channel_abort(state.dma_out); + dma_channel_abort(state.dma_in); + + dma_channel_config out_cfg = dma_channel_get_default_config(state.dma_out); + channel_config_set_transfer_data_size(&out_cfg, DMA_SIZE_8); + channel_config_set_bswap(&out_cfg, true); + channel_config_set_dreq(&out_cfg, pio_get_dreq(state.pio, state.pio_sm, true)); + dma_channel_configure(state.dma_out, &out_cfg, &state.pio->txf[state.pio_sm], cmd, cmd_len, true); + + dma_channel_config in_cfg = dma_channel_get_default_config(state.dma_in); + channel_config_set_transfer_data_size(&in_cfg, DMA_SIZE_8); + channel_config_set_bswap(&in_cfg, true); + channel_config_set_dreq(&in_cfg, pio_get_dreq(state.pio, state.pio_sm, false)); + channel_config_set_write_increment(&in_cfg, true); + channel_config_set_read_increment(&in_cfg, false); + dma_channel_configure(state.dma_in, &in_cfg, buf, &state.pio->rxf[state.pio_sm], len, true); + + pio_sm_set_enabled(state.pio, state.pio_sm, true); + __compiler_memory_barrier(); + dma_channel_wait_for_finish_blocking(state.dma_out); + dma_channel_wait_for_finish_blocking(state.dma_in); + __compiler_memory_barrier(); + pio_sm_set_enabled(state.pio, state.pio_sm, false); + pio_sm_exec(state.pio, state.pio_sm, pio_encode_mov(pio_pins, pio_null)); +} + +void pio_write(uint8_t opcode, uint16_t addr, uint8_t* buf, uint16_t len) { + uint8_t cmd[8] = {}; + uint16_t cmd_len = mk_cmd_buf(cmd, opcode, addr); + uint16_t total = len + cmd_len; + + pio_sm_set_enabled(state.pio, state.pio_sm, false); + pio_sm_set_wrap(state.pio, state.pio_sm, state.pio_offset, state.pio_offset + PIO_OFFSET_WRITE_BITS_END - 1); + pio_sm_clear_fifos(state.pio, state.pio_sm); + pio_sm_set_pindirs_with_mask(state.pio, state.pio_sm, data_pin_mask(), data_pin_mask()); + pio_sm_restart(state.pio, state.pio_sm); + pio_sm_clkdiv_restart(state.pio, state.pio_sm); + + pio_sm_put(state.pio, state.pio_sm, total * QSPI_LOOP_CNT - 1); + pio_sm_exec(state.pio, state.pio_sm, pio_encode_out(pio_x, 32)); + pio_sm_put(state.pio, state.pio_sm, 0); + pio_sm_exec(state.pio, state.pio_sm, pio_encode_out(pio_y, 32)); + pio_sm_exec(state.pio, state.pio_sm, pio_encode_jmp(state.pio_offset)); + + dma_channel_abort(state.dma_out); + + dma_channel_config out_cfg = dma_channel_get_default_config(state.dma_out); + channel_config_set_transfer_data_size(&out_cfg, DMA_SIZE_8); + channel_config_set_bswap(&out_cfg, true); + channel_config_set_dreq(&out_cfg, pio_get_dreq(state.pio, state.pio_sm, true)); + + pio_sm_set_enabled(state.pio, state.pio_sm, true); + dma_channel_configure(state.dma_out, &out_cfg, &state.pio->txf[state.pio_sm], cmd, cmd_len, true); + dma_channel_wait_for_finish_blocking(state.dma_out); + dma_channel_configure(state.dma_out, &out_cfg, &state.pio->txf[state.pio_sm], buf, len, true); + dma_channel_wait_for_finish_blocking(state.dma_out); + + const uint32_t stall = 1u << (PIO_FDEBUG_TXSTALL_LSB + state.pio_sm); + state.pio->fdebug = stall; + while (!(state.pio->fdebug & stall)) tight_loop_contents(); + + __compiler_memory_barrier(); + pio_sm_set_consecutive_pindirs(state.pio, state.pio_sm, PIO_SPI_DATA_IO0_PIN, 4, false); + pio_sm_exec(state.pio, state.pio_sm, pio_encode_mov(pio_pins, pio_null)); + pio_sm_set_enabled(state.pio, state.pio_sm, false); +} + +using datasize_t = int16_t; + + +constexpr uint8_t QSPI_MODE = 0x02 << 6; + + +constexpr uint8_t PACK_NONE = 0x00; +constexpr uint8_t PACK_FIRST = 1 << 1; +constexpr uint8_t PACK_REMAINED = 1 << 2; +constexpr uint8_t PACK_COMPLETED = 1 << 3; + +constexpr uint8_t SPI_READ = (0x00 << 5); +constexpr uint8_t SPI_WRITE = (0x01 << 5); + +constexpr uint32_t CREG_BLOCK = 0x00; +constexpr uint32_t SREG_BLOCK(uint8_t n) { return 1 + 4 * n; } +constexpr uint32_t TXBUF_BLOCK(uint8_t n) { return 2 + 4 * n; } +constexpr uint32_t RXBUF_BLOCK(uint8_t n) { return 3 + 4 * n; } + +constexpr uint32_t offset_inc(uint32_t addr, uint32_t n) { return addr + (n << 8); } + +constexpr uint32_t REG_CIDR = (0x0000 << 8) + CREG_BLOCK; +constexpr uint32_t REG_RTL = (0x0004 << 8) + CREG_BLOCK; +constexpr uint32_t REG_SYSR = (0x2000 << 8) + CREG_BLOCK; +constexpr uint32_t REG_SYCR0 = (0x2004 << 8) + CREG_BLOCK; +constexpr uint32_t REG_IMR = (0x2104 << 8) + CREG_BLOCK; +constexpr uint32_t REG_IRCLR = (0x2108 << 8) + CREG_BLOCK; +constexpr uint32_t REG_SIMR = (0x2114 << 8) + CREG_BLOCK; +constexpr uint32_t REG_SLIMR = (0x2124 << 8) + CREG_BLOCK; +constexpr uint32_t REG_SLIRCLR = (0x2128 << 8) + CREG_BLOCK; +constexpr uint32_t REG_SHAR = (0x4120 << 8) + CREG_BLOCK; +constexpr uint32_t REG_GAR = (0x4130 << 8) + CREG_BLOCK; +constexpr uint32_t REG_SUBR = (0x4134 << 8) + CREG_BLOCK; +constexpr uint32_t REG_SIPR = (0x4138 << 8) + CREG_BLOCK; +constexpr uint32_t REG_LLAR = (0x4140 << 8) + CREG_BLOCK; +constexpr uint32_t REG_GUAR = (0x4150 << 8) + CREG_BLOCK; +constexpr uint32_t REG_SUB6R = (0x4160 << 8) + CREG_BLOCK; +constexpr uint32_t REG_GA6R = (0x4170 << 8) + CREG_BLOCK; +constexpr uint32_t REG_CHPLCKR = (0x41F4 << 8) + CREG_BLOCK; +constexpr uint32_t REG_NETLCKR = (0x41F5 << 8) + CREG_BLOCK; + +constexpr uint32_t REG_SN_MR(uint8_t n) { return (0x0000 << 8) + SREG_BLOCK(n); } +constexpr uint32_t REG_SN_CR(uint8_t n) { return (0x0010 << 8) + SREG_BLOCK(n); } +constexpr uint32_t REG_SN_IR(uint8_t n) { return (0x0020 << 8) + SREG_BLOCK(n); } +constexpr uint32_t REG_SN_IRCLR(uint8_t n) { return (0x0028 << 8) + SREG_BLOCK(n); } +constexpr uint32_t REG_SN_SR(uint8_t n) { return (0x0030 << 8) + SREG_BLOCK(n); } +constexpr uint32_t REG_SN_MR2(uint8_t n) { return (0x0144 << 8) + SREG_BLOCK(n); } +constexpr uint32_t REG_SN_TX_BSR(uint8_t n) { return (0x0200 << 8) + SREG_BLOCK(n); } +constexpr uint32_t REG_SN_TX_FSR(uint8_t n) { return (0x0204 << 8) + SREG_BLOCK(n); } +constexpr uint32_t REG_SN_TX_WR(uint8_t n) { return (0x020C << 8) + SREG_BLOCK(n); } +constexpr uint32_t REG_SN_RX_BSR(uint8_t n) { return (0x0220 << 8) + SREG_BLOCK(n); } +constexpr uint32_t REG_SN_RX_RSR(uint8_t n) { return (0x0224 << 8) + SREG_BLOCK(n); } +constexpr uint32_t REG_SN_RX_RD(uint8_t n) { return (0x0228 << 8) + SREG_BLOCK(n); } + +constexpr uint8_t SYSR_CHPL = 1 << 7; +constexpr uint8_t SYSR_NETL = 1 << 6; +constexpr uint8_t SYCR0_RST = 0x00; +constexpr uint8_t SN_MR_MACRAW = 0x07; +constexpr uint8_t SN_CR_OPEN = 0x01; +constexpr uint8_t SN_CR_CLOSE = 0x10; +constexpr uint8_t SN_CR_SEND = 0x20; +constexpr uint8_t SN_CR_RECV = 0x40; +constexpr uint8_t SN_IR_SENDOK = 0x10; +constexpr uint8_t SN_IR_TIMEOUT = 0x08; +constexpr uint8_t SOCK_CLOSED = 0x00; + +uint8_t reg_read(uint32_t addr_sel); +void reg_write(uint32_t addr_sel, uint8_t wb); +void reg_read_buf(uint32_t addr_sel, uint8_t* buf, datasize_t len); +void reg_write_buf(uint32_t addr_sel, uint8_t* buf, datasize_t len); +uint16_t get_sn_tx_fsr(uint8_t sn); +uint16_t get_sn_rx_rsr(uint8_t sn); + +uint16_t get_cidr() { return (((uint16_t)reg_read(REG_CIDR) | (((reg_read(REG_RTL)) & 0x0F) << 1)) << 8) + reg_read(offset_inc(REG_CIDR, 1)); } +uint8_t get_sysr() { return reg_read(REG_SYSR); } +uint8_t get_sycr0() { return reg_read(REG_SYCR0); } +void set_sycr0(uint8_t v) { reg_write(REG_SYCR0, v); } +void set_imr(uint8_t v) { reg_write(REG_IMR, v); } +void set_irclr(uint8_t v) { reg_write(REG_IRCLR, v); } +void set_simr(uint8_t v) { reg_write(REG_SIMR, v); } +void set_slimr(uint8_t v) { reg_write(REG_SLIMR, v); } +void set_slirclr(uint8_t v) { reg_write(REG_SLIRCLR, v); } +void set_shar(uint8_t* v) { reg_write_buf(REG_SHAR, v, 6); } +void get_shar(uint8_t* v) { reg_read_buf(REG_SHAR, v, 6); } +void set_gar(uint8_t* v) { reg_write_buf(REG_GAR, v, 4); } +void get_gar(uint8_t* v) { reg_read_buf(REG_GAR, v, 4); } +void set_subr(uint8_t* v) { reg_write_buf(REG_SUBR, v, 4); } +void get_subr(uint8_t* v) { reg_read_buf(REG_SUBR, v, 4); } +void set_sipr(uint8_t* v) { reg_write_buf(REG_SIPR, v, 4); } +void get_sipr(uint8_t* v) { reg_read_buf(REG_SIPR, v, 4); } +void set_llar(uint8_t* v) { reg_write_buf(REG_LLAR, v, 16); } +void get_llar(uint8_t* v) { reg_read_buf(REG_LLAR, v, 16); } +void set_guar(uint8_t* v) { reg_write_buf(REG_GUAR, v, 16); } +void get_guar(uint8_t* v) { reg_read_buf(REG_GUAR, v, 16); } +void set_sub6r(uint8_t* v) { reg_write_buf(REG_SUB6R, v, 16); } +void get_sub6r(uint8_t* v) { reg_read_buf(REG_SUB6R, v, 16); } +void set_ga6r(uint8_t* v) { reg_write_buf(REG_GA6R, v, 16); } +void get_ga6r(uint8_t* v) { reg_read_buf(REG_GA6R, v, 16); } +void set_chplckr(uint8_t v) { reg_write(REG_CHPLCKR, v); } +void chip_lock() { set_chplckr(0xFF); } +void chip_unlock() { set_chplckr(0xCE); } +void set_netlckr(uint8_t v) { reg_write(REG_NETLCKR, v); } +void net_lock() { set_netlckr(0xC5); } +void net_unlock() { set_netlckr(0x3A); } + +void set_sn_mr(uint8_t sn, uint8_t v) { reg_write(REG_SN_MR(sn), v); } +void set_sn_cr(uint8_t sn, uint8_t v) { reg_write(REG_SN_CR(sn), v); } +uint8_t get_sn_cr(uint8_t sn) { return reg_read(REG_SN_CR(sn)); } +uint8_t get_sn_ir(uint8_t sn) { return reg_read(REG_SN_IR(sn)); } +void set_sn_irclr(uint8_t sn, uint8_t v) { reg_write(REG_SN_IRCLR(sn), v); } +void set_sn_ir(uint8_t sn, uint8_t v) { set_sn_irclr(sn, v); } +uint8_t get_sn_sr(uint8_t sn) { return reg_read(REG_SN_SR(sn)); } +void set_sn_mr2(uint8_t sn, uint8_t v) { reg_write(REG_SN_MR2(sn), v); } +void set_sn_tx_bsr(uint8_t sn, uint8_t v) { reg_write(REG_SN_TX_BSR(sn), v); } +void set_sn_txbuf_size(uint8_t sn, uint8_t v) { set_sn_tx_bsr(sn, v); } +uint8_t get_sn_tx_bsr(uint8_t sn) { return reg_read(REG_SN_TX_BSR(sn)); } +uint16_t get_sn_tx_max(uint8_t sn) { return get_sn_tx_bsr(sn) << 10; } +uint16_t get_sn_tx_wr(uint8_t sn) { return ((uint16_t)reg_read(REG_SN_TX_WR(sn)) << 8) + reg_read(offset_inc(REG_SN_TX_WR(sn), 1)); } +void set_sn_tx_wr(uint8_t sn, uint16_t v) { + reg_write(REG_SN_TX_WR(sn), (uint8_t)(v >> 8)); + reg_write(offset_inc(REG_SN_TX_WR(sn), 1), (uint8_t)v); +} +void set_sn_rx_bsr(uint8_t sn, uint8_t v) { reg_write(REG_SN_RX_BSR(sn), v); } +void set_sn_rxbuf_size(uint8_t sn, uint8_t v) { set_sn_rx_bsr(sn, v); } +void set_sn_rx_rd(uint8_t sn, uint16_t v) { + reg_write(REG_SN_RX_RD(sn), (uint8_t)(v >> 8)); + reg_write(offset_inc(REG_SN_RX_RD(sn), 1), (uint8_t)v); +} +uint16_t get_sn_rx_rd(uint8_t sn) { return ((uint16_t)reg_read(REG_SN_RX_RD(sn)) << 8) + reg_read(offset_inc(REG_SN_RX_RD(sn), 1)); } +static uint8_t make_opcode(uint32_t addr, uint8_t rw) { + return static_cast((addr & 0xFF) | rw | QSPI_MODE); +} + +static uint16_t make_addr(uint32_t addr) { + return static_cast((addr & 0x00FFFF00) >> 8); +} + +void reg_write(uint32_t addr_sel, uint8_t wb) { + pio_frame_start(); + pio_write(make_opcode(addr_sel, SPI_WRITE), make_addr(addr_sel), &wb, 1); + pio_frame_end(); +} + +uint8_t reg_read(uint32_t addr_sel) { + uint8_t ret[2] = {0}; + pio_frame_start(); + pio_read(make_opcode(addr_sel, SPI_READ), make_addr(addr_sel), ret, 1); + pio_frame_end(); + return ret[0]; +} + +void reg_write_buf(uint32_t addr_sel, uint8_t* buf, datasize_t len) { + pio_frame_start(); + pio_write(make_opcode(addr_sel, SPI_WRITE), make_addr(addr_sel), buf, len); + pio_frame_end(); +} + +void reg_read_buf(uint32_t addr_sel, uint8_t* buf, datasize_t len) { + pio_frame_start(); + pio_read(make_opcode(addr_sel, SPI_READ), make_addr(addr_sel), buf, len); + pio_frame_end(); +} + +uint16_t get_sn_tx_fsr(uint8_t sn) { + uint16_t prev_val = -1, val = 0; + do { + prev_val = val; + val = reg_read(REG_SN_TX_FSR(sn)); + val = (val << 8) + reg_read(offset_inc(REG_SN_TX_FSR(sn), 1)); + } while (val != prev_val); + return val; +} + +uint16_t get_sn_rx_rsr(uint8_t sn) { + uint16_t prev_val = -1, val = 0; + do { + prev_val = val; + val = reg_read(REG_SN_RX_RSR(sn)); + val = (val << 8) + reg_read(offset_inc(REG_SN_RX_RSR(sn), 1)); + } while (val != prev_val); + return val; +} + +void send_data(uint8_t sn, uint8_t *data, uint16_t len) { + uint16_t ptr = get_sn_tx_wr(sn); + uint32_t addrsel = ((uint32_t)ptr << 8) + TXBUF_BLOCK(sn); + reg_write_buf(addrsel, data, len); + ptr += len; + set_sn_tx_wr(sn, ptr); +} + +void recv_data(uint8_t sn, uint8_t *data, uint16_t len) { + if (len == 0) return; + uint16_t ptr = get_sn_rx_rd(sn); + uint32_t addrsel = ((uint32_t)ptr << 8) + RXBUF_BLOCK(sn); + reg_read_buf(addrsel, data, len); + ptr += len; + set_sn_rx_rd(sn, ptr); +} + +void soft_reset() { + uint8_t gw[4], sn[4], sip[4], mac[6]; + uint8_t gw6[16], sn6[16], lla[16], gua[16]; + uint8_t islock = get_sysr(); + + chip_unlock(); + get_shar(mac); get_gar(gw); get_subr(sn); get_sipr(sip); + get_ga6r(gw6); get_sub6r(sn6); get_llar(lla); get_guar(gua); + set_sycr0(SYCR0_RST); + get_sycr0(); + + net_unlock(); + set_shar(mac); set_gar(gw); set_subr(sn); set_sipr(sip); + set_ga6r(gw6); set_sub6r(sn6); set_llar(lla); set_guar(gua); + + if (islock & SYSR_CHPL) chip_lock(); + if (islock & SYSR_NETL) net_lock(); +} + +int8_t init_buffers(std::span txsize, std::span rxsize) { + soft_reset(); + if (!txsize.empty()) { + int8_t tmp = 0; + for (int i = 0; i < sock_count; i++) { + tmp += txsize[i]; + if (tmp > 32) return -1; + } + for (int i = 0; i < sock_count; i++) set_sn_txbuf_size(i, txsize[i]); + } + if (!rxsize.empty()) { + int8_t tmp = 0; + for (int i = 0; i < sock_count; i++) { + tmp += rxsize[i]; + if (tmp > 32) return -1; + } + for (int i = 0; i < sock_count; i++) set_sn_rxbuf_size(i, rxsize[i]); + } + return 0; +} + +uint16_t sock_is_sending = 0; +uint16_t sock_remained_size[sock_count] = {0,}; +uint8_t sock_pack_info[sock_count] = {0,}; + +#define FAIL(e) return std::unexpected(sock_error::e) +#define CHECK_SOCKNUM() do { if(sn >= sock_count) FAIL(sock_num); } while(0) +#define CHECK_SOCKDATA() do { if(len == 0) FAIL(data_len); } while(0) + +std::expected close(socket_id sid) { + uint8_t sn = static_cast(sid); + CHECK_SOCKNUM(); + set_sn_cr(sn, SN_CR_CLOSE); + while (get_sn_cr(sn)); + set_sn_ir(sn, 0xFF); + sock_is_sending &= ~(1 << sn); + sock_remained_size[sn] = 0; + sock_pack_info[sn] = PACK_NONE; + while (get_sn_sr(sn) != SOCK_CLOSED); + return {}; +} + +} // namespace + +volatile bool irq_pending = false; + +void clear_interrupt(intr_kind intr) { + set_irclr((uint8_t)intr); + uint8_t sir = (uint8_t)((uint16_t)intr >> 8); + for (int i = 0; i < sock_count; i++) + if (sir & (1 << i)) set_sn_irclr(i, 0xFF); + set_slirclr((uint8_t)((uint32_t)intr >> 16)); +} + +void rearm_gpio_irq() { + gpio_set_irq_enabled(PIN_INT, GPIO_IRQ_LEVEL_LOW, true); +} + +void set_interrupt_mask(intr_kind intr) { + set_imr((uint8_t)intr); + set_simr((uint8_t)((uint16_t)intr >> 8)); + set_slimr((uint8_t)((uint32_t)intr >> 16)); +} + + +std::expected open_socket(socket_id sid, protocol proto, sock_flag flag) { + uint8_t sn = static_cast(sid); + uint8_t pr = static_cast(proto); + uint8_t fl = static_cast(flag); + CHECK_SOCKNUM(); + if ((pr & 0x0F) != SN_MR_MACRAW) FAIL(sock_mode); + close(sid); + set_sn_mr(sn, (pr | (fl & 0xF0))); + set_sn_mr2(sn, fl & 0x03); + set_sn_cr(sn, SN_CR_OPEN); + while (get_sn_cr(sn)); + sock_is_sending &= ~(1 << sn); + sock_remained_size[sn] = 0; + sock_pack_info[sn] = PACK_COMPLETED; + while (get_sn_sr(sn) == SOCK_CLOSED); + return sid; +} + +std::expected send(socket_id sid, std::span buf) { + uint8_t sn = static_cast(sid); + uint16_t len = buf.size(); + uint8_t tmp = 0; + uint16_t freesize = 0; + + CHECK_SOCKNUM(); + + freesize = get_sn_tx_max(sn); + if (len > freesize) len = freesize; + while (1) { + freesize = get_sn_tx_fsr(sn); + if (get_sn_sr(sn) == SOCK_CLOSED) FAIL(sock_closed); + if (len <= freesize) break; + }; + send_data(sn, const_cast(buf.data()), len); + set_sn_cr(sn, SN_CR_SEND); + while (get_sn_cr(sn)); + while (1) { + tmp = get_sn_ir(sn); + if (tmp & SN_IR_SENDOK) { + set_sn_ir(sn, SN_IR_SENDOK); + break; + } else if (tmp & SN_IR_TIMEOUT) { + set_sn_ir(sn, SN_IR_TIMEOUT); + FAIL(timeout); + } + } + return len; +} + +std::expected recv(socket_id sid, std::span buf) { + uint8_t sn = static_cast(sid); + uint16_t len = buf.size(); + uint8_t head[2]; + uint16_t pack_len = 0; + + CHECK_SOCKNUM(); + CHECK_SOCKDATA(); + + if (sock_remained_size[sn] == 0) { + while (1) { + pack_len = get_sn_rx_rsr(sn); + if (get_sn_sr(sn) == SOCK_CLOSED) FAIL(sock_closed); + if (pack_len != 0) { + sock_pack_info[sn] = PACK_NONE; + break; + } + }; + } + + recv_data(sn, head, 2); + set_sn_cr(sn, SN_CR_RECV); + while (get_sn_cr(sn)); + + if (sock_remained_size[sn] == 0) { + sock_remained_size[sn] = head[0]; + sock_remained_size[sn] = (sock_remained_size[sn] << 8) + head[1] - 2; + if (sock_remained_size[sn] > 1514) { + close(sid); + FAIL(fatal_packlen); + } + sock_pack_info[sn] = PACK_FIRST; + } + if (len < sock_remained_size[sn]) pack_len = len; + else pack_len = sock_remained_size[sn]; + recv_data(sn, buf.data(), pack_len); + + sock_remained_size[sn] = pack_len; + sock_pack_info[sn] |= PACK_FIRST; + + if (len < sock_remained_size[sn]) pack_len = len; + else pack_len = sock_remained_size[sn]; + recv_data(sn, buf.data(), pack_len); + set_sn_cr(sn, SN_CR_RECV); + while (get_sn_cr(sn)); + + sock_remained_size[sn] -= pack_len; + if (sock_remained_size[sn] != 0) sock_pack_info[sn] |= PACK_REMAINED; + else sock_pack_info[sn] |= PACK_COMPLETED; + + return pack_len; +} + + +void reset() { + gpio_init(PIN_RST); + gpio_set_dir(PIN_RST, GPIO_OUT); + gpio_put(PIN_RST, 0); + sleep_ms(100); + gpio_put(PIN_RST, 1); + sleep_ms(100); +} + +void init_spi() { + pio_init(); +} + +void init() { + pio_frame_end(); + std::array txsize = {32, 0, 0, 0, 0, 0, 0, 0}; + std::array rxsize = {32, 0, 0, 0, 0, 0, 0, 0}; + init_buffers(txsize, rxsize); +} + +bool check() { + return get_cidr() == 0x6300; +} + + + +uint16_t get_socket_recv_buf(socket_id sid) { + return get_sn_rx_rsr(static_cast(sid)); +} + +} // namespace w6300 diff --git a/w6300/w6300.h b/w6300/w6300.h new file mode 100644 index 0000000..ccd70e2 --- /dev/null +++ b/w6300/w6300.h @@ -0,0 +1,52 @@ +#pragma once +#include +#include +#include +#include + +namespace w6300 { + +enum class socket_id : uint8_t {}; + +enum class sock_error : int16_t { + busy = 0, + sock_num = -1, + sock_closed = -4, + sock_mode = -5, + arg = -10, + timeout = -13, + data_len = -14, + fatal_packlen = -1001, +}; + +enum class protocol : uint8_t { + macraw = 0x07, +}; + +enum class sock_flag : uint8_t { + none = 0, +}; + +enum intr_kind : uint32_t { + ik_sock_0 = (1 << 8), + ik_int_all = 0x00FFFF97 +}; + +void init_spi(); +void reset(); +void init(); +bool check(); + +extern volatile bool irq_pending; + +void clear_interrupt(intr_kind intr); +void set_interrupt_mask(intr_kind intr); +void rearm_gpio_irq(); + +std::expected open_socket(socket_id sn, protocol proto, sock_flag flag); +std::expected send(socket_id sn, std::span buf); +std::expected recv(socket_id sn, std::span buf); + +uint16_t get_socket_recv_buf(socket_id sn); + +} // namespace w6300