Faster state tracking, stable optimizer

This commit is contained in:
Ian Gulliver
2023-12-30 20:37:18 -08:00
parent 58e3a477ef
commit 8c9dab282b
3 changed files with 29 additions and 37 deletions

View File

@@ -4,6 +4,7 @@ import (
"bufio" "bufio"
"log" "log"
"os" "os"
"slices"
"github.com/samber/lo" "github.com/samber/lo"
"github.com/securemesh/coding" "github.com/securemesh/coding"
@@ -42,6 +43,7 @@ func optimize(st *state.State, samples [][]byte) *state.State {
} }
type sampleResult struct { type sampleResult struct {
symbol byte
state *state.State state *state.State
score int score int
} }
@@ -50,29 +52,33 @@ func optimize2(baseState *state.State, samples [][]byte) *state.State {
ch := make(chan sampleResult, 100) ch := make(chan sampleResult, 100)
for i := 0; i < 256; i++ { for i := 0; i < 256; i++ {
s := byte(i) res := sampleResult{
symbol: byte(i),
}
go func () { go func () {
st := baseState.Clone() st := baseState.Clone()
st.IncrementSymbol(s) st.IncrementSymbol(res.symbol)
ch <- sampleResult{ res.state = st
state: st, res.score = totalLength(st, samples)
score: totalLength(st, samples), ch <- res
}
}() }()
} }
var best *state.State = nil results := []sampleResult{}
bestScore := totalLength(baseState, samples)
for i := 0; i < 256; i++ { for i := 0; i < 256; i++ {
res := <-ch results = append(results, <-ch)
if res.score < bestScore {
best = res.state
bestScore = res.score
}
} }
return best slices.SortFunc(results, func(a, b sampleResult) int { return int(a.symbol) - int(b.symbol) })
best := slices.MaxFunc(results, func(a, b sampleResult) int { return b.score - a.score })
if best.score == totalLength(baseState, samples) {
return nil
}
return best.state
} }
func totalLength(st *state.State, samples [][]byte) int { func totalLength(st *state.State, samples [][]byte) int {

View File

@@ -15,12 +15,10 @@ var chatState = newStateFromSeed([][]byte{
/* 08 */ []byte("i"), /* 08 */ []byte("i"),
/* 09 */ []byte(""), /* 09 */ []byte(""),
/* 10 */ []byte(" "), /* 10 */ []byte(" "),
/* 11 */ []byte(""), /* 11 */ []byte("e"),
/* 12 */ []byte(""), /* 12 */ []byte(""),
/* 13 */ []byte(""), /* 13 */ []byte("t"),
/* 14 */ []byte("et"), /* 14 */ []byte("o"),
/* 15 */ []byte(""),
/* 16 */ []byte("o"),
}) })
func ChatState() *state.State { func ChatState() *state.State {

View File

@@ -12,7 +12,7 @@ type node struct {
} }
type State struct { type State struct {
nodes []*node nodes [256]node
bySymbol map[byte]int bySymbol map[byte]int
} }
@@ -22,10 +22,7 @@ func NewState() *State {
} }
for i := 0; i < 256; i++ { for i := 0; i < 256; i++ {
st.nodes = append(st.nodes, &node{ st.nodes[i].symbol = byte(i)
symbol: byte(i),
})
st.bySymbol[byte(i)] = i st.bySymbol[byte(i)] = i
} }
@@ -33,16 +30,10 @@ func NewState() *State {
} }
func (st State) Clone() *State { func (st State) Clone() *State {
st2 := &State{ return &State{
nodes: st.nodes,
bySymbol: maps.Clone(st.bySymbol), bySymbol: maps.Clone(st.bySymbol),
} }
for _, node := range st.nodes {
tmp := *node
st2.nodes = append(st2.nodes, &tmp)
}
return st2
} }
// Returns old index // Returns old index
@@ -61,11 +52,8 @@ func (st *State) IncrementSymbol(symbol byte) int {
break break
} }
st.nodes[iterIndex] = prevNode st.nodes[iterIndex], st.nodes[prevIndex] = prevNode, iterNode
st.bySymbol[prevNode.symbol] = iterIndex st.bySymbol[iterNode.symbol], st.bySymbol[prevNode.symbol] = prevIndex, iterIndex
st.nodes[prevIndex] = iterNode
st.bySymbol[iterNode.symbol] = prevIndex
} }
return nodeIndex return nodeIndex