package vm import "bytes" import "encoding/hex" import "fmt" import "strings" import "github.com/lunixbochs/struc" import "github.com/pkg/errors" type OperandType uint8 type OpCodeType uint32 const ( Literal OperandType = 0 FunctionMemoryIndex = 1 GlobalMemoryIndex = 2 ) const ( OpNoOp OpCodeType = 0x00000000 OpNop = OpNoOp OpCall = 0x00000001 OpCal = OpCall OpReturn = 0x00000002 OpRet = OpReturn OpMove = 0x00000100 OpMov = OpMove OpAdd = 0x00000200 OpSubtract = 0x00000201 OpSub = OpSubtract OpMultiply = 0x00000202 OpMul = OpMultiply OpDivideUnsigned = 0x00000203 OpDivU = OpDivideUnsigned OpDivideSigned = 0x00000204 OpDivS = OpDivideSigned OpIsEqual = 0x00000300 OpEq = OpIsEqual OpIsLessThanUnsigned = 0x00000301 OpLTU = OpIsLessThanUnsigned OpIsLessThanSigned = 0x00000302 OpLTS = OpIsLessThanSigned OpIsGreaterThanUnsigned = 0x00000303 OpGTU = OpIsGreaterThanUnsigned OpIsGreaterThanSigned = 0x00000304 OpGTS = OpIsGreaterThanSigned OpIsLessThanOrEqualUnsigned = 0x00000305 OpLTEU = OpIsLessThanOrEqualUnsigned OpIsLessThanOrEqualSigned = 0x00000306 OpLTES = OpIsLessThanOrEqualSigned OpIsGreaterThanOrEqualUnsigned = 0x00000307 OpGTEU = OpIsGreaterThanOrEqualUnsigned OpIsGreaterThanOrEqualSigned = 0x00000308 OpGTES = OpIsGreaterThanOrEqualSigned OpJump = 0x00000400 OpJmp = OpJump OpJumpIfTrue = 0x00000401 OpJmpT = OpJumpIfTrue OpJumpIfFalse = 0x00000402 OpJmpF = OpJumpIfFalse ) const InstructionBytes = 32 const GlobalMemoryEntries = 16 const FunctionMemoryEntries = 16 type OpHandler func(*State, *Instruction) type Operand struct { Type OperandType Reserved [3]byte Value uint64 } type Instruction struct { OpCode OpCodeType Reserved [4]byte Operand1 Operand Operand2 Operand opHandler OpHandler `struc:"skip"` } type State struct { Running bool Error error Functions [][]*Instruction FunctionIndex int64 InstructionIndex int64 ComparisonResult bool GlobalMemory [GlobalMemoryEntries]uint64 Stack []*StackFrame } type StackFrame struct { PreviousFunctionIndex int64 PreviousInstructionIndex int64 FunctionMemory [FunctionMemoryEntries]uint64 } var OpHandlers = map[OpCodeType]OpHandler{ OpNoOp: (*State).NoOp, OpCall: (*State).Call, OpReturn: (*State).Return, OpMove: (*State).Move, OpAdd: (*State).Add, OpSubtract: (*State).Subtract, OpMultiply: (*State).Multiply, OpDivideUnsigned: (*State).DivideUnsigned, OpDivideSigned: (*State).DivideSigned, OpIsEqual: (*State).IsEqual, OpIsLessThanUnsigned: (*State).IsLessThanUnsigned, OpIsLessThanSigned: (*State).IsLessThanSigned, OpIsGreaterThanUnsigned: (*State).IsGreaterThanUnsigned, OpIsGreaterThanSigned: (*State).IsGreaterThanSigned, OpIsLessThanOrEqualUnsigned: (*State).IsLessThanOrEqualUnsigned, OpIsLessThanOrEqualSigned: (*State).IsLessThanOrEqualSigned, OpIsGreaterThanOrEqualUnsigned: (*State).IsGreaterThanOrEqualUnsigned, OpIsGreaterThanOrEqualSigned: (*State).IsGreaterThanOrEqualSigned, OpJump: (*State).Jump, OpJumpIfTrue: (*State).JumpIfTrue, OpJumpIfFalse: (*State).JumpIfFalse, } func NewInstruction(byteCode []byte) (*Instruction, error) { instr := &Instruction{} reader := bytes.NewReader(byteCode) err := struc.Unpack(reader, instr) if err != nil { return nil, errors.Wrap(err, "Error decoding instruction") } return instr, nil } func NewState(byteCodes [][]byte) (*State, error) { state := &State{} for i, byteCode := range byteCodes { instrs := []*Instruction{} for start := 0; start < len(byteCode); start += InstructionBytes { chunk := byteCode[start : start+InstructionBytes] instr, err := NewInstruction(chunk) if err != nil { return nil, errors.Wrapf(err, "At function index %d, byte offset %d", i, start) } instrs = append(instrs, instr) } instrs = append(instrs, &Instruction{ OpCode: OpReturn, }) state.Functions = append(state.Functions, instrs) } return state, nil } func (state *State) StackFrame() *StackFrame { return state.Stack[len(state.Stack)-1] } func (state *State) Function() []*Instruction { return state.Functions[state.FunctionIndex] } func (state *State) Execute() { state.setHandlers() state.call(0) state.Running = true for state.Running { state.ProcessInstruction() } } func (state *State) setError(err error) { state.Error = err state.Running = false } func (state *State) setHandlers() { for _, fnc := range state.Functions { for _, instr := range fnc { handler, found := OpHandlers[instr.OpCode] if !found { state.setError(fmt.Errorf("Invalid OpCode: 0x%08x", instr.OpCode)) return } instr.opHandler = handler } } } func (state *State) ProcessInstruction() { fnc := state.Function() instr := fnc[state.InstructionIndex] state.InstructionIndex += 1 instr.opHandler(state, instr) } func (state *State) ReadUnsigned(op *Operand) uint64 { switch op.Type { case Literal: return op.Value case FunctionMemoryIndex: if op.Value >= FunctionMemoryEntries { state.setError(fmt.Errorf("Invalid function memory index: %016x", op.Value)) return 0 } return state.StackFrame().FunctionMemory[op.Value] case GlobalMemoryIndex: if op.Value >= GlobalMemoryEntries { state.setError(fmt.Errorf("Invalid global memory index: %016x", op.Value)) return 0 } return state.GlobalMemory[op.Value] default: state.setError(fmt.Errorf("Unknown operand type: 0x%02x", op.Type)) return 0 } } func (state *State) ReadSigned(op *Operand) int64 { return int64(state.ReadUnsigned(op)) } func (state *State) WriteUnsigned(op *Operand, value uint64) { switch op.Type { case Literal: state.setError(fmt.Errorf("Write to literal operand")) case FunctionMemoryIndex: if op.Value >= FunctionMemoryEntries { state.setError(fmt.Errorf("Invalid function memory index: %016x", op.Value)) return } state.StackFrame().FunctionMemory[op.Value] = value case GlobalMemoryIndex: if op.Value >= GlobalMemoryEntries { state.setError(fmt.Errorf("Invalid global memory index: %016x", op.Value)) return } state.GlobalMemory[op.Value] = value default: state.setError(fmt.Errorf("Unknown operand type: 0x%02x", op.Type)) } } func (state *State) WriteSigned(op *Operand, value int64) { state.WriteUnsigned(op, uint64(value)) } func (state *State) NoOp(instr *Instruction) { } func (state *State) Call(instr *Instruction) { in := state.ReadSigned(&instr.Operand1) state.call(in) } func (state *State) call(functionOffset int64) { if state.FunctionIndex+functionOffset >= int64(len(state.Functions)) { state.setError(fmt.Errorf("Invalid function call index: %d + %d = %d", state.FunctionIndex, functionOffset, state.FunctionIndex+functionOffset)) return } stackFrame := &StackFrame{ PreviousFunctionIndex: state.FunctionIndex, PreviousInstructionIndex: state.InstructionIndex, } state.Stack = append(state.Stack, stackFrame) state.FunctionIndex += functionOffset state.InstructionIndex = 0 } func (state *State) Return(instr *Instruction) { state.ret() } func (state *State) ret() { state.FunctionIndex = state.StackFrame().PreviousFunctionIndex state.InstructionIndex = state.StackFrame().PreviousInstructionIndex state.Stack = state.Stack[:len(state.Stack)-1] if len(state.Stack) == 0 { state.Running = false } } func (state *State) Move(instr *Instruction) { in := state.ReadUnsigned(&instr.Operand2) state.WriteUnsigned(&instr.Operand1, in) } func (state *State) Add(instr *Instruction) { in1 := state.ReadUnsigned(&instr.Operand1) in2 := state.ReadUnsigned(&instr.Operand2) state.WriteUnsigned(&instr.Operand1, in1+in2) } func (state *State) Subtract(instr *Instruction) { in1 := state.ReadUnsigned(&instr.Operand1) in2 := state.ReadUnsigned(&instr.Operand2) state.WriteUnsigned(&instr.Operand1, in1-in2) } func (state *State) Multiply(instr *Instruction) { in1 := state.ReadUnsigned(&instr.Operand1) in2 := state.ReadUnsigned(&instr.Operand2) state.WriteUnsigned(&instr.Operand1, in1*in2) } func (state *State) DivideUnsigned(instr *Instruction) { in1 := state.ReadUnsigned(&instr.Operand1) in2 := state.ReadUnsigned(&instr.Operand2) state.WriteUnsigned(&instr.Operand1, in1/in2) } func (state *State) DivideSigned(instr *Instruction) { in1 := state.ReadSigned(&instr.Operand1) in2 := state.ReadSigned(&instr.Operand2) state.WriteSigned(&instr.Operand1, in1/in2) } func (state *State) IsEqual(instr *Instruction) { in1 := state.ReadUnsigned(&instr.Operand1) in2 := state.ReadUnsigned(&instr.Operand2) state.ComparisonResult = (in1 == in2) } func (state *State) IsLessThanUnsigned(instr *Instruction) { in1 := state.ReadUnsigned(&instr.Operand1) in2 := state.ReadUnsigned(&instr.Operand2) state.ComparisonResult = (in1 < in2) } func (state *State) IsLessThanSigned(instr *Instruction) { in1 := state.ReadSigned(&instr.Operand1) in2 := state.ReadSigned(&instr.Operand2) state.ComparisonResult = (in1 < in2) } func (state *State) IsGreaterThanUnsigned(instr *Instruction) { in1 := state.ReadUnsigned(&instr.Operand1) in2 := state.ReadUnsigned(&instr.Operand2) state.ComparisonResult = (in1 > in2) } func (state *State) IsGreaterThanSigned(instr *Instruction) { in1 := state.ReadSigned(&instr.Operand1) in2 := state.ReadSigned(&instr.Operand2) state.ComparisonResult = (in1 > in2) } func (state *State) IsLessThanOrEqualUnsigned(instr *Instruction) { in1 := state.ReadUnsigned(&instr.Operand1) in2 := state.ReadUnsigned(&instr.Operand2) state.ComparisonResult = (in1 <= in2) } func (state *State) IsLessThanOrEqualSigned(instr *Instruction) { in1 := state.ReadSigned(&instr.Operand1) in2 := state.ReadSigned(&instr.Operand2) state.ComparisonResult = (in1 <= in2) } func (state *State) IsGreaterThanOrEqualUnsigned(instr *Instruction) { in1 := state.ReadUnsigned(&instr.Operand1) in2 := state.ReadUnsigned(&instr.Operand2) state.ComparisonResult = (in1 >= in2) } func (state *State) IsGreaterThanOrEqualSigned(instr *Instruction) { in1 := state.ReadSigned(&instr.Operand1) in2 := state.ReadSigned(&instr.Operand2) state.ComparisonResult = (in1 >= in2) } func (state *State) Jump(instr *Instruction) { in := state.ReadSigned(&instr.Operand1) state.InstructionIndex += in - 1 } func (state *State) JumpIfTrue(instr *Instruction) { if state.ComparisonResult == true { state.Jump(instr) } } func (state *State) JumpIfFalse(instr *Instruction) { if state.ComparisonResult == false { state.Jump(instr) } } func main() { asm := [][]string{ []string{ "0000020000000000010000000000000000000000000000000000000000000001", "0000000100000000000000000000000000000001000000000000000000000000", "0000030100000000010000000000000000000000000000000000000000000003", "000004010000000000000000fffffffffffffffd000000000000000000000000", }, []string{ "0000020000000000020000000000000000000000000000000000000000000001", }, } functionByteCode := [][]byte{} for _, fnc := range asm { fncString := strings.Join(fnc, "") byteCode, err := hex.DecodeString(fncString) if err != nil { panic(err) } functionByteCode = append(functionByteCode, byteCode) } state, err := NewState(functionByteCode) if err != nil { panic(err) } state.Execute() if state.Error != nil { fmt.Printf("ERROR: %s\n", state.Error) fmt.Printf("\tNext function index: 0x%016x\n", state.FunctionIndex) fmt.Printf("\tNext instruction index: 0x%016x\n", state.InstructionIndex) fmt.Printf("\n") } fmt.Printf("Global memory:\n") for i, v := range state.GlobalMemory { fmt.Printf("\t0x%08x: 0x%016x %d %d\n", i, v, v, int64(v)) } }