Refactor MCP integration to use structured handlers with snake_case
This commit is contained in:
@@ -1,41 +0,0 @@
|
||||
package taskcp_test
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/gopatchy/taskcp"
|
||||
"github.com/mark3labs/mcp-go/server"
|
||||
)
|
||||
|
||||
func ExampleRegisterMCPTools() {
|
||||
service := taskcp.New()
|
||||
|
||||
project := service.AddProject()
|
||||
fmt.Printf("Created project: %s\n", project.ID)
|
||||
|
||||
task1 := project.InsertTaskBefore("", "Compile the code", func(task *taskcp.Task) {
|
||||
fmt.Printf("Task %s completed with state: %s\n", task.ID, task.State)
|
||||
})
|
||||
|
||||
task2 := project.InsertTaskBefore("", "Run tests", func(task *taskcp.Task) {
|
||||
fmt.Printf("Task %s completed with state: %s\n", task.ID, task.State)
|
||||
})
|
||||
|
||||
task1.NextTaskID = task2.ID
|
||||
project.NextTaskID = task1.ID
|
||||
|
||||
mcpServer := server.NewMCPServer(
|
||||
"TaskCP Server",
|
||||
"1.0.0",
|
||||
server.WithToolCapabilities(true),
|
||||
)
|
||||
|
||||
err := taskcp.RegisterMCPTools(mcpServer, service)
|
||||
if err != nil {
|
||||
fmt.Printf("Failed to register tools: %v\n", err)
|
||||
return
|
||||
}
|
||||
|
||||
fmt.Println("MCP tools registered successfully")
|
||||
|
||||
}
|
||||
165
mcp.go
165
mcp.go
@@ -2,22 +2,73 @@ package taskcp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"github.com/mark3labs/mcp-go/mcp"
|
||||
"github.com/mark3labs/mcp-go/server"
|
||||
)
|
||||
|
||||
func RegisterMCPTools(s *server.MCPServer, service *Service) error {
|
||||
s.AddTool(
|
||||
type setTaskSuccessArgs struct {
|
||||
ProjectID string `json:"project_id"`
|
||||
TaskID string `json:"task_id"`
|
||||
Result string `json:"result"`
|
||||
Notes string `json:"notes,omitempty"`
|
||||
}
|
||||
|
||||
type setTaskFailureArgs struct {
|
||||
ProjectID string `json:"project_id"`
|
||||
TaskID string `json:"task_id"`
|
||||
Error string `json:"error"`
|
||||
Notes string `json:"notes,omitempty"`
|
||||
}
|
||||
|
||||
type taskResponse struct {
|
||||
TaskID string `json:"task_id"`
|
||||
Message string `json:"message"`
|
||||
NextTask *Task `json:"next_task,omitempty"`
|
||||
}
|
||||
|
||||
type errorResponse struct {
|
||||
Error string `json:"error"`
|
||||
}
|
||||
|
||||
type ServiceHandlerFunc[TArgs any, TResponse any] func(s *Service, ctx context.Context, args TArgs) (*TResponse, error)
|
||||
|
||||
func wrapServiceHandler[TArgs any, TResponse any](s *Service, handler ServiceHandlerFunc[TArgs, TResponse]) func(context.Context, mcp.CallToolRequest) (*mcp.CallToolResult, error) {
|
||||
return func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
|
||||
var args TArgs
|
||||
if err := request.BindArguments(&args); err != nil {
|
||||
errorJSON, _ := json.Marshal(errorResponse{Error: err.Error()})
|
||||
return mcp.NewToolResultText(string(errorJSON)), nil
|
||||
}
|
||||
|
||||
response, err := handler(s, ctx, args)
|
||||
if err != nil {
|
||||
errorJSON, _ := json.Marshal(errorResponse{Error: err.Error()})
|
||||
return mcp.NewToolResultText(string(errorJSON)), nil
|
||||
}
|
||||
|
||||
resultJSON, err := json.MarshalIndent(response, "", " ")
|
||||
if err != nil {
|
||||
errorJSON, _ := json.Marshal(errorResponse{Error: err.Error()})
|
||||
return mcp.NewToolResultText(string(errorJSON)), nil
|
||||
}
|
||||
|
||||
return mcp.NewToolResultText(string(resultJSON)), nil
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Service) RegisterMCPTools(mcpServer *server.MCPServer) error {
|
||||
mcpServer.AddTool(
|
||||
mcp.NewTool(
|
||||
"SetTaskSuccess",
|
||||
"set_task_success",
|
||||
mcp.WithDescription("Mark a task as successfully completed"),
|
||||
mcp.WithString("projectId",
|
||||
mcp.WithString("project_id",
|
||||
mcp.Required(),
|
||||
mcp.Description("The project ID"),
|
||||
),
|
||||
mcp.WithString("taskId",
|
||||
mcp.WithString("task_id",
|
||||
mcp.Required(),
|
||||
mcp.Description("The task ID to mark as successful"),
|
||||
),
|
||||
@@ -29,18 +80,18 @@ func RegisterMCPTools(s *server.MCPServer, service *Service) error {
|
||||
mcp.Description("Additional notes about the task completion"),
|
||||
),
|
||||
),
|
||||
handleSetTaskSuccess(service),
|
||||
wrapServiceHandler(s, handleSetTaskSuccess),
|
||||
)
|
||||
|
||||
s.AddTool(
|
||||
mcpServer.AddTool(
|
||||
mcp.NewTool(
|
||||
"SetTaskFailure",
|
||||
"set_task_failure",
|
||||
mcp.WithDescription("Mark a task as failed"),
|
||||
mcp.WithString("projectId",
|
||||
mcp.WithString("project_id",
|
||||
mcp.Required(),
|
||||
mcp.Description("The project ID"),
|
||||
),
|
||||
mcp.WithString("taskId",
|
||||
mcp.WithString("task_id",
|
||||
mcp.Required(),
|
||||
mcp.Description("The task ID to mark as failed"),
|
||||
),
|
||||
@@ -52,78 +103,42 @@ func RegisterMCPTools(s *server.MCPServer, service *Service) error {
|
||||
mcp.Description("Additional notes about the task failure"),
|
||||
),
|
||||
),
|
||||
handleSetTaskFailure(service),
|
||||
wrapServiceHandler(s, handleSetTaskFailure),
|
||||
)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func handleSetTaskSuccess(service *Service) func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
|
||||
return func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
|
||||
projectId, err := request.RequireString("projectId")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get projectId: %w", err)
|
||||
}
|
||||
|
||||
taskId, err := request.RequireString("taskId")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get taskId: %w", err)
|
||||
}
|
||||
|
||||
result, err := request.RequireString("result")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get result: %w", err)
|
||||
}
|
||||
|
||||
notes := request.GetString("notes", "")
|
||||
|
||||
project, err := service.GetProject(projectId)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get project: %w", err)
|
||||
}
|
||||
|
||||
nextTask := project.SetTaskSuccess(taskId, result, notes)
|
||||
|
||||
message := fmt.Sprintf("Task %s marked as successful", taskId)
|
||||
if nextTask != nil {
|
||||
message += fmt.Sprintf("\nNext task: %s (ID: %s)", nextTask.Instructions, nextTask.ID)
|
||||
}
|
||||
|
||||
return mcp.NewToolResultText(message), nil
|
||||
func handleSetTaskSuccess(s *Service, ctx context.Context, args setTaskSuccessArgs) (*taskResponse, error) {
|
||||
project, err := s.GetProject(args.ProjectID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get project: %w", err)
|
||||
}
|
||||
|
||||
nextTask := project.SetTaskSuccess(args.TaskID, args.Result, args.Notes)
|
||||
|
||||
response := &taskResponse{
|
||||
TaskID: args.TaskID,
|
||||
Message: fmt.Sprintf("Task %s marked as successful", args.TaskID),
|
||||
NextTask: nextTask,
|
||||
}
|
||||
|
||||
return response, nil
|
||||
}
|
||||
|
||||
func handleSetTaskFailure(service *Service) func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
|
||||
return func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
|
||||
projectId, err := request.RequireString("projectId")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get projectId: %w", err)
|
||||
}
|
||||
|
||||
taskId, err := request.RequireString("taskId")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get taskId: %w", err)
|
||||
}
|
||||
|
||||
errorMsg, err := request.RequireString("error")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get error: %w", err)
|
||||
}
|
||||
|
||||
notes := request.GetString("notes", "")
|
||||
|
||||
project, err := service.GetProject(projectId)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get project: %w", err)
|
||||
}
|
||||
|
||||
nextTask := project.SetTaskFailure(taskId, errorMsg, notes)
|
||||
|
||||
message := fmt.Sprintf("Task %s marked as failed", taskId)
|
||||
if nextTask != nil {
|
||||
message += fmt.Sprintf("\nNext task: %s (ID: %s)", nextTask.Instructions, nextTask.ID)
|
||||
}
|
||||
|
||||
return mcp.NewToolResultText(message), nil
|
||||
func handleSetTaskFailure(s *Service, ctx context.Context, args setTaskFailureArgs) (*taskResponse, error) {
|
||||
project, err := s.GetProject(args.ProjectID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get project: %w", err)
|
||||
}
|
||||
|
||||
nextTask := project.SetTaskFailure(args.TaskID, args.Error, args.Notes)
|
||||
|
||||
response := &taskResponse{
|
||||
TaskID: args.TaskID,
|
||||
Message: fmt.Sprintf("Task %s marked as failed", args.TaskID),
|
||||
NextTask: nextTask,
|
||||
}
|
||||
|
||||
return response, nil
|
||||
}
|
||||
@@ -11,7 +11,7 @@ func TestRegisterMCPTools(t *testing.T) {
|
||||
|
||||
s := server.NewMCPServer("Test Server", "1.0.0")
|
||||
|
||||
err := RegisterMCPTools(s, service)
|
||||
err := service.RegisterMCPTools(s)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to register MCP tools: %v", err)
|
||||
}
|
||||
|
||||
41
taskcp.go
41
taskcp.go
@@ -27,18 +27,15 @@ const (
|
||||
)
|
||||
|
||||
type Task struct {
|
||||
ID string
|
||||
State TaskState
|
||||
NextTaskID string
|
||||
CompletionCallback func(task *Task)
|
||||
ID string `json:"id"`
|
||||
State TaskState `json:"-"`
|
||||
Instructions string `json:"instructions"`
|
||||
Result string `json:"-"`
|
||||
Error string `json:"-"`
|
||||
Notes string `json:"-"`
|
||||
|
||||
// Written by creator
|
||||
Instructions string
|
||||
|
||||
// Written by executor
|
||||
Result string
|
||||
Error string
|
||||
Notes string
|
||||
nextTaskID string
|
||||
completionCallback func(task *Task)
|
||||
}
|
||||
|
||||
func New() *Service {
|
||||
@@ -66,12 +63,12 @@ func (s *Service) GetProject(id string) (*Project, error) {
|
||||
return project, nil
|
||||
}
|
||||
|
||||
func (p *Project) InsertTaskBefore(id string, instructions string, completionCallback func(task *Task)) *Task {
|
||||
task := p.newTask(instructions, completionCallback, id)
|
||||
func (p *Project) InsertTaskBefore(beforeID string, instructions string, completionCallback func(task *Task)) *Task {
|
||||
task := p.newTask(instructions, completionCallback, beforeID)
|
||||
|
||||
for t := range p.tasks() {
|
||||
if t.NextTaskID == id {
|
||||
t.NextTaskID = task.ID
|
||||
if t.nextTaskID == beforeID {
|
||||
t.nextTaskID = task.ID
|
||||
break
|
||||
}
|
||||
}
|
||||
@@ -94,8 +91,8 @@ func (p *Project) SetTaskSuccess(id string, result string, notes string) *Task {
|
||||
task.State = TaskStateSuccess
|
||||
task.Result = result
|
||||
task.Notes = notes
|
||||
task.CompletionCallback(task)
|
||||
p.NextTaskID = task.NextTaskID
|
||||
task.completionCallback(task)
|
||||
p.NextTaskID = task.nextTaskID
|
||||
|
||||
return p.GetNextTask()
|
||||
}
|
||||
@@ -105,8 +102,8 @@ func (p *Project) SetTaskFailure(id string, error string, notes string) *Task {
|
||||
task.State = TaskStateFailure
|
||||
task.Error = error
|
||||
task.Notes = notes
|
||||
task.CompletionCallback(task)
|
||||
p.NextTaskID = task.NextTaskID
|
||||
task.completionCallback(task)
|
||||
p.NextTaskID = task.nextTaskID
|
||||
|
||||
return p.GetNextTask()
|
||||
}
|
||||
@@ -115,9 +112,9 @@ func (p *Project) newTask(instructions string, completionCallback func(task *Tas
|
||||
task := &Task{
|
||||
ID: uuid.New().String(),
|
||||
State: TaskStatePending,
|
||||
NextTaskID: nextTaskID,
|
||||
nextTaskID: nextTaskID,
|
||||
Instructions: instructions,
|
||||
CompletionCallback: completionCallback,
|
||||
completionCallback: completionCallback,
|
||||
}
|
||||
p.Tasks[task.ID] = task
|
||||
return task
|
||||
@@ -125,7 +122,7 @@ func (p *Project) newTask(instructions string, completionCallback func(task *Tas
|
||||
|
||||
func (p *Project) tasks() iter.Seq[*Task] {
|
||||
return func(yield func(*Task) bool) {
|
||||
for tid := p.NextTaskID; tid != ""; tid = p.Tasks[tid].NextTaskID {
|
||||
for tid := p.NextTaskID; tid != ""; tid = p.Tasks[tid].nextTaskID {
|
||||
t := p.Tasks[tid]
|
||||
if !yield(t) {
|
||||
return
|
||||
|
||||
Reference in New Issue
Block a user