From 33d54f399c24fd13c1304bee830d0137213b7eb2 Mon Sep 17 00:00:00 2001 From: Ian Gulliver Date: Sat, 5 Jul 2025 14:24:50 -0700 Subject: [PATCH] Refactor MCP integration to use structured handlers with snake_case --- example_mcp_test.go | 41 ----------- mcp.go | 165 ++++++++++++++++++++++++-------------------- mcp_test.go | 2 +- taskcp.go | 41 +++++------ 4 files changed, 110 insertions(+), 139 deletions(-) delete mode 100644 example_mcp_test.go diff --git a/example_mcp_test.go b/example_mcp_test.go deleted file mode 100644 index 78156ec..0000000 --- a/example_mcp_test.go +++ /dev/null @@ -1,41 +0,0 @@ -package taskcp_test - -import ( - "fmt" - - "github.com/gopatchy/taskcp" - "github.com/mark3labs/mcp-go/server" -) - -func ExampleRegisterMCPTools() { - service := taskcp.New() - - project := service.AddProject() - fmt.Printf("Created project: %s\n", project.ID) - - task1 := project.InsertTaskBefore("", "Compile the code", func(task *taskcp.Task) { - fmt.Printf("Task %s completed with state: %s\n", task.ID, task.State) - }) - - task2 := project.InsertTaskBefore("", "Run tests", func(task *taskcp.Task) { - fmt.Printf("Task %s completed with state: %s\n", task.ID, task.State) - }) - - task1.NextTaskID = task2.ID - project.NextTaskID = task1.ID - - mcpServer := server.NewMCPServer( - "TaskCP Server", - "1.0.0", - server.WithToolCapabilities(true), - ) - - err := taskcp.RegisterMCPTools(mcpServer, service) - if err != nil { - fmt.Printf("Failed to register tools: %v\n", err) - return - } - - fmt.Println("MCP tools registered successfully") - -} \ No newline at end of file diff --git a/mcp.go b/mcp.go index e415a48..f038a81 100644 --- a/mcp.go +++ b/mcp.go @@ -2,22 +2,73 @@ package taskcp import ( "context" + "encoding/json" "fmt" "github.com/mark3labs/mcp-go/mcp" "github.com/mark3labs/mcp-go/server" ) -func RegisterMCPTools(s *server.MCPServer, service *Service) error { - s.AddTool( +type setTaskSuccessArgs struct { + ProjectID string `json:"project_id"` + TaskID string `json:"task_id"` + Result string `json:"result"` + Notes string `json:"notes,omitempty"` +} + +type setTaskFailureArgs struct { + ProjectID string `json:"project_id"` + TaskID string `json:"task_id"` + Error string `json:"error"` + Notes string `json:"notes,omitempty"` +} + +type taskResponse struct { + TaskID string `json:"task_id"` + Message string `json:"message"` + NextTask *Task `json:"next_task,omitempty"` +} + +type errorResponse struct { + Error string `json:"error"` +} + +type ServiceHandlerFunc[TArgs any, TResponse any] func(s *Service, ctx context.Context, args TArgs) (*TResponse, error) + +func wrapServiceHandler[TArgs any, TResponse any](s *Service, handler ServiceHandlerFunc[TArgs, TResponse]) func(context.Context, mcp.CallToolRequest) (*mcp.CallToolResult, error) { + return func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + var args TArgs + if err := request.BindArguments(&args); err != nil { + errorJSON, _ := json.Marshal(errorResponse{Error: err.Error()}) + return mcp.NewToolResultText(string(errorJSON)), nil + } + + response, err := handler(s, ctx, args) + if err != nil { + errorJSON, _ := json.Marshal(errorResponse{Error: err.Error()}) + return mcp.NewToolResultText(string(errorJSON)), nil + } + + resultJSON, err := json.MarshalIndent(response, "", " ") + if err != nil { + errorJSON, _ := json.Marshal(errorResponse{Error: err.Error()}) + return mcp.NewToolResultText(string(errorJSON)), nil + } + + return mcp.NewToolResultText(string(resultJSON)), nil + } +} + +func (s *Service) RegisterMCPTools(mcpServer *server.MCPServer) error { + mcpServer.AddTool( mcp.NewTool( - "SetTaskSuccess", + "set_task_success", mcp.WithDescription("Mark a task as successfully completed"), - mcp.WithString("projectId", + mcp.WithString("project_id", mcp.Required(), mcp.Description("The project ID"), ), - mcp.WithString("taskId", + mcp.WithString("task_id", mcp.Required(), mcp.Description("The task ID to mark as successful"), ), @@ -29,18 +80,18 @@ func RegisterMCPTools(s *server.MCPServer, service *Service) error { mcp.Description("Additional notes about the task completion"), ), ), - handleSetTaskSuccess(service), + wrapServiceHandler(s, handleSetTaskSuccess), ) - s.AddTool( + mcpServer.AddTool( mcp.NewTool( - "SetTaskFailure", + "set_task_failure", mcp.WithDescription("Mark a task as failed"), - mcp.WithString("projectId", + mcp.WithString("project_id", mcp.Required(), mcp.Description("The project ID"), ), - mcp.WithString("taskId", + mcp.WithString("task_id", mcp.Required(), mcp.Description("The task ID to mark as failed"), ), @@ -52,78 +103,42 @@ func RegisterMCPTools(s *server.MCPServer, service *Service) error { mcp.Description("Additional notes about the task failure"), ), ), - handleSetTaskFailure(service), + wrapServiceHandler(s, handleSetTaskFailure), ) return nil } -func handleSetTaskSuccess(service *Service) func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - return func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - projectId, err := request.RequireString("projectId") - if err != nil { - return nil, fmt.Errorf("failed to get projectId: %w", err) - } - - taskId, err := request.RequireString("taskId") - if err != nil { - return nil, fmt.Errorf("failed to get taskId: %w", err) - } - - result, err := request.RequireString("result") - if err != nil { - return nil, fmt.Errorf("failed to get result: %w", err) - } - - notes := request.GetString("notes", "") - - project, err := service.GetProject(projectId) - if err != nil { - return nil, fmt.Errorf("failed to get project: %w", err) - } - - nextTask := project.SetTaskSuccess(taskId, result, notes) - - message := fmt.Sprintf("Task %s marked as successful", taskId) - if nextTask != nil { - message += fmt.Sprintf("\nNext task: %s (ID: %s)", nextTask.Instructions, nextTask.ID) - } - - return mcp.NewToolResultText(message), nil +func handleSetTaskSuccess(s *Service, ctx context.Context, args setTaskSuccessArgs) (*taskResponse, error) { + project, err := s.GetProject(args.ProjectID) + if err != nil { + return nil, fmt.Errorf("failed to get project: %w", err) } + + nextTask := project.SetTaskSuccess(args.TaskID, args.Result, args.Notes) + + response := &taskResponse{ + TaskID: args.TaskID, + Message: fmt.Sprintf("Task %s marked as successful", args.TaskID), + NextTask: nextTask, + } + + return response, nil } -func handleSetTaskFailure(service *Service) func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - return func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - projectId, err := request.RequireString("projectId") - if err != nil { - return nil, fmt.Errorf("failed to get projectId: %w", err) - } - - taskId, err := request.RequireString("taskId") - if err != nil { - return nil, fmt.Errorf("failed to get taskId: %w", err) - } - - errorMsg, err := request.RequireString("error") - if err != nil { - return nil, fmt.Errorf("failed to get error: %w", err) - } - - notes := request.GetString("notes", "") - - project, err := service.GetProject(projectId) - if err != nil { - return nil, fmt.Errorf("failed to get project: %w", err) - } - - nextTask := project.SetTaskFailure(taskId, errorMsg, notes) - - message := fmt.Sprintf("Task %s marked as failed", taskId) - if nextTask != nil { - message += fmt.Sprintf("\nNext task: %s (ID: %s)", nextTask.Instructions, nextTask.ID) - } - - return mcp.NewToolResultText(message), nil +func handleSetTaskFailure(s *Service, ctx context.Context, args setTaskFailureArgs) (*taskResponse, error) { + project, err := s.GetProject(args.ProjectID) + if err != nil { + return nil, fmt.Errorf("failed to get project: %w", err) } + + nextTask := project.SetTaskFailure(args.TaskID, args.Error, args.Notes) + + response := &taskResponse{ + TaskID: args.TaskID, + Message: fmt.Sprintf("Task %s marked as failed", args.TaskID), + NextTask: nextTask, + } + + return response, nil } \ No newline at end of file diff --git a/mcp_test.go b/mcp_test.go index 8f2d0ac..50e78b8 100644 --- a/mcp_test.go +++ b/mcp_test.go @@ -11,7 +11,7 @@ func TestRegisterMCPTools(t *testing.T) { s := server.NewMCPServer("Test Server", "1.0.0") - err := RegisterMCPTools(s, service) + err := service.RegisterMCPTools(s) if err != nil { t.Fatalf("Failed to register MCP tools: %v", err) } diff --git a/taskcp.go b/taskcp.go index 92c86ee..386f5e7 100644 --- a/taskcp.go +++ b/taskcp.go @@ -27,18 +27,15 @@ const ( ) type Task struct { - ID string - State TaskState - NextTaskID string - CompletionCallback func(task *Task) + ID string `json:"id"` + State TaskState `json:"-"` + Instructions string `json:"instructions"` + Result string `json:"-"` + Error string `json:"-"` + Notes string `json:"-"` - // Written by creator - Instructions string - - // Written by executor - Result string - Error string - Notes string + nextTaskID string + completionCallback func(task *Task) } func New() *Service { @@ -66,12 +63,12 @@ func (s *Service) GetProject(id string) (*Project, error) { return project, nil } -func (p *Project) InsertTaskBefore(id string, instructions string, completionCallback func(task *Task)) *Task { - task := p.newTask(instructions, completionCallback, id) +func (p *Project) InsertTaskBefore(beforeID string, instructions string, completionCallback func(task *Task)) *Task { + task := p.newTask(instructions, completionCallback, beforeID) for t := range p.tasks() { - if t.NextTaskID == id { - t.NextTaskID = task.ID + if t.nextTaskID == beforeID { + t.nextTaskID = task.ID break } } @@ -94,8 +91,8 @@ func (p *Project) SetTaskSuccess(id string, result string, notes string) *Task { task.State = TaskStateSuccess task.Result = result task.Notes = notes - task.CompletionCallback(task) - p.NextTaskID = task.NextTaskID + task.completionCallback(task) + p.NextTaskID = task.nextTaskID return p.GetNextTask() } @@ -105,8 +102,8 @@ func (p *Project) SetTaskFailure(id string, error string, notes string) *Task { task.State = TaskStateFailure task.Error = error task.Notes = notes - task.CompletionCallback(task) - p.NextTaskID = task.NextTaskID + task.completionCallback(task) + p.NextTaskID = task.nextTaskID return p.GetNextTask() } @@ -115,9 +112,9 @@ func (p *Project) newTask(instructions string, completionCallback func(task *Tas task := &Task{ ID: uuid.New().String(), State: TaskStatePending, - NextTaskID: nextTaskID, + nextTaskID: nextTaskID, Instructions: instructions, - CompletionCallback: completionCallback, + completionCallback: completionCallback, } p.Tasks[task.ID] = task return task @@ -125,7 +122,7 @@ func (p *Project) newTask(instructions string, completionCallback func(task *Tas func (p *Project) tasks() iter.Seq[*Task] { return func(yield func(*Task) bool) { - for tid := p.NextTaskID; tid != ""; tid = p.Tasks[tid].NextTaskID { + for tid := p.NextTaskID; tid != ""; tid = p.Tasks[tid].nextTaskID { t := p.Tasks[tid] if !yield(t) { return