diff --git a/coding_test.go b/coding_test.go index 8961acc..f4196fe 100644 --- a/coding_test.go +++ b/coding_test.go @@ -12,7 +12,7 @@ import ( func TestSimple(t *testing.T) { msg := []byte("this is a test. this is only a test.") - encoded := coding.Encode(seeds.ChatHeap(), msg) + encoded := coding.Encode(seeds.ChatState(), msg) t.Logf("orig=%d encoded=%d", len(msg), len(encoded)) } @@ -27,7 +27,7 @@ func TestSMS(t *testing.T) { for s.Scan() { msg := s.Bytes() - e := coding.Encode(seeds.ChatHeap(), msg) + e := coding.Encode(seeds.ChatState(), msg) orig += len(msg) encoded += len(e) } diff --git a/encode.go b/encode.go index 4136a8a..99d12a6 100644 --- a/encode.go +++ b/encode.go @@ -6,15 +6,15 @@ import ( "github.com/icza/bitio" "github.com/samber/lo" "github.com/securemesh/coding/codes" - "github.com/securemesh/coding/heap" + "github.com/securemesh/coding/state" ) -func Encode(h *heap.Heap, msg []byte) []byte { +func Encode(st *state.State, msg []byte) []byte { buf := &bytes.Buffer{} w := bitio.NewWriter(buf) for _, b := range msg { - index := h.IncrementSymbol(b) + index := st.IncrementSymbol(b) code := codes.CodeForIndex(index) lo.Must0(w.WriteBits(uint64(code.Value), uint8(code.Bits))) } diff --git a/genseed/genseed.go b/genseed/genseed.go index 772ea12..2c74acd 100644 --- a/genseed/genseed.go +++ b/genseed/genseed.go @@ -7,7 +7,7 @@ import ( "github.com/samber/lo" "github.com/securemesh/coding" - "github.com/securemesh/coding/heap" + "github.com/securemesh/coding/state" "github.com/securemesh/coding/seeds" ) @@ -18,56 +18,66 @@ func main() { return len(sample) })) - def := heap.NewHeap() + def := state.NewState() log.Printf("def=%d [%s]", totalLength(def, samples), def) - chat := seeds.ChatHeap() + chat := seeds.ChatState() log.Printf("chat=%d [%s]", totalLength(chat, samples), chat) - opt := optimize(heap.NewHeap(), samples) + chatOpt := optimize(chat, samples) + if chatOpt == nil { + log.Printf("\toptimal from further additions") + } else { + log.Printf("\tnot optimal [%s]", chatOpt) + } + + opt := optimize(state.NewState(), samples) log.Printf("opt=%d [%s]", totalLength(opt, samples), opt) } -func optimize(h *heap.Heap, samples [][]byte) *heap.Heap { +func optimize(st *state.State, samples [][]byte) *state.State { + var best *state.State + for true { - better := optimize2(h, samples) + better := optimize2(st, samples) if better == nil { - return h + return best } - h = better - log.Printf("\titer=%d [%s]", totalLength(h, samples), h) + best = better + st = better + log.Printf("\titer=%d [%s]", totalLength(st, samples), st) } - return h + return st } type sampleResult struct { - heap *heap.Heap + state *state.State score int } -func optimize2(baseHeap *heap.Heap, samples [][]byte) *heap.Heap { +func optimize2(baseState *state.State, samples [][]byte) *state.State { ch := make(chan sampleResult, 100) for i := 0; i < 256; i++ { s := byte(i) go func () { - h := baseHeap.Clone() - h.IncrementSymbol(s) + st := baseState.Clone() + st.IncrementSymbol(s) ch <- sampleResult{ - heap: h, - score: totalLength(h, samples), + state: st, + score: totalLength(st, samples), } }() } - var best *heap.Heap = nil - bestScore := totalLength(baseHeap, samples) + var best *state.State = nil + bestScore := totalLength(baseState, samples) for i := 0; i < 256; i++ { res := <-ch if res.score < bestScore { - best = res.heap + best = res.state bestScore = res.score } } @@ -75,9 +85,9 @@ func optimize2(baseHeap *heap.Heap, samples [][]byte) *heap.Heap { return best } -func totalLength(heap *heap.Heap, samples [][]byte) int { +func totalLength(st *state.State, samples [][]byte) int { return lo.SumBy(samples, func(sample []byte) int { - return len(coding.Encode(heap.Clone(), sample)) + return len(coding.Encode(st.Clone(), sample)) }) } diff --git a/heap/heap.go b/heap/heap.go deleted file mode 100644 index 79deb3d..0000000 --- a/heap/heap.go +++ /dev/null @@ -1,86 +0,0 @@ -package heap - -import ( - "fmt" - "maps" - "slices" - "strings" -) - -type node struct { - symbol byte - count int -} - -type Heap struct { - nodes [256]node - bySymbol map[byte]int -} - -func NewHeap() *Heap { - h := &Heap{ - bySymbol: map[byte]int{}, - } - - for i := 0; i < 256; i++ { - h.nodes[i].symbol = byte(i) - h.bySymbol[byte(i)] = i - } - - return h -} - -func (h Heap) Clone() *Heap { - return &Heap{ - nodes: h.nodes, - bySymbol: maps.Clone(h.bySymbol), - } -} - -func (h *Heap) IncrementSymbol(symbol byte) int { - nodeIndex := h.bySymbol[symbol] - h.nodes[nodeIndex].count++ - - iterIndex := nodeIndex - for iterIndex != 0 { - parentIndex := h.parentIndex(iterIndex) - - if h.nodes[iterIndex].count < h.nodes[parentIndex].count || (h.nodes[iterIndex].count == h.nodes[parentIndex].count && h.nodes[iterIndex].symbol > h.nodes[parentIndex].symbol) { - break - } - - h.nodes[iterIndex], h.nodes[parentIndex] = h.nodes[parentIndex], h.nodes[iterIndex] - h.bySymbol[h.nodes[iterIndex].symbol] = iterIndex - h.bySymbol[h.nodes[parentIndex].symbol] = parentIndex - iterIndex = parentIndex - } - - return nodeIndex -} - -func (h Heap) String() string { - nodes := []node{} - - for _, node := range h.nodes { - if node.count == 0 { - continue - } - - nodes = append(nodes, node) - } - - slices.SortStableFunc(nodes, func(a, b node) int { return int(a.symbol) - int(b.symbol) }) - slices.SortStableFunc(nodes, func(a, b node) int { return a.count - b.count }) - - strs := []string{} - - for _, node := range nodes { - strs = append(strs, fmt.Sprintf("{%#U}=%d", node.symbol, node.count)) - } - - return strings.Join(strs, ", ") -} - -func (h Heap) parentIndex(nodeIndex int) int { - return (nodeIndex - 1) / 2 -} diff --git a/seeds/seeds.go b/seeds/seeds.go index 01c1b84..e5dd082 100644 --- a/seeds/seeds.go +++ b/seeds/seeds.go @@ -1,41 +1,39 @@ package seeds import ( - "github.com/securemesh/coding/heap" + "github.com/securemesh/coding/state" ) -var chatHeap = newHeapFromSeed([][]byte{ - /* 01 */ []byte("\x07'(,-8?ACDFHJLMNPRSTUWYbcfgjkpxzê"), - /* 02 */ []byte("\n.dvw"), - /* 03 */ []byte("Ihlmor"), - /* 04 */ []byte("nu"), - /* 05 */ []byte("ey"), - /* 06 */ []byte("i"), - /* 07 */ []byte("s"), - /* 08 */ []byte(""), +var chatState = newStateFromSeed([][]byte{ + /* 01 */ []byte("',.0:?CIbgjkpvxz\xea"), + /* 02 */ []byte("\nfw"), + /* 03 */ []byte("cdmuy"), + /* 04 */ []byte("l"), + /* 05 */ []byte("r"), + /* 06 */ []byte("t"), + /* 07 */ []byte("ahos"), + /* 08 */ []byte("in"), /* 09 */ []byte(""), /* 10 */ []byte(""), - /* 11 */ []byte("at"), + /* 11 */ []byte(" "), /* 12 */ []byte(""), - /* 13 */ []byte(""), - /* 14 */ []byte(""), - /* 15 */ []byte(" "), + /* 13 */ []byte("e"), }) -func ChatHeap() *heap.Heap { - return chatHeap.Clone() +func ChatState() *state.State { + return chatState.Clone() } -func newHeapFromSeed(seed [][]byte) *heap.Heap { - h := heap.NewHeap() +func newStateFromSeed(seed [][]byte) *state.State { + st := state.NewState() for i := range seed { for _, s := range seed[i:] { for _, b := range s { - h.IncrementSymbol(b) + st.IncrementSymbol(b) } } } - return h + return st } diff --git a/state/state.go b/state/state.go new file mode 100644 index 0000000..4e92eb6 --- /dev/null +++ b/state/state.go @@ -0,0 +1,86 @@ +package state + +import ( + "fmt" + "maps" + "strings" +) + +type node struct { + symbol byte + count int +} + +type State struct { + nodes []*node + bySymbol map[byte]int +} + +func NewState() *State { + st := &State{ + bySymbol: map[byte]int{}, + } + + for i := 0; i < 256; i++ { + st.nodes = append(st.nodes, &node{ + symbol: byte(i), + }) + + st.bySymbol[byte(i)] = i + } + + return st +} + +func (st State) Clone() *State { + st2 := &State{ + bySymbol: maps.Clone(st.bySymbol), + } + + for _, node := range st.nodes { + tmp := *node + st2.nodes = append(st2.nodes, &tmp) + } + + return st2 +} + +// Returns old index +func (st *State) IncrementSymbol(symbol byte) int { + nodeIndex := st.bySymbol[symbol] + st.nodes[nodeIndex].count++ + + for iterIndex := nodeIndex; iterIndex > 0; iterIndex-- { + prevIndex := iterIndex - 1 + iterNode := st.nodes[iterIndex] + prevNode := st.nodes[prevIndex] + + if prevNode.count > iterNode.count { + break + } else if prevNode.count == iterNode.count && prevNode.symbol < iterNode.symbol { + break + } + + st.nodes[iterIndex] = prevNode + st.bySymbol[prevNode.symbol] = iterIndex + + st.nodes[prevIndex] = iterNode + st.bySymbol[iterNode.symbol] = prevIndex + } + + return nodeIndex +} + +func (st State) String() string { + strs := []string{} + + for _, node := range st.nodes { + if node.count == 0 { + break + } + + strs = append(strs, fmt.Sprintf("{%#U}=%d", node.symbol, node.count)) + } + + return strings.Join(strs, ", ") +}