Refactor MCP integration to use structured handlers with snake_case

This commit is contained in:
Ian Gulliver
2025-07-05 14:24:50 -07:00
parent f0c7ee76ef
commit 33d54f399c
4 changed files with 110 additions and 139 deletions

View File

@@ -1,41 +0,0 @@
package taskcp_test
import (
"fmt"
"github.com/gopatchy/taskcp"
"github.com/mark3labs/mcp-go/server"
)
func ExampleRegisterMCPTools() {
service := taskcp.New()
project := service.AddProject()
fmt.Printf("Created project: %s\n", project.ID)
task1 := project.InsertTaskBefore("", "Compile the code", func(task *taskcp.Task) {
fmt.Printf("Task %s completed with state: %s\n", task.ID, task.State)
})
task2 := project.InsertTaskBefore("", "Run tests", func(task *taskcp.Task) {
fmt.Printf("Task %s completed with state: %s\n", task.ID, task.State)
})
task1.NextTaskID = task2.ID
project.NextTaskID = task1.ID
mcpServer := server.NewMCPServer(
"TaskCP Server",
"1.0.0",
server.WithToolCapabilities(true),
)
err := taskcp.RegisterMCPTools(mcpServer, service)
if err != nil {
fmt.Printf("Failed to register tools: %v\n", err)
return
}
fmt.Println("MCP tools registered successfully")
}

165
mcp.go
View File

@@ -2,22 +2,73 @@ package taskcp
import ( import (
"context" "context"
"encoding/json"
"fmt" "fmt"
"github.com/mark3labs/mcp-go/mcp" "github.com/mark3labs/mcp-go/mcp"
"github.com/mark3labs/mcp-go/server" "github.com/mark3labs/mcp-go/server"
) )
func RegisterMCPTools(s *server.MCPServer, service *Service) error { type setTaskSuccessArgs struct {
s.AddTool( ProjectID string `json:"project_id"`
TaskID string `json:"task_id"`
Result string `json:"result"`
Notes string `json:"notes,omitempty"`
}
type setTaskFailureArgs struct {
ProjectID string `json:"project_id"`
TaskID string `json:"task_id"`
Error string `json:"error"`
Notes string `json:"notes,omitempty"`
}
type taskResponse struct {
TaskID string `json:"task_id"`
Message string `json:"message"`
NextTask *Task `json:"next_task,omitempty"`
}
type errorResponse struct {
Error string `json:"error"`
}
type ServiceHandlerFunc[TArgs any, TResponse any] func(s *Service, ctx context.Context, args TArgs) (*TResponse, error)
func wrapServiceHandler[TArgs any, TResponse any](s *Service, handler ServiceHandlerFunc[TArgs, TResponse]) func(context.Context, mcp.CallToolRequest) (*mcp.CallToolResult, error) {
return func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
var args TArgs
if err := request.BindArguments(&args); err != nil {
errorJSON, _ := json.Marshal(errorResponse{Error: err.Error()})
return mcp.NewToolResultText(string(errorJSON)), nil
}
response, err := handler(s, ctx, args)
if err != nil {
errorJSON, _ := json.Marshal(errorResponse{Error: err.Error()})
return mcp.NewToolResultText(string(errorJSON)), nil
}
resultJSON, err := json.MarshalIndent(response, "", " ")
if err != nil {
errorJSON, _ := json.Marshal(errorResponse{Error: err.Error()})
return mcp.NewToolResultText(string(errorJSON)), nil
}
return mcp.NewToolResultText(string(resultJSON)), nil
}
}
func (s *Service) RegisterMCPTools(mcpServer *server.MCPServer) error {
mcpServer.AddTool(
mcp.NewTool( mcp.NewTool(
"SetTaskSuccess", "set_task_success",
mcp.WithDescription("Mark a task as successfully completed"), mcp.WithDescription("Mark a task as successfully completed"),
mcp.WithString("projectId", mcp.WithString("project_id",
mcp.Required(), mcp.Required(),
mcp.Description("The project ID"), mcp.Description("The project ID"),
), ),
mcp.WithString("taskId", mcp.WithString("task_id",
mcp.Required(), mcp.Required(),
mcp.Description("The task ID to mark as successful"), mcp.Description("The task ID to mark as successful"),
), ),
@@ -29,18 +80,18 @@ func RegisterMCPTools(s *server.MCPServer, service *Service) error {
mcp.Description("Additional notes about the task completion"), mcp.Description("Additional notes about the task completion"),
), ),
), ),
handleSetTaskSuccess(service), wrapServiceHandler(s, handleSetTaskSuccess),
) )
s.AddTool( mcpServer.AddTool(
mcp.NewTool( mcp.NewTool(
"SetTaskFailure", "set_task_failure",
mcp.WithDescription("Mark a task as failed"), mcp.WithDescription("Mark a task as failed"),
mcp.WithString("projectId", mcp.WithString("project_id",
mcp.Required(), mcp.Required(),
mcp.Description("The project ID"), mcp.Description("The project ID"),
), ),
mcp.WithString("taskId", mcp.WithString("task_id",
mcp.Required(), mcp.Required(),
mcp.Description("The task ID to mark as failed"), mcp.Description("The task ID to mark as failed"),
), ),
@@ -52,78 +103,42 @@ func RegisterMCPTools(s *server.MCPServer, service *Service) error {
mcp.Description("Additional notes about the task failure"), mcp.Description("Additional notes about the task failure"),
), ),
), ),
handleSetTaskFailure(service), wrapServiceHandler(s, handleSetTaskFailure),
) )
return nil return nil
} }
func handleSetTaskSuccess(service *Service) func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { func handleSetTaskSuccess(s *Service, ctx context.Context, args setTaskSuccessArgs) (*taskResponse, error) {
return func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { project, err := s.GetProject(args.ProjectID)
projectId, err := request.RequireString("projectId") if err != nil {
if err != nil { return nil, fmt.Errorf("failed to get project: %w", err)
return nil, fmt.Errorf("failed to get projectId: %w", err)
}
taskId, err := request.RequireString("taskId")
if err != nil {
return nil, fmt.Errorf("failed to get taskId: %w", err)
}
result, err := request.RequireString("result")
if err != nil {
return nil, fmt.Errorf("failed to get result: %w", err)
}
notes := request.GetString("notes", "")
project, err := service.GetProject(projectId)
if err != nil {
return nil, fmt.Errorf("failed to get project: %w", err)
}
nextTask := project.SetTaskSuccess(taskId, result, notes)
message := fmt.Sprintf("Task %s marked as successful", taskId)
if nextTask != nil {
message += fmt.Sprintf("\nNext task: %s (ID: %s)", nextTask.Instructions, nextTask.ID)
}
return mcp.NewToolResultText(message), nil
} }
nextTask := project.SetTaskSuccess(args.TaskID, args.Result, args.Notes)
response := &taskResponse{
TaskID: args.TaskID,
Message: fmt.Sprintf("Task %s marked as successful", args.TaskID),
NextTask: nextTask,
}
return response, nil
} }
func handleSetTaskFailure(service *Service) func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { func handleSetTaskFailure(s *Service, ctx context.Context, args setTaskFailureArgs) (*taskResponse, error) {
return func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { project, err := s.GetProject(args.ProjectID)
projectId, err := request.RequireString("projectId") if err != nil {
if err != nil { return nil, fmt.Errorf("failed to get project: %w", err)
return nil, fmt.Errorf("failed to get projectId: %w", err)
}
taskId, err := request.RequireString("taskId")
if err != nil {
return nil, fmt.Errorf("failed to get taskId: %w", err)
}
errorMsg, err := request.RequireString("error")
if err != nil {
return nil, fmt.Errorf("failed to get error: %w", err)
}
notes := request.GetString("notes", "")
project, err := service.GetProject(projectId)
if err != nil {
return nil, fmt.Errorf("failed to get project: %w", err)
}
nextTask := project.SetTaskFailure(taskId, errorMsg, notes)
message := fmt.Sprintf("Task %s marked as failed", taskId)
if nextTask != nil {
message += fmt.Sprintf("\nNext task: %s (ID: %s)", nextTask.Instructions, nextTask.ID)
}
return mcp.NewToolResultText(message), nil
} }
nextTask := project.SetTaskFailure(args.TaskID, args.Error, args.Notes)
response := &taskResponse{
TaskID: args.TaskID,
Message: fmt.Sprintf("Task %s marked as failed", args.TaskID),
NextTask: nextTask,
}
return response, nil
} }

View File

@@ -11,7 +11,7 @@ func TestRegisterMCPTools(t *testing.T) {
s := server.NewMCPServer("Test Server", "1.0.0") s := server.NewMCPServer("Test Server", "1.0.0")
err := RegisterMCPTools(s, service) err := service.RegisterMCPTools(s)
if err != nil { if err != nil {
t.Fatalf("Failed to register MCP tools: %v", err) t.Fatalf("Failed to register MCP tools: %v", err)
} }

View File

@@ -27,18 +27,15 @@ const (
) )
type Task struct { type Task struct {
ID string ID string `json:"id"`
State TaskState State TaskState `json:"-"`
NextTaskID string Instructions string `json:"instructions"`
CompletionCallback func(task *Task) Result string `json:"-"`
Error string `json:"-"`
Notes string `json:"-"`
// Written by creator nextTaskID string
Instructions string completionCallback func(task *Task)
// Written by executor
Result string
Error string
Notes string
} }
func New() *Service { func New() *Service {
@@ -66,12 +63,12 @@ func (s *Service) GetProject(id string) (*Project, error) {
return project, nil return project, nil
} }
func (p *Project) InsertTaskBefore(id string, instructions string, completionCallback func(task *Task)) *Task { func (p *Project) InsertTaskBefore(beforeID string, instructions string, completionCallback func(task *Task)) *Task {
task := p.newTask(instructions, completionCallback, id) task := p.newTask(instructions, completionCallback, beforeID)
for t := range p.tasks() { for t := range p.tasks() {
if t.NextTaskID == id { if t.nextTaskID == beforeID {
t.NextTaskID = task.ID t.nextTaskID = task.ID
break break
} }
} }
@@ -94,8 +91,8 @@ func (p *Project) SetTaskSuccess(id string, result string, notes string) *Task {
task.State = TaskStateSuccess task.State = TaskStateSuccess
task.Result = result task.Result = result
task.Notes = notes task.Notes = notes
task.CompletionCallback(task) task.completionCallback(task)
p.NextTaskID = task.NextTaskID p.NextTaskID = task.nextTaskID
return p.GetNextTask() return p.GetNextTask()
} }
@@ -105,8 +102,8 @@ func (p *Project) SetTaskFailure(id string, error string, notes string) *Task {
task.State = TaskStateFailure task.State = TaskStateFailure
task.Error = error task.Error = error
task.Notes = notes task.Notes = notes
task.CompletionCallback(task) task.completionCallback(task)
p.NextTaskID = task.NextTaskID p.NextTaskID = task.nextTaskID
return p.GetNextTask() return p.GetNextTask()
} }
@@ -115,9 +112,9 @@ func (p *Project) newTask(instructions string, completionCallback func(task *Tas
task := &Task{ task := &Task{
ID: uuid.New().String(), ID: uuid.New().String(),
State: TaskStatePending, State: TaskStatePending,
NextTaskID: nextTaskID, nextTaskID: nextTaskID,
Instructions: instructions, Instructions: instructions,
CompletionCallback: completionCallback, completionCallback: completionCallback,
} }
p.Tasks[task.ID] = task p.Tasks[task.ID] = task
return task return task
@@ -125,7 +122,7 @@ func (p *Project) newTask(instructions string, completionCallback func(task *Tas
func (p *Project) tasks() iter.Seq[*Task] { func (p *Project) tasks() iter.Seq[*Task] {
return func(yield func(*Task) bool) { return func(yield func(*Task) bool) {
for tid := p.NextTaskID; tid != ""; tid = p.Tasks[tid].NextTaskID { for tid := p.NextTaskID; tid != ""; tid = p.Tasks[tid].nextTaskID {
t := p.Tasks[tid] t := p.Tasks[tid]
if !yield(t) { if !yield(t) {
return return