diff --git a/taskcp.go b/taskcp.go index 2948bd9..21b4b35 100644 --- a/taskcp.go +++ b/taskcp.go @@ -33,6 +33,7 @@ const ( type Task struct { ID string `json:"id"` State TaskState `json:"-"` + Title string `json:"title"` Instructions string `json:"instructions"` Data map[string]any `json:"data,omitempty"` Result string `json:"-"` @@ -41,9 +42,8 @@ type Task struct { NextTaskID string `json:"-"` - projectID string - mcpService string - completionCallback func(task *Task) error + project *Project + completionCallback func(project *Project, task *Task) error } func New(mcpService string) *Service { @@ -73,8 +73,8 @@ func (s *Service) GetProject(id string) (*Project, error) { return project, nil } -func (p *Project) InsertTaskBefore(beforeID string, instructions string, completionCallback func(task *Task) error) *Task { - task := p.newTask(instructions, completionCallback, beforeID) +func (p *Project) InsertTaskBefore(beforeID string, title string, instructions string, completionCallback func(project *Project, task *Task) error) *Task { + task := p.newTask(title, instructions, completionCallback, beforeID) if p.nextTaskID == "" && beforeID == "" { p.nextTaskID = task.ID @@ -106,9 +106,11 @@ func (p *Project) SetTaskSuccess(id string, result string, notes string) (*Task, task.Result = result task.Notes = notes - err := task.completionCallback(task) - if err != nil { - return nil, err + if task.completionCallback != nil { + err := task.completionCallback(task.project, task) + if err != nil { + return nil, err + } } p.nextTaskID = task.NextTaskID @@ -122,9 +124,11 @@ func (p *Project) SetTaskFailure(id string, error string, notes string) (*Task, task.Error = error task.Notes = notes - err := task.completionCallback(task) - if err != nil { - return nil, err + if task.completionCallback != nil { + err := task.completionCallback(task.project, task) + if err != nil { + return nil, err + } } p.nextTaskID = task.NextTaskID @@ -132,16 +136,16 @@ func (p *Project) SetTaskFailure(id string, error string, notes string) (*Task, return p.GetNextTask(), nil } -func (p *Project) newTask(instructions string, completionCallback func(task *Task) error, nextTaskID string) *Task { +func (p *Project) newTask(title string, instructions string, completionCallback func(project *Project, task *Task) error, nextTaskID string) *Task { task := &Task{ ID: uuid.New().String(), State: TaskStatePending, NextTaskID: nextTaskID, + Title: title, Instructions: instructions, Data: map[string]any{}, completionCallback: completionCallback, - projectID: p.ID, - mcpService: p.mcpService, + project: p, } task.Instructions = strings.ReplaceAll(task.Instructions, "{SUCCESS_PROMPT}", task.SuccessPrompt()) @@ -165,13 +169,13 @@ func (p *Project) tasks() iter.Seq[*Task] { func (t *Task) SuccessPrompt() string { return fmt.Sprintf(`To mark this task as successful, use the MCP tool: %s.set_task_success(project_id="%s", task_id="%s", result="", notes="")`, - t.mcpService, t.projectID, t.ID) + t.project.mcpService, t.project.ID, t.ID) } func (t *Task) FailurePrompt() string { return fmt.Sprintf(`To mark this task as failed, use the MCP tool: %s.set_task_failure(project_id="%s", task_id="%s", error="", notes="")`, - t.mcpService, t.projectID, t.ID) + t.project.mcpService, t.project.ID, t.ID) } func (t *Task) String() string { diff --git a/taskcp_test.go b/taskcp_test.go index b42e87d..d3ef8ff 100644 --- a/taskcp_test.go +++ b/taskcp_test.go @@ -8,18 +8,17 @@ import ( "github.com/stretchr/testify/require" ) - func TestTaskPrompts(t *testing.T) { service := taskcp.New("my_service") project := service.AddProject() - - task := project.InsertTaskBefore("", "Write unit tests", func(task *taskcp.Task) error { return nil }) - + + task := project.InsertTaskBefore("", "Write unit tests", "", func(project *taskcp.Project, task *taskcp.Task) error { return nil }) + successPrompt := task.SuccessPrompt() require.Contains(t, successPrompt, "my_service.set_task_success") require.Contains(t, successPrompt, `project_id="`+project.ID+`"`) require.Contains(t, successPrompt, `task_id="`+task.ID+`"`) - + failurePrompt := task.FailurePrompt() require.Contains(t, failurePrompt, "my_service.set_task_failure") require.Contains(t, failurePrompt, `project_id="`+project.ID+`"`) @@ -29,12 +28,12 @@ func TestTaskPrompts(t *testing.T) { func TestPlaceholderExpansion(t *testing.T) { service := taskcp.New("my_service") project := service.AddProject() - - task1 := project.InsertTaskBefore("", "Please complete this task. {SUCCESS_PROMPT}", func(task *taskcp.Task) error { return nil }) + + task1 := project.InsertTaskBefore("", "Please complete this task.", "{SUCCESS_PROMPT}", func(project *taskcp.Project, task *taskcp.Task) error { return nil }) require.Contains(t, task1.Instructions, "my_service.set_task_success") require.NotContains(t, task1.Instructions, "{SUCCESS_PROMPT}") - - task2 := project.InsertTaskBefore("", "Try this risky operation. {FAILURE_PROMPT}", func(task *taskcp.Task) error { return nil }) + + task2 := project.InsertTaskBefore("", "Try this risky operation.", "{FAILURE_PROMPT}", func(project *taskcp.Project, task *taskcp.Task) error { return nil }) require.Contains(t, task2.Instructions, "my_service.set_task_failure") require.NotContains(t, task2.Instructions, "{FAILURE_PROMPT}") } @@ -42,33 +41,33 @@ func TestPlaceholderExpansion(t *testing.T) { func TestTaskFlow(t *testing.T) { service := taskcp.New("test_service") project := service.AddProject() - + var completed []string - - task1 := project.InsertTaskBefore("", "First task", func(task *taskcp.Task) error { + + task1 := project.InsertTaskBefore("", "First task", "", func(project *taskcp.Project, task *taskcp.Task) error { completed = append(completed, task.ID) return nil }) - - task2 := project.InsertTaskBefore("", "Second task", func(task *taskcp.Task) error { + + task2 := project.InsertTaskBefore("", "Second task", "", func(project *taskcp.Project, task *taskcp.Task) error { completed = append(completed, task.ID) return nil }) - + current := project.GetNextTask() require.NotNil(t, current) require.Equal(t, task1.ID, current.ID) - + next, err := project.SetTaskSuccess(current.ID, "Task 1 done", "") require.NoError(t, err) require.NotNil(t, next) require.Equal(t, task2.ID, next.ID) require.Equal(t, taskcp.TaskStateRunning, next.State) - + next2, err := project.SetTaskFailure(next.ID, "Task 2 failed", "Error details") require.NoError(t, err) require.Nil(t, next2) - + require.Equal(t, []string{task1.ID, task2.ID}, completed) require.Equal(t, taskcp.TaskStateSuccess, project.Tasks[task1.ID].State) require.Equal(t, taskcp.TaskStateFailure, project.Tasks[task2.ID].State) @@ -77,24 +76,24 @@ func TestTaskFlow(t *testing.T) { func TestCallbackError(t *testing.T) { service := taskcp.New("test_service") project := service.AddProject() - + expectedErr := fmt.Errorf("callback error") - - task := project.InsertTaskBefore("", "Task with error callback", func(task *taskcp.Task) error { + + task := project.InsertTaskBefore("", "Task with error callback", "", func(project *taskcp.Project, task *taskcp.Task) error { return expectedErr }) - + current := project.GetNextTask() require.NotNil(t, current) require.Equal(t, task.ID, current.ID) - + // Test error propagation on success _, err := project.SetTaskSuccess(current.ID, "Result", "") require.Error(t, err) require.Equal(t, expectedErr, err) - + // Test error propagation on failure _, err = project.SetTaskFailure(current.ID, "Task failed", "") require.Error(t, err) require.Equal(t, expectedErr, err) -} \ No newline at end of file +}