From 1c4cf25e34f8cca50564cf6946a8b5f980d49838 Mon Sep 17 00:00:00 2001 From: Ian Gulliver Date: Wed, 24 Nov 2021 19:52:22 -0800 Subject: [PATCH] Support multiple scoring criteria --- gen/opcode.go | 52 +++++++------- grow/cli/main.go | 14 +++- grow/definition.go | 173 +++++++++++++++++++++++++++++++++++---------- grow/sample.go | 12 ---- grow/status.go | 4 +- 5 files changed, 176 insertions(+), 79 deletions(-) diff --git a/gen/opcode.go b/gen/opcode.go index 95588aa..d9071e8 100644 --- a/gen/opcode.go +++ b/gen/opcode.go @@ -10,40 +10,40 @@ var opCodes = []vm.OpCodeType{ vm.OpMov, vm.OpAdd, - vm.OpSub, +// vm.OpSub, vm.OpMul, - vm.OpDivU, - vm.OpDivS, +// vm.OpDivU, +// vm.OpDivS, - vm.OpNot, - vm.OpAnd, - vm.OpOr, +// vm.OpNot, +// vm.OpAnd, +// vm.OpOr, vm.OpXor, - vm.OpShR, - vm.OpShL, +// vm.OpShR, +// vm.OpShL, - vm.OpEq, - vm.OpLTU, - vm.OpLTS, - vm.OpGTU, - vm.OpGTS, - vm.OpLTEU, - vm.OpLTES, - vm.OpGTEU, - vm.OpGTES, +// vm.OpEq, +// vm.OpLTU, +// vm.OpLTS, +// vm.OpGTU, +// vm.OpGTS, +// vm.OpLTEU, +// vm.OpLTES, +// vm.OpGTEU, +// vm.OpGTES, - vm.OpJmp, - vm.OpJmpT, - vm.OpJmpF, +// vm.OpJmp, +// vm.OpJmpT, +// vm.OpJmpF, - vm.OpCal, - vm.OpCalT, - vm.OpCalF, +// vm.OpCal, +// vm.OpCalT, +// vm.OpCalF, - vm.OpRet, - vm.OpRetT, - vm.OpRetF, +// vm.OpRet, +// vm.OpRetT, +// vm.OpRetF, vm.OpSqrt, } diff --git a/grow/cli/main.go b/grow/cli/main.go index fb66a89..97c0da3 100644 --- a/grow/cli/main.go +++ b/grow/cli/main.go @@ -1,9 +1,11 @@ package main import "flag" +import "fmt" import "log" import "math/rand" import "os" +import "strings" import "time" import "github.com/firestuff/subcoding/asm" @@ -54,6 +56,16 @@ func main() { log.Fatal(err) } - log.Printf("New best score %d / %d (after %d attempts):\n%s", status.BestScore, status.TargetScore, status.Attempts, src) + log.Printf("New best score [%s] after %d attempts:\n%s", scoreString(status.BestScores), status.Attempts, src) } } + +func scoreString(scores []*grow.Score) string { + strs := []string{} + + for _, score := range scores { + strs = append(strs, fmt.Sprintf("%d / %d", score.Current, score.Total)) + } + + return strings.Join(strs, ", ") +} diff --git a/grow/definition.go b/grow/definition.go index c35ba94..7ac8974 100644 --- a/grow/definition.go +++ b/grow/definition.go @@ -1,6 +1,7 @@ package grow import "io" +import "sort" import "gopkg.in/yaml.v2" @@ -15,6 +16,14 @@ type Definition struct { InstructionsPerFunctionStdDev uint64 `yaml:"instructions_per_function_std_dev"` Samples []*Sample `yaml:"samples"` + + // Sample indices ranked by each output dimension + SampleRanks [][]int +} + +type Score struct { + Current uint64 + Total uint64 } func NewDefinition(r io.Reader) (*Definition, error) { @@ -28,14 +37,20 @@ func NewDefinition(r io.Reader) (*Definition, error) { return nil, err } + // TODO: Test & handle non-consistent In and Out dimensions + + def.buildSampleRanks() + return def, nil } func (def *Definition) Grow(statusChan chan<- Status) (*vm.Program, error) { - status := Status{ - TargetScore: def.sumOuts(), + if statusChan != nil { + defer close(statusChan) } + status := Status{} + if statusChan != nil { statusChan <- status } @@ -54,70 +69,153 @@ func (def *Definition) Grow(statusChan chan<- Status) (*vm.Program, error) { Mutate(def, prog) - score, err := def.score(prog) + scores, err := def.score(prog) if err != nil { // Can never get best score continue } - if score > status.BestScore { - err = def.minifyProgram(prog) - if err != nil { - if statusChan != nil { - close(statusChan) - } + if !def.scoreIsBetter(status.BestScores, scores) { + continue + } - return nil, err - } + err = def.minifyProgram(prog) + if err != nil { + return nil, err + } - status.BestScore = score - status.BestProgram = prog.Copy() + status.BestScores = scores + status.BestProgram = prog.Copy() - if statusChan != nil { - statusChan <- status - } + if statusChan != nil { + statusChan <- status + } - if status.BestScore == status.TargetScore { - if statusChan != nil { - close(statusChan) - } - - return prog, nil - } + if status.BestScores[0].Current == status.BestScores[0].Total { + return prog, nil } } } -func (def *Definition) score(prog *vm.Program) (uint64, error) { - score := uint64(0) +func (def *Definition) buildSampleRanks() { + for col := 0; col < len(def.Samples[0].Out); col++ { + rank := []int{} + + for i := 0; i < len(def.Samples); i++ { + rank = append(rank, i) + } + + sort.SliceStable(rank, func(i, j int) bool { + return def.Samples[i].Out[col] < def.Samples[j].Out[col] + }) + + def.SampleRanks = append(def.SampleRanks, rank) + } +} + +func (def *Definition) score(prog *vm.Program) ([]*Score, error) { + outputs := [][]uint64{} for _, sample := range def.Samples { state, err := vm.NewState(prog) if err != nil { - return 0, err + return nil, err } sample.SetInputs(state) err = state.Execute() if err != nil { - return 0, err + return nil, err } - score += sample.matchingOuts(state) + output := []uint64{} + for i := 0; i < len(def.Samples[0].Out); i++ { + // TODO: Handle signedness? + output = append(output, state.GlobalMemory().MustReadUnsigned(uint64(i))) + } + outputs = append(outputs, output) } - return score, nil + // TODO: Points for proximity to target values? + // TODO: Points for correlation coeficient with target values across samples? + + return []*Score{ + def.scoreMatching(outputs), + def.scoreRank(outputs), + }, nil } -func (def *Definition) sumOuts() uint64 { - sum := uint64(0) +func (def *Definition) scoreMatching(outputs [][]uint64) *Score { + ret := &Score{} - for _, sample := range def.Samples { - sum += uint64(len(sample.Out)) + for s, sample := range def.Samples { + for o, out := range sample.Out { + ret.Total++ + + if outputs[s][o] == out { + ret.Current++ + } + } } - return sum + return ret +} + +func (def *Definition) scoreRank(outputs [][]uint64) *Score { + ranks := [][]int{} + + for col := 0; col < len(outputs[0]); col++ { + rank := []int{} + + for i := 0; i < len(def.Samples); i++ { + rank = append(rank, i) + } + + sort.SliceStable(rank, func(i, j int) bool { + return outputs[i][col] < outputs[j][col] + }) + + ranks = append(ranks, rank) + } + + ret := &Score{} + + for col, vals := range ranks { + for i, val := range vals { + ret.Total++ + + if val == def.SampleRanks[col][i] { + ret.Current++ + } + } + } + + return ret +} + +func (def *Definition) scoreIsBetter(old, new []*Score) bool { + if old == nil { + return true + } + + for i, score := range new { + best := old[i] + + switch { + case score.Current == best.Current: + continue + + case score.Current > best.Current: + return true + + case score.Current < best.Current: + return false + } + } + + // Unchanged + return false } func (def *Definition) minifyProgram(prog *vm.Program) error { @@ -132,7 +230,7 @@ func (def *Definition) minifyProgram(prog *vm.Program) error { } func (def *Definition) minifyFunction(prog *vm.Program, f int) error { - baseScore, err := def.score(prog) + baseScores, err := def.score(prog) if err != nil { return err } @@ -147,8 +245,9 @@ func (def *Definition) minifyFunction(prog *vm.Program, f int) error { copy(tmp, prog.Functions[f].Instructions) prog.Functions[f].Instructions = append(tmp[:i], tmp[i+1:]...) - newScore, err := def.score(prog) - if err == nil && newScore >= baseScore { + newScores, err := def.score(prog) + // XXX: Use all scores + if err == nil && newScores[0].Current >= baseScores[0].Current { loop = true break } else { diff --git a/grow/sample.go b/grow/sample.go index 1283698..e65b1e3 100644 --- a/grow/sample.go +++ b/grow/sample.go @@ -12,15 +12,3 @@ func (s *Sample) SetInputs(state *vm.State) { state.GlobalMemory().WriteUnsigned(uint64(i), val) } } - -func (s *Sample) matchingOuts(state *vm.State) uint64 { - ret := uint64(0) - - for i, val := range s.Out { - if state.GlobalMemory().MustReadUnsigned(uint64(i)) == val { - ret++ - } - } - - return ret -} diff --git a/grow/status.go b/grow/status.go index 641b847..d8bfa47 100644 --- a/grow/status.go +++ b/grow/status.go @@ -3,10 +3,8 @@ package grow import "github.com/firestuff/subcoding/vm" type Status struct { - TargetScore uint64 - Attempts uint64 - BestScore uint64 + BestScores []*Score BestProgram *vm.Program }