Files
coding/state/state.go
2023-12-31 16:09:35 -08:00

159 lines
2.6 KiB
Go

package state
import (
"bytes"
"fmt"
"strings"
)
type node struct {
symbol []byte
count int
index int
}
type trieNode struct {
node *node
children [16]*trieNode
}
type State struct {
nodes []*node
root *trieNode
}
func NewState() *State {
st := &State{
root: &trieNode{},
}
for i := 0; i < 256; i++ {
node := &node{
symbol: []byte{byte(i)},
index: i,
}
st.nodes = append(st.nodes, node)
st.insertTrieNode(node)
}
return st
}
func (st State) Clone() *State {
st2 := &State{
root: &trieNode{},
}
for _, node := range st.nodes {
node2 := *node
st2.nodes = append(st2.nodes, &node2)
st2.insertTrieNode(&node2)
}
return st2
}
func (st *State) AddSymbol(symbol []byte) {
node := &node{
symbol: symbol,
index: len(st.nodes),
}
st.nodes = append(st.nodes, node)
st.insertTrieNode(node)
}
// Returns (symbol_length, old_index)
func (st *State) IncrementSymbol(symbols []byte) (int, int) {
node := st.nodeFromSymbols(symbols)
node.count++
origIndex := node.index
for iterIndex := origIndex; iterIndex > 0; iterIndex-- {
prevIndex := iterIndex - 1
iterNode := st.nodes[iterIndex]
prevNode := st.nodes[prevIndex]
if prevNode.count > iterNode.count {
break
} else if prevNode.count == iterNode.count && bytes.Compare(prevNode.symbol, iterNode.symbol) < 0 {
break
}
st.nodes[iterNode.index], st.nodes[prevNode.index] = prevNode, iterNode
prevNode.index, iterNode.index = iterIndex, prevIndex
}
return len(node.symbol), origIndex
}
func (st State) String() string {
strs := []string{}
for _, node := range st.nodes {
if node.count == 0 {
break
}
strs = append(strs, fmt.Sprintf("%#U=%d", node.symbol, node.count))
}
return strings.Join(strs, ", ")
}
func (st State) Symbols() [][]byte {
ret := [][]byte{}
for _, node := range st.nodes {
ret = append(ret, node.symbol)
}
return ret
}
func (st State) insertTrieNode(node *node) {
iter := st.root
for _, b := range node.symbol {
nibbleTrieNode := iter.getOrInsertChild((b & 0xf0) >> 4)
iter = nibbleTrieNode.getOrInsertChild(b & 0x0f)
}
iter.node = node
}
func (st State) nodeFromSymbols(symbols []byte) *node {
var lastFound *node
iter := st.root
for _, b := range symbols {
nibbleTrieNode := iter.children[(b&0xf0)>>4]
if nibbleTrieNode == nil {
break
}
iter = nibbleTrieNode.children[b&0x0f]
if iter == nil {
break
}
if iter.node != nil {
lastFound = iter.node
}
}
return lastFound
}
func (tn *trieNode) getOrInsertChild(val byte) *trieNode {
child := tn.children[val]
if child != nil {
return child
}
child = &trieNode{}
tn.children[val] = child
return child
}