Multi-byte symbols
This commit is contained in:
@@ -5,7 +5,7 @@ type Code struct {
|
||||
Bits int
|
||||
}
|
||||
|
||||
var codes = [256]Code{
|
||||
var codes = []Code{
|
||||
{Value: 0b0000, Bits: 4},
|
||||
{Value: 0b0001, Bits: 4},
|
||||
{Value: 0b0010, Bits: 4},
|
||||
@@ -260,8 +260,10 @@ var codes = [256]Code{
|
||||
{Value: 0b1111111011, Bits: 10},
|
||||
{Value: 0b1111111100, Bits: 10},
|
||||
{Value: 0b1111111101, Bits: 10},
|
||||
{Value: 0b1111111110, Bits: 10},
|
||||
{Value: 0b1111111111, Bits: 10},
|
||||
{Value: 0b11111111100, Bits: 11},
|
||||
{Value: 0b11111111101, Bits: 11},
|
||||
{Value: 0b11111111110, Bits: 11},
|
||||
{Value: 0b11111111111, Bits: 11},
|
||||
}
|
||||
|
||||
func CodeForIndex(index int) Code {
|
||||
|
||||
@@ -13,8 +13,9 @@ func Encode(st *state.State, msg []byte) []byte {
|
||||
buf := &bytes.Buffer{}
|
||||
w := bitio.NewWriter(buf)
|
||||
|
||||
for _, b := range msg {
|
||||
index := st.IncrementSymbol(b)
|
||||
for i := 0; i < len(msg); {
|
||||
l, index := st.IncrementSymbol(msg[i:])
|
||||
i += l
|
||||
code := codes.CodeForIndex(index)
|
||||
lo.Must0(w.WriteBits(uint64(code.Value), uint8(code.Bits)))
|
||||
}
|
||||
|
||||
@@ -2,14 +2,15 @@ package main
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"log"
|
||||
"os"
|
||||
"slices"
|
||||
|
||||
"github.com/samber/lo"
|
||||
"github.com/securemesh/coding"
|
||||
"github.com/securemesh/coding/state"
|
||||
"github.com/securemesh/coding/seeds"
|
||||
"github.com/securemesh/coding/state"
|
||||
)
|
||||
|
||||
func main() {
|
||||
@@ -30,6 +31,8 @@ func main() {
|
||||
}
|
||||
|
||||
func optimize(st *state.State, samples [][]byte) *state.State {
|
||||
st.AddSymbol([]byte("it "))
|
||||
|
||||
for true {
|
||||
better := optimize2(st, samples)
|
||||
if better == nil {
|
||||
@@ -43,20 +46,21 @@ func optimize(st *state.State, samples [][]byte) *state.State {
|
||||
}
|
||||
|
||||
type sampleResult struct {
|
||||
symbol byte
|
||||
state *state.State
|
||||
score int
|
||||
symbol []byte
|
||||
state *state.State
|
||||
score int
|
||||
}
|
||||
|
||||
func optimize2(baseState *state.State, samples [][]byte) *state.State {
|
||||
ch := make(chan sampleResult, 100)
|
||||
symbols := baseState.Symbols()
|
||||
|
||||
for i := 0; i < 256; i++ {
|
||||
for _, symbol := range symbols {
|
||||
res := sampleResult{
|
||||
symbol: byte(i),
|
||||
symbol: symbol,
|
||||
}
|
||||
|
||||
go func () {
|
||||
go func() {
|
||||
st := baseState.Clone()
|
||||
st.IncrementSymbol(res.symbol)
|
||||
res.state = st
|
||||
@@ -67,11 +71,11 @@ func optimize2(baseState *state.State, samples [][]byte) *state.State {
|
||||
|
||||
results := []sampleResult{}
|
||||
|
||||
for i := 0; i < 256; i++ {
|
||||
for _ = range symbols {
|
||||
results = append(results, <-ch)
|
||||
}
|
||||
|
||||
slices.SortFunc(results, func(a, b sampleResult) int { return int(a.symbol) - int(b.symbol) })
|
||||
slices.SortFunc(results, func(a, b sampleResult) int { return bytes.Compare(a.symbol, b.symbol) })
|
||||
best := slices.MaxFunc(results, func(a, b sampleResult) int { return b.score - a.score })
|
||||
|
||||
if best.score == totalLength(baseState, samples) {
|
||||
|
||||
@@ -31,7 +31,7 @@ func newStateFromSeed(seed [][]byte) *state.State {
|
||||
for i := range seed {
|
||||
for _, s := range seed[i:] {
|
||||
for _, b := range s {
|
||||
st.IncrementSymbol(b)
|
||||
st.IncrementSymbol([]byte{b})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
103
state/state.go
103
state/state.go
@@ -1,34 +1,40 @@
|
||||
package state
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type node struct {
|
||||
symbol byte
|
||||
symbol []byte
|
||||
count int
|
||||
index int
|
||||
}
|
||||
|
||||
type trieNode struct {
|
||||
node *node
|
||||
children [16]*trieNode
|
||||
}
|
||||
|
||||
type State struct {
|
||||
nodes []*node
|
||||
bySymbol map[byte]*node
|
||||
nodes []*node
|
||||
root *trieNode
|
||||
}
|
||||
|
||||
func NewState() *State {
|
||||
st := &State{
|
||||
bySymbol: map[byte]*node{},
|
||||
root: &trieNode{},
|
||||
}
|
||||
|
||||
for i := 0; i < 256; i++ {
|
||||
node := &node{
|
||||
symbol: byte(i),
|
||||
symbol: []byte{byte(i)},
|
||||
index: i,
|
||||
}
|
||||
|
||||
st.nodes = append(st.nodes, node)
|
||||
st.bySymbol[node.symbol] = node
|
||||
st.insertTrieNode(node)
|
||||
}
|
||||
|
||||
return st
|
||||
@@ -36,21 +42,31 @@ func NewState() *State {
|
||||
|
||||
func (st State) Clone() *State {
|
||||
st2 := &State{
|
||||
bySymbol: map[byte]*node{},
|
||||
root: &trieNode{},
|
||||
}
|
||||
|
||||
for _, node := range st.nodes {
|
||||
tmp := *node
|
||||
st2.nodes = append(st2.nodes, &tmp)
|
||||
st2.bySymbol[tmp.symbol] = &tmp
|
||||
node2 := *node
|
||||
st2.nodes = append(st2.nodes, &node2)
|
||||
st2.insertTrieNode(&node2)
|
||||
}
|
||||
|
||||
return st2
|
||||
}
|
||||
|
||||
// Returns old index
|
||||
func (st *State) IncrementSymbol(symbol byte) int {
|
||||
node := st.nodeFromSymbol(symbol)
|
||||
func (st *State) AddSymbol(symbol []byte) {
|
||||
node := &node{
|
||||
symbol: symbol,
|
||||
index: len(st.nodes),
|
||||
}
|
||||
|
||||
st.nodes = append(st.nodes, node)
|
||||
st.insertTrieNode(node)
|
||||
}
|
||||
|
||||
// Returns (symbol_length, old_index)
|
||||
func (st *State) IncrementSymbol(symbols []byte) (int, int) {
|
||||
node := st.nodeFromSymbols(symbols)
|
||||
node.count++
|
||||
origIndex := node.index
|
||||
|
||||
@@ -61,7 +77,7 @@ func (st *State) IncrementSymbol(symbol byte) int {
|
||||
|
||||
if prevNode.count > iterNode.count {
|
||||
break
|
||||
} else if prevNode.count == iterNode.count && prevNode.symbol < iterNode.symbol {
|
||||
} else if prevNode.count == iterNode.count && bytes.Compare(prevNode.symbol, iterNode.symbol) < 0 {
|
||||
break
|
||||
}
|
||||
|
||||
@@ -69,7 +85,7 @@ func (st *State) IncrementSymbol(symbol byte) int {
|
||||
prevNode.index, iterNode.index = iterIndex, prevIndex
|
||||
}
|
||||
|
||||
return origIndex
|
||||
return len(node.symbol), origIndex
|
||||
}
|
||||
|
||||
func (st State) String() string {
|
||||
@@ -80,12 +96,63 @@ func (st State) String() string {
|
||||
break
|
||||
}
|
||||
|
||||
strs = append(strs, fmt.Sprintf("{%#U}=%d", node.symbol, node.count))
|
||||
strs = append(strs, fmt.Sprintf("%#U=%d", node.symbol, node.count))
|
||||
}
|
||||
|
||||
return strings.Join(strs, ", ")
|
||||
}
|
||||
|
||||
func (st State) nodeFromSymbol(symbol byte) *node {
|
||||
return st.bySymbol[symbol]
|
||||
func (st State) Symbols() [][]byte {
|
||||
ret := [][]byte{}
|
||||
|
||||
for _, node := range st.nodes {
|
||||
ret = append(ret, node.symbol)
|
||||
}
|
||||
|
||||
return ret
|
||||
}
|
||||
|
||||
func (st State) insertTrieNode(node *node) {
|
||||
iter := st.root
|
||||
|
||||
for _, b := range node.symbol {
|
||||
nibbleTrieNode := iter.getOrInsertChild((b & 0xf0) >> 4)
|
||||
iter = nibbleTrieNode.getOrInsertChild(b & 0x0f)
|
||||
}
|
||||
|
||||
iter.node = node
|
||||
}
|
||||
|
||||
func (st State) nodeFromSymbols(symbols []byte) *node {
|
||||
var lastFound *node
|
||||
iter := st.root
|
||||
|
||||
for _, b := range symbols {
|
||||
nibbleTrieNode := iter.children[(b&0xf0)>>4]
|
||||
if nibbleTrieNode == nil {
|
||||
break
|
||||
}
|
||||
|
||||
iter = nibbleTrieNode.children[b&0x0f]
|
||||
if iter == nil {
|
||||
break
|
||||
}
|
||||
|
||||
if iter.node != nil {
|
||||
lastFound = iter.node
|
||||
}
|
||||
}
|
||||
|
||||
return lastFound
|
||||
}
|
||||
|
||||
func (tn *trieNode) getOrInsertChild(val byte) *trieNode {
|
||||
child := tn.children[val]
|
||||
if child != nil {
|
||||
return child
|
||||
}
|
||||
|
||||
child = &trieNode{}
|
||||
tn.children[val] = child
|
||||
return child
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user