Zero-copy TX: span_writer packer, static buffers, no vector returns

This commit is contained in:
Ian Gulliver
2026-04-10 22:18:44 +09:00
parent 94895fd2fe
commit e2a5d97dae
10 changed files with 173 additions and 133 deletions

View File

@@ -2,10 +2,9 @@
#include <cstdint> #include <cstdint>
#include <functional> #include <functional>
#include <span> #include <span>
#include <vector>
#include "wire.h" #include "wire.h"
using handler_fn = std::vector<std::vector<uint8_t>> (*)(uint32_t message_id, std::span<const uint8_t> payload); using handler_fn = size_t (*)(uint32_t message_id, std::span<const uint8_t> payload, span_writer &out);
struct handler_entry { struct handler_entry {
int8_t type_id; int8_t type_id;
@@ -13,16 +12,16 @@ struct handler_entry {
}; };
template <typename Req, auto Fn> template <typename Req, auto Fn>
std::vector<std::vector<uint8_t>> typed_handler(uint32_t message_id, std::span<const uint8_t> payload) { size_t typed_handler(uint32_t message_id, std::span<const uint8_t> payload, span_writer &out) {
msgpack::parser p(payload.data(), static_cast<int>(payload.size())); msgpack::parser p(payload.data(), static_cast<int>(payload.size()));
Req req; Req req;
auto tup = req.as_tuple(); auto tup = req.as_tuple();
auto r = msgpack::unpack(p, tup); auto r = msgpack::unpack(p, tup);
if (!r) { if (!r) {
return {encode_response(message_id, DeviceError{1, "decode request ext_id=" + return encode_response_into(out, message_id, DeviceError{1, "decode request ext_id=" +
std::to_string(Req::ext_id) + ": msgpack error " + std::to_string(static_cast<int>(r.error()))})}; std::to_string(Req::ext_id) + ": msgpack error " + std::to_string(static_cast<int>(r.error()))});
} }
return Fn(message_id, req); return Fn(message_id, req, out);
} }
void dispatch_init(); void dispatch_init();

View File

@@ -2,11 +2,10 @@
#include <cstdint> #include <cstdint>
#include <span> #include <span>
#include <string_view> #include <string_view>
#include <vector>
#include "wire.h" #include "wire.h"
extern std::string_view firmware_name; extern std::string_view firmware_name;
std::vector<std::vector<uint8_t>> handle_picoboot(uint32_t message_id, std::span<const uint8_t> payload); size_t handle_picoboot(uint32_t message_id, std::span<const uint8_t> payload, span_writer &out);
std::vector<std::vector<uint8_t>> handle_info(uint32_t message_id, std::span<const uint8_t> payload); size_t handle_info(uint32_t message_id, std::span<const uint8_t> payload, span_writer &out);
std::vector<std::vector<uint8_t>> handle_log(uint32_t message_id, std::span<const uint8_t> payload); size_t handle_log(uint32_t message_id, std::span<const uint8_t> payload, span_writer &out);

View File

@@ -6,11 +6,11 @@
#include <expected> #include <expected>
#include <iterator> #include <iterator>
#include <limits> #include <limits>
#include <memory>
#include <string_view> #include <string_view>
#include <tuple> #include <tuple>
#include <type_traits> #include <type_traits>
#include <vector> #include <vector>
#include "span_writer.h"
namespace msgpack { namespace msgpack {
@@ -163,26 +163,23 @@ inline result<body_info> get_body_info(const uint8_t *p, int size) {
} }
class packer { class packer {
public:
using buffer = std::vector<std::uint8_t>;
private: private:
std::shared_ptr<buffer> m_buffer; span_writer m_buf;
template <typename T> void push_big_endian(T n) { template <typename T> void push_big_endian(T n) {
auto p = reinterpret_cast<std::uint8_t *>(&n) + (sizeof(T) - 1); auto p = reinterpret_cast<std::uint8_t *>(&n) + (sizeof(T) - 1);
for (size_t i = 0; i < sizeof(T); ++i, --p) { for (size_t i = 0; i < sizeof(T); ++i, --p) {
m_buffer->push_back(*p); m_buf.push_back(*p);
} }
} }
template <class Range> void push(const Range &r) { template <class Range> void push(const Range &r) {
m_buffer->insert(m_buffer->end(), std::begin(r), std::end(r)); m_buf.insert(m_buf.end(), std::begin(r), std::end(r));
} }
public: public:
packer() : m_buffer(std::make_shared<buffer>()) {} packer(uint8_t *data, size_t capacity) : m_buf(data, capacity) {}
packer(const std::shared_ptr<buffer> &buf) : m_buffer(buf) {} packer(span_writer buf) : m_buf(buf) {}
packer(const packer &) = delete; packer(const packer &) = delete;
packer &operator=(const packer &) = delete; packer &operator=(const packer &) = delete;
@@ -190,12 +187,12 @@ public:
using pack_result = result<std::reference_wrapper<packer>>; using pack_result = result<std::reference_wrapper<packer>>;
pack_result pack_nil() { pack_result pack_nil() {
m_buffer->push_back(format::NIL); m_buf.push_back(format::NIL);
return *this; return *this;
} }
pack_result pack_bool(bool v) { pack_result pack_bool(bool v) {
m_buffer->push_back(v ? format::TRUE : format::FALSE); m_buf.push_back(v ? format::TRUE : format::FALSE);
return *this; return *this;
} }
@@ -203,36 +200,36 @@ public:
pack_result pack_integer(T n) { pack_result pack_integer(T n) {
if constexpr (std::is_signed_v<T>) { if constexpr (std::is_signed_v<T>) {
if (n >= 0 && n <= 0x7F) { if (n >= 0 && n <= 0x7F) {
m_buffer->push_back(static_cast<uint8_t>(n)); m_buf.push_back(static_cast<uint8_t>(n));
} else if (n >= -32 && n < 0) { } else if (n >= -32 && n < 0) {
m_buffer->push_back(static_cast<uint8_t>(n)); // negative fixint m_buf.push_back(static_cast<uint8_t>(n)); // negative fixint
} else if (n >= std::numeric_limits<int8_t>::min() && n <= std::numeric_limits<int8_t>::max()) { } else if (n >= std::numeric_limits<int8_t>::min() && n <= std::numeric_limits<int8_t>::max()) {
m_buffer->push_back(format::INT8); m_buf.push_back(format::INT8);
m_buffer->push_back(static_cast<uint8_t>(n)); m_buf.push_back(static_cast<uint8_t>(n));
} else if (n >= std::numeric_limits<int16_t>::min() && n <= std::numeric_limits<int16_t>::max()) { } else if (n >= std::numeric_limits<int16_t>::min() && n <= std::numeric_limits<int16_t>::max()) {
m_buffer->push_back(format::INT16); m_buf.push_back(format::INT16);
push_big_endian(static_cast<int16_t>(n)); push_big_endian(static_cast<int16_t>(n));
} else if (n >= std::numeric_limits<int32_t>::min() && n <= std::numeric_limits<int32_t>::max()) { } else if (n >= std::numeric_limits<int32_t>::min() && n <= std::numeric_limits<int32_t>::max()) {
m_buffer->push_back(format::INT32); m_buf.push_back(format::INT32);
push_big_endian(static_cast<int32_t>(n)); push_big_endian(static_cast<int32_t>(n));
} else { } else {
m_buffer->push_back(format::INT64); m_buf.push_back(format::INT64);
push_big_endian(static_cast<int64_t>(n)); push_big_endian(static_cast<int64_t>(n));
} }
} else { } else {
if (n <= 0x7F) { if (n <= 0x7F) {
m_buffer->push_back(static_cast<uint8_t>(n)); m_buf.push_back(static_cast<uint8_t>(n));
} else if (n <= std::numeric_limits<uint8_t>::max()) { } else if (n <= std::numeric_limits<uint8_t>::max()) {
m_buffer->push_back(format::UINT8); m_buf.push_back(format::UINT8);
m_buffer->push_back(static_cast<uint8_t>(n)); m_buf.push_back(static_cast<uint8_t>(n));
} else if (n <= std::numeric_limits<uint16_t>::max()) { } else if (n <= std::numeric_limits<uint16_t>::max()) {
m_buffer->push_back(format::UINT16); m_buf.push_back(format::UINT16);
push_big_endian(static_cast<uint16_t>(n)); push_big_endian(static_cast<uint16_t>(n));
} else if (n <= std::numeric_limits<uint32_t>::max()) { } else if (n <= std::numeric_limits<uint32_t>::max()) {
m_buffer->push_back(format::UINT32); m_buf.push_back(format::UINT32);
push_big_endian(static_cast<uint32_t>(n)); push_big_endian(static_cast<uint32_t>(n));
} else { } else {
m_buffer->push_back(format::UINT64); m_buf.push_back(format::UINT64);
push_big_endian(static_cast<uint64_t>(n)); push_big_endian(static_cast<uint64_t>(n));
} }
} }
@@ -240,13 +237,13 @@ public:
} }
pack_result pack_float(float n) { pack_result pack_float(float n) {
m_buffer->push_back(format::FLOAT32); m_buf.push_back(format::FLOAT32);
push_big_endian(n); push_big_endian(n);
return *this; return *this;
} }
pack_result pack_double(double n) { pack_result pack_double(double n) {
m_buffer->push_back(format::FLOAT64); m_buf.push_back(format::FLOAT64);
push_big_endian(n); push_big_endian(n);
return *this; return *this;
} }
@@ -255,15 +252,15 @@ public:
pack_result pack_str(const Range &r) { pack_result pack_str(const Range &r) {
auto sz = static_cast<size_t>(std::distance(std::begin(r), std::end(r))); auto sz = static_cast<size_t>(std::distance(std::begin(r), std::end(r)));
if (sz < 32) { if (sz < 32) {
m_buffer->push_back(format::FIXSTR_MIN | static_cast<uint8_t>(sz)); m_buf.push_back(format::FIXSTR_MIN | static_cast<uint8_t>(sz));
} else if (sz <= std::numeric_limits<uint8_t>::max()) { } else if (sz <= std::numeric_limits<uint8_t>::max()) {
m_buffer->push_back(format::STR8); m_buf.push_back(format::STR8);
m_buffer->push_back(static_cast<uint8_t>(sz)); m_buf.push_back(static_cast<uint8_t>(sz));
} else if (sz <= std::numeric_limits<uint16_t>::max()) { } else if (sz <= std::numeric_limits<uint16_t>::max()) {
m_buffer->push_back(format::STR16); m_buf.push_back(format::STR16);
push_big_endian(static_cast<uint16_t>(sz)); push_big_endian(static_cast<uint16_t>(sz));
} else if (sz <= std::numeric_limits<uint32_t>::max()) { } else if (sz <= std::numeric_limits<uint32_t>::max()) {
m_buffer->push_back(format::STR32); m_buf.push_back(format::STR32);
push_big_endian(static_cast<uint32_t>(sz)); push_big_endian(static_cast<uint32_t>(sz));
} else { } else {
return std::unexpected(error_code::overflow); return std::unexpected(error_code::overflow);
@@ -280,13 +277,13 @@ public:
pack_result pack_bin(const Range &r) { pack_result pack_bin(const Range &r) {
auto sz = static_cast<size_t>(std::distance(std::begin(r), std::end(r))); auto sz = static_cast<size_t>(std::distance(std::begin(r), std::end(r)));
if (sz <= std::numeric_limits<uint8_t>::max()) { if (sz <= std::numeric_limits<uint8_t>::max()) {
m_buffer->push_back(format::BIN8); m_buf.push_back(format::BIN8);
m_buffer->push_back(static_cast<uint8_t>(sz)); m_buf.push_back(static_cast<uint8_t>(sz));
} else if (sz <= std::numeric_limits<uint16_t>::max()) { } else if (sz <= std::numeric_limits<uint16_t>::max()) {
m_buffer->push_back(format::BIN16); m_buf.push_back(format::BIN16);
push_big_endian(static_cast<uint16_t>(sz)); push_big_endian(static_cast<uint16_t>(sz));
} else if (sz <= std::numeric_limits<uint32_t>::max()) { } else if (sz <= std::numeric_limits<uint32_t>::max()) {
m_buffer->push_back(format::BIN32); m_buf.push_back(format::BIN32);
push_big_endian(static_cast<uint32_t>(sz)); push_big_endian(static_cast<uint32_t>(sz));
} else { } else {
return std::unexpected(error_code::overflow); return std::unexpected(error_code::overflow);
@@ -297,12 +294,12 @@ public:
pack_result pack_array(size_t n) { pack_result pack_array(size_t n) {
if (n <= 15) { if (n <= 15) {
m_buffer->push_back(format::FIXARRAY_MIN | static_cast<uint8_t>(n)); m_buf.push_back(format::FIXARRAY_MIN | static_cast<uint8_t>(n));
} else if (n <= std::numeric_limits<uint16_t>::max()) { } else if (n <= std::numeric_limits<uint16_t>::max()) {
m_buffer->push_back(format::ARRAY16); m_buf.push_back(format::ARRAY16);
push_big_endian(static_cast<uint16_t>(n)); push_big_endian(static_cast<uint16_t>(n));
} else if (n <= std::numeric_limits<uint32_t>::max()) { } else if (n <= std::numeric_limits<uint32_t>::max()) {
m_buffer->push_back(format::ARRAY32); m_buf.push_back(format::ARRAY32);
push_big_endian(static_cast<uint32_t>(n)); push_big_endian(static_cast<uint32_t>(n));
} else { } else {
return std::unexpected(error_code::overflow); return std::unexpected(error_code::overflow);
@@ -312,12 +309,12 @@ public:
pack_result pack_map(size_t n) { pack_result pack_map(size_t n) {
if (n <= 15) { if (n <= 15) {
m_buffer->push_back(format::FIXMAP_MIN | static_cast<uint8_t>(n)); m_buf.push_back(format::FIXMAP_MIN | static_cast<uint8_t>(n));
} else if (n <= std::numeric_limits<uint16_t>::max()) { } else if (n <= std::numeric_limits<uint16_t>::max()) {
m_buffer->push_back(format::MAP16); m_buf.push_back(format::MAP16);
push_big_endian(static_cast<uint16_t>(n)); push_big_endian(static_cast<uint16_t>(n));
} else if (n <= std::numeric_limits<uint32_t>::max()) { } else if (n <= std::numeric_limits<uint32_t>::max()) {
m_buffer->push_back(format::MAP32); m_buf.push_back(format::MAP32);
push_big_endian(static_cast<uint32_t>(n)); push_big_endian(static_cast<uint32_t>(n));
} else { } else {
return std::unexpected(error_code::overflow); return std::unexpected(error_code::overflow);
@@ -330,26 +327,26 @@ public:
auto sz = static_cast<size_t>(std::distance(std::begin(r), std::end(r))); auto sz = static_cast<size_t>(std::distance(std::begin(r), std::end(r)));
switch (sz) { switch (sz) {
case 1: m_buffer->push_back(format::FIXEXT1); break; case 1: m_buf.push_back(format::FIXEXT1); break;
case 2: m_buffer->push_back(format::FIXEXT2); break; case 2: m_buf.push_back(format::FIXEXT2); break;
case 4: m_buffer->push_back(format::FIXEXT4); break; case 4: m_buf.push_back(format::FIXEXT4); break;
case 8: m_buffer->push_back(format::FIXEXT8); break; case 8: m_buf.push_back(format::FIXEXT8); break;
case 16: m_buffer->push_back(format::FIXEXT16); break; case 16: m_buf.push_back(format::FIXEXT16); break;
default: default:
if (sz <= std::numeric_limits<uint8_t>::max()) { if (sz <= std::numeric_limits<uint8_t>::max()) {
m_buffer->push_back(format::EXT8); m_buf.push_back(format::EXT8);
m_buffer->push_back(static_cast<uint8_t>(sz)); m_buf.push_back(static_cast<uint8_t>(sz));
} else if (sz <= std::numeric_limits<uint16_t>::max()) { } else if (sz <= std::numeric_limits<uint16_t>::max()) {
m_buffer->push_back(format::EXT16); m_buf.push_back(format::EXT16);
push_big_endian(static_cast<uint16_t>(sz)); push_big_endian(static_cast<uint16_t>(sz));
} else if (sz <= std::numeric_limits<uint32_t>::max()) { } else if (sz <= std::numeric_limits<uint32_t>::max()) {
m_buffer->push_back(format::EXT32); m_buf.push_back(format::EXT32);
push_big_endian(static_cast<uint32_t>(sz)); push_big_endian(static_cast<uint32_t>(sz));
} else { } else {
return std::unexpected(error_code::overflow); return std::unexpected(error_code::overflow);
} }
} }
m_buffer->push_back(static_cast<uint8_t>(type)); m_buf.push_back(static_cast<uint8_t>(type));
push(r); push(r);
return *this; return *this;
} }
@@ -392,7 +389,8 @@ public:
template <typename T> template <typename T>
requires requires(const T &v) { { T::ext_id } -> std::convertible_to<int8_t>; v.as_tuple(); } requires requires(const T &v) { { T::ext_id } -> std::convertible_to<int8_t>; v.as_tuple(); }
pack_result pack(const T &v) { pack_result pack(const T &v) {
packer inner; uint8_t ext_buf[256];
packer inner(ext_buf, sizeof(ext_buf));
auto r = inner.pack(v.as_tuple()); auto r = inner.pack(v.as_tuple());
if (!r) return r; if (!r) return r;
return pack_ext(T::ext_id, inner.get_payload()); return pack_ext(T::ext_id, inner.get_payload());
@@ -413,7 +411,7 @@ private:
} }
public: public:
const buffer &get_payload() const { return *m_buffer; } const span_writer &get_payload() const { return m_buf; }
}; };
class parser { class parser {

View File

@@ -3,14 +3,14 @@
#include <cstdint> #include <cstdint>
#include <functional> #include <functional>
#include <span> #include <span>
#include <vector> #include "span_writer.h"
struct net_state { struct net_state {
std::array<uint8_t, 6> mac; std::array<uint8_t, 6> mac;
std::array<uint8_t, 4> ip; std::array<uint8_t, 4> ip;
}; };
using net_handler = std::function<std::vector<std::vector<uint8_t>>(std::span<const uint8_t> payload)>; using net_handler = std::function<size_t(std::span<const uint8_t> payload, span_writer &out)>;
bool net_init(); bool net_init();
const net_state& net_get_state(); const net_state& net_get_state();

View File

@@ -0,0 +1,35 @@
#pragma once
#include <cstddef>
#include <cstdint>
#include <cstring>
class span_writer {
uint8_t *m_data;
size_t m_capacity;
size_t m_size = 0;
public:
span_writer(uint8_t *data, size_t capacity) : m_data(data), m_capacity(capacity) {}
void push_back(uint8_t v) {
if (m_size < m_capacity) m_data[m_size++] = v;
}
template <class It>
void insert(uint8_t *, It first, It last) {
while (first != last && m_size < m_capacity)
m_data[m_size++] = *first++;
}
size_t size() const { return m_size; }
size_t capacity() const { return m_capacity; }
bool full() const { return m_size >= m_capacity; }
uint8_t *data() { return m_data; }
const uint8_t *data() const { return m_data; }
uint8_t *begin() { return m_data; }
uint8_t *end() { return m_data + m_size; }
const uint8_t *begin() const { return m_data; }
const uint8_t *end() const { return m_data + m_size; }
};

View File

@@ -1,6 +1,7 @@
#pragma once #pragma once
#include <array> #include <array>
#include <cstdint> #include <cstdint>
#include <span>
#include <string> #include <string>
#include <tuple> #include <tuple>
#include <vector> #include <vector>
@@ -96,18 +97,27 @@ struct DecodedMessage {
std::vector<uint8_t> payload; std::vector<uint8_t> payload;
}; };
inline std::vector<uint8_t> pack_envelope(uint32_t message_id, const std::vector<uint8_t> &payload) { inline size_t pack_envelope_into(span_writer &out, uint32_t message_id, const uint8_t *payload, size_t payload_len) {
uint32_t checksum = halfsiphash::hash32(payload.data(), payload.size(), hash_key); uint32_t checksum = halfsiphash::hash32(payload, payload_len, hash_key);
msgpack::packer p; uint8_t env_buf[512];
p.pack(Envelope{message_id, checksum, payload}); span_writer env_body(env_buf, sizeof(env_buf));
return p.get_payload(); msgpack::packer env_p(env_body);
env_p.pack_array(3);
env_p.pack(message_id);
env_p.pack(checksum);
env_p.pack_bin(std::span<const uint8_t>{payload, payload_len});
msgpack::packer outer(out);
outer.pack_ext(Envelope::ext_id, env_body);
return out.size();
} }
template <typename T> template <typename T>
inline std::vector<uint8_t> encode_response(uint32_t message_id, const T &msg) { inline size_t encode_response_into(span_writer &out, uint32_t message_id, const T &msg) {
msgpack::packer inner; uint8_t inner_buf[256];
msgpack::packer inner(inner_buf, sizeof(inner_buf));
inner.pack(msg); inner.pack(msg);
return pack_envelope(message_id, inner.get_payload()); auto &pl = inner.get_payload();
return pack_envelope_into(out, message_id, pl.data(), pl.size());
} }
inline msgpack::result<DecodedMessage> try_decode(const uint8_t *data, size_t len) { inline msgpack::result<DecodedMessage> try_decode(const uint8_t *data, size_t len) {
@@ -154,8 +164,6 @@ inline msgpack::result<T> decode_response(const uint8_t *data, size_t len) {
return out; return out;
} }
inline std::vector<uint8_t> encode_request(uint32_t message_id, const auto &msg) { inline size_t encode_request_into(span_writer &out, uint32_t message_id, const auto &msg) {
msgpack::packer inner; return encode_response_into(out, message_id, msg);
inner.pack(msg);
return pack_envelope(message_id, inner.get_payload());
} }

View File

@@ -22,20 +22,21 @@ void dispatch_schedule_ms(uint32_t ms, std::function<void()> fn) {
} }
[[noreturn]] void dispatch_run(std::span<const handler_entry> handlers) { [[noreturn]] void dispatch_run(std::span<const handler_entry> handlers) {
std::unordered_map<int8_t, std::vector<std::vector<uint8_t>> (*)(uint32_t, std::span<const uint8_t>)> handler_map; std::unordered_map<int8_t, handler_fn> handler_map;
for (auto& entry : handlers) { for (auto& entry : handlers) {
handler_map[entry.type_id] = entry.handle; handler_map[entry.type_id] = entry.handle;
} }
static usb_cdc usb; static usb_cdc usb;
static static_vector<uint8_t, 256> usb_rx_buf; static static_vector<uint8_t, 256> usb_rx_buf;
static uint8_t tx_buf[1514];
net_set_handler([&](std::span<const uint8_t> payload) -> std::vector<std::vector<uint8_t>> { net_set_handler([&](std::span<const uint8_t> payload, span_writer &out) -> size_t {
auto msg = try_decode(payload.data(), payload.size()); auto msg = try_decode(payload.data(), payload.size());
if (!msg) return {}; if (!msg) return 0;
auto it = handler_map.find(msg->type_id); auto it = handler_map.find(msg->type_id);
if (it == handler_map.end()) return {}; if (it == handler_map.end()) return 0;
return it->second(msg->message_id, msg->payload); return it->second(msg->message_id, msg->payload, out);
}); });
while (true) { while (true) {
@@ -62,13 +63,16 @@ void dispatch_schedule_ms(uint32_t ms, std::function<void()> fn) {
auto it = handler_map.find(msg->type_id); auto it = handler_map.find(msg->type_id);
if (it != handler_map.end()) { if (it != handler_map.end()) {
for (auto& response : it->second(msg->message_id, msg->payload)) { span_writer out(tx_buf, sizeof(tx_buf));
if (response.size() > usb.tx.free()) { size_t resp_len = it->second(msg->message_id, msg->payload, out);
auto err = encode_response(msg->message_id, if (resp_len > 0) {
DeviceError{2, "response too large: " + std::to_string(response.size())}); if (resp_len > usb.tx.free()) {
usb.send(err); span_writer err_out(tx_buf, sizeof(tx_buf));
size_t err_len = encode_response_into(err_out, msg->message_id,
DeviceError{2, "response too large: " + std::to_string(resp_len)});
usb.send(std::span<const uint8_t>{tx_buf, err_len});
} else { } else {
usb.send(response); usb.send(std::span<const uint8_t>{tx_buf, resp_len});
} }
} }
} }

View File

@@ -5,12 +5,12 @@
#include "net.h" #include "net.h"
#include "debug_log.h" #include "debug_log.h"
std::vector<std::vector<uint8_t>> handle_picoboot(uint32_t message_id, std::span<const uint8_t>) { size_t handle_picoboot(uint32_t message_id, std::span<const uint8_t>, span_writer &out) {
dispatch_schedule_ms(100, []{ reset_usb_boot(0, 1); }); dispatch_schedule_ms(100, []{ reset_usb_boot(0, 1); });
return {encode_response(message_id, ResponsePICOBOOT{})}; return encode_response_into(out, message_id, ResponsePICOBOOT{});
} }
std::vector<std::vector<uint8_t>> handle_info(uint32_t message_id, std::span<const uint8_t>) { size_t handle_info(uint32_t message_id, std::span<const uint8_t>, span_writer &out) {
ResponseInfo resp; ResponseInfo resp;
pico_unique_board_id_t uid; pico_unique_board_id_t uid;
pico_get_unique_board_id(&uid); pico_get_unique_board_id(&uid);
@@ -19,12 +19,12 @@ std::vector<std::vector<uint8_t>> handle_info(uint32_t message_id, std::span<con
resp.mac = ns.mac; resp.mac = ns.mac;
resp.ip = ns.ip; resp.ip = ns.ip;
resp.firmware_name = firmware_name; resp.firmware_name = firmware_name;
return {encode_response(message_id, resp)}; return encode_response_into(out, message_id, resp);
} }
std::vector<std::vector<uint8_t>> handle_log(uint32_t message_id, std::span<const uint8_t>) { size_t handle_log(uint32_t message_id, std::span<const uint8_t>, span_writer &out) {
ResponseLog resp; ResponseLog resp;
for (auto& e : dlog_drain()) for (auto& e : dlog_drain())
resp.entries.push_back(LogEntry{e.timestamp_us, std::move(e.message)}); resp.entries.push_back(LogEntry{e.timestamp_us, std::move(e.message)});
return {encode_response(message_id, resp)}; return encode_response_into(out, message_id, resp);
} }

View File

@@ -139,6 +139,8 @@ static void handle_arp(const uint8_t* frame, size_t len) {
send_raw(&reply, sizeof(reply)); send_raw(&reply, sizeof(reply));
} }
static uint8_t tx_buf[1514];
static void handle_udp(const uint8_t* frame, size_t len) { static void handle_udp(const uint8_t* frame, size_t len) {
if (len < sizeof(udp_header)) return; if (len < sizeof(udp_header)) return;
auto& pkt = *reinterpret_cast<const udp_header*>(frame); auto& pkt = *reinterpret_cast<const udp_header*>(frame);
@@ -154,16 +156,14 @@ static void handle_udp(const uint8_t* frame, size_t len) {
auto* payload = frame + sizeof(udp_header); auto* payload = frame + sizeof(udp_header);
size_t payload_len = udp_len - 8; size_t payload_len = udp_len - 8;
auto responses = msg_handler(std::span<const uint8_t>{payload, payload_len}); span_writer resp(tx_buf + sizeof(udp_header), sizeof(tx_buf) - sizeof(udp_header));
size_t resp_len = msg_handler(std::span<const uint8_t>{payload, payload_len}, resp);
if (resp_len == 0) return;
for (auto& resp : responses) { size_t ip_total = 20 + 8 + resp_len;
uint8_t reply_buf[1514];
size_t udp_data_len = resp.size();
size_t ip_total = 20 + 8 + udp_data_len;
size_t reply_len = sizeof(eth_header) + ip_total; size_t reply_len = sizeof(eth_header) + ip_total;
if (reply_len > sizeof(reply_buf)) continue;
auto& rip = *reinterpret_cast<ipv4_header*>(reply_buf); auto& rip = *reinterpret_cast<ipv4_header*>(tx_buf);
rip.eth.dst = pkt.ip.eth.src; rip.eth.dst = pkt.ip.eth.src;
rip.eth.src = state.mac; rip.eth.src = state.mac;
rip.eth.ethertype = ETH_IPV4; rip.eth.ethertype = ETH_IPV4;
@@ -179,16 +179,13 @@ static void handle_udp(const uint8_t* frame, size_t len) {
rip.dst = pkt.ip.src; rip.dst = pkt.ip.src;
rip.checksum = ip_checksum(rip.ip_start(), 20); rip.checksum = ip_checksum(rip.ip_start(), 20);
auto& rudp = *reinterpret_cast<udp_header*>(reply_buf); auto& rudp = *reinterpret_cast<udp_header*>(tx_buf);
rudp.src_port = PICOMAP_PORT; rudp.src_port = PICOMAP_PORT;
rudp.dst_port = pkt.src_port; rudp.dst_port = pkt.src_port;
rudp.length = __builtin_bswap16(8 + udp_data_len); rudp.length = __builtin_bswap16(8 + resp_len);
rudp.checksum = 0; rudp.checksum = 0;
memcpy(reply_buf + sizeof(udp_header), resp.data(), udp_data_len); send_raw(tx_buf, reply_len);
send_raw(reply_buf, reply_len);
}
} }
static void handle_icmp(const uint8_t* frame, size_t len) { static void handle_icmp(const uint8_t* frame, size_t len) {

View File

@@ -29,8 +29,10 @@ static ResponseTest test_discovery() {
ResponseTest resp; ResponseTest resp;
resp.pass = true; resp.pass = true;
auto req = encode_request(0, RequestInfo{}); uint8_t req_buf[256];
auto send_result = w6300::send(test_socket, std::span<const uint8_t>{req}); span_writer req_out(req_buf, sizeof(req_buf));
size_t req_len = encode_request_into(req_out, 0, RequestInfo{});
auto send_result = w6300::send(test_socket, std::span<const uint8_t>{req_buf, req_len});
if (!send_result) { if (!send_result) {
resp.pass = false; resp.pass = false;
resp.messages.push_back("send: error " + std::to_string(static_cast<int>(send_result.error()))); resp.messages.push_back("send: error " + std::to_string(static_cast<int>(send_result.error())));
@@ -95,13 +97,11 @@ static const std::unordered_map<std::string_view, test_fn> tests = {
{"discovery", test_discovery}, {"discovery", test_discovery},
}; };
static std::vector<std::vector<uint8_t>> handle_test(uint32_t message_id, const RequestTest& req) { static size_t handle_test(uint32_t message_id, const RequestTest& req, span_writer &out) {
auto it = tests.find(req.name); auto it = tests.find(req.name);
if (it == tests.end()) { if (it == tests.end())
return {encode_response(message_id, ResponseTest{false, {"unknown test: " + req.name}})}; return encode_response_into(out, message_id, ResponseTest{false, {"unknown test: " + req.name}});
} return encode_response_into(out, message_id, it->second());
return {encode_response(message_id, it->second())};
} }
static constexpr handler_entry handlers[] = { static constexpr handler_entry handlers[] = {