From e301c672a9dd20e292643177c061bf856fa7f6be Mon Sep 17 00:00:00 2001 From: Ian Gulliver Date: Tue, 7 Apr 2026 22:12:20 +0900 Subject: [PATCH] Broadcast discovery with InfoAll, interface broadcast detection, clean output --- cmd/picomap/main.go | 50 ++++++++++++---------------- lib/client/client.go | 79 ++++++++++++++++++++++++++------------------ lib/client/serial.go | 2 ++ lib/client/udp.go | 68 +++++++++++++++++++++++++------------- 4 files changed, 115 insertions(+), 84 deletions(-) diff --git a/cmd/picomap/main.go b/cmd/picomap/main.go index 90c9164..6d595eb 100644 --- a/cmd/picomap/main.go +++ b/cmd/picomap/main.go @@ -50,17 +50,6 @@ type deviceResult struct { err error } -func infoFromClient(dev string, c *client.Client) (*client.ResponseInfo, error) { - log := slog.With("dev", dev) - log.Info("requesting info") - info, err := c.Info() - if err != nil { - return nil, err - } - log.Info("got info", "firmware", info.FirmwareName) - return info, nil -} - func printInfo(dev string, info *client.ResponseInfo) { slog.Info("device", "dev", dev, @@ -76,22 +65,30 @@ func cmdInfo(args []string) error { iface := fs.String("iface", "", "bind to this network interface (for broadcast)") fs.Parse(args) - if *udpAddr != "" { - log := slog.With("addr", *udpAddr) - if *iface != "" { - log = log.With("iface", *iface) + if *udpAddr == "" && *iface != "" { + bcast, err := client.InterfaceBroadcast(*iface) + if err != nil { + return err } - log.Info("connecting via UDP") - c, err := client.NewUDP(*udpAddr, *iface, 2*time.Second) + *udpAddr = bcast + } + + if *udpAddr != "" { + c, err := client.NewUDP(*udpAddr, *iface, 500*time.Millisecond) if err != nil { return err } defer c.Close() - info, err := infoFromClient(*udpAddr, c) + infos, err := c.InfoAll() if err != nil { return err } - printInfo(*udpAddr, info) + if len(infos) == 0 { + return fmt.Errorf("no devices responded") + } + for _, info := range infos { + printInfo(net.IP(info.IP[:]).String(), info) + } return nil } @@ -108,13 +105,12 @@ func cmdInfo(args []string) error { for i, dev := range devs { results[i].dev = dev wg.Go(func() { - slog.Info("connecting", "dev", dev) - c, err := client.NewSerial(dev, 2*time.Second) + c, err := client.NewSerial(dev, 500*time.Millisecond) if err != nil { results[i].err = err return } - info, err := infoFromClient(dev, c) + info, err := c.Info() c.Close() if err != nil { results[i].err = err @@ -177,7 +173,7 @@ func cmdLog(_ []string) error { } for _, dev := range devs { log := slog.With("dev", dev) - c, err := client.NewSerial(dev, 2*time.Second) + c, err := client.NewSerial(dev, 500*time.Millisecond) if err != nil { log.Error("connect error", "err", err) continue @@ -254,8 +250,7 @@ func cmdLoad(args []string) error { for i := range targets { log := slog.With("dev", devs[i]) wg.Go(func() { - log.Info("connecting for info") - c, err := client.NewSerial(devs[i], 2*time.Second) + c, err := client.NewSerial(devs[i], 500*time.Millisecond) if err != nil { errs[i] = err return @@ -286,7 +281,7 @@ func cmdLoad(args []string) error { log := slog.With("serial", devices[i].serial) wg.Go(func() { log.Info("sending PICOBOOT") - c, err := client.NewSerial(devices[i].dev, 2*time.Second) + c, err := client.NewSerial(devices[i].dev, 500*time.Millisecond) if err != nil { errs[i] = err return @@ -347,8 +342,7 @@ func cmdTest(args []string) error { var testDev string for _, dev := range devs { log := slog.With("dev", dev) - log.Info("connecting for info") - c, err := client.NewSerial(dev, 2*time.Second) + c, err := client.NewSerial(dev, 500*time.Millisecond) if err != nil { log.Warn("connect error", "err", err) continue diff --git a/lib/client/client.go b/lib/client/client.go index 24e5476..18bc917 100644 --- a/lib/client/client.go +++ b/lib/client/client.go @@ -16,6 +16,7 @@ type transport interface { Send(data []byte) error SetReadTimeout(timeout time.Duration) Reader() io.Reader + Broadcast() bool Close() error } @@ -44,59 +45,71 @@ func (c *Client) send(msg any) (uint32, error) { return id, c.transport.Send(data) } -func (c *Client) receive(expectedID uint32) (any, error) { - c.transport.SetReadTimeout(c.timeout) - dec := msgpack.NewDecoder(c.transport.Reader()) - var env Envelope - if err := dec.Decode(&env); err != nil { - return nil, fmt.Errorf("decode envelope: %w", err) - } - if env.MessageID != expectedID { - return nil, fmt.Errorf("message id mismatch: got %d, want %d", env.MessageID, expectedID) - } - expected := halfsiphash.Sum32(env.Payload, HashKey) - if env.Checksum != expected { - return nil, fmt.Errorf("checksum mismatch: got %08x, want %08x", env.Checksum, expected) - } - var inner any - if err := msgpack.Unmarshal(env.Payload, &inner); err != nil { - return nil, fmt.Errorf("decode inner: %w", err) - } - if devErr, ok := inner.(*DeviceError); ok { - return nil, devErr - } - return inner, nil -} - -func roundTrip[T any](c *Client, req any) (*T, error) { +func roundTrip[T any](c *Client, req any) ([]*T, error) { id, err := c.send(req) if err != nil { return nil, err } - resp, err := c.receive(id) + c.transport.SetReadTimeout(c.timeout) + dec := msgpack.NewDecoder(c.transport.Reader()) + broadcast := c.transport.Broadcast() + var results []*T + for { + var env Envelope + if err := dec.Decode(&env); err != nil { + break + } + if env.MessageID != id { + continue + } + expected := halfsiphash.Sum32(env.Payload, HashKey) + if env.Checksum != expected { + continue + } + var inner any + if err := msgpack.Unmarshal(env.Payload, &inner); err != nil { + continue + } + if devErr, ok := inner.(*DeviceError); ok { + return nil, devErr + } + if typed, ok := inner.(*T); ok { + results = append(results, typed) + if !broadcast { + break + } + } + } + return results, nil +} + +func first[T any](results []*T, err error) (*T, error) { if err != nil { return nil, err } - typed, ok := resp.(*T) - if !ok { - return nil, fmt.Errorf("unexpected response: %T", resp) + if len(results) == 0 { + return nil, fmt.Errorf("no response") } - return typed, nil + return results[0], nil } func (c *Client) PICOBOOT() error { - _, err := roundTrip[ResponsePICOBOOT](c, &RequestPICOBOOT{}) + _, err := first(roundTrip[ResponsePICOBOOT](c, &RequestPICOBOOT{})) return err } func (c *Client) Info() (*ResponseInfo, error) { + return first(roundTrip[ResponseInfo](c, &RequestInfo{})) +} + +func (c *Client) InfoAll() ([]*ResponseInfo, error) { return roundTrip[ResponseInfo](c, &RequestInfo{}) } func (c *Client) Log() (*ResponseLog, error) { - return roundTrip[ResponseLog](c, &RequestLog{}) + return first(roundTrip[ResponseLog](c, &RequestLog{})) } func (c *Client) Test(name string) (*ResponseTest, error) { - return roundTrip[ResponseTest](c, &RequestTest{Name: name}) + return first(roundTrip[ResponseTest](c, &RequestTest{Name: name})) } diff --git a/lib/client/serial.go b/lib/client/serial.go index 570aa1c..9ac0503 100644 --- a/lib/client/serial.go +++ b/lib/client/serial.go @@ -61,6 +61,8 @@ func (t *serialTransport) Reader() io.Reader { return t.port } +func (t *serialTransport) Broadcast() bool { return false } + func (t *serialTransport) Close() error { return t.port.Close() } diff --git a/lib/client/udp.go b/lib/client/udp.go index f09d18b..348a2ee 100644 --- a/lib/client/udp.go +++ b/lib/client/udp.go @@ -14,36 +14,46 @@ import ( const PicomapPort = 28781 type udpTransport struct { - conn *net.UDPConn - addr *net.UDPAddr - buf bytes.Buffer + conn *net.UDPConn + addr *net.UDPAddr + broadcast bool + buf bytes.Buffer } -func isBroadcast(ip net.IP) bool { - if ip.Equal(net.IPv4bcast) { - return true - } - ip4 := ip.To4() - return ip4 != nil && ip4[3] == 255 -} - -func interfaceIPv4(name string) (net.IP, error) { +func interfaceIPv4Net(name string) (net.IP, *net.IPNet, error) { ifi, err := net.InterfaceByName(name) if err != nil { - return nil, fmt.Errorf("interface %s: %w", name, err) + return nil, nil, fmt.Errorf("interface %s: %w", name, err) } addrs, err := ifi.Addrs() if err != nil { - return nil, fmt.Errorf("interface %s addrs: %w", name, err) + return nil, nil, fmt.Errorf("interface %s addrs: %w", name, err) } for _, a := range addrs { if ipnet, ok := a.(*net.IPNet); ok { if ip4 := ipnet.IP.To4(); ip4 != nil { - return ip4, nil + return ip4, ipnet, nil } } } - return nil, fmt.Errorf("interface %s has no IPv4 address", name) + return nil, nil, fmt.Errorf("interface %s has no IPv4 address", name) +} + +func broadcastAddr(ip net.IP, mask net.IPMask) net.IP { + bcast := make(net.IP, 4) + ip4 := ip.To4() + for i := range 4 { + bcast[i] = ip4[i] | ^mask[i] + } + return bcast +} + +func InterfaceBroadcast(name string) (string, error) { + ip, ipnet, err := interfaceIPv4Net(name) + if err != nil { + return "", err + } + return broadcastAddr(ip, ipnet.Mask).String(), nil } func enableBroadcast(conn *net.UDPConn) error { @@ -58,6 +68,17 @@ func enableBroadcast(conn *net.UDPConn) error { return serr } +func isInterfaceBroadcast(iface string, ip net.IP) bool { + if iface == "" { + return false + } + localIP, ipnet, err := interfaceIPv4Net(iface) + if err != nil { + return false + } + return ip.Equal(broadcastAddr(localIP, ipnet.Mask)) || ip.Equal(net.IPv4bcast) +} + func NewUDP(addr string, iface string, timeout time.Duration) (*Client, error) { raddr, err := net.ResolveUDPAddr("udp4", fmt.Sprintf("%s:%d", addr, PicomapPort)) if err != nil { @@ -66,7 +87,7 @@ func NewUDP(addr string, iface string, timeout time.Duration) (*Client, error) { var laddr *net.UDPAddr if iface != "" { - ip, err := interfaceIPv4(iface) + ip, _, err := interfaceIPv4Net(iface) if err != nil { return nil, err } @@ -78,14 +99,13 @@ func NewUDP(addr string, iface string, timeout time.Duration) (*Client, error) { return nil, fmt.Errorf("listen: %w", err) } - if isBroadcast(raddr.IP) { - if err := enableBroadcast(conn); err != nil { - conn.Close() - return nil, fmt.Errorf("SO_BROADCAST: %w", err) - } + if err := enableBroadcast(conn); err != nil { + conn.Close() + return nil, fmt.Errorf("SO_BROADCAST: %w", err) } - return &Client{transport: &udpTransport{conn: conn, addr: raddr}, timeout: timeout}, nil + bcast := isInterfaceBroadcast(iface, raddr.IP) + return &Client{transport: &udpTransport{conn: conn, addr: raddr, broadcast: bcast}, timeout: timeout}, nil } func (t *udpTransport) Send(data []byte) error { @@ -111,6 +131,8 @@ func (t *udpTransport) Reader() io.Reader { } } +func (t *udpTransport) Broadcast() bool { return t.broadcast } + func (t *udpTransport) Close() error { return t.conn.Close() }