diff --git a/cmd/info/main.go b/cmd/info/main.go index 97ff77c..3a627a5 100644 --- a/cmd/info/main.go +++ b/cmd/info/main.go @@ -6,7 +6,6 @@ import ( "time" "github.com/theater/picomap/lib/client" - "github.com/theater/picomap/lib/picoserial" ) func main() { @@ -17,20 +16,20 @@ func main() { } func run() error { - dev, err := picoserial.FindDevice() + devs, err := client.ListSerial() if err != nil { return err } - if dev == "" { + if len(devs) == 0 { return fmt.Errorf("no device found") } + dev := devs[0] fmt.Printf("Device: %s\n", dev) - t, err := picoserial.Open(dev) + c, err := client.NewSerial(dev, 2*time.Second) if err != nil { return err } - c := client.New(t, 2*time.Second) info, err := c.Info() c.Close() if err != nil { diff --git a/cmd/load/main.go b/cmd/load/main.go index f0b7274..541deaf 100644 --- a/cmd/load/main.go +++ b/cmd/load/main.go @@ -8,7 +8,6 @@ import ( "time" "github.com/theater/picomap/lib/client" - "github.com/theater/picomap/lib/picoserial" "github.com/theater/picomap/lib/picotool" ) @@ -35,17 +34,17 @@ func run(buildDir string) error { return fmt.Errorf("build failed: %w", err) } - dev, err := picoserial.FindDevice() + devs, err := client.ListSerial() if err != nil { return err } - if dev != "" { + if len(devs) > 0 { + dev := devs[0] fmt.Printf("Sending PICOBOOT request to %s...\n", dev) - t, err := picoserial.Open(dev) + c, err := client.NewSerial(dev, 2*time.Second) if err != nil { return err } - c := client.New(t, 2*time.Second) err = c.PICOBOOT() c.Close() if err != nil { diff --git a/include/device.h b/include/device.h deleted file mode 100644 index 5c31059..0000000 --- a/include/device.h +++ /dev/null @@ -1,51 +0,0 @@ -#pragma once -#include -#include -#include "msgpackpp.h" -#include "halfsiphash.h" -#include "static_vector.h" -#include "protocol.h" - -static constexpr uint8_t hash_key[8] = {}; - -struct DecodedMessage { - uint32_t message_id; - int8_t type_id; -}; - -inline std::vector pack_envelope(uint32_t message_id, const std::vector &payload) { - uint32_t checksum = halfsiphash::hash32(payload.data(), payload.size(), hash_key); - msgpackpp::packer p; - p.pack(Envelope{message_id, checksum, payload}); - return p.get_payload(); -} - -template -inline std::vector encode_response(uint32_t message_id, const T &msg) { - msgpackpp::packer inner; - inner.pack(msg); - return pack_envelope(message_id, inner.get_payload()); -} - -inline msgpackpp::result try_decode(const uint8_t *data, size_t len) { - msgpackpp::parser p(data, static_cast(len)); - - Envelope env; - auto r = msgpackpp::unpack(p, env); - if (!r) return std::unexpected(r.error()); - - uint32_t expected = halfsiphash::hash32(env.payload.data(), env.payload.size(), hash_key); - if (env.checksum != expected) return std::unexpected(msgpackpp::error_code::invalid); - - msgpackpp::parser inner(env.payload.data(), static_cast(env.payload.size())); - if (!inner.is_ext()) return std::unexpected(msgpackpp::error_code::type_error); - auto ext = inner.get_ext(); - if (!ext) return std::unexpected(ext.error()); - - return DecodedMessage{env.message_id, std::get<0>(*ext)}; -} - -template -inline msgpackpp::result try_decode(const static_vector &buf) { - return try_decode(buf.data(), buf.size()); -} diff --git a/include/msgpackpp.h b/include/msgpack.h similarity index 99% rename from include/msgpackpp.h rename to include/msgpack.h index f94e254..5132f7a 100644 --- a/include/msgpackpp.h +++ b/include/msgpack.h @@ -12,7 +12,7 @@ #include #include -namespace msgpackpp { +namespace msgpack { enum class error_code { overflow, @@ -789,4 +789,4 @@ result unpack(const parser &p, T &out) { return p.next(); } -} // namespace msgpackpp +} // namespace msgpack diff --git a/include/protocol.h b/include/protocol.h deleted file mode 100644 index 3611a26..0000000 --- a/include/protocol.h +++ /dev/null @@ -1,49 +0,0 @@ -#pragma once -#include -#include -#include -#include -#include - -struct Envelope { - static constexpr int8_t ext_id = 0; - uint32_t message_id; - uint32_t checksum; - std::vector payload; - auto as_tuple() const { return std::tie(message_id, checksum, payload); } - auto as_tuple() { return std::tie(message_id, checksum, payload); } -}; - -struct DeviceError { - static constexpr int8_t ext_id = 1; - uint32_t code; - std::string message; - auto as_tuple() const { return std::tie(code, message); } - auto as_tuple() { return std::tie(code, message); } -}; - -struct RequestPICOBOOT { - static constexpr int8_t ext_id = 2; - auto as_tuple() const { return std::tie(); } - auto as_tuple() { return std::tie(); } -}; - -struct ResponsePICOBOOT { - static constexpr int8_t ext_id = 3; - auto as_tuple() const { return std::tie(); } - auto as_tuple() { return std::tie(); } -}; - -struct RequestInfo { - static constexpr int8_t ext_id = 4; - auto as_tuple() const { return std::tie(); } - auto as_tuple() { return std::tie(); } -}; - -struct ResponseInfo { - static constexpr int8_t ext_id = 5; - std::array board_id; - std::array mac; - auto as_tuple() const { return std::tie(board_id, mac); } - auto as_tuple() { return std::tie(board_id, mac); } -}; diff --git a/include/wire.h b/include/wire.h new file mode 100644 index 0000000..703b268 --- /dev/null +++ b/include/wire.h @@ -0,0 +1,96 @@ +#pragma once +#include +#include +#include +#include +#include +#include "msgpack.h" +#include "halfsiphash.h" +#include "static_vector.h" + +struct Envelope { + static constexpr int8_t ext_id = 0; + uint32_t message_id; + uint32_t checksum; + std::vector payload; + auto as_tuple() const { return std::tie(message_id, checksum, payload); } + auto as_tuple() { return std::tie(message_id, checksum, payload); } +}; + +struct DeviceError { + static constexpr int8_t ext_id = 1; + uint32_t code; + std::string message; + auto as_tuple() const { return std::tie(code, message); } + auto as_tuple() { return std::tie(code, message); } +}; + +struct RequestPICOBOOT { + static constexpr int8_t ext_id = 2; + auto as_tuple() const { return std::tie(); } + auto as_tuple() { return std::tie(); } +}; + +struct ResponsePICOBOOT { + static constexpr int8_t ext_id = 3; + auto as_tuple() const { return std::tie(); } + auto as_tuple() { return std::tie(); } +}; + +struct RequestInfo { + static constexpr int8_t ext_id = 4; + auto as_tuple() const { return std::tie(); } + auto as_tuple() { return std::tie(); } +}; + +struct ResponseInfo { + static constexpr int8_t ext_id = 5; + std::array board_id; + std::array mac; + auto as_tuple() const { return std::tie(board_id, mac); } + auto as_tuple() { return std::tie(board_id, mac); } +}; + +static constexpr uint8_t hash_key[8] = {}; + +struct DecodedMessage { + uint32_t message_id; + int8_t type_id; +}; + +inline std::vector pack_envelope(uint32_t message_id, const std::vector &payload) { + uint32_t checksum = halfsiphash::hash32(payload.data(), payload.size(), hash_key); + msgpack::packer p; + p.pack(Envelope{message_id, checksum, payload}); + return p.get_payload(); +} + +template +inline std::vector encode_response(uint32_t message_id, const T &msg) { + msgpack::packer inner; + inner.pack(msg); + return pack_envelope(message_id, inner.get_payload()); +} + +inline msgpack::result try_decode(const uint8_t *data, size_t len) { + msgpack::parser p(data, static_cast(len)); + + Envelope env; + auto r = msgpack::unpack(p, env); + if (!r) return std::unexpected(r.error()); + + uint32_t expected = halfsiphash::hash32(env.payload.data(), env.payload.size(), hash_key); + if (env.checksum != expected) return std::unexpected(msgpack::error_code::invalid); + + msgpack::parser inner(env.payload.data(), static_cast(env.payload.size())); + if (!inner.is_ext()) return std::unexpected(msgpack::error_code::type_error); + auto ext = inner.get_ext(); + if (!ext) return std::unexpected(ext.error()); + + return DecodedMessage{env.message_id, std::get<0>(*ext)}; +} + +template +inline msgpack::result try_decode(const static_vector &buf) { + return try_decode(buf.data(), buf.size()); +} diff --git a/lib/client/client.go b/lib/client/client.go index 53ebc6a..e45f80f 100644 --- a/lib/client/client.go +++ b/lib/client/client.go @@ -7,19 +7,22 @@ import ( "github.com/theater/picomap/lib/halfsiphash" "github.com/theater/picomap/lib/msgpack" - "github.com/theater/picomap/lib/transport" + "io" ) var HashKey = [8]byte{} -type Client struct { - transport transport.Transport - timeout time.Duration - nextID atomic.Uint32 +type transport interface { + Send(data []byte) error + SetReadTimeout(timeout time.Duration) + Reader() io.Reader + Close() error } -func New(t transport.Transport, timeout time.Duration) *Client { - return &Client{transport: t, timeout: timeout} +type Client struct { + transport transport + timeout time.Duration + nextID atomic.Uint32 } func (c *Client) Close() error { diff --git a/lib/client/serial.go b/lib/client/serial.go new file mode 100644 index 0000000..6d586a4 --- /dev/null +++ b/lib/client/serial.go @@ -0,0 +1,53 @@ +package client + +import ( + "fmt" + "io" + "time" + + "go.bug.st/serial" + "go.bug.st/serial/enumerator" +) + +func ListSerial() ([]string, error) { + ports, err := enumerator.GetDetailedPortsList() + if err != nil { + return nil, fmt.Errorf("enumerating ports: %w", err) + } + var result []string + for _, p := range ports { + if p.IsUSB { + result = append(result, p.Name) + } + } + return result, nil +} + +type serialTransport struct { + port serial.Port +} + +func NewSerial(portName string, timeout time.Duration) (*Client, error) { + port, err := serial.Open(portName, &serial.Mode{BaudRate: 115200}) + if err != nil { + return nil, fmt.Errorf("opening %s: %w", portName, err) + } + return &Client{transport: &serialTransport{port: port}, timeout: timeout}, nil +} + +func (t *serialTransport) Send(data []byte) error { + _, err := t.port.Write(data) + return err +} + +func (t *serialTransport) SetReadTimeout(timeout time.Duration) { + t.port.SetReadTimeout(timeout) +} + +func (t *serialTransport) Reader() io.Reader { + return t.port +} + +func (t *serialTransport) Close() error { + return t.port.Close() +} diff --git a/lib/picoserial/picoserial.go b/lib/picoserial/picoserial.go deleted file mode 100644 index f16f4c3..0000000 --- a/lib/picoserial/picoserial.go +++ /dev/null @@ -1,52 +0,0 @@ -package picoserial - -import ( - "fmt" - "io" - "time" - - "go.bug.st/serial" - "go.bug.st/serial/enumerator" -) - -func FindDevice() (string, error) { - ports, err := enumerator.GetDetailedPortsList() - if err != nil { - return "", fmt.Errorf("enumerating ports: %w", err) - } - for _, p := range ports { - if p.IsUSB { - return p.Name, nil - } - } - return "", nil -} - -type SerialTransport struct { - port serial.Port -} - -func Open(portName string) (*SerialTransport, error) { - port, err := serial.Open(portName, &serial.Mode{BaudRate: 115200}) - if err != nil { - return nil, fmt.Errorf("opening %s: %w", portName, err) - } - return &SerialTransport{port: port}, nil -} - -func (t *SerialTransport) Send(data []byte) error { - _, err := t.port.Write(data) - return err -} - -func (t *SerialTransport) SetReadTimeout(timeout time.Duration) { - t.port.SetReadTimeout(timeout) -} - -func (t *SerialTransport) Reader() io.Reader { - return t.port -} - -func (t *SerialTransport) Close() error { - return t.port.Close() -} diff --git a/lib/transport/transport.go b/lib/transport/transport.go deleted file mode 100644 index 2dc6b35..0000000 --- a/lib/transport/transport.go +++ /dev/null @@ -1,13 +0,0 @@ -package transport - -import ( - "io" - "time" -) - -type Transport interface { - Send(data []byte) error - SetReadTimeout(timeout time.Duration) - Reader() io.Reader - Close() error -} diff --git a/picomap.cpp b/picomap.cpp index 97c7a2f..42be308 100644 --- a/picomap.cpp +++ b/picomap.cpp @@ -2,7 +2,7 @@ #include "pico/stdlib.h" #include "pico/bootrom.h" #include "pico/unique_id.h" -#include "device.h" +#include "wire.h" #include "w6300.h"