diff --git a/cmd/solver-tune/main.go b/cmd/solver-tune/main.go index f679621..ac97c02 100644 --- a/cmd/solver-tune/main.go +++ b/cmd/solver-tune/main.go @@ -15,10 +15,15 @@ import ( "rooms/solver" ) +type roomGroupData struct { + Size int `json:"size"` + Count int `json:"count"` +} + type tripData struct { - RoomSize int `json:"room_size"` - PreferNotMultiple int `json:"prefer_not_multiple"` - NoPreferCost int `json:"no_prefer_cost"` + PreferNotMultiple int `json:"prefer_not_multiple"` + NoPreferCost int `json:"no_prefer_cost"` + RoomGroups []roomGroupData `json:"room_groups"` } type studentData struct { @@ -189,7 +194,18 @@ func main() { }) } - fmt.Printf("Students: %d, Room size: %d, Constraints: %d\n", n, trip.RoomSize, len(constraints)) + var roomSizes []int + for _, rg := range trip.RoomGroups { + for range rg.Count { + roomSizes = append(roomSizes, rg.Size) + } + } + if len(roomSizes) == 0 { + fmt.Fprintf(os.Stderr, "no room_groups in trip data\n") + os.Exit(1) + } + + fmt.Printf("Students: %d, Room sizes: %v, Constraints: %d\n", n, roomSizes, len(constraints)) fmt.Printf("Prefer Not multiple: %d, No Prefer cost: %d\n", trip.PreferNotMultiple, trip.NoPreferCost) fmt.Printf("Runs per config: %d\n\n", *runs) @@ -207,7 +223,7 @@ func main() { for run := range *runs { rng := rand.New(rand.NewSource(int64(run * 31337))) start := time.Now() - sols := solver.SolveFast(n, trip.RoomSize, trip.PreferNotMultiple, trip.NoPreferCost, constraints, params, rng) + sols := solver.SolveFast(n, roomSizes, trip.PreferNotMultiple, trip.NoPreferCost, constraints, params, rng) elapsed := time.Since(start) if len(sols) > 0 { var assignments [][]int diff --git a/drop.sql b/drop.sql index 07edf98..058282e 100644 --- a/drop.sql +++ b/drop.sql @@ -1,6 +1,7 @@ DROP TABLE IF EXISTS roommate_constraints; DROP TABLE IF EXISTS parents; DROP TABLE IF EXISTS students; +DROP TABLE IF EXISTS room_groups; DROP TABLE IF EXISTS trip_admins; DROP TABLE IF EXISTS trips; DROP TYPE IF EXISTS constraint_level; diff --git a/main.go b/main.go index e536008..b532459 100644 --- a/main.go +++ b/main.go @@ -87,6 +87,9 @@ func main() { http.HandleFunc("GET /api/trips/{tripID}/constraints", handleListConstraints(db)) http.HandleFunc("POST /api/trips/{tripID}/constraints", handleCreateConstraint(db)) http.HandleFunc("DELETE /api/trips/{tripID}/constraints/{constraintID}", handleDeleteConstraint(db)) + http.HandleFunc("GET /api/trips/{tripID}/room-groups", handleListRoomGroups(db)) + http.HandleFunc("POST /api/trips/{tripID}/room-groups", handleCreateRoomGroup(db)) + http.HandleFunc("DELETE /api/trips/{tripID}/room-groups/{groupID}", handleDeleteRoomGroup(db)) http.HandleFunc("POST /api/trips/{tripID}/solve", handleSolve(db)) http.HandleFunc("/healthz", func(w http.ResponseWriter, r *http.Request) { if err := db.Ping(); err != nil { @@ -295,13 +298,13 @@ func handleListTrips(db *sql.DB) http.HandlerFunc { return } rows, err := db.Query(` - SELECT t.id, t.name, t.room_size, t.prefer_not_multiple, t.no_prefer_cost, COALESCE( + SELECT t.id, t.name, t.prefer_not_multiple, t.no_prefer_cost, COALESCE( json_agg(json_build_object('id', ta.id, 'email', ta.email)) FILTER (WHERE ta.id IS NOT NULL), '[]' ) FROM trips t LEFT JOIN trip_admins ta ON ta.trip_id = t.id - GROUP BY t.id, t.name, t.room_size, t.prefer_not_multiple, t.no_prefer_cost + GROUP BY t.id, t.name, t.prefer_not_multiple, t.no_prefer_cost ORDER BY t.id`) if err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) @@ -316,7 +319,6 @@ func handleListTrips(db *sql.DB) http.HandlerFunc { type trip struct { ID int64 `json:"id"` Name string `json:"name"` - RoomSize int `json:"room_size"` PreferNotMultiple int `json:"prefer_not_multiple"` NoPreferCost int `json:"no_prefer_cost"` Admins []tripAdmin `json:"admins"` @@ -326,7 +328,7 @@ func handleListTrips(db *sql.DB) http.HandlerFunc { for rows.Next() { var t trip var adminsJSON string - if err := rows.Scan(&t.ID, &t.Name, &t.RoomSize, &t.PreferNotMultiple, &t.NoPreferCost, &adminsJSON); err != nil { + if err := rows.Scan(&t.ID, &t.Name, &t.PreferNotMultiple, &t.NoPreferCost, &adminsJSON); err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) return } @@ -470,14 +472,14 @@ func handleGetTrip(db *sql.DB) http.HandlerFunc { return } var name string - var roomSize, preferNotMultiple, noPreferCost int - err := db.QueryRow("SELECT name, room_size, prefer_not_multiple, no_prefer_cost FROM trips WHERE id = $1", tripID).Scan(&name, &roomSize, &preferNotMultiple, &noPreferCost) + var preferNotMultiple, noPreferCost int + err := db.QueryRow("SELECT name, prefer_not_multiple, no_prefer_cost FROM trips WHERE id = $1", tripID).Scan(&name, &preferNotMultiple, &noPreferCost) if err != nil { http.Error(w, "trip not found", http.StatusNotFound) return } w.Header().Set("Content-Type", "application/json") - json.NewEncoder(w).Encode(map[string]any{"id": tripID, "name": name, "room_size": roomSize, "prefer_not_multiple": preferNotMultiple, "no_prefer_cost": noPreferCost}) + json.NewEncoder(w).Encode(map[string]any{"id": tripID, "name": name, "prefer_not_multiple": preferNotMultiple, "no_prefer_cost": noPreferCost}) } } @@ -672,7 +674,6 @@ func handleUpdateTrip(db *sql.DB) http.HandlerFunc { return } var body struct { - RoomSize *int `json:"room_size"` PreferNotMultiple *int `json:"prefer_not_multiple"` NoPreferCost *int `json:"no_prefer_cost"` } @@ -680,12 +681,6 @@ func handleUpdateTrip(db *sql.DB) http.HandlerFunc { http.Error(w, "invalid request body", http.StatusBadRequest) return } - if body.RoomSize != nil { - if *body.RoomSize < 1 { - http.Error(w, "room_size must be at least 1", http.StatusBadRequest) - return - } - } if body.PreferNotMultiple != nil { if *body.PreferNotMultiple < 1 { http.Error(w, "prefer_not_multiple must be at least 1", http.StatusBadRequest) @@ -698,12 +693,6 @@ func handleUpdateTrip(db *sql.DB) http.HandlerFunc { return } } - if body.RoomSize != nil { - if _, err := db.Exec("UPDATE trips SET room_size = $1 WHERE id = $2", *body.RoomSize, tripID); err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - } if body.PreferNotMultiple != nil { if _, err := db.Exec("UPDATE trips SET prefer_not_multiple = $1 WHERE id = $2", *body.PreferNotMultiple, tripID); err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) @@ -993,16 +982,18 @@ func handleListConstraints(db *sql.DB) http.HandlerFunc { hardConflicts = append(hardConflicts, chain) } - var roomSize int - db.QueryRow("SELECT room_size FROM trips WHERE id = $1", tripID).Scan(&roomSize) + var maxRoomSize int + db.QueryRow("SELECT COALESCE(MAX(size), 0) FROM room_groups WHERE trip_id = $1", tripID).Scan(&maxRoomSize) mustGroups := map[int64][]string{} for _, id := range studentIDs { root := ufFind(id) mustGroups[root] = append(mustGroups[root], studentName[id]) } - for _, members := range mustGroups { - if len(members) > roomSize { - oversizedGroups = append(oversizedGroups, members) + if maxRoomSize > 0 { + for _, members := range mustGroups { + if len(members) > maxRoomSize { + oversizedGroups = append(oversizedGroups, members) + } } } } @@ -1119,6 +1110,93 @@ func handleDeleteConstraint(db *sql.DB) http.HandlerFunc { } } +func handleListRoomGroups(db *sql.DB) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + _, tripID, ok := requireTripAdmin(db, w, r) + if !ok { + return + } + rows, err := db.Query("SELECT id, size, count FROM room_groups WHERE trip_id = $1 ORDER BY id", tripID) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + defer rows.Close() + type roomGroup struct { + ID int64 `json:"id"` + Size int `json:"size"` + Count int `json:"count"` + } + var groups []roomGroup + for rows.Next() { + var g roomGroup + if err := rows.Scan(&g.ID, &g.Size, &g.Count); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + groups = append(groups, g) + } + if groups == nil { + groups = []roomGroup{} + } + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(groups) + } +} + +func handleCreateRoomGroup(db *sql.DB) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + _, tripID, ok := requireTripAdmin(db, w, r) + if !ok { + return + } + var body struct { + Size int `json:"size"` + Count int `json:"count"` + } + if err := json.NewDecoder(r.Body).Decode(&body); err != nil { + http.Error(w, "invalid request body", http.StatusBadRequest) + return + } + if body.Size < 1 || body.Count < 1 { + http.Error(w, "size and count must be at least 1", http.StatusBadRequest) + return + } + var id int64 + err := db.QueryRow("INSERT INTO room_groups (trip_id, size, count) VALUES ($1, $2, $3) RETURNING id", tripID, body.Size, body.Count).Scan(&id) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(map[string]any{"id": id, "size": body.Size, "count": body.Count}) + } +} + +func handleDeleteRoomGroup(db *sql.DB) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + _, tripID, ok := requireTripAdmin(db, w, r) + if !ok { + return + } + groupID, err := strconv.ParseInt(r.PathValue("groupID"), 10, 64) + if err != nil { + http.Error(w, "invalid group ID", http.StatusBadRequest) + return + } + result, err := db.Exec("DELETE FROM room_groups WHERE id = $1 AND trip_id = $2", groupID, tripID) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + if n, _ := result.RowsAffected(); n == 0 { + http.Error(w, "room group not found", http.StatusNotFound) + return + } + w.WriteHeader(http.StatusNoContent) + } +} + func handleSolve(db *sql.DB) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { _, tripID, ok := requireTripAdmin(db, w, r) @@ -1126,13 +1204,35 @@ func handleSolve(db *sql.DB) http.HandlerFunc { return } - var roomSize, pnMultiple, npCost int - err := db.QueryRow("SELECT room_size, prefer_not_multiple, no_prefer_cost FROM trips WHERE id = $1", tripID).Scan(&roomSize, &pnMultiple, &npCost) + var pnMultiple, npCost int + err := db.QueryRow("SELECT prefer_not_multiple, no_prefer_cost FROM trips WHERE id = $1", tripID).Scan(&pnMultiple, &npCost) if err != nil { http.Error(w, "trip not found", http.StatusNotFound) return } + rgRows, err := db.Query("SELECT size, count FROM room_groups WHERE trip_id = $1 ORDER BY id", tripID) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + defer rgRows.Close() + var roomSizes []int + for rgRows.Next() { + var size, count int + if err := rgRows.Scan(&size, &count); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + for range count { + roomSizes = append(roomSizes, size) + } + } + if len(roomSizes) == 0 { + http.Error(w, "no room groups configured", http.StatusBadRequest) + return + } + rows, err := db.Query("SELECT id, name FROM students WHERE trip_id = $1 ORDER BY id", tripID) if err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) @@ -1219,14 +1319,14 @@ func handleSolve(db *sql.DB) http.HandlerFunc { } rng := rand.New(rand.NewSource(42)) - solutions := solver.SolveFast(n, roomSize, pnMultiple, npCost, constraints, solver.DefaultParams, rng) + solutions := solver.SolveFast(n, roomSizes, pnMultiple, npCost, constraints, solver.DefaultParams, rng) if solutions == nil { http.Error(w, "hard conflicts exist, resolve before solving", http.StatusBadRequest) return } - numRooms := (n + roomSize - 1) / roomSize + numRooms := len(roomSizes) type roomMember struct { ID int64 `json:"id"` diff --git a/schema.sql b/schema.sql index 586009b..471792d 100644 --- a/schema.sql +++ b/schema.sql @@ -11,11 +11,19 @@ END $$; CREATE TABLE IF NOT EXISTS trips ( id BIGSERIAL PRIMARY KEY, name TEXT NOT NULL, - room_size INTEGER NOT NULL DEFAULT 2, prefer_not_multiple INTEGER NOT NULL DEFAULT 5, no_prefer_cost INTEGER NOT NULL DEFAULT 10 ); +CREATE TABLE IF NOT EXISTS room_groups ( + id BIGSERIAL PRIMARY KEY, + trip_id BIGINT NOT NULL REFERENCES trips(id) ON DELETE CASCADE, + size INTEGER NOT NULL, + count INTEGER NOT NULL, + CHECK(size >= 1), + CHECK(count >= 1) +); + CREATE TABLE IF NOT EXISTS trip_admins ( id BIGSERIAL PRIMARY KEY, trip_id BIGINT NOT NULL REFERENCES trips(id) ON DELETE CASCADE, diff --git a/solver/solver.go b/solver/solver.go index 97be1b8..b372ec3 100644 --- a/solver/solver.go +++ b/solver/solver.go @@ -57,9 +57,9 @@ func normalizeKey(a []int) string { } type solverState struct { - n int - roomSize int - numRooms int + n int + roomSizes []int + numRooms int pnMultiple int npCost int @@ -77,11 +77,11 @@ type solverState struct { mustApartFor [][]int } -func newSolverState(n, roomSize, pnMultiple, npCost int, constraints []Constraint) *solverState { +func newSolverState(n int, roomSizes []int, pnMultiple, npCost int, constraints []Constraint) *solverState { s := &solverState{ n: n, - roomSize: roomSize, - numRooms: (n + roomSize - 1) / roomSize, + roomSizes: roomSizes, + numRooms: len(roomSizes), pnMultiple: pnMultiple, npCost: npCost, constraints: constraints, @@ -235,7 +235,6 @@ func (s *solverState) feasibleForGroup(assignment []int, groupRoot int, room int func (s *solverState) fastHillClimb(assignment []int) int { n := s.n - roomSize := s.roomSize numRooms := s.numRooms roomCounts := make([]int, numRooms) @@ -386,7 +385,7 @@ func (s *solverState) fastHillClimb(assignment []int) int { if room == gRoom { continue } - if roomCounts[room]+len(grp) > roomSize { + if roomCounts[room]+len(grp) > s.roomSizes[room] { continue } if !s.feasibleForGroup(assignment, gRoot, room) { @@ -410,7 +409,7 @@ func (s *solverState) fastHillClimb(assignment []int) int { } newGRoom := roomCounts[gRoom] - len(grp) + len(grp2) newG2Room := roomCounts[g2Room] - len(grp2) + len(grp) - if newGRoom > roomSize || newG2Room > roomSize { + if newGRoom > s.roomSizes[gRoom] || newG2Room > s.roomSizes[g2Room] { continue } if !s.feasibleForGroup(assignment, gRoot, g2Room) { @@ -456,9 +455,7 @@ func (s *solverState) fastHillClimb(assignment []int) int { func (s *solverState) initialPlacement(assignment []int) bool { roomCap := make([]int, s.numRooms) - for i := range roomCap { - roomCap[i] = s.roomSize - } + copy(roomCap, s.roomSizes) var placeGroups func(gi int) bool placeGroups = func(gi int) bool { @@ -517,9 +514,7 @@ func (s *solverState) initialPlacement(assignment []int) bool { func (s *solverState) randomPlacement(assignment []int, rng *rand.Rand) bool { roomCap := make([]int, s.numRooms) - for i := range roomCap { - roomCap[i] = s.roomSize - } + copy(roomCap, s.roomSizes) perm := rng.Perm(len(s.groupList)) for _, pi := range perm { grp := s.groupList[pi] @@ -597,12 +592,12 @@ func (t *solutionTracker) add(a []int, s int) { } } -func SolveFast(n, roomSize, pnMultiple, npCost int, constraints []Constraint, params Params, rng *rand.Rand) []Solution { +func SolveFast(n int, roomSizes []int, pnMultiple, npCost int, constraints []Constraint, params Params, rng *rand.Rand) []Solution { if n == 0 { return nil } - st := newSolverState(n, roomSize, pnMultiple, npCost, constraints) + st := newSolverState(n, roomSizes, pnMultiple, npCost, constraints) if st.hasHardConflict() { return nil } @@ -637,8 +632,8 @@ func SolveFast(n, roomSize, pnMultiple, npCost int, constraints []Constraint, pa for _, room := range a { rc[room]++ } - for _, cnt := range rc { - if cnt > roomSize { + for room, cnt := range rc { + if cnt > st.roomSizes[room] { return false } } @@ -657,7 +652,7 @@ func SolveFast(n, roomSize, pnMultiple, npCost int, constraints []Constraint, pa if room == oldRoom { continue } - if roomCount(assignment, room)+len(grp) > roomSize { + if roomCount(assignment, room)+len(grp) > st.roomSizes[room] { continue } for _, m := range grp { diff --git a/static/trip.html b/static/trip.html index 318e354..45e393b 100644 --- a/static/trip.html +++ b/static/trip.html @@ -95,7 +95,15 @@