Multi-byte symbols

This commit is contained in:
Ian Gulliver
2023-12-31 16:09:35 -08:00
parent 04f1ef728f
commit 2f88ac2708
5 changed files with 107 additions and 33 deletions

View File

@@ -1,34 +1,40 @@
package state
import (
"bytes"
"fmt"
"strings"
)
type node struct {
symbol byte
symbol []byte
count int
index int
}
type trieNode struct {
node *node
children [16]*trieNode
}
type State struct {
nodes []*node
bySymbol map[byte]*node
nodes []*node
root *trieNode
}
func NewState() *State {
st := &State{
bySymbol: map[byte]*node{},
root: &trieNode{},
}
for i := 0; i < 256; i++ {
node := &node{
symbol: byte(i),
symbol: []byte{byte(i)},
index: i,
}
st.nodes = append(st.nodes, node)
st.bySymbol[node.symbol] = node
st.insertTrieNode(node)
}
return st
@@ -36,21 +42,31 @@ func NewState() *State {
func (st State) Clone() *State {
st2 := &State{
bySymbol: map[byte]*node{},
root: &trieNode{},
}
for _, node := range st.nodes {
tmp := *node
st2.nodes = append(st2.nodes, &tmp)
st2.bySymbol[tmp.symbol] = &tmp
node2 := *node
st2.nodes = append(st2.nodes, &node2)
st2.insertTrieNode(&node2)
}
return st2
}
// Returns old index
func (st *State) IncrementSymbol(symbol byte) int {
node := st.nodeFromSymbol(symbol)
func (st *State) AddSymbol(symbol []byte) {
node := &node{
symbol: symbol,
index: len(st.nodes),
}
st.nodes = append(st.nodes, node)
st.insertTrieNode(node)
}
// Returns (symbol_length, old_index)
func (st *State) IncrementSymbol(symbols []byte) (int, int) {
node := st.nodeFromSymbols(symbols)
node.count++
origIndex := node.index
@@ -61,7 +77,7 @@ func (st *State) IncrementSymbol(symbol byte) int {
if prevNode.count > iterNode.count {
break
} else if prevNode.count == iterNode.count && prevNode.symbol < iterNode.symbol {
} else if prevNode.count == iterNode.count && bytes.Compare(prevNode.symbol, iterNode.symbol) < 0 {
break
}
@@ -69,7 +85,7 @@ func (st *State) IncrementSymbol(symbol byte) int {
prevNode.index, iterNode.index = iterIndex, prevIndex
}
return origIndex
return len(node.symbol), origIndex
}
func (st State) String() string {
@@ -80,12 +96,63 @@ func (st State) String() string {
break
}
strs = append(strs, fmt.Sprintf("{%#U}=%d", node.symbol, node.count))
strs = append(strs, fmt.Sprintf("%#U=%d", node.symbol, node.count))
}
return strings.Join(strs, ", ")
}
func (st State) nodeFromSymbol(symbol byte) *node {
return st.bySymbol[symbol]
func (st State) Symbols() [][]byte {
ret := [][]byte{}
for _, node := range st.nodes {
ret = append(ret, node.symbol)
}
return ret
}
func (st State) insertTrieNode(node *node) {
iter := st.root
for _, b := range node.symbol {
nibbleTrieNode := iter.getOrInsertChild((b & 0xf0) >> 4)
iter = nibbleTrieNode.getOrInsertChild(b & 0x0f)
}
iter.node = node
}
func (st State) nodeFromSymbols(symbols []byte) *node {
var lastFound *node
iter := st.root
for _, b := range symbols {
nibbleTrieNode := iter.children[(b&0xf0)>>4]
if nibbleTrieNode == nil {
break
}
iter = nibbleTrieNode.children[b&0x0f]
if iter == nil {
break
}
if iter.node != nil {
lastFound = iter.node
}
}
return lastFound
}
func (tn *trieNode) getOrInsertChild(val byte) *trieNode {
child := tn.children[val]
if child != nil {
return child
}
child = &trieNode{}
tn.children[val] = child
return child
}