diff --git a/cmd/picomap/main.go b/cmd/picomap/main.go index 3d0b8db..899acfc 100644 --- a/cmd/picomap/main.go +++ b/cmd/picomap/main.go @@ -2,6 +2,7 @@ package main import ( "encoding/hex" + "flag" "fmt" "log/slog" "net" @@ -26,7 +27,7 @@ func main() { var err error switch cmd { case "info": - err = cmdInfo() + err = cmdInfo(args) case "load": err = cmdLoad(args) case "log": @@ -49,7 +50,47 @@ type deviceResult struct { err error } -func cmdInfo() error { +func infoFromClient(dev string, c *client.Client) (*client.ResponseInfo, error) { + log := slog.With("dev", dev) + log.Info("requesting info") + info, err := c.Info() + if err != nil { + return nil, err + } + log.Info("got info", "firmware", info.FirmwareName) + return info, nil +} + +func printInfo(dev string, info *client.ResponseInfo) { + slog.Info("device", + "dev", dev, + "board_id", hex.EncodeToString(info.BoardID[:]), + "mac", net.HardwareAddr(info.MAC[:]).String(), + "ip", net.IP(info.IP[:]).String(), + "firmware", info.FirmwareName) +} + +func cmdInfo(args []string) error { + fs := flag.NewFlagSet("info", flag.ExitOnError) + udpAddr := fs.String("udp", "", "connect via UDP to this IP address") + fs.Parse(args) + + if *udpAddr != "" { + log := slog.With("addr", *udpAddr) + log.Info("connecting via UDP") + c, err := client.NewUDP(*udpAddr, 2*time.Second) + if err != nil { + return err + } + defer c.Close() + info, err := infoFromClient(*udpAddr, c) + if err != nil { + return err + } + printInfo(*udpAddr, info) + return nil + } + devs, err := client.ListSerial() if err != nil { return err @@ -62,22 +103,19 @@ func cmdInfo() error { var wg sync.WaitGroup for i, dev := range devs { results[i].dev = dev - log := slog.With("dev", dev) wg.Go(func() { - log.Info("connecting") + slog.Info("connecting", "dev", dev) c, err := client.NewSerial(dev, 2*time.Second) if err != nil { results[i].err = err return } - log.Info("requesting info") - info, err := c.Info() + info, err := infoFromClient(dev, c) c.Close() if err != nil { results[i].err = err return } - log.Info("got info", "firmware", info.FirmwareName) results[i].info = info }) } @@ -88,12 +126,7 @@ func cmdInfo() error { slog.Error("device error", "dev", r.dev, "err", r.err) continue } - slog.Info("device", - "dev", r.dev, - "board_id", hex.EncodeToString(r.info.BoardID[:]), - "mac", net.HardwareAddr(r.info.MAC[:]).String(), - "ip", net.IP(r.info.IP[:]).String(), - "firmware", r.info.FirmwareName) + printInfo(r.dev, r.info) } return nil diff --git a/firmware/include/net.h b/firmware/include/net.h index e2a2327..15551f0 100644 --- a/firmware/include/net.h +++ b/firmware/include/net.h @@ -1,12 +1,18 @@ #pragma once #include #include +#include +#include +#include struct net_state { std::array mac; std::array ip; }; +using net_handler = std::function>(std::span payload)>; + bool net_init(); const net_state& net_get_state(); +void net_set_handler(net_handler handler); void net_poll(); diff --git a/firmware/lib/dispatch.cpp b/firmware/lib/dispatch.cpp index bb54f82..b303f88 100644 --- a/firmware/lib/dispatch.cpp +++ b/firmware/lib/dispatch.cpp @@ -30,6 +30,14 @@ void dispatch_schedule_ms(uint32_t ms, std::function fn) { static usb_cdc usb; static static_vector usb_rx_buf; + net_set_handler([&](std::span payload) -> std::vector> { + auto msg = try_decode(payload.data(), payload.size()); + if (!msg) return {}; + auto it = handler_map.find(msg->type_id); + if (it == handler_map.end()) return {}; + return it->second(msg->message_id, msg->payload); + }); + while (true) { dlog_if_slow("tud_task", 1000, [&]{ tud_task(); }); dlog_if_slow("drain", 1000, [&]{ usb.drain(); }); diff --git a/firmware/lib/net.cpp b/firmware/lib/net.cpp index 73a65b1..b5ab3fa 100644 --- a/firmware/lib/net.cpp +++ b/firmware/lib/net.cpp @@ -49,6 +49,15 @@ struct __attribute__((packed)) ipv4_header { }; static_assert(sizeof(ipv4_header) == 34); +struct __attribute__((packed)) udp_header { + ipv4_header ip; + uint16_t src_port; + uint16_t dst_port; + uint16_t length; + uint16_t checksum; +}; +static_assert(sizeof(udp_header) == 42); + struct __attribute__((packed)) icmp_echo { uint8_t type; uint8_t code; @@ -64,12 +73,14 @@ static constexpr uint16_t ARP_HTYPE_ETH = __builtin_bswap16(1); static constexpr uint16_t ARP_PTYPE_IPV4 = __builtin_bswap16(0x0800); static constexpr uint16_t ARP_OP_REQUEST = __builtin_bswap16(1); static constexpr uint16_t ARP_OP_REPLY = __builtin_bswap16(2); +static constexpr uint16_t PICOMAP_PORT = __builtin_bswap16(28781); static constexpr mac_addr MAC_BROADCAST = {0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF}; static constexpr ip4_addr IP_BROADCAST_ALL = {255, 255, 255, 255}; static constexpr ip4_addr IP_BROADCAST_SUBNET = {169, 254, 255, 255}; static net_state state; static w6300::socket_id raw_socket{0}; +static net_handler msg_handler; static uint16_t ip_checksum(const void* data, size_t len) { auto p = static_cast(data); @@ -130,6 +141,58 @@ static void handle_arp(const uint8_t* frame, size_t len) { send_raw(&reply, sizeof(reply)); } +static void handle_udp(const uint8_t* frame, size_t len) { + if (len < sizeof(udp_header)) return; + auto& pkt = *reinterpret_cast(frame); + + if (pkt.dst_port != PICOMAP_PORT) return; + if (!ip_match_or_broadcast(pkt.ip.dst)) return; + if (!msg_handler) return; + + size_t udp_len = __builtin_bswap16(pkt.length); + if (udp_len < 8) return; + if (sizeof(eth_header) + pkt.ip.ip_total_len() < sizeof(udp_header) + udp_len - 8) return; + + auto* payload = frame + sizeof(udp_header); + size_t payload_len = udp_len - 8; + + auto responses = msg_handler(std::span{payload, payload_len}); + + for (auto& resp : responses) { + uint8_t reply_buf[1514]; + size_t udp_data_len = resp.size(); + size_t ip_total = 20 + 8 + udp_data_len; + size_t reply_len = sizeof(eth_header) + ip_total; + if (reply_len > sizeof(reply_buf)) continue; + + auto& rip = *reinterpret_cast(reply_buf); + rip.eth.dst = pkt.ip.eth.src; + rip.eth.src = state.mac; + rip.eth.ethertype = ETH_IPV4; + rip.ver_ihl = 0x45; + rip.dscp_ecn = 0; + rip.total_len = __builtin_bswap16(ip_total); + rip.identification = 0; + rip.flags_frag = 0; + rip.ttl = 64; + rip.protocol = 17; + rip.checksum = 0; + rip.src = state.ip; + rip.dst = pkt.ip.src; + rip.checksum = ip_checksum(rip.ip_start(), 20); + + auto& rudp = *reinterpret_cast(reply_buf); + rudp.src_port = PICOMAP_PORT; + rudp.dst_port = pkt.src_port; + rudp.length = __builtin_bswap16(8 + udp_data_len); + rudp.checksum = 0; + + memcpy(reply_buf + sizeof(udp_header), resp.data(), udp_data_len); + + send_raw(reply_buf, reply_len); + } +} + static void handle_icmp(const uint8_t* frame, size_t len) { auto& ip = *reinterpret_cast(frame); size_t ip_hdr_len = ip.ip_header_len(); @@ -175,6 +238,9 @@ static void handle_ipv4(const uint8_t* frame, size_t len) { case 1: handle_icmp(frame, len); break; + case 17: + handle_udp(frame, len); + break; } } @@ -225,6 +291,10 @@ const net_state& net_get_state() { return state; } +void net_set_handler(net_handler handler) { + msg_handler = std::move(handler); +} + void net_poll() { if (w6300::get_socket_recv_buf(raw_socket) == 0) return; static uint8_t rx_buf[1518]; diff --git a/lib/client/udp.go b/lib/client/udp.go new file mode 100644 index 0000000..daf89d3 --- /dev/null +++ b/lib/client/udp.go @@ -0,0 +1,55 @@ +package client + +import ( + "bytes" + "fmt" + "io" + "net" + "time" +) + +const PicomapPort = 28781 + +type udpTransport struct { + conn *net.UDPConn + buf bytes.Buffer +} + +func NewUDP(addr string, timeout time.Duration) (*Client, error) { + raddr, err := net.ResolveUDPAddr("udp4", fmt.Sprintf("%s:%d", addr, PicomapPort)) + if err != nil { + return nil, fmt.Errorf("resolve %s: %w", addr, err) + } + conn, err := net.DialUDP("udp4", nil, raddr) + if err != nil { + return nil, fmt.Errorf("dial %s: %w", addr, err) + } + return &Client{transport: &udpTransport{conn: conn}, timeout: timeout}, nil +} + +func (t *udpTransport) Send(data []byte) error { + _, err := t.conn.Write(data) + return err +} + +func (t *udpTransport) SetReadTimeout(timeout time.Duration) { + t.conn.SetReadDeadline(time.Now().Add(timeout)) +} + +func (t *udpTransport) Reader() io.Reader { + for { + if t.buf.Len() > 0 { + return &t.buf + } + pkt := make([]byte, 1500) + n, err := t.conn.Read(pkt) + if err != nil { + return &t.buf + } + t.buf.Write(pkt[:n]) + } +} + +func (t *udpTransport) Close() error { + return t.conn.Close() +}