Refactor MCP integration to use structured handlers with snake_case

This commit is contained in:
Ian Gulliver
2025-07-05 14:24:50 -07:00
parent f0c7ee76ef
commit 33d54f399c
4 changed files with 110 additions and 139 deletions

View File

@@ -1,41 +0,0 @@
package taskcp_test
import (
"fmt"
"github.com/gopatchy/taskcp"
"github.com/mark3labs/mcp-go/server"
)
func ExampleRegisterMCPTools() {
service := taskcp.New()
project := service.AddProject()
fmt.Printf("Created project: %s\n", project.ID)
task1 := project.InsertTaskBefore("", "Compile the code", func(task *taskcp.Task) {
fmt.Printf("Task %s completed with state: %s\n", task.ID, task.State)
})
task2 := project.InsertTaskBefore("", "Run tests", func(task *taskcp.Task) {
fmt.Printf("Task %s completed with state: %s\n", task.ID, task.State)
})
task1.NextTaskID = task2.ID
project.NextTaskID = task1.ID
mcpServer := server.NewMCPServer(
"TaskCP Server",
"1.0.0",
server.WithToolCapabilities(true),
)
err := taskcp.RegisterMCPTools(mcpServer, service)
if err != nil {
fmt.Printf("Failed to register tools: %v\n", err)
return
}
fmt.Println("MCP tools registered successfully")
}

165
mcp.go
View File

@@ -2,22 +2,73 @@ package taskcp
import (
"context"
"encoding/json"
"fmt"
"github.com/mark3labs/mcp-go/mcp"
"github.com/mark3labs/mcp-go/server"
)
func RegisterMCPTools(s *server.MCPServer, service *Service) error {
s.AddTool(
type setTaskSuccessArgs struct {
ProjectID string `json:"project_id"`
TaskID string `json:"task_id"`
Result string `json:"result"`
Notes string `json:"notes,omitempty"`
}
type setTaskFailureArgs struct {
ProjectID string `json:"project_id"`
TaskID string `json:"task_id"`
Error string `json:"error"`
Notes string `json:"notes,omitempty"`
}
type taskResponse struct {
TaskID string `json:"task_id"`
Message string `json:"message"`
NextTask *Task `json:"next_task,omitempty"`
}
type errorResponse struct {
Error string `json:"error"`
}
type ServiceHandlerFunc[TArgs any, TResponse any] func(s *Service, ctx context.Context, args TArgs) (*TResponse, error)
func wrapServiceHandler[TArgs any, TResponse any](s *Service, handler ServiceHandlerFunc[TArgs, TResponse]) func(context.Context, mcp.CallToolRequest) (*mcp.CallToolResult, error) {
return func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
var args TArgs
if err := request.BindArguments(&args); err != nil {
errorJSON, _ := json.Marshal(errorResponse{Error: err.Error()})
return mcp.NewToolResultText(string(errorJSON)), nil
}
response, err := handler(s, ctx, args)
if err != nil {
errorJSON, _ := json.Marshal(errorResponse{Error: err.Error()})
return mcp.NewToolResultText(string(errorJSON)), nil
}
resultJSON, err := json.MarshalIndent(response, "", " ")
if err != nil {
errorJSON, _ := json.Marshal(errorResponse{Error: err.Error()})
return mcp.NewToolResultText(string(errorJSON)), nil
}
return mcp.NewToolResultText(string(resultJSON)), nil
}
}
func (s *Service) RegisterMCPTools(mcpServer *server.MCPServer) error {
mcpServer.AddTool(
mcp.NewTool(
"SetTaskSuccess",
"set_task_success",
mcp.WithDescription("Mark a task as successfully completed"),
mcp.WithString("projectId",
mcp.WithString("project_id",
mcp.Required(),
mcp.Description("The project ID"),
),
mcp.WithString("taskId",
mcp.WithString("task_id",
mcp.Required(),
mcp.Description("The task ID to mark as successful"),
),
@@ -29,18 +80,18 @@ func RegisterMCPTools(s *server.MCPServer, service *Service) error {
mcp.Description("Additional notes about the task completion"),
),
),
handleSetTaskSuccess(service),
wrapServiceHandler(s, handleSetTaskSuccess),
)
s.AddTool(
mcpServer.AddTool(
mcp.NewTool(
"SetTaskFailure",
"set_task_failure",
mcp.WithDescription("Mark a task as failed"),
mcp.WithString("projectId",
mcp.WithString("project_id",
mcp.Required(),
mcp.Description("The project ID"),
),
mcp.WithString("taskId",
mcp.WithString("task_id",
mcp.Required(),
mcp.Description("The task ID to mark as failed"),
),
@@ -52,78 +103,42 @@ func RegisterMCPTools(s *server.MCPServer, service *Service) error {
mcp.Description("Additional notes about the task failure"),
),
),
handleSetTaskFailure(service),
wrapServiceHandler(s, handleSetTaskFailure),
)
return nil
}
func handleSetTaskSuccess(service *Service) func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
return func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
projectId, err := request.RequireString("projectId")
if err != nil {
return nil, fmt.Errorf("failed to get projectId: %w", err)
}
taskId, err := request.RequireString("taskId")
if err != nil {
return nil, fmt.Errorf("failed to get taskId: %w", err)
}
result, err := request.RequireString("result")
if err != nil {
return nil, fmt.Errorf("failed to get result: %w", err)
}
notes := request.GetString("notes", "")
project, err := service.GetProject(projectId)
if err != nil {
return nil, fmt.Errorf("failed to get project: %w", err)
}
nextTask := project.SetTaskSuccess(taskId, result, notes)
message := fmt.Sprintf("Task %s marked as successful", taskId)
if nextTask != nil {
message += fmt.Sprintf("\nNext task: %s (ID: %s)", nextTask.Instructions, nextTask.ID)
}
return mcp.NewToolResultText(message), nil
func handleSetTaskSuccess(s *Service, ctx context.Context, args setTaskSuccessArgs) (*taskResponse, error) {
project, err := s.GetProject(args.ProjectID)
if err != nil {
return nil, fmt.Errorf("failed to get project: %w", err)
}
nextTask := project.SetTaskSuccess(args.TaskID, args.Result, args.Notes)
response := &taskResponse{
TaskID: args.TaskID,
Message: fmt.Sprintf("Task %s marked as successful", args.TaskID),
NextTask: nextTask,
}
return response, nil
}
func handleSetTaskFailure(service *Service) func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
return func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
projectId, err := request.RequireString("projectId")
if err != nil {
return nil, fmt.Errorf("failed to get projectId: %w", err)
}
taskId, err := request.RequireString("taskId")
if err != nil {
return nil, fmt.Errorf("failed to get taskId: %w", err)
}
errorMsg, err := request.RequireString("error")
if err != nil {
return nil, fmt.Errorf("failed to get error: %w", err)
}
notes := request.GetString("notes", "")
project, err := service.GetProject(projectId)
if err != nil {
return nil, fmt.Errorf("failed to get project: %w", err)
}
nextTask := project.SetTaskFailure(taskId, errorMsg, notes)
message := fmt.Sprintf("Task %s marked as failed", taskId)
if nextTask != nil {
message += fmt.Sprintf("\nNext task: %s (ID: %s)", nextTask.Instructions, nextTask.ID)
}
return mcp.NewToolResultText(message), nil
func handleSetTaskFailure(s *Service, ctx context.Context, args setTaskFailureArgs) (*taskResponse, error) {
project, err := s.GetProject(args.ProjectID)
if err != nil {
return nil, fmt.Errorf("failed to get project: %w", err)
}
nextTask := project.SetTaskFailure(args.TaskID, args.Error, args.Notes)
response := &taskResponse{
TaskID: args.TaskID,
Message: fmt.Sprintf("Task %s marked as failed", args.TaskID),
NextTask: nextTask,
}
return response, nil
}

View File

@@ -11,7 +11,7 @@ func TestRegisterMCPTools(t *testing.T) {
s := server.NewMCPServer("Test Server", "1.0.0")
err := RegisterMCPTools(s, service)
err := service.RegisterMCPTools(s)
if err != nil {
t.Fatalf("Failed to register MCP tools: %v", err)
}

View File

@@ -27,18 +27,15 @@ const (
)
type Task struct {
ID string
State TaskState
NextTaskID string
CompletionCallback func(task *Task)
ID string `json:"id"`
State TaskState `json:"-"`
Instructions string `json:"instructions"`
Result string `json:"-"`
Error string `json:"-"`
Notes string `json:"-"`
// Written by creator
Instructions string
// Written by executor
Result string
Error string
Notes string
nextTaskID string
completionCallback func(task *Task)
}
func New() *Service {
@@ -66,12 +63,12 @@ func (s *Service) GetProject(id string) (*Project, error) {
return project, nil
}
func (p *Project) InsertTaskBefore(id string, instructions string, completionCallback func(task *Task)) *Task {
task := p.newTask(instructions, completionCallback, id)
func (p *Project) InsertTaskBefore(beforeID string, instructions string, completionCallback func(task *Task)) *Task {
task := p.newTask(instructions, completionCallback, beforeID)
for t := range p.tasks() {
if t.NextTaskID == id {
t.NextTaskID = task.ID
if t.nextTaskID == beforeID {
t.nextTaskID = task.ID
break
}
}
@@ -94,8 +91,8 @@ func (p *Project) SetTaskSuccess(id string, result string, notes string) *Task {
task.State = TaskStateSuccess
task.Result = result
task.Notes = notes
task.CompletionCallback(task)
p.NextTaskID = task.NextTaskID
task.completionCallback(task)
p.NextTaskID = task.nextTaskID
return p.GetNextTask()
}
@@ -105,8 +102,8 @@ func (p *Project) SetTaskFailure(id string, error string, notes string) *Task {
task.State = TaskStateFailure
task.Error = error
task.Notes = notes
task.CompletionCallback(task)
p.NextTaskID = task.NextTaskID
task.completionCallback(task)
p.NextTaskID = task.nextTaskID
return p.GetNextTask()
}
@@ -115,9 +112,9 @@ func (p *Project) newTask(instructions string, completionCallback func(task *Tas
task := &Task{
ID: uuid.New().String(),
State: TaskStatePending,
NextTaskID: nextTaskID,
nextTaskID: nextTaskID,
Instructions: instructions,
CompletionCallback: completionCallback,
completionCallback: completionCallback,
}
p.Tasks[task.ID] = task
return task
@@ -125,7 +122,7 @@ func (p *Project) newTask(instructions string, completionCallback func(task *Tas
func (p *Project) tasks() iter.Seq[*Task] {
return func(yield func(*Task) bool) {
for tid := p.NextTaskID; tid != ""; tid = p.Tasks[tid].NextTaskID {
for tid := p.NextTaskID; tid != ""; tid = p.Tasks[tid].nextTaskID {
t := p.Tasks[tid]
if !yield(t) {
return