From eda5e0cf7389bd1cb6862a4fd8f6a57b3fc2d014 Mon Sep 17 00:00:00 2001 From: Ian Gulliver Date: Sat, 5 Jul 2025 21:34:49 -0700 Subject: [PATCH] Add error handling to completion callbacks --- mcp.go | 10 ++++++++-- taskcp.go | 41 +++++++++++++++++++++++++---------------- taskcp_test.go | 44 +++++++++++++++++++++++++++++++++++++------- 3 files changed, 70 insertions(+), 25 deletions(-) diff --git a/mcp.go b/mcp.go index f038a81..32096b2 100644 --- a/mcp.go +++ b/mcp.go @@ -115,7 +115,10 @@ func handleSetTaskSuccess(s *Service, ctx context.Context, args setTaskSuccessAr return nil, fmt.Errorf("failed to get project: %w", err) } - nextTask := project.SetTaskSuccess(args.TaskID, args.Result, args.Notes) + nextTask, err := project.SetTaskSuccess(args.TaskID, args.Result, args.Notes) + if err != nil { + return nil, fmt.Errorf("completion callback error: %w", err) + } response := &taskResponse{ TaskID: args.TaskID, @@ -132,7 +135,10 @@ func handleSetTaskFailure(s *Service, ctx context.Context, args setTaskFailureAr return nil, fmt.Errorf("failed to get project: %w", err) } - nextTask := project.SetTaskFailure(args.TaskID, args.Error, args.Notes) + nextTask, err := project.SetTaskFailure(args.TaskID, args.Error, args.Notes) + if err != nil { + return nil, fmt.Errorf("completion callback error: %w", err) + } response := &taskResponse{ TaskID: args.TaskID, diff --git a/taskcp.go b/taskcp.go index c55ec00..ccf64df 100644 --- a/taskcp.go +++ b/taskcp.go @@ -9,8 +9,8 @@ import ( ) type Service struct { - Projects map[string]*Project - mcpService string + Projects map[string]*Project + mcpService string } type Project struct { @@ -40,7 +40,7 @@ type Task struct { projectID string mcpService string nextTaskID string - completionCallback func(task *Task) + completionCallback func(task *Task) error } func New(mcpService string) *Service { @@ -70,7 +70,7 @@ func (s *Service) GetProject(id string) (*Project, error) { return project, nil } -func (p *Project) InsertTaskBefore(beforeID string, instructions string, completionCallback func(task *Task)) *Task { +func (p *Project) InsertTaskBefore(beforeID string, instructions string, completionCallback func(task *Task) error) *Task { task := p.newTask(instructions, completionCallback, beforeID) if p.nextTaskID == "" && beforeID == "" { @@ -97,29 +97,39 @@ func (p *Project) GetNextTask() *Task { return task } -func (p *Project) SetTaskSuccess(id string, result string, notes string) *Task { +func (p *Project) SetTaskSuccess(id string, result string, notes string) (*Task, error) { task := p.Tasks[id] task.State = TaskStateSuccess task.Result = result task.Notes = notes - task.completionCallback(task) + + err := task.completionCallback(task) + if err != nil { + return nil, err + } + p.nextTaskID = task.nextTaskID - return p.GetNextTask() + return p.GetNextTask(), nil } -func (p *Project) SetTaskFailure(id string, error string, notes string) *Task { +func (p *Project) SetTaskFailure(id string, error string, notes string) (*Task, error) { task := p.Tasks[id] task.State = TaskStateFailure task.Error = error task.Notes = notes - task.completionCallback(task) + + err := task.completionCallback(task) + if err != nil { + return nil, err + } + p.nextTaskID = task.nextTaskID - return p.GetNextTask() + return p.GetNextTask(), nil } -func (p *Project) newTask(instructions string, completionCallback func(task *Task), nextTaskID string) *Task { +func (p *Project) newTask(instructions string, completionCallback func(task *Task) error, nextTaskID string) *Task { task := &Task{ ID: uuid.New().String(), State: TaskStatePending, @@ -129,10 +139,10 @@ func (p *Project) newTask(instructions string, completionCallback func(task *Tas projectID: p.ID, mcpService: p.mcpService, } - + task.Instructions = strings.ReplaceAll(task.Instructions, "{SUCCESS_PROMPT}", task.SuccessPrompt()) task.Instructions = strings.ReplaceAll(task.Instructions, "{FAILURE_PROMPT}", task.FailurePrompt()) - + p.Tasks[task.ID] = task return task } @@ -150,13 +160,12 @@ func (p *Project) tasks() iter.Seq[*Task] { func (t *Task) SuccessPrompt() string { return fmt.Sprintf(`To mark this task as successful, use the MCP tool: -%s.set_task_success(project_id="%s", task_id="%s", result="", notes="")`, +%s.set_task_success(project_id="%s", task_id="%s", result="", notes="")`, t.mcpService, t.projectID, t.ID) } func (t *Task) FailurePrompt() string { return fmt.Sprintf(`To mark this task as failed, use the MCP tool: -%s.set_task_failure(project_id="%s", task_id="%s", error="", notes="")`, +%s.set_task_failure(project_id="%s", task_id="%s", error="", notes="")`, t.mcpService, t.projectID, t.ID) } - diff --git a/taskcp_test.go b/taskcp_test.go index 2814466..b42e87d 100644 --- a/taskcp_test.go +++ b/taskcp_test.go @@ -1,6 +1,7 @@ package taskcp_test import ( + "fmt" "testing" "github.com/gopatchy/taskcp" @@ -12,7 +13,7 @@ func TestTaskPrompts(t *testing.T) { service := taskcp.New("my_service") project := service.AddProject() - task := project.InsertTaskBefore("", "Write unit tests", func(task *taskcp.Task) {}) + task := project.InsertTaskBefore("", "Write unit tests", func(task *taskcp.Task) error { return nil }) successPrompt := task.SuccessPrompt() require.Contains(t, successPrompt, "my_service.set_task_success") @@ -29,11 +30,11 @@ func TestPlaceholderExpansion(t *testing.T) { service := taskcp.New("my_service") project := service.AddProject() - task1 := project.InsertTaskBefore("", "Please complete this task. {SUCCESS_PROMPT}", func(task *taskcp.Task) {}) + task1 := project.InsertTaskBefore("", "Please complete this task. {SUCCESS_PROMPT}", func(task *taskcp.Task) error { return nil }) require.Contains(t, task1.Instructions, "my_service.set_task_success") require.NotContains(t, task1.Instructions, "{SUCCESS_PROMPT}") - task2 := project.InsertTaskBefore("", "Try this risky operation. {FAILURE_PROMPT}", func(task *taskcp.Task) {}) + task2 := project.InsertTaskBefore("", "Try this risky operation. {FAILURE_PROMPT}", func(task *taskcp.Task) error { return nil }) require.Contains(t, task2.Instructions, "my_service.set_task_failure") require.NotContains(t, task2.Instructions, "{FAILURE_PROMPT}") } @@ -44,27 +45,56 @@ func TestTaskFlow(t *testing.T) { var completed []string - task1 := project.InsertTaskBefore("", "First task", func(task *taskcp.Task) { + task1 := project.InsertTaskBefore("", "First task", func(task *taskcp.Task) error { completed = append(completed, task.ID) + return nil }) - task2 := project.InsertTaskBefore("", "Second task", func(task *taskcp.Task) { + task2 := project.InsertTaskBefore("", "Second task", func(task *taskcp.Task) error { completed = append(completed, task.ID) + return nil }) current := project.GetNextTask() require.NotNil(t, current) require.Equal(t, task1.ID, current.ID) - next := project.SetTaskSuccess(current.ID, "Task 1 done", "") + next, err := project.SetTaskSuccess(current.ID, "Task 1 done", "") + require.NoError(t, err) require.NotNil(t, next) require.Equal(t, task2.ID, next.ID) require.Equal(t, taskcp.TaskStateRunning, next.State) - next2 := project.SetTaskFailure(next.ID, "Task 2 failed", "Error details") + next2, err := project.SetTaskFailure(next.ID, "Task 2 failed", "Error details") + require.NoError(t, err) require.Nil(t, next2) require.Equal(t, []string{task1.ID, task2.ID}, completed) require.Equal(t, taskcp.TaskStateSuccess, project.Tasks[task1.ID].State) require.Equal(t, taskcp.TaskStateFailure, project.Tasks[task2.ID].State) +} + +func TestCallbackError(t *testing.T) { + service := taskcp.New("test_service") + project := service.AddProject() + + expectedErr := fmt.Errorf("callback error") + + task := project.InsertTaskBefore("", "Task with error callback", func(task *taskcp.Task) error { + return expectedErr + }) + + current := project.GetNextTask() + require.NotNil(t, current) + require.Equal(t, task.ID, current.ID) + + // Test error propagation on success + _, err := project.SetTaskSuccess(current.ID, "Result", "") + require.Error(t, err) + require.Equal(t, expectedErr, err) + + // Test error propagation on failure + _, err = project.SetTaskFailure(current.ID, "Task failed", "") + require.Error(t, err) + require.Equal(t, expectedErr, err) } \ No newline at end of file