diff --git a/main.go b/main.go index 25481dd..cd59871 100644 --- a/main.go +++ b/main.go @@ -36,7 +36,7 @@ type Instruction struct { type State struct { Error error - FunctionByteCode [][]byte + Functions [][]*Instruction FunctionIndex int64 InstructionIndex int64 @@ -81,51 +81,73 @@ var OpHandlers = map[uint32]OpHandler{ 0x00000402: (*State).JumpIfFalse, } -func NewState(functionByteCode [][]byte) *State { - return &State{ - FunctionByteCode: functionByteCode, +func NewInstruction(byteCode []byte) (*Instruction, error) { + instr := &Instruction{} + + reader := bytes.NewReader(byteCode) + + err := struc.Unpack(reader, instr) + if err != nil { + return nil, errors.Wrap(err, "Error decoding instruction") } + + return instr, nil +} + +func NewState(byteCodes [][]byte) (*State, error) { + state := &State{} + + for i, byteCode := range byteCodes { + instrs := []*Instruction{} + + for start := 0; start < len(byteCode); start += InstructionBytes { + chunk := byteCode[start : start+InstructionBytes] + instr, err := NewInstruction(chunk) + if err != nil { + return nil, errors.Wrapf(err, "At function index %d, byte offset %d", i, start) + } + instrs = append(instrs, instr) + } + + state.Functions = append(state.Functions, instrs) + } + + return state, nil } func (state *State) StackFrame() *StackFrame { - return state.Stack[len(state.Stack) - 1] + return state.Stack[len(state.Stack)-1] } -func (state *State) Function() []byte { - return state.FunctionByteCode[state.FunctionIndex] +func (state *State) Function() []*Instruction { + return state.Functions[state.FunctionIndex] } func (state *State) Execute() { state.call(0) - for len(state.Stack) > 0 && state.Error == nil { - start := state.InstructionIndex * InstructionBytes - chunk := state.Function()[start : start+InstructionBytes] - state.InstructionIndex += 1 - state.ProcessInstruction(chunk) + for len(state.Stack) > 0 { + state.ProcessInstruction() - if state.InstructionIndex >= int64(len(state.Function()) / InstructionBytes) { + if state.Error != nil { + break + } + + if state.InstructionIndex >= int64(len(state.Function())) { state.ret() } } - - if state.Error != nil { - state.InstructionIndex -= 1 - } } -func (state *State) ProcessInstruction(byteCode []byte) { - reader := bytes.NewReader(byteCode) +func (state *State) ProcessInstruction() { + fnc := state.Function() - instr := &Instruction{} - - err := struc.Unpack(reader, instr) - if err != nil { - state.Error = errors.Wrap(err, "Error decoding instruction") + if state.InstructionIndex >= int64(len(fnc)) { + state.ret() return } - fmt.Printf("%+v\n", instr) + instr := fnc[state.InstructionIndex] handler, found := OpHandlers[instr.OpCode] if !found { @@ -133,6 +155,8 @@ func (state *State) ProcessInstruction(byteCode []byte) { return } + state.InstructionIndex += 1 + handler(state, instr) } @@ -194,18 +218,16 @@ func (state *State) WriteSigned(op *Operand, value int64) { } func (state *State) NoOp(instr *Instruction) { - fmt.Printf("NoOp\n") } func (state *State) Call(instr *Instruction) { - fmt.Printf("Call\n") in := state.ReadSigned(&instr.Operand1) state.call(in) } func (state *State) call(functionOffset int64) { - if state.FunctionIndex + functionOffset >= int64(len(state.FunctionByteCode)) { - state.Error = fmt.Errorf("Invalid function call index: %d + %d = %d", state.FunctionIndex, functionOffset, state.FunctionIndex + functionOffset) + if state.FunctionIndex+functionOffset >= int64(len(state.Functions)) { + state.Error = fmt.Errorf("Invalid function call index: %d + %d = %d", state.FunctionIndex, functionOffset, state.FunctionIndex+functionOffset) return } @@ -219,7 +241,6 @@ func (state *State) call(functionOffset int64) { } func (state *State) Return(instr *Instruction) { - fmt.Printf("Return\n") state.ret() } @@ -230,124 +251,106 @@ func (state *State) ret() { } func (state *State) Move(instr *Instruction) { - fmt.Printf("Move\n") in := state.ReadUnsigned(&instr.Operand2) state.WriteUnsigned(&instr.Operand1, in) } func (state *State) Add(instr *Instruction) { - fmt.Printf("Add\n") in1 := state.ReadUnsigned(&instr.Operand1) in2 := state.ReadUnsigned(&instr.Operand2) state.WriteUnsigned(&instr.Operand1, in1+in2) } func (state *State) Subtract(instr *Instruction) { - fmt.Printf("Subtract\n") in1 := state.ReadUnsigned(&instr.Operand1) in2 := state.ReadUnsigned(&instr.Operand2) state.WriteUnsigned(&instr.Operand1, in1-in2) } func (state *State) Multiply(instr *Instruction) { - fmt.Printf("Multiply\n") in1 := state.ReadUnsigned(&instr.Operand1) in2 := state.ReadUnsigned(&instr.Operand2) state.WriteUnsigned(&instr.Operand1, in1*in2) } func (state *State) DivideUnsigned(instr *Instruction) { - fmt.Printf("DivideUnsigned\n") in1 := state.ReadUnsigned(&instr.Operand1) in2 := state.ReadUnsigned(&instr.Operand2) state.WriteUnsigned(&instr.Operand1, in1/in2) } func (state *State) DivideSigned(instr *Instruction) { - fmt.Printf("DivideSigned\n") in1 := state.ReadSigned(&instr.Operand1) in2 := state.ReadSigned(&instr.Operand2) state.WriteSigned(&instr.Operand1, in1/in2) } func (state *State) IsEqual(instr *Instruction) { - fmt.Printf("IsEqual\n") in1 := state.ReadUnsigned(&instr.Operand1) in2 := state.ReadUnsigned(&instr.Operand2) state.ComparisonResult = (in1 == in2) } func (state *State) IsLessThanUnsigned(instr *Instruction) { - fmt.Printf("IsLessThanUnsigned\n") in1 := state.ReadUnsigned(&instr.Operand1) in2 := state.ReadUnsigned(&instr.Operand2) state.ComparisonResult = (in1 < in2) } func (state *State) IsLessThanSigned(instr *Instruction) { - fmt.Printf("IsLessThanSigned\n") in1 := state.ReadSigned(&instr.Operand1) in2 := state.ReadSigned(&instr.Operand2) state.ComparisonResult = (in1 < in2) } func (state *State) IsGreaterThanUnsigned(instr *Instruction) { - fmt.Printf("IsGreaterThanUnsigned\n") in1 := state.ReadUnsigned(&instr.Operand1) in2 := state.ReadUnsigned(&instr.Operand2) state.ComparisonResult = (in1 > in2) } func (state *State) IsGreaterThanSigned(instr *Instruction) { - fmt.Printf("IsGreaterThanSigned\n") in1 := state.ReadSigned(&instr.Operand1) in2 := state.ReadSigned(&instr.Operand2) state.ComparisonResult = (in1 > in2) } func (state *State) IsLessThanOrEqualUnsigned(instr *Instruction) { - fmt.Printf("IsLessThanOrEqualUnsigned\n") in1 := state.ReadUnsigned(&instr.Operand1) in2 := state.ReadUnsigned(&instr.Operand2) state.ComparisonResult = (in1 <= in2) } func (state *State) IsLessThanOrEqualSigned(instr *Instruction) { - fmt.Printf("IsLessThanOrEqualSigned\n") in1 := state.ReadSigned(&instr.Operand1) in2 := state.ReadSigned(&instr.Operand2) state.ComparisonResult = (in1 <= in2) } func (state *State) IsGreaterThanOrEqualUnsigned(instr *Instruction) { - fmt.Printf("IsGreaterThanOrEqualUnsigned\n") in1 := state.ReadUnsigned(&instr.Operand1) in2 := state.ReadUnsigned(&instr.Operand2) state.ComparisonResult = (in1 >= in2) } func (state *State) IsGreaterThanOrEqualSigned(instr *Instruction) { - fmt.Printf("IsGreaterThanOrEqualSigned\n") in1 := state.ReadSigned(&instr.Operand1) in2 := state.ReadSigned(&instr.Operand2) state.ComparisonResult = (in1 >= in2) } func (state *State) Jump(instr *Instruction) { - fmt.Printf("Jump\n") in := state.ReadSigned(&instr.Operand1) state.InstructionIndex += in - 1 } func (state *State) JumpIfTrue(instr *Instruction) { - fmt.Printf("JumpIfTrue\n") if state.ComparisonResult == true { state.Jump(instr) } } func (state *State) JumpIfFalse(instr *Instruction) { - fmt.Printf("JumpIfFalse\n") if state.ComparisonResult == false { state.Jump(instr) } @@ -379,18 +382,20 @@ func main() { functionByteCode = append(functionByteCode, byteCode) } - state := NewState(functionByteCode) + state, err := NewState(functionByteCode) + if err != nil { + panic(err) + } + state.Execute() if state.Error != nil { fmt.Printf("ERROR: %s\n", state.Error) - fmt.Printf("\tat function index 0x%016x, instruction index 0x%016x\n", state.FunctionIndex, state.InstructionIndex) + fmt.Printf("\tNext function index: 0x%016x\n", state.FunctionIndex) + fmt.Printf("\tNext instruction index: 0x%016x\n", state.InstructionIndex) + fmt.Printf("\n") } - fmt.Printf("\n") - fmt.Printf("Comparison Result: %t\n", state.ComparisonResult) - - fmt.Printf("\n") fmt.Printf("Global memory:\n") for i, v := range state.GlobalMemory { fmt.Printf("\t0x%08x: 0x%016x %d %d\n", i, v, v, int64(v))