diff --git a/codes/codes.go b/codes/codes.go index 0c4475d..5c349c7 100644 --- a/codes/codes.go +++ b/codes/codes.go @@ -5,7 +5,7 @@ type Code struct { Bits int } -var codes = [256]Code{ +var codes = []Code{ {Value: 0b0000, Bits: 4}, {Value: 0b0001, Bits: 4}, {Value: 0b0010, Bits: 4}, @@ -260,8 +260,10 @@ var codes = [256]Code{ {Value: 0b1111111011, Bits: 10}, {Value: 0b1111111100, Bits: 10}, {Value: 0b1111111101, Bits: 10}, - {Value: 0b1111111110, Bits: 10}, - {Value: 0b1111111111, Bits: 10}, + {Value: 0b11111111100, Bits: 11}, + {Value: 0b11111111101, Bits: 11}, + {Value: 0b11111111110, Bits: 11}, + {Value: 0b11111111111, Bits: 11}, } func CodeForIndex(index int) Code { diff --git a/encode.go b/encode.go index 99d12a6..c466c5e 100644 --- a/encode.go +++ b/encode.go @@ -13,8 +13,9 @@ func Encode(st *state.State, msg []byte) []byte { buf := &bytes.Buffer{} w := bitio.NewWriter(buf) - for _, b := range msg { - index := st.IncrementSymbol(b) + for i := 0; i < len(msg); { + l, index := st.IncrementSymbol(msg[i:]) + i += l code := codes.CodeForIndex(index) lo.Must0(w.WriteBits(uint64(code.Value), uint8(code.Bits))) } diff --git a/genseed/genseed.go b/genseed/genseed.go index dce754e..7d2b56d 100644 --- a/genseed/genseed.go +++ b/genseed/genseed.go @@ -2,14 +2,15 @@ package main import ( "bufio" + "bytes" "log" "os" "slices" "github.com/samber/lo" "github.com/securemesh/coding" - "github.com/securemesh/coding/state" "github.com/securemesh/coding/seeds" + "github.com/securemesh/coding/state" ) func main() { @@ -30,6 +31,8 @@ func main() { } func optimize(st *state.State, samples [][]byte) *state.State { + st.AddSymbol([]byte("it ")) + for true { better := optimize2(st, samples) if better == nil { @@ -43,20 +46,21 @@ func optimize(st *state.State, samples [][]byte) *state.State { } type sampleResult struct { - symbol byte - state *state.State - score int + symbol []byte + state *state.State + score int } func optimize2(baseState *state.State, samples [][]byte) *state.State { ch := make(chan sampleResult, 100) + symbols := baseState.Symbols() - for i := 0; i < 256; i++ { + for _, symbol := range symbols { res := sampleResult{ - symbol: byte(i), + symbol: symbol, } - go func () { + go func() { st := baseState.Clone() st.IncrementSymbol(res.symbol) res.state = st @@ -67,11 +71,11 @@ func optimize2(baseState *state.State, samples [][]byte) *state.State { results := []sampleResult{} - for i := 0; i < 256; i++ { + for _ = range symbols { results = append(results, <-ch) } - slices.SortFunc(results, func(a, b sampleResult) int { return int(a.symbol) - int(b.symbol) }) + slices.SortFunc(results, func(a, b sampleResult) int { return bytes.Compare(a.symbol, b.symbol) }) best := slices.MaxFunc(results, func(a, b sampleResult) int { return b.score - a.score }) if best.score == totalLength(baseState, samples) { diff --git a/seeds/seeds.go b/seeds/seeds.go index 6994fae..d4008b7 100644 --- a/seeds/seeds.go +++ b/seeds/seeds.go @@ -31,7 +31,7 @@ func newStateFromSeed(seed [][]byte) *state.State { for i := range seed { for _, s := range seed[i:] { for _, b := range s { - st.IncrementSymbol(b) + st.IncrementSymbol([]byte{b}) } } } diff --git a/state/state.go b/state/state.go index 5cd3167..1d62ec7 100644 --- a/state/state.go +++ b/state/state.go @@ -1,34 +1,40 @@ package state import ( + "bytes" "fmt" "strings" ) type node struct { - symbol byte + symbol []byte count int index int } +type trieNode struct { + node *node + children [16]*trieNode +} + type State struct { - nodes []*node - bySymbol map[byte]*node + nodes []*node + root *trieNode } func NewState() *State { st := &State{ - bySymbol: map[byte]*node{}, + root: &trieNode{}, } for i := 0; i < 256; i++ { node := &node{ - symbol: byte(i), + symbol: []byte{byte(i)}, index: i, } st.nodes = append(st.nodes, node) - st.bySymbol[node.symbol] = node + st.insertTrieNode(node) } return st @@ -36,21 +42,31 @@ func NewState() *State { func (st State) Clone() *State { st2 := &State{ - bySymbol: map[byte]*node{}, + root: &trieNode{}, } for _, node := range st.nodes { - tmp := *node - st2.nodes = append(st2.nodes, &tmp) - st2.bySymbol[tmp.symbol] = &tmp + node2 := *node + st2.nodes = append(st2.nodes, &node2) + st2.insertTrieNode(&node2) } return st2 } -// Returns old index -func (st *State) IncrementSymbol(symbol byte) int { - node := st.nodeFromSymbol(symbol) +func (st *State) AddSymbol(symbol []byte) { + node := &node{ + symbol: symbol, + index: len(st.nodes), + } + + st.nodes = append(st.nodes, node) + st.insertTrieNode(node) +} + +// Returns (symbol_length, old_index) +func (st *State) IncrementSymbol(symbols []byte) (int, int) { + node := st.nodeFromSymbols(symbols) node.count++ origIndex := node.index @@ -61,7 +77,7 @@ func (st *State) IncrementSymbol(symbol byte) int { if prevNode.count > iterNode.count { break - } else if prevNode.count == iterNode.count && prevNode.symbol < iterNode.symbol { + } else if prevNode.count == iterNode.count && bytes.Compare(prevNode.symbol, iterNode.symbol) < 0 { break } @@ -69,7 +85,7 @@ func (st *State) IncrementSymbol(symbol byte) int { prevNode.index, iterNode.index = iterIndex, prevIndex } - return origIndex + return len(node.symbol), origIndex } func (st State) String() string { @@ -80,12 +96,63 @@ func (st State) String() string { break } - strs = append(strs, fmt.Sprintf("{%#U}=%d", node.symbol, node.count)) + strs = append(strs, fmt.Sprintf("%#U=%d", node.symbol, node.count)) } return strings.Join(strs, ", ") } -func (st State) nodeFromSymbol(symbol byte) *node { - return st.bySymbol[symbol] +func (st State) Symbols() [][]byte { + ret := [][]byte{} + + for _, node := range st.nodes { + ret = append(ret, node.symbol) + } + + return ret +} + +func (st State) insertTrieNode(node *node) { + iter := st.root + + for _, b := range node.symbol { + nibbleTrieNode := iter.getOrInsertChild((b & 0xf0) >> 4) + iter = nibbleTrieNode.getOrInsertChild(b & 0x0f) + } + + iter.node = node +} + +func (st State) nodeFromSymbols(symbols []byte) *node { + var lastFound *node + iter := st.root + + for _, b := range symbols { + nibbleTrieNode := iter.children[(b&0xf0)>>4] + if nibbleTrieNode == nil { + break + } + + iter = nibbleTrieNode.children[b&0x0f] + if iter == nil { + break + } + + if iter.node != nil { + lastFound = iter.node + } + } + + return lastFound +} + +func (tn *trieNode) getOrInsertChild(val byte) *trieNode { + child := tn.children[val] + if child != nil { + return child + } + + child = &trieNode{} + tn.children[val] = child + return child }