diff --git a/lib/client/client.go b/lib/client/client.go index d31d7da..53ebc6a 100644 --- a/lib/client/client.go +++ b/lib/client/client.go @@ -42,15 +42,10 @@ func (c *Client) send(msg any) (uint32, error) { } func (c *Client) receive(expectedID uint32) (any, error) { - data, err := c.transport.Receive(c.timeout) - if err != nil { - return nil, fmt.Errorf("receive: %w", err) - } - if len(data) == 0 { - return nil, fmt.Errorf("no response") - } + c.transport.SetReadTimeout(c.timeout) + dec := msgpack.NewDecoder(c.transport.Reader()) var env Envelope - if err := msgpack.Unmarshal(data, &env); err != nil { + if err := dec.Decode(&env); err != nil { return nil, fmt.Errorf("decode envelope: %w", err) } if env.MessageID != expectedID { diff --git a/lib/picoserial/picoserial.go b/lib/picoserial/picoserial.go index e1b401e..f16f4c3 100644 --- a/lib/picoserial/picoserial.go +++ b/lib/picoserial/picoserial.go @@ -2,6 +2,7 @@ package picoserial import ( "fmt" + "io" "time" "go.bug.st/serial" @@ -38,20 +39,12 @@ func (t *SerialTransport) Send(data []byte) error { return err } -func (t *SerialTransport) Receive(timeout time.Duration) ([]byte, error) { +func (t *SerialTransport) SetReadTimeout(timeout time.Duration) { t.port.SetReadTimeout(timeout) - var resp []byte - buf := make([]byte, 256) - for { - n, err := t.port.Read(buf) - if n > 0 { - resp = append(resp, buf[:n]...) - } - if err != nil || n == 0 { - break - } - } - return resp, nil +} + +func (t *SerialTransport) Reader() io.Reader { + return t.port } func (t *SerialTransport) Close() error { diff --git a/lib/transport/transport.go b/lib/transport/transport.go index 664fc23..2dc6b35 100644 --- a/lib/transport/transport.go +++ b/lib/transport/transport.go @@ -1,9 +1,13 @@ package transport -import "time" +import ( + "io" + "time" +) type Transport interface { Send(data []byte) error - Receive(timeout time.Duration) ([]byte, error) + SetReadTimeout(timeout time.Duration) + Reader() io.Reader Close() error }