Files

159 lines
2.6 KiB
Go
Raw Permalink Normal View History

2023-12-30 20:14:01 -07:00
package state
import (
2023-12-31 16:09:35 -08:00
"bytes"
2023-12-30 20:14:01 -07:00
"fmt"
"strings"
)
type node struct {
2023-12-31 16:09:35 -08:00
symbol []byte
2023-12-30 20:14:01 -07:00
count int
2023-12-31 14:28:28 -08:00
index int
2023-12-30 20:14:01 -07:00
}
2023-12-31 16:09:35 -08:00
type trieNode struct {
node *node
children [16]*trieNode
}
2023-12-30 20:14:01 -07:00
type State struct {
2023-12-31 16:09:35 -08:00
nodes []*node
root *trieNode
2023-12-30 20:14:01 -07:00
}
func NewState() *State {
st := &State{
2023-12-31 16:09:35 -08:00
root: &trieNode{},
2023-12-30 20:14:01 -07:00
}
for i := 0; i < 256; i++ {
2023-12-31 14:15:10 -08:00
node := &node{
2023-12-31 16:09:35 -08:00
symbol: []byte{byte(i)},
2023-12-31 14:28:28 -08:00
index: i,
2023-12-31 14:15:10 -08:00
}
st.nodes = append(st.nodes, node)
2023-12-31 16:09:35 -08:00
st.insertTrieNode(node)
2023-12-30 20:14:01 -07:00
}
return st
}
func (st State) Clone() *State {
2023-12-31 14:15:10 -08:00
st2 := &State{
2023-12-31 16:09:35 -08:00
root: &trieNode{},
2023-12-30 20:14:01 -07:00
}
2023-12-31 14:15:10 -08:00
for _, node := range st.nodes {
2023-12-31 16:09:35 -08:00
node2 := *node
st2.nodes = append(st2.nodes, &node2)
st2.insertTrieNode(&node2)
2023-12-31 14:15:10 -08:00
}
return st2
2023-12-30 20:14:01 -07:00
}
2023-12-31 16:09:35 -08:00
func (st *State) AddSymbol(symbol []byte) {
node := &node{
symbol: symbol,
index: len(st.nodes),
}
st.nodes = append(st.nodes, node)
st.insertTrieNode(node)
}
// Returns (symbol_length, old_index)
func (st *State) IncrementSymbol(symbols []byte) (int, int) {
node := st.nodeFromSymbols(symbols)
2023-12-31 14:28:28 -08:00
node.count++
origIndex := node.index
2023-12-30 20:14:01 -07:00
2023-12-31 14:28:28 -08:00
for iterIndex := origIndex; iterIndex > 0; iterIndex-- {
2023-12-30 20:14:01 -07:00
prevIndex := iterIndex - 1
iterNode := st.nodes[iterIndex]
prevNode := st.nodes[prevIndex]
if prevNode.count > iterNode.count {
break
2023-12-31 16:09:35 -08:00
} else if prevNode.count == iterNode.count && bytes.Compare(prevNode.symbol, iterNode.symbol) < 0 {
2023-12-30 20:14:01 -07:00
break
}
2023-12-31 14:28:28 -08:00
st.nodes[iterNode.index], st.nodes[prevNode.index] = prevNode, iterNode
prevNode.index, iterNode.index = iterIndex, prevIndex
2023-12-30 20:14:01 -07:00
}
2023-12-31 16:09:35 -08:00
return len(node.symbol), origIndex
2023-12-30 20:14:01 -07:00
}
func (st State) String() string {
strs := []string{}
for _, node := range st.nodes {
if node.count == 0 {
break
}
2023-12-31 16:09:35 -08:00
strs = append(strs, fmt.Sprintf("%#U=%d", node.symbol, node.count))
2023-12-30 20:14:01 -07:00
}
return strings.Join(strs, ", ")
}
2023-12-31 14:41:26 -08:00
2023-12-31 16:09:35 -08:00
func (st State) Symbols() [][]byte {
ret := [][]byte{}
for _, node := range st.nodes {
ret = append(ret, node.symbol)
}
return ret
}
func (st State) insertTrieNode(node *node) {
iter := st.root
for _, b := range node.symbol {
nibbleTrieNode := iter.getOrInsertChild((b & 0xf0) >> 4)
iter = nibbleTrieNode.getOrInsertChild(b & 0x0f)
}
iter.node = node
}
func (st State) nodeFromSymbols(symbols []byte) *node {
var lastFound *node
iter := st.root
for _, b := range symbols {
nibbleTrieNode := iter.children[(b&0xf0)>>4]
if nibbleTrieNode == nil {
break
}
iter = nibbleTrieNode.children[b&0x0f]
if iter == nil {
break
}
if iter.node != nil {
lastFound = iter.node
}
}
return lastFound
}
func (tn *trieNode) getOrInsertChild(val byte) *trieNode {
child := tn.children[val]
if child != nil {
return child
}
child = &trieNode{}
tn.children[val] = child
return child
2023-12-31 14:41:26 -08:00
}