Faster state tracking, stable optimizer
This commit is contained in:
@@ -4,6 +4,7 @@ import (
|
||||
"bufio"
|
||||
"log"
|
||||
"os"
|
||||
"slices"
|
||||
|
||||
"github.com/samber/lo"
|
||||
"github.com/securemesh/coding"
|
||||
@@ -42,6 +43,7 @@ func optimize(st *state.State, samples [][]byte) *state.State {
|
||||
}
|
||||
|
||||
type sampleResult struct {
|
||||
symbol byte
|
||||
state *state.State
|
||||
score int
|
||||
}
|
||||
@@ -50,29 +52,33 @@ func optimize2(baseState *state.State, samples [][]byte) *state.State {
|
||||
ch := make(chan sampleResult, 100)
|
||||
|
||||
for i := 0; i < 256; i++ {
|
||||
s := byte(i)
|
||||
res := sampleResult{
|
||||
symbol: byte(i),
|
||||
}
|
||||
|
||||
go func () {
|
||||
st := baseState.Clone()
|
||||
st.IncrementSymbol(s)
|
||||
ch <- sampleResult{
|
||||
state: st,
|
||||
score: totalLength(st, samples),
|
||||
}
|
||||
st.IncrementSymbol(res.symbol)
|
||||
res.state = st
|
||||
res.score = totalLength(st, samples)
|
||||
ch <- res
|
||||
}()
|
||||
}
|
||||
|
||||
var best *state.State = nil
|
||||
bestScore := totalLength(baseState, samples)
|
||||
results := []sampleResult{}
|
||||
|
||||
for i := 0; i < 256; i++ {
|
||||
res := <-ch
|
||||
if res.score < bestScore {
|
||||
best = res.state
|
||||
bestScore = res.score
|
||||
}
|
||||
results = append(results, <-ch)
|
||||
}
|
||||
|
||||
return best
|
||||
slices.SortFunc(results, func(a, b sampleResult) int { return int(a.symbol) - int(b.symbol) })
|
||||
best := slices.MaxFunc(results, func(a, b sampleResult) int { return b.score - a.score })
|
||||
|
||||
if best.score == totalLength(baseState, samples) {
|
||||
return nil
|
||||
}
|
||||
|
||||
return best.state
|
||||
}
|
||||
|
||||
func totalLength(st *state.State, samples [][]byte) int {
|
||||
|
||||
@@ -15,12 +15,10 @@ var chatState = newStateFromSeed([][]byte{
|
||||
/* 08 */ []byte("i"),
|
||||
/* 09 */ []byte(""),
|
||||
/* 10 */ []byte(" "),
|
||||
/* 11 */ []byte(""),
|
||||
/* 11 */ []byte("e"),
|
||||
/* 12 */ []byte(""),
|
||||
/* 13 */ []byte(""),
|
||||
/* 14 */ []byte("et"),
|
||||
/* 15 */ []byte(""),
|
||||
/* 16 */ []byte("o"),
|
||||
/* 13 */ []byte("t"),
|
||||
/* 14 */ []byte("o"),
|
||||
})
|
||||
|
||||
func ChatState() *state.State {
|
||||
|
||||
@@ -12,7 +12,7 @@ type node struct {
|
||||
}
|
||||
|
||||
type State struct {
|
||||
nodes []*node
|
||||
nodes [256]node
|
||||
bySymbol map[byte]int
|
||||
}
|
||||
|
||||
@@ -22,10 +22,7 @@ func NewState() *State {
|
||||
}
|
||||
|
||||
for i := 0; i < 256; i++ {
|
||||
st.nodes = append(st.nodes, &node{
|
||||
symbol: byte(i),
|
||||
})
|
||||
|
||||
st.nodes[i].symbol = byte(i)
|
||||
st.bySymbol[byte(i)] = i
|
||||
}
|
||||
|
||||
@@ -33,16 +30,10 @@ func NewState() *State {
|
||||
}
|
||||
|
||||
func (st State) Clone() *State {
|
||||
st2 := &State{
|
||||
return &State{
|
||||
nodes: st.nodes,
|
||||
bySymbol: maps.Clone(st.bySymbol),
|
||||
}
|
||||
|
||||
for _, node := range st.nodes {
|
||||
tmp := *node
|
||||
st2.nodes = append(st2.nodes, &tmp)
|
||||
}
|
||||
|
||||
return st2
|
||||
}
|
||||
|
||||
// Returns old index
|
||||
@@ -61,11 +52,8 @@ func (st *State) IncrementSymbol(symbol byte) int {
|
||||
break
|
||||
}
|
||||
|
||||
st.nodes[iterIndex] = prevNode
|
||||
st.bySymbol[prevNode.symbol] = iterIndex
|
||||
|
||||
st.nodes[prevIndex] = iterNode
|
||||
st.bySymbol[iterNode.symbol] = prevIndex
|
||||
st.nodes[iterIndex], st.nodes[prevIndex] = prevNode, iterNode
|
||||
st.bySymbol[iterNode.symbol], st.bySymbol[prevNode.symbol] = prevIndex, iterIndex
|
||||
}
|
||||
|
||||
return nodeIndex
|
||||
|
||||
Reference in New Issue
Block a user