Extract solver into library package with delta-scoring optimization

This commit is contained in:
Ian Gulliver
2026-02-16 12:38:26 -08:00
parent fe6350f93e
commit d40f92628a
4 changed files with 1733 additions and 342 deletions

1
.gitignore vendored Normal file
View File

@@ -0,0 +1 @@
tmp/

339
cmd/solver-tune/main.go Normal file
View File

@@ -0,0 +1,339 @@
package main
import (
"encoding/json"
"flag"
"fmt"
"math/rand"
"os"
"slices"
"sort"
"strconv"
"strings"
"time"
"rooms/solver"
)
type tripData struct {
RoomSize int `json:"room_size"`
PreferNotMultiple int `json:"prefer_not_multiple"`
NoPreferCost int `json:"no_prefer_cost"`
}
type studentData struct {
ID int64 `json:"id"`
Name string `json:"name"`
}
type constraintsData struct {
Overalls []struct {
StudentAID int64 `json:"student_a_id"`
StudentBID int64 `json:"student_b_id"`
Kind string `json:"kind"`
} `json:"overalls"`
}
func normalizeKey(a []int) string {
rm := map[int][]int{}
for i, room := range a {
rm[room] = append(rm[room], i)
}
var gs [][]int
for _, members := range rm {
slices.Sort(members)
gs = append(gs, members)
}
slices.SortFunc(gs, func(a, b []int) int { return a[0] - b[0] })
var buf strings.Builder
for _, g := range gs {
for i, m := range g {
if i > 0 {
buf.WriteByte(',')
}
buf.WriteString(strconv.Itoa(m))
}
buf.WriteByte(';')
}
return buf.String()
}
type runResult struct {
score int
solutions [][]int
elapsed time.Duration
}
func printStats(label string, results []runResult, runs int) {
scores := map[int]int{}
solutionSets := map[string]int{}
var totalTime time.Duration
var totalSolutions int
for _, r := range results {
totalTime += r.elapsed
scores[r.score]++
totalSolutions += len(r.solutions)
for _, sol := range r.solutions {
key := normalizeKey(sol)
solutionSets[key]++
}
}
fmt.Printf("--- %s ---\n", label)
fmt.Printf(" avg time: %v\n", totalTime/time.Duration(runs))
var scoreList []struct {
score int
count int
}
for s, c := range scores {
scoreList = append(scoreList, struct {
score int
count int
}{s, c})
}
sort.Slice(scoreList, func(i, j int) bool { return scoreList[i].score > scoreList[j].score })
fmt.Printf(" score distribution:\n")
for _, sc := range scoreList {
fmt.Printf(" score %d: %d/%d runs (%.0f%%)\n", sc.score, sc.count, runs, float64(sc.count)/float64(runs)*100)
}
fmt.Printf(" unique solutions seen: %d\n", len(solutionSets))
fmt.Printf(" avg solutions per run: %.1f\n", float64(totalSolutions)/float64(runs))
var solFreqs []struct {
key string
count int
}
for k, c := range solutionSets {
solFreqs = append(solFreqs, struct {
key string
count int
}{k, c})
}
sort.Slice(solFreqs, func(i, j int) bool { return solFreqs[i].count > solFreqs[j].count })
stableCount := 0
for _, sf := range solFreqs {
if sf.count == runs {
stableCount++
}
}
fmt.Printf(" solutions found in all runs: %d\n", stableCount)
if len(solFreqs) > 0 {
topN := min(5, len(solFreqs))
fmt.Printf(" top %d solution frequencies: ", topN)
for i := range topN {
if i > 0 {
fmt.Print(", ")
}
fmt.Printf("%d/%d", solFreqs[i].count, runs)
}
fmt.Println()
}
fmt.Println()
}
func main() {
dir := flag.String("dir", "tmp", "directory with trip/students/constraints JSON files")
runs := flag.Int("runs", 20, "number of solver runs per parameter set")
algo := flag.String("algo", "both", "algorithm: hillclimb, fast, sa, hybrid, both, all")
numRandom := flag.String("random", "30", "comma-separated random placement counts (hillclimb)")
numPerturb := flag.String("perturb", "200", "comma-separated perturbation counts (hillclimb)")
perturbMin := flag.Int("pmin", 2, "perturbation min groups (hillclimb)")
perturbMax := flag.Int("pmax", 5, "perturbation max groups (hillclimb)")
saRestarts := flag.String("restarts", "20", "comma-separated SA/hybrid restart counts")
saSteps := flag.String("steps", "10000", "comma-separated SA step counts")
hybridSteps := flag.String("hsteps", "5000", "comma-separated hybrid SA step counts")
saTempHigh := flag.Float64("thigh", 5.0, "SA initial temperature")
saTempLow := flag.Float64("tlow", 0.01, "SA final temperature")
hybridTempHigh := flag.Float64("hthigh", 10.0, "hybrid SA initial temperature")
hybridTempLow := flag.Float64("htlow", 0.1, "hybrid SA final temperature")
flag.Parse()
tripBytes, err := os.ReadFile(*dir + "/1")
if err != nil {
fmt.Fprintf(os.Stderr, "reading trip: %v\n", err)
os.Exit(1)
}
var trip tripData
json.Unmarshal(tripBytes, &trip)
studentsBytes, err := os.ReadFile(*dir + "/students")
if err != nil {
fmt.Fprintf(os.Stderr, "reading students: %v\n", err)
os.Exit(1)
}
var students []studentData
json.Unmarshal(studentsBytes, &students)
constraintsBytes, err := os.ReadFile(*dir + "/constraints")
if err != nil {
fmt.Fprintf(os.Stderr, "reading constraints: %v\n", err)
os.Exit(1)
}
var cd constraintsData
json.Unmarshal(constraintsBytes, &cd)
idx := map[int64]int{}
for i, s := range students {
idx[s.ID] = i
}
n := len(students)
var constraints []solver.Constraint
for _, o := range cd.Overalls {
ai, aOk := idx[o.StudentAID]
bi, bOk := idx[o.StudentBID]
if !aOk || !bOk {
continue
}
constraints = append(constraints, solver.Constraint{
StudentA: ai,
StudentB: bi,
Kind: o.Kind,
})
}
fmt.Printf("Students: %d, Room size: %d, Constraints: %d\n", n, trip.RoomSize, len(constraints))
fmt.Printf("Prefer Not multiple: %d, No Prefer cost: %d\n", trip.PreferNotMultiple, trip.NoPreferCost)
fmt.Printf("Runs per config: %d\n\n", *runs)
if *algo == "hillclimb" || *algo == "both" {
randomCounts := parseIntList(*numRandom)
perturbCounts := parseIntList(*numPerturb)
for _, nr := range randomCounts {
for _, np := range perturbCounts {
params := solver.Params{
NumRandom: nr,
NumPerturb: np,
PerturbMin: *perturbMin,
PerturbMax: *perturbMax,
}
var results []runResult
for run := range *runs {
rng := rand.New(rand.NewSource(int64(run * 31337)))
start := time.Now()
sols := solver.Solve(n, trip.RoomSize, trip.PreferNotMultiple, trip.NoPreferCost, constraints, params, rng)
elapsed := time.Since(start)
if len(sols) > 0 {
var assignments [][]int
for _, s := range sols {
assignments = append(assignments, s.Assignment)
}
results = append(results, runResult{sols[0].Score, assignments, elapsed})
}
}
label := fmt.Sprintf("hillclimb random=%d perturb=%d pmin=%d pmax=%d", nr, np, *perturbMin, *perturbMax)
printStats(label, results, *runs)
}
}
}
if *algo == "fast" || *algo == "both" || *algo == "all" {
randomCounts := parseIntList(*numRandom)
perturbCounts := parseIntList(*numPerturb)
for _, nr := range randomCounts {
for _, np := range perturbCounts {
params := solver.Params{
NumRandom: nr,
NumPerturb: np,
PerturbMin: *perturbMin,
PerturbMax: *perturbMax,
}
var results []runResult
for run := range *runs {
rng := rand.New(rand.NewSource(int64(run * 31337)))
start := time.Now()
sols := solver.SolveFast(n, trip.RoomSize, trip.PreferNotMultiple, trip.NoPreferCost, constraints, params, rng)
elapsed := time.Since(start)
if len(sols) > 0 {
var assignments [][]int
for _, s := range sols {
assignments = append(assignments, s.Assignment)
}
results = append(results, runResult{sols[0].Score, assignments, elapsed})
}
}
label := fmt.Sprintf("fast random=%d perturb=%d pmin=%d pmax=%d", nr, np, *perturbMin, *perturbMax)
printStats(label, results, *runs)
}
}
}
if *algo == "sa" || *algo == "all" {
restartCounts := parseIntList(*saRestarts)
stepCounts := parseIntList(*saSteps)
for _, nr := range restartCounts {
for _, ns := range stepCounts {
params := solver.SAParams{
Restarts: nr,
Steps: ns,
TempHigh: *saTempHigh,
TempLow: *saTempLow,
}
var results []runResult
for run := range *runs {
rng := rand.New(rand.NewSource(int64(run * 31337)))
start := time.Now()
sols := solver.SolveSA(n, trip.RoomSize, trip.PreferNotMultiple, trip.NoPreferCost, constraints, params, rng)
elapsed := time.Since(start)
if len(sols) > 0 {
var assignments [][]int
for _, s := range sols {
assignments = append(assignments, s.Assignment)
}
results = append(results, runResult{sols[0].Score, assignments, elapsed})
}
}
label := fmt.Sprintf("sa restarts=%d steps=%d thigh=%.1f tlow=%.3f", nr, ns, *saTempHigh, *saTempLow)
printStats(label, results, *runs)
}
}
}
if *algo == "hybrid" || *algo == "both" || *algo == "all" {
restartCounts := parseIntList(*saRestarts)
stepCounts := parseIntList(*hybridSteps)
for _, nr := range restartCounts {
for _, ns := range stepCounts {
params := solver.HybridParams{
SARestarts: nr,
SASteps: ns,
TempHigh: *hybridTempHigh,
TempLow: *hybridTempLow,
}
var results []runResult
for run := range *runs {
rng := rand.New(rand.NewSource(int64(run * 31337)))
start := time.Now()
sols := solver.SolveHybrid(n, trip.RoomSize, trip.PreferNotMultiple, trip.NoPreferCost, constraints, params, rng)
elapsed := time.Since(start)
if len(sols) > 0 {
var assignments [][]int
for _, s := range sols {
assignments = append(assignments, s.Assignment)
}
results = append(results, runResult{sols[0].Score, assignments, elapsed})
}
}
label := fmt.Sprintf("hybrid restarts=%d steps=%d thigh=%.1f tlow=%.3f", nr, ns, *hybridTempHigh, *hybridTempLow)
printStats(label, results, *runs)
}
}
}
}
func parseIntList(s string) []int {
parts := strings.Split(s, ",")
var result []int
for _, p := range parts {
v, err := strconv.Atoi(strings.TrimSpace(p))
if err == nil {
result = append(result, v)
}
}
return result
}

362
main.go
View File

@@ -21,6 +21,8 @@ import (
"github.com/lib/pq"
"google.golang.org/api/idtoken"
"rooms/solver"
)
//go:embed schema.sql
@@ -1167,13 +1169,13 @@ func handleSolve(db *sql.DB) http.HandlerFunc {
}
defer crows.Close()
type constraint struct {
aID, bID int64
kind, level string
type dbConstraint struct {
aID, bID int64
kind, level string
}
var allConstraints []constraint
var allConstraints []dbConstraint
for crows.Next() {
var c constraint
var c dbConstraint
if err := crows.Scan(&c.aID, &c.bID, &c.kind, &c.level); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
@@ -1182,7 +1184,6 @@ func handleSolve(db *sql.DB) http.HandlerFunc {
}
type pairKey struct{ a, b int64 }
overalls := map[pairKey]string{}
byPair := map[pairKey]map[string]string{}
for _, c := range allConstraints {
pk := pairKey{c.aID, c.bID}
@@ -1192,6 +1193,7 @@ func handleSolve(db *sql.DB) http.HandlerFunc {
byPair[pk][c.level] = c.kind
}
levelPriority := []string{"admin", "parent", "student"}
overalls := map[pairKey]string{}
for pk, levels := range byPair {
for _, lev := range levelPriority {
if kind, ok := levels[lev]; ok {
@@ -1207,349 +1209,25 @@ func handleSolve(db *sql.DB) http.HandlerFunc {
}
n := len(studentIDs)
mustTogether := map[[2]int]bool{}
mustApart := map[[2]int]bool{}
var constraints []solver.Constraint
for pk, kind := range overalls {
ai, bi := idx[pk.a], idx[pk.b]
switch kind {
case "must":
p := [2]int{ai, bi}
if p[0] > p[1] { p[0], p[1] = p[1], p[0] }
mustTogether[p] = true
case "must_not":
p := [2]int{ai, bi}
if p[0] > p[1] { p[0], p[1] = p[1], p[0] }
mustApart[p] = true
}
constraints = append(constraints, solver.Constraint{
StudentA: idx[pk.a],
StudentB: idx[pk.b],
Kind: kind,
})
}
uf := make([]int, n)
for i := range uf { uf[i] = i }
var ufFind func(int) int
ufFind = func(x int) int {
if uf[x] != x { uf[x] = ufFind(uf[x]) }
return uf[x]
}
ufUnion := func(a, b int) {
ra, rb := ufFind(a), ufFind(b)
if ra != rb { uf[ra] = rb }
}
rng := rand.New(rand.NewSource(rand.Int63()))
solutions := solver.SolveFast(n, roomSize, pnMultiple, npCost, constraints, solver.DefaultParams, rng)
for p := range mustTogether {
ufUnion(p[0], p[1])
}
hasConflict := false
for p := range mustApart {
if ufFind(p[0]) == ufFind(p[1]) {
hasConflict = true
break
}
}
if hasConflict {
if solutions == nil {
http.Error(w, "hard conflicts exist, resolve before solving", http.StatusBadRequest)
return
}
groups := map[int][]int{}
for i := range n {
root := ufFind(i)
groups[root] = append(groups[root], i)
}
hasPrefer := make([]bool, n)
for pk, kind := range overalls {
if kind == "prefer" {
hasPrefer[idx[pk.a]] = true
}
}
score := func(assignment []int) int {
s := 0
gotPrefer := make([]bool, n)
for pk, kind := range overalls {
ai, bi := idx[pk.a], idx[pk.b]
sameRoom := assignment[ai] == assignment[bi]
switch kind {
case "prefer":
if sameRoom {
s++
gotPrefer[ai] = true
}
case "prefer_not":
if sameRoom { s -= pnMultiple }
}
}
for i := range n {
if hasPrefer[i] && !gotPrefer[i] {
s -= npCost
}
}
return s
}
feasible := func(assignment []int) bool {
for p := range mustApart {
if assignment[p[0]] == assignment[p[1]] { return false }
}
roomCounts := map[int]int{}
for _, room := range assignment {
roomCounts[room]++
}
for _, cnt := range roomCounts {
if cnt > roomSize { return false }
}
return true
}
numRooms := (n + roomSize - 1) / roomSize
assignment := make([]int, n)
groupList := make([][]int, 0, len(groups))
for _, members := range groups {
groupList = append(groupList, members)
}
slices.SortFunc(groupList, func(a, b []int) int { return len(b) - len(a) })
roomCap := make([]int, numRooms)
for i := range roomCap { roomCap[i] = roomSize }
placed := false
var placeGroups func(gi int) bool
placeGroups = func(gi int) bool {
if gi >= len(groupList) { return true }
grp := groupList[gi]
for room := range numRooms {
if roomCap[room] < len(grp) { continue }
ok := true
for _, member := range grp {
for p := range mustApart {
partner := -1
if p[0] == member { partner = p[1] }
if p[1] == member { partner = p[0] }
if partner >= 0 && assignment[partner] == room {
alreadyPlaced := false
for gj := range gi {
if slices.Contains(groupList[gj], partner) {
alreadyPlaced = true
break
}
}
if alreadyPlaced { ok = false; break }
}
}
if !ok { break }
}
if !ok { continue }
for _, member := range grp { assignment[member] = room }
roomCap[room] -= len(grp)
if placeGroups(gi + 1) { return true }
roomCap[room] += len(grp)
}
return false
}
placed = placeGroups(0)
if !placed {
for i := range n {
assignment[i] = i % numRooms
}
}
initialAssignment := make([]int, n)
copy(initialAssignment, assignment)
bestScore := score(assignment)
var bestSolutions [][]int
seen := map[string]bool{}
normalizeKey := func(a []int) string {
rm := map[int][]int{}
for i, room := range a {
rm[room] = append(rm[room], i)
}
var gs [][]int
for _, members := range rm {
slices.Sort(members)
gs = append(gs, members)
}
slices.SortFunc(gs, func(a, b []int) int { return a[0] - b[0] })
var buf strings.Builder
for _, g := range gs {
for i, m := range g {
if i > 0 {
buf.WriteByte(',')
}
buf.WriteString(strconv.Itoa(m))
}
buf.WriteByte(';')
}
return buf.String()
}
addSolution := func(a []int, s int) {
if s > bestScore {
bestScore = s
bestSolutions = nil
seen = map[string]bool{}
}
if s == bestScore {
key := normalizeKey(a)
if !seen[key] {
seen[key] = true
bestSolutions = append(bestSolutions, slices.Clone(a))
}
}
}
addSolution(assignment, bestScore)
roomCount := func(a []int, room int) int {
c := 0
for _, r := range a {
if r == room { c++ }
}
return c
}
uniqueGroups := make([]int, 0, len(groups))
for root := range groups {
uniqueGroups = append(uniqueGroups, root)
}
slices.Sort(uniqueGroups)
hillClimb := func(assignment []int) int {
currentScore := score(assignment)
for {
bestDelta := 0
bestMove := -1
bestTarget := -1
bestSwapJ := -1
for gi, gRoot := range uniqueGroups {
grp := groups[gRoot]
gRoom := assignment[grp[0]]
for room := range numRooms {
if room == gRoom { continue }
if roomCount(assignment, room)+len(grp) > roomSize { continue }
for _, m := range grp { assignment[m] = room }
if feasible(assignment) {
delta := score(assignment) - currentScore
if delta > bestDelta {
bestDelta = delta
bestMove = gi
bestTarget = room
bestSwapJ = -1
}
}
for _, m := range grp { assignment[m] = gRoom }
}
for gj := gi + 1; gj < len(uniqueGroups); gj++ {
grp2 := groups[uniqueGroups[gj]]
g2Room := assignment[grp2[0]]
if gRoom == g2Room { continue }
newGRoom := roomCount(assignment, gRoom) - len(grp) + len(grp2)
newG2Room := roomCount(assignment, g2Room) - len(grp2) + len(grp)
if newGRoom > roomSize || newG2Room > roomSize { continue }
for _, m := range grp { assignment[m] = g2Room }
for _, m := range grp2 { assignment[m] = gRoom }
if feasible(assignment) {
delta := score(assignment) - currentScore
if delta > bestDelta {
bestDelta = delta
bestMove = gi
bestTarget = -1
bestSwapJ = gj
}
}
for _, m := range grp { assignment[m] = gRoom }
for _, m := range grp2 { assignment[m] = g2Room }
}
}
if bestDelta <= 0 { break }
grp := groups[uniqueGroups[bestMove]]
gRoom := assignment[grp[0]]
if bestSwapJ < 0 {
for _, m := range grp { assignment[m] = bestTarget }
} else {
grp2 := groups[uniqueGroups[bestSwapJ]]
g2Room := assignment[grp2[0]]
for _, m := range grp { assignment[m] = g2Room }
for _, m := range grp2 { assignment[m] = gRoom }
}
currentScore += bestDelta
}
return currentScore
}
randomPlacement := func() bool {
perm := rand.Perm(len(groupList))
for i := range roomCap { roomCap[i] = roomSize }
for _, pi := range perm {
grp := groupList[pi]
placed := false
order := rand.Perm(numRooms)
for _, room := range order {
if roomCap[room] < len(grp) { continue }
valid := true
for _, member := range grp {
for p := range mustApart {
partner := -1
if p[0] == member { partner = p[1] }
if p[1] == member { partner = p[0] }
if partner >= 0 && assignment[partner] == room {
valid = false
break
}
}
if !valid { break }
}
if !valid { continue }
for _, member := range grp { assignment[member] = room }
roomCap[room] -= len(grp)
placed = true
break
}
if !placed { return false }
}
return true
}
perturb := func(src []int, count int) {
copy(assignment, src)
indices := rand.Perm(len(uniqueGroups))
count = min(count, len(indices))
for _, gi := range indices[:count] {
grp := groups[uniqueGroups[gi]]
oldRoom := assignment[grp[0]]
rooms := rand.Perm(numRooms)
for _, room := range rooms {
if room == oldRoom { continue }
if roomCount(assignment, room)+len(grp) > roomSize { continue }
for _, m := range grp { assignment[m] = room }
if feasible(assignment) { break }
for _, m := range grp { assignment[m] = oldRoom }
}
}
}
copy(assignment, initialAssignment)
addSolution(assignment, hillClimb(assignment))
for range 30 {
if randomPlacement() {
addSolution(assignment, hillClimb(assignment))
}
}
for range 200 {
src := bestSolutions[rand.Intn(len(bestSolutions))]
perturb(src, 2+rand.Intn(3))
addSolution(assignment, hillClimb(assignment))
}
type roomMember struct {
ID int64 `json:"id"`
Name string `json:"name"`
@@ -1559,9 +1237,9 @@ func handleSolve(db *sql.DB) http.HandlerFunc {
Score int `json:"score"`
}
var results []solutionResult
for _, sol := range bestSolutions {
for _, sol := range solutions {
roomMap := map[int][]roomMember{}
for i, room := range sol {
for i, room := range sol.Assignment {
sid := studentIDs[i]
roomMap[room] = append(roomMap[room], roomMember{ID: sid, Name: studentName[sid]})
}
@@ -1573,7 +1251,7 @@ func handleSolve(db *sql.DB) http.HandlerFunc {
}
}
slices.SortFunc(rooms, func(a, b []roomMember) int { return strings.Compare(a[0].Name, b[0].Name) })
results = append(results, solutionResult{Rooms: rooms, Score: bestScore})
results = append(results, solutionResult{Rooms: rooms, Score: sol.Score})
}
slices.SortFunc(results, func(a, b solutionResult) int {
for i := range min(len(a.Rooms), len(b.Rooms)) {

1373
solver/solver.go Normal file

File diff suppressed because it is too large Load Diff