In-app flash write, UF2 parser, remove picotool dependency, reboot command

This commit is contained in:
Ian Gulliver
2026-04-11 22:26:54 +09:00
parent a635aa04e0
commit e3d97f4946
11 changed files with 263 additions and 76 deletions

View File

@@ -13,7 +13,7 @@ import (
"time" "time"
"github.com/theater/picomap/lib/client" "github.com/theater/picomap/lib/client"
"github.com/theater/picomap/lib/picotool" "github.com/theater/picomap/lib/uf2"
) )
func main() { func main() {
@@ -197,9 +197,12 @@ func cmdLog(_ []string) error {
} }
func cmdLoad(args []string) error { func cmdLoad(args []string) error {
fs := flag.NewFlagSet("load", flag.ExitOnError)
dryRun := fs.Bool("dry-run", false, "parse UF2 and log operations without flashing")
fs.Parse(args)
target := "all" target := "all"
if len(args) > 0 { if fs.NArg() > 0 {
target = args[0] target = fs.Arg(0)
} }
wd, err := os.Getwd() wd, err := os.Getwd()
@@ -212,11 +215,6 @@ func cmdLoad(args []string) error {
return err return err
} }
devs, err := client.ListSerial()
if err != nil {
return err
}
allTargets := []struct { allTargets := []struct {
name string name string
uf2 string uf2 string
@@ -240,6 +238,10 @@ func cmdLoad(args []string) error {
return fmt.Errorf("unknown target %q", target) return fmt.Errorf("unknown target %q", target)
} }
devs, err := client.ListSerial()
if err != nil {
return err
}
if len(devs) < len(targets) { if len(devs) < len(targets) {
return fmt.Errorf("need %d device(s), found %d", len(targets), len(devs)) return fmt.Errorf("need %d device(s), found %d", len(targets), len(devs))
} }
@@ -281,47 +283,17 @@ func cmdLoad(args []string) error {
for i := range devices { for i := range devices {
log := slog.With("serial", devices[i].serial) log := slog.With("serial", devices[i].serial)
wg.Go(func() { wg.Go(func() {
log.Info("sending PICOBOOT") log.Info("flashing", "uf2", devices[i].name)
c, err := client.NewSerial(devices[i].dev, 500*time.Millisecond) errs[i] = flashDevice(devices[i].dev, devices[i].uf2, *dryRun, log)
if err != nil {
errs[i] = err
return
}
err = c.PICOBOOT()
c.Close()
if err != nil {
errs[i] = fmt.Errorf("PICOBOOT %s: %w", devices[i].serial, err)
return
}
log.Info("PICOBOOT sent")
})
}
wg.Wait()
for i, err := range errs {
if err != nil {
return fmt.Errorf("[%s] %w", devices[i].serial, err)
}
}
uf2s := make([]string, len(targets))
for i := range targets {
uf2s[i] = targets[i].uf2
}
for i := range devices {
log := slog.With("serial", devices[i].serial)
wg.Go(func() {
log.Info("loading", "uf2", devices[i].name)
errs[i] = picotool.Load(devices[i].uf2, devices[i].serial, 10*time.Second)
if errs[i] == nil { if errs[i] == nil {
log.Info("loaded", "uf2", devices[i].name) log.Info("flashed", "uf2", devices[i].name)
} }
}) })
} }
wg.Wait() wg.Wait()
for i, err := range errs { for i, err := range errs {
if err != nil { if err != nil {
return fmt.Errorf("[%s] load: %w", devices[i].serial, err) return fmt.Errorf("[%s] flash: %w", devices[i].serial, err)
} }
} }
@@ -329,6 +301,54 @@ func cmdLoad(args []string) error {
return nil return nil
} }
func flashDevice(dev, uf2Path string, dryRun bool, log *slog.Logger) error {
blocks, err := uf2.Parse(uf2Path)
if err != nil {
return fmt.Errorf("parse uf2: %w", err)
}
log.Info("parsed uf2", "blocks", len(blocks))
const sectorSize = 4096
if dryRun {
erased := make(map[uint32]bool)
for _, b := range blocks {
sector := b.Addr &^ (sectorSize - 1)
if !erased[sector] {
log.Info("erasing", "addr", fmt.Sprintf("%08x", sector))
erased[sector] = true
}
log.Info("writing", "addr", fmt.Sprintf("%08x", b.Addr), "len", len(b.Data))
}
return nil
}
c, err := client.NewSerial(dev, 5*time.Second)
if err != nil {
return err
}
defer c.Close()
erased := make(map[uint32]bool)
for _, b := range blocks {
sector := b.Addr &^ (sectorSize - 1)
if !erased[sector] {
if err := c.FlashErase(sector, sectorSize); err != nil {
return fmt.Errorf("erase %08x: %w", sector, err)
}
erased[sector] = true
}
if err := c.FlashWrite(b.Addr, b.Data); err != nil {
return fmt.Errorf("write %08x: %w", b.Addr, err)
}
}
log.Info("rebooting")
c.Reboot()
return nil
}
func findTestDevice() (string, error) { func findTestDevice() (string, error) {
devs, err := client.ListSerial() devs, err := client.ListSerial()
if err != nil { if err != nil {

View File

@@ -7,6 +7,9 @@ static constexpr handler_entry handlers[] = {
{RequestPICOBOOT::ext_id, typed_handler<RequestPICOBOOT, handle_picoboot>}, {RequestPICOBOOT::ext_id, typed_handler<RequestPICOBOOT, handle_picoboot>},
{RequestInfo::ext_id, typed_handler<RequestInfo, handle_info>}, {RequestInfo::ext_id, typed_handler<RequestInfo, handle_info>},
{RequestLog::ext_id, typed_handler<RequestLog, handle_log>}, {RequestLog::ext_id, typed_handler<RequestLog, handle_log>},
{RequestFlashErase::ext_id, typed_handler<RequestFlashErase, handle_flash_erase>},
{RequestFlashWrite::ext_id, typed_handler<RequestFlashWrite, handle_flash_write>},
{RequestReboot::ext_id, typed_handler<RequestReboot, handle_reboot>},
}; };
int main() { int main() {

View File

@@ -9,3 +9,6 @@ extern std::string_view firmware_name;
std::optional<ResponsePICOBOOT> handle_picoboot(const responder& resp, const RequestPICOBOOT&); std::optional<ResponsePICOBOOT> handle_picoboot(const responder& resp, const RequestPICOBOOT&);
std::optional<ResponseInfo> handle_info(const responder& resp, const RequestInfo&); std::optional<ResponseInfo> handle_info(const responder& resp, const RequestInfo&);
std::optional<ResponseLog> handle_log(const responder& resp, const RequestLog&); std::optional<ResponseLog> handle_log(const responder& resp, const RequestLog&);
std::optional<ResponseFlashErase> handle_flash_erase(const responder& resp, const RequestFlashErase&);
std::optional<ResponseFlashWrite> handle_flash_write(const responder& resp, const RequestFlashWrite&);
std::optional<ResponseReboot> handle_reboot(const responder& resp, const RequestReboot&);

View File

@@ -74,6 +74,46 @@ struct ResponseLog {
auto as_tuple() { return std::tie(entries); } auto as_tuple() { return std::tie(entries); }
}; };
struct RequestFlashErase {
static constexpr int8_t ext_id = 8;
uint32_t addr;
uint32_t len;
auto as_tuple() const { return std::tie(addr, len); }
auto as_tuple() { return std::tie(addr, len); }
};
struct ResponseFlashErase {
static constexpr int8_t ext_id = 9;
auto as_tuple() const { return std::tie(); }
auto as_tuple() { return std::tie(); }
};
struct RequestFlashWrite {
static constexpr int8_t ext_id = 10;
uint32_t addr;
std::span<const uint8_t> data;
auto as_tuple() const { return std::tie(addr, data); }
auto as_tuple() { return std::tie(addr, data); }
};
struct ResponseFlashWrite {
static constexpr int8_t ext_id = 11;
auto as_tuple() const { return std::tie(); }
auto as_tuple() { return std::tie(); }
};
struct RequestReboot {
static constexpr int8_t ext_id = 12;
auto as_tuple() const { return std::tie(); }
auto as_tuple() { return std::tie(); }
};
struct ResponseReboot {
static constexpr int8_t ext_id = 13;
auto as_tuple() const { return std::tie(); }
auto as_tuple() { return std::tie(); }
};
struct RequestListTests { struct RequestListTests {
static constexpr int8_t ext_id = 125; static constexpr int8_t ext_id = 125;
auto as_tuple() const { return std::tie(); } auto as_tuple() const { return std::tie(); }

View File

@@ -42,8 +42,8 @@ bool dispatch_cancel_timer(timer_handle h) {
} }
static usb_cdc usb; static usb_cdc usb;
static static_vector<uint8_t, 256> usb_rx_buf; static static_vector<uint8_t, 4096> usb_rx_buf;
static std::array<uint8_t, 1514> tx_buf; static std::array<uint8_t, 4096> tx_buf;
auto dispatch_msg = [&](const DecodedMessage& msg, std::function<void(std::span<const uint8_t>)> send) { auto dispatch_msg = [&](const DecodedMessage& msg, std::function<void(std::span<const uint8_t>)> send) {
auto it = handler_map.find(msg.type_id); auto it = handler_map.find(msg.type_id);

View File

@@ -1,10 +1,15 @@
#include "handlers.h" #include "handlers.h"
#include "pico/unique_id.h" #include "pico/unique_id.h"
#include "pico/bootrom.h" #include "pico/bootrom.h"
#include "hardware/flash.h"
#include "hardware/watchdog.h"
#include "dispatch.h" #include "dispatch.h"
#include "net.h" #include "net.h"
#include "debug_log.h" #include "debug_log.h"
static constexpr uint32_t XIP_BASE_ADDR = 0x10000000;
static constexpr uint32_t FLASH_SIZE = 2 * 1024 * 1024;
std::optional<ResponsePICOBOOT> handle_picoboot(const responder&, const RequestPICOBOOT&) { std::optional<ResponsePICOBOOT> handle_picoboot(const responder&, const RequestPICOBOOT&) {
dispatch_schedule_ms(100, []{ reset_usb_boot(0, 1); }); dispatch_schedule_ms(100, []{ reset_usb_boot(0, 1); });
return ResponsePICOBOOT{}; return ResponsePICOBOOT{};
@@ -28,3 +33,41 @@ std::optional<ResponseLog> handle_log(const responder&, const RequestLog&) {
resp.entries.push_back(LogEntry{e.timestamp_us, std::move(e.message)}); resp.entries.push_back(LogEntry{e.timestamp_us, std::move(e.message)});
return resp; return resp;
} }
std::optional<ResponseFlashErase> handle_flash_erase(const responder&, const RequestFlashErase& req) {
if (req.addr < XIP_BASE_ADDR || req.addr + req.len > XIP_BASE_ADDR + FLASH_SIZE) {
dlogf("flash erase: out of range %08lx+%lu",
static_cast<unsigned long>(req.addr), static_cast<unsigned long>(req.len));
return std::nullopt;
}
uint32_t offset = req.addr - XIP_BASE_ADDR;
if (offset % FLASH_SECTOR_SIZE != 0 || req.len % FLASH_SECTOR_SIZE != 0 || req.len == 0) {
dlogf("flash erase: bad alignment %08lx+%lu",
static_cast<unsigned long>(req.addr), static_cast<unsigned long>(req.len));
return std::nullopt;
}
flash_range_erase(offset, req.len);
return ResponseFlashErase{};
}
std::optional<ResponseFlashWrite> handle_flash_write(const responder&, const RequestFlashWrite& req) {
if (req.addr < XIP_BASE_ADDR || req.addr + req.data.size() > XIP_BASE_ADDR + FLASH_SIZE) {
dlogf("flash write: out of range %08lx+%zu",
static_cast<unsigned long>(req.addr), req.data.size());
return std::nullopt;
}
uint32_t offset = req.addr - XIP_BASE_ADDR;
if (offset % FLASH_PAGE_SIZE != 0 || req.data.size() % FLASH_PAGE_SIZE != 0 || req.data.empty()) {
dlogf("flash write: bad alignment %08lx+%zu",
static_cast<unsigned long>(req.addr), req.data.size());
return std::nullopt;
}
flash_range_program(offset, req.data.data(), req.data.size());
return ResponseFlashWrite{};
}
std::optional<ResponseReboot> handle_reboot(const responder&, const RequestReboot&) {
dispatch_schedule_ms(100, []{ watchdog_reboot(0, 0, 0); });
return ResponseReboot{};
}

View File

@@ -17,6 +17,9 @@ static constexpr handler_entry handlers[] = {
{RequestPICOBOOT::ext_id, typed_handler<RequestPICOBOOT, handle_picoboot>}, {RequestPICOBOOT::ext_id, typed_handler<RequestPICOBOOT, handle_picoboot>},
{RequestInfo::ext_id, typed_handler<RequestInfo, handle_info>}, {RequestInfo::ext_id, typed_handler<RequestInfo, handle_info>},
{RequestLog::ext_id, typed_handler<RequestLog, handle_log>}, {RequestLog::ext_id, typed_handler<RequestLog, handle_log>},
{RequestFlashErase::ext_id, typed_handler<RequestFlashErase, handle_flash_erase>},
{RequestFlashWrite::ext_id, typed_handler<RequestFlashWrite, handle_flash_write>},
{RequestReboot::ext_id, typed_handler<RequestReboot, handle_reboot>},
{RequestListTests::ext_id, typed_handler<RequestListTests, handle_list_tests>}, {RequestListTests::ext_id, typed_handler<RequestListTests, handle_list_tests>},
{RequestTest::ext_id, typed_handler<RequestTest, handle_test>}, {RequestTest::ext_id, typed_handler<RequestTest, handle_test>},
}; };

View File

@@ -117,6 +117,21 @@ func (c *Client) Log() (*ResponseLog, error) {
return first(roundTrip[ResponseLog](c, &RequestLog{})) return first(roundTrip[ResponseLog](c, &RequestLog{}))
} }
func (c *Client) FlashErase(addr, length uint32) error {
_, err := first(roundTrip[ResponseFlashErase](c, &RequestFlashErase{Addr: addr, Len: length}))
return err
}
func (c *Client) FlashWrite(addr uint32, data []byte) error {
_, err := first(roundTrip[ResponseFlashWrite](c, &RequestFlashWrite{Addr: addr, Data: data}))
return err
}
func (c *Client) Reboot() error {
_, err := first(roundTrip[ResponseReboot](c, &RequestReboot{}))
return err
}
func (c *Client) ListTests() (*ResponseListTests, error) { func (c *Client) ListTests() (*ResponseListTests, error) {
return first(roundTrip[ResponseListTests](c, &RequestListTests{})) return first(roundTrip[ResponseListTests](c, &RequestListTests{}))
} }

View File

@@ -24,6 +24,21 @@ type ResponseLog struct {
Entries []LogEntry Entries []LogEntry
} }
type RequestFlashErase struct {
Addr uint32
Len uint32
}
type ResponseFlashErase struct{}
type RequestFlashWrite struct {
Addr uint32
Data []byte
}
type ResponseFlashWrite struct{}
type RequestReboot struct{}
type ResponseReboot struct{}
type RequestListTests struct{} type RequestListTests struct{}
type ResponseListTests struct { type ResponseListTests struct {
Names []string Names []string
@@ -62,6 +77,12 @@ func init() {
msgpack.RegisterExt(5, (*ResponseInfo)(nil)) msgpack.RegisterExt(5, (*ResponseInfo)(nil))
msgpack.RegisterExt(6, (*RequestLog)(nil)) msgpack.RegisterExt(6, (*RequestLog)(nil))
msgpack.RegisterExt(7, (*ResponseLog)(nil)) msgpack.RegisterExt(7, (*ResponseLog)(nil))
msgpack.RegisterExt(8, (*RequestFlashErase)(nil))
msgpack.RegisterExt(9, (*ResponseFlashErase)(nil))
msgpack.RegisterExt(10, (*RequestFlashWrite)(nil))
msgpack.RegisterExt(11, (*ResponseFlashWrite)(nil))
msgpack.RegisterExt(12, (*RequestReboot)(nil))
msgpack.RegisterExt(13, (*ResponseReboot)(nil))
msgpack.RegisterExt(125, (*RequestListTests)(nil)) msgpack.RegisterExt(125, (*RequestListTests)(nil))
msgpack.RegisterExt(124, (*ResponseListTests)(nil)) msgpack.RegisterExt(124, (*ResponseListTests)(nil))
msgpack.RegisterExt(127, (*RequestTest)(nil)) msgpack.RegisterExt(127, (*RequestTest)(nil))

View File

@@ -1,32 +0,0 @@
package picotool
import (
"fmt"
"os/exec"
"time"
)
func Load(uf2Path string, serial string, timeout time.Duration) error {
deadline := time.Now().Add(timeout)
var out []byte
var err error
for {
cmd := exec.Command("picotool", "load", uf2Path, "-x", "--ser", serial)
out, err = cmd.CombinedOutput()
if err == nil {
return nil
}
if time.Now().After(deadline) {
return fmt.Errorf("picotool load: %w\n%s", err, out)
}
}
}
func Reboot(serial string) error {
cmd := exec.Command("picotool", "reboot", "--ser", serial)
out, err := cmd.CombinedOutput()
if err != nil {
return fmt.Errorf("picotool reboot: %w\n%s", err, out)
}
return nil
}

71
lib/uf2/uf2.go Normal file
View File

@@ -0,0 +1,71 @@
package uf2
import (
"encoding/binary"
"fmt"
"os"
"sort"
)
const (
blockSize = 512
magic0 = 0x0A324655
magic1 = 0x9E5D5157
magicEnd = 0x0AB16F30
flagNotMainFlash = 0x00000001
flagFamilyIDPresent = 0x00002000
absoluteFamilyID = 0xe48bff57
)
type Block struct {
Addr uint32
Data []byte
}
func Parse(path string) ([]Block, error) {
data, err := os.ReadFile(path)
if err != nil {
return nil, err
}
if len(data)%blockSize != 0 {
return nil, fmt.Errorf("file size %d not multiple of %d", len(data), blockSize)
}
var blocks []Block
for i := 0; i < len(data); i += blockSize {
b := data[i : i+blockSize]
m0 := binary.LittleEndian.Uint32(b[0:4])
m1 := binary.LittleEndian.Uint32(b[4:8])
me := binary.LittleEndian.Uint32(b[508:512])
if m0 != magic0 || m1 != magic1 || me != magicEnd {
return nil, fmt.Errorf("block %d: bad magic", i/blockSize)
}
flags := binary.LittleEndian.Uint32(b[8:12])
if flags&flagNotMainFlash != 0 {
continue
}
if flags&flagFamilyIDPresent != 0 {
familyID := binary.LittleEndian.Uint32(b[28:32])
if familyID == absoluteFamilyID {
continue
}
}
addr := binary.LittleEndian.Uint32(b[12:16])
size := binary.LittleEndian.Uint32(b[16:20])
if size > 256 {
return nil, fmt.Errorf("block %d: data size %d > 256", i/blockSize, size)
}
blocks = append(blocks, Block{
Addr: addr,
Data: make([]byte, size),
})
copy(blocks[len(blocks)-1].Data, b[32:32+size])
}
sort.Slice(blocks, func(i, j int) bool {
return blocks[i].Addr < blocks[j].Addr
})
return blocks, nil
}