diff --git a/cmd/picomap/main.go b/cmd/picomap/main.go index 3c70e1b..66ae7f0 100644 --- a/cmd/picomap/main.go +++ b/cmd/picomap/main.go @@ -13,7 +13,7 @@ import ( "time" "github.com/theater/picomap/lib/client" - "github.com/theater/picomap/lib/picotool" + "github.com/theater/picomap/lib/uf2" ) func main() { @@ -197,9 +197,12 @@ func cmdLog(_ []string) error { } func cmdLoad(args []string) error { + fs := flag.NewFlagSet("load", flag.ExitOnError) + dryRun := fs.Bool("dry-run", false, "parse UF2 and log operations without flashing") + fs.Parse(args) target := "all" - if len(args) > 0 { - target = args[0] + if fs.NArg() > 0 { + target = fs.Arg(0) } wd, err := os.Getwd() @@ -212,11 +215,6 @@ func cmdLoad(args []string) error { return err } - devs, err := client.ListSerial() - if err != nil { - return err - } - allTargets := []struct { name string uf2 string @@ -240,6 +238,10 @@ func cmdLoad(args []string) error { return fmt.Errorf("unknown target %q", target) } + devs, err := client.ListSerial() + if err != nil { + return err + } if len(devs) < len(targets) { return fmt.Errorf("need %d device(s), found %d", len(targets), len(devs)) } @@ -281,47 +283,17 @@ func cmdLoad(args []string) error { for i := range devices { log := slog.With("serial", devices[i].serial) wg.Go(func() { - log.Info("sending PICOBOOT") - c, err := client.NewSerial(devices[i].dev, 500*time.Millisecond) - if err != nil { - errs[i] = err - return - } - err = c.PICOBOOT() - c.Close() - if err != nil { - errs[i] = fmt.Errorf("PICOBOOT %s: %w", devices[i].serial, err) - return - } - log.Info("PICOBOOT sent") - }) - } - wg.Wait() - for i, err := range errs { - if err != nil { - return fmt.Errorf("[%s] %w", devices[i].serial, err) - } - } - - uf2s := make([]string, len(targets)) - for i := range targets { - uf2s[i] = targets[i].uf2 - } - - for i := range devices { - log := slog.With("serial", devices[i].serial) - wg.Go(func() { - log.Info("loading", "uf2", devices[i].name) - errs[i] = picotool.Load(devices[i].uf2, devices[i].serial, 10*time.Second) + log.Info("flashing", "uf2", devices[i].name) + errs[i] = flashDevice(devices[i].dev, devices[i].uf2, *dryRun, log) if errs[i] == nil { - log.Info("loaded", "uf2", devices[i].name) + log.Info("flashed", "uf2", devices[i].name) } }) } wg.Wait() for i, err := range errs { if err != nil { - return fmt.Errorf("[%s] load: %w", devices[i].serial, err) + return fmt.Errorf("[%s] flash: %w", devices[i].serial, err) } } @@ -329,6 +301,54 @@ func cmdLoad(args []string) error { return nil } +func flashDevice(dev, uf2Path string, dryRun bool, log *slog.Logger) error { + blocks, err := uf2.Parse(uf2Path) + if err != nil { + return fmt.Errorf("parse uf2: %w", err) + } + + log.Info("parsed uf2", "blocks", len(blocks)) + + const sectorSize = 4096 + + if dryRun { + erased := make(map[uint32]bool) + for _, b := range blocks { + sector := b.Addr &^ (sectorSize - 1) + if !erased[sector] { + log.Info("erasing", "addr", fmt.Sprintf("%08x", sector)) + erased[sector] = true + } + log.Info("writing", "addr", fmt.Sprintf("%08x", b.Addr), "len", len(b.Data)) + } + return nil + } + + c, err := client.NewSerial(dev, 5*time.Second) + if err != nil { + return err + } + defer c.Close() + + erased := make(map[uint32]bool) + for _, b := range blocks { + sector := b.Addr &^ (sectorSize - 1) + if !erased[sector] { + if err := c.FlashErase(sector, sectorSize); err != nil { + return fmt.Errorf("erase %08x: %w", sector, err) + } + erased[sector] = true + } + if err := c.FlashWrite(b.Addr, b.Data); err != nil { + return fmt.Errorf("write %08x: %w", b.Addr, err) + } + } + + log.Info("rebooting") + c.Reboot() + return nil +} + func findTestDevice() (string, error) { devs, err := client.ListSerial() if err != nil { diff --git a/firmware/firmware.cpp b/firmware/firmware.cpp index 5f2675a..4d7f8e8 100644 --- a/firmware/firmware.cpp +++ b/firmware/firmware.cpp @@ -7,6 +7,9 @@ static constexpr handler_entry handlers[] = { {RequestPICOBOOT::ext_id, typed_handler}, {RequestInfo::ext_id, typed_handler}, {RequestLog::ext_id, typed_handler}, + {RequestFlashErase::ext_id, typed_handler}, + {RequestFlashWrite::ext_id, typed_handler}, + {RequestReboot::ext_id, typed_handler}, }; int main() { diff --git a/firmware/include/handlers.h b/firmware/include/handlers.h index 7fc69f5..1d19bc1 100644 --- a/firmware/include/handlers.h +++ b/firmware/include/handlers.h @@ -9,3 +9,6 @@ extern std::string_view firmware_name; std::optional handle_picoboot(const responder& resp, const RequestPICOBOOT&); std::optional handle_info(const responder& resp, const RequestInfo&); std::optional handle_log(const responder& resp, const RequestLog&); +std::optional handle_flash_erase(const responder& resp, const RequestFlashErase&); +std::optional handle_flash_write(const responder& resp, const RequestFlashWrite&); +std::optional handle_reboot(const responder& resp, const RequestReboot&); diff --git a/firmware/include/wire.h b/firmware/include/wire.h index a13e0ee..bb0c47d 100644 --- a/firmware/include/wire.h +++ b/firmware/include/wire.h @@ -74,6 +74,46 @@ struct ResponseLog { auto as_tuple() { return std::tie(entries); } }; +struct RequestFlashErase { + static constexpr int8_t ext_id = 8; + uint32_t addr; + uint32_t len; + auto as_tuple() const { return std::tie(addr, len); } + auto as_tuple() { return std::tie(addr, len); } +}; + +struct ResponseFlashErase { + static constexpr int8_t ext_id = 9; + auto as_tuple() const { return std::tie(); } + auto as_tuple() { return std::tie(); } +}; + +struct RequestFlashWrite { + static constexpr int8_t ext_id = 10; + uint32_t addr; + std::span data; + auto as_tuple() const { return std::tie(addr, data); } + auto as_tuple() { return std::tie(addr, data); } +}; + +struct ResponseFlashWrite { + static constexpr int8_t ext_id = 11; + auto as_tuple() const { return std::tie(); } + auto as_tuple() { return std::tie(); } +}; + +struct RequestReboot { + static constexpr int8_t ext_id = 12; + auto as_tuple() const { return std::tie(); } + auto as_tuple() { return std::tie(); } +}; + +struct ResponseReboot { + static constexpr int8_t ext_id = 13; + auto as_tuple() const { return std::tie(); } + auto as_tuple() { return std::tie(); } +}; + struct RequestListTests { static constexpr int8_t ext_id = 125; auto as_tuple() const { return std::tie(); } diff --git a/firmware/lib/dispatch.cpp b/firmware/lib/dispatch.cpp index e2a5206..92386d1 100644 --- a/firmware/lib/dispatch.cpp +++ b/firmware/lib/dispatch.cpp @@ -42,8 +42,8 @@ bool dispatch_cancel_timer(timer_handle h) { } static usb_cdc usb; - static static_vector usb_rx_buf; - static std::array tx_buf; + static static_vector usb_rx_buf; + static std::array tx_buf; auto dispatch_msg = [&](const DecodedMessage& msg, std::function)> send) { auto it = handler_map.find(msg.type_id); diff --git a/firmware/lib/handlers.cpp b/firmware/lib/handlers.cpp index af432f8..0e36019 100644 --- a/firmware/lib/handlers.cpp +++ b/firmware/lib/handlers.cpp @@ -1,10 +1,15 @@ #include "handlers.h" #include "pico/unique_id.h" #include "pico/bootrom.h" +#include "hardware/flash.h" +#include "hardware/watchdog.h" #include "dispatch.h" #include "net.h" #include "debug_log.h" +static constexpr uint32_t XIP_BASE_ADDR = 0x10000000; +static constexpr uint32_t FLASH_SIZE = 2 * 1024 * 1024; + std::optional handle_picoboot(const responder&, const RequestPICOBOOT&) { dispatch_schedule_ms(100, []{ reset_usb_boot(0, 1); }); return ResponsePICOBOOT{}; @@ -28,3 +33,41 @@ std::optional handle_log(const responder&, const RequestLog&) { resp.entries.push_back(LogEntry{e.timestamp_us, std::move(e.message)}); return resp; } + +std::optional handle_flash_erase(const responder&, const RequestFlashErase& req) { + if (req.addr < XIP_BASE_ADDR || req.addr + req.len > XIP_BASE_ADDR + FLASH_SIZE) { + dlogf("flash erase: out of range %08lx+%lu", + static_cast(req.addr), static_cast(req.len)); + return std::nullopt; + } + uint32_t offset = req.addr - XIP_BASE_ADDR; + if (offset % FLASH_SECTOR_SIZE != 0 || req.len % FLASH_SECTOR_SIZE != 0 || req.len == 0) { + dlogf("flash erase: bad alignment %08lx+%lu", + static_cast(req.addr), static_cast(req.len)); + return std::nullopt; + } + flash_range_erase(offset, req.len); + return ResponseFlashErase{}; +} + +std::optional handle_flash_write(const responder&, const RequestFlashWrite& req) { + if (req.addr < XIP_BASE_ADDR || req.addr + req.data.size() > XIP_BASE_ADDR + FLASH_SIZE) { + dlogf("flash write: out of range %08lx+%zu", + static_cast(req.addr), req.data.size()); + return std::nullopt; + } + uint32_t offset = req.addr - XIP_BASE_ADDR; + if (offset % FLASH_PAGE_SIZE != 0 || req.data.size() % FLASH_PAGE_SIZE != 0 || req.data.empty()) { + dlogf("flash write: bad alignment %08lx+%zu", + static_cast(req.addr), req.data.size()); + return std::nullopt; + } + flash_range_program(offset, req.data.data(), req.data.size()); + return ResponseFlashWrite{}; +} + +std::optional handle_reboot(const responder&, const RequestReboot&) { + dispatch_schedule_ms(100, []{ watchdog_reboot(0, 0, 0); }); + return ResponseReboot{}; +} + diff --git a/firmware/test.cpp b/firmware/test.cpp index b515371..24583e9 100644 --- a/firmware/test.cpp +++ b/firmware/test.cpp @@ -17,6 +17,9 @@ static constexpr handler_entry handlers[] = { {RequestPICOBOOT::ext_id, typed_handler}, {RequestInfo::ext_id, typed_handler}, {RequestLog::ext_id, typed_handler}, + {RequestFlashErase::ext_id, typed_handler}, + {RequestFlashWrite::ext_id, typed_handler}, + {RequestReboot::ext_id, typed_handler}, {RequestListTests::ext_id, typed_handler}, {RequestTest::ext_id, typed_handler}, }; diff --git a/lib/client/client.go b/lib/client/client.go index 75119ad..ebcb2a3 100644 --- a/lib/client/client.go +++ b/lib/client/client.go @@ -117,6 +117,21 @@ func (c *Client) Log() (*ResponseLog, error) { return first(roundTrip[ResponseLog](c, &RequestLog{})) } +func (c *Client) FlashErase(addr, length uint32) error { + _, err := first(roundTrip[ResponseFlashErase](c, &RequestFlashErase{Addr: addr, Len: length})) + return err +} + +func (c *Client) FlashWrite(addr uint32, data []byte) error { + _, err := first(roundTrip[ResponseFlashWrite](c, &RequestFlashWrite{Addr: addr, Data: data})) + return err +} + +func (c *Client) Reboot() error { + _, err := first(roundTrip[ResponseReboot](c, &RequestReboot{})) + return err +} + func (c *Client) ListTests() (*ResponseListTests, error) { return first(roundTrip[ResponseListTests](c, &RequestListTests{})) } diff --git a/lib/client/types.go b/lib/client/types.go index 14072dd..9eac788 100644 --- a/lib/client/types.go +++ b/lib/client/types.go @@ -24,6 +24,21 @@ type ResponseLog struct { Entries []LogEntry } +type RequestFlashErase struct { + Addr uint32 + Len uint32 +} +type ResponseFlashErase struct{} + +type RequestFlashWrite struct { + Addr uint32 + Data []byte +} +type ResponseFlashWrite struct{} + +type RequestReboot struct{} +type ResponseReboot struct{} + type RequestListTests struct{} type ResponseListTests struct { Names []string @@ -62,6 +77,12 @@ func init() { msgpack.RegisterExt(5, (*ResponseInfo)(nil)) msgpack.RegisterExt(6, (*RequestLog)(nil)) msgpack.RegisterExt(7, (*ResponseLog)(nil)) + msgpack.RegisterExt(8, (*RequestFlashErase)(nil)) + msgpack.RegisterExt(9, (*ResponseFlashErase)(nil)) + msgpack.RegisterExt(10, (*RequestFlashWrite)(nil)) + msgpack.RegisterExt(11, (*ResponseFlashWrite)(nil)) + msgpack.RegisterExt(12, (*RequestReboot)(nil)) + msgpack.RegisterExt(13, (*ResponseReboot)(nil)) msgpack.RegisterExt(125, (*RequestListTests)(nil)) msgpack.RegisterExt(124, (*ResponseListTests)(nil)) msgpack.RegisterExt(127, (*RequestTest)(nil)) diff --git a/lib/picotool/picotool.go b/lib/picotool/picotool.go deleted file mode 100644 index 68304f2..0000000 --- a/lib/picotool/picotool.go +++ /dev/null @@ -1,32 +0,0 @@ -package picotool - -import ( - "fmt" - "os/exec" - "time" -) - -func Load(uf2Path string, serial string, timeout time.Duration) error { - deadline := time.Now().Add(timeout) - var out []byte - var err error - for { - cmd := exec.Command("picotool", "load", uf2Path, "-x", "--ser", serial) - out, err = cmd.CombinedOutput() - if err == nil { - return nil - } - if time.Now().After(deadline) { - return fmt.Errorf("picotool load: %w\n%s", err, out) - } - } -} - -func Reboot(serial string) error { - cmd := exec.Command("picotool", "reboot", "--ser", serial) - out, err := cmd.CombinedOutput() - if err != nil { - return fmt.Errorf("picotool reboot: %w\n%s", err, out) - } - return nil -} diff --git a/lib/uf2/uf2.go b/lib/uf2/uf2.go new file mode 100644 index 0000000..010c532 --- /dev/null +++ b/lib/uf2/uf2.go @@ -0,0 +1,71 @@ +package uf2 + +import ( + "encoding/binary" + "fmt" + "os" + "sort" +) + +const ( + blockSize = 512 + magic0 = 0x0A324655 + magic1 = 0x9E5D5157 + magicEnd = 0x0AB16F30 + + flagNotMainFlash = 0x00000001 + flagFamilyIDPresent = 0x00002000 + absoluteFamilyID = 0xe48bff57 +) + +type Block struct { + Addr uint32 + Data []byte +} + +func Parse(path string) ([]Block, error) { + data, err := os.ReadFile(path) + if err != nil { + return nil, err + } + if len(data)%blockSize != 0 { + return nil, fmt.Errorf("file size %d not multiple of %d", len(data), blockSize) + } + + var blocks []Block + for i := 0; i < len(data); i += blockSize { + b := data[i : i+blockSize] + m0 := binary.LittleEndian.Uint32(b[0:4]) + m1 := binary.LittleEndian.Uint32(b[4:8]) + me := binary.LittleEndian.Uint32(b[508:512]) + if m0 != magic0 || m1 != magic1 || me != magicEnd { + return nil, fmt.Errorf("block %d: bad magic", i/blockSize) + } + flags := binary.LittleEndian.Uint32(b[8:12]) + if flags&flagNotMainFlash != 0 { + continue + } + if flags&flagFamilyIDPresent != 0 { + familyID := binary.LittleEndian.Uint32(b[28:32]) + if familyID == absoluteFamilyID { + continue + } + } + addr := binary.LittleEndian.Uint32(b[12:16]) + size := binary.LittleEndian.Uint32(b[16:20]) + if size > 256 { + return nil, fmt.Errorf("block %d: data size %d > 256", i/blockSize, size) + } + blocks = append(blocks, Block{ + Addr: addr, + Data: make([]byte, size), + }) + copy(blocks[len(blocks)-1].Data, b[32:32+size]) + } + + sort.Slice(blocks, func(i, j int) bool { + return blocks[i].Addr < blocks[j].Addr + }) + + return blocks, nil +}