Files
taskcp/mcp.go
Ian Gulliver a96b350b28 Simplified API
2025-07-12 15:32:03 -07:00

155 lines
4.0 KiB
Go

package taskcp
import (
"context"
"encoding/json"
"fmt"
"github.com/mark3labs/mcp-go/mcp"
"github.com/mark3labs/mcp-go/server"
)
type setTaskSuccessArgs struct {
ProjectID int `json:"project_id"`
TaskID int `json:"task_id"`
Result string `json:"result"`
Notes string `json:"notes,omitempty"`
}
type setTaskFailureArgs struct {
ProjectID int `json:"project_id"`
TaskID int `json:"task_id"`
Error string `json:"error"`
Notes string `json:"notes,omitempty"`
}
type taskResponse struct {
NextTask *Task `json:"next_task"`
}
type errorResponse struct {
Error string `json:"error"`
}
type ServiceHandlerFunc[TArgs any, TResponse any] func(s *Service, ctx context.Context, args TArgs) (*TResponse, error)
func wrapServiceHandler[TArgs any, TResponse any](s *Service, handler ServiceHandlerFunc[TArgs, TResponse]) func(context.Context, mcp.CallToolRequest) (*mcp.CallToolResult, error) {
return func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
var args TArgs
if err := request.BindArguments(&args); err != nil {
errorJSON, _ := json.Marshal(errorResponse{Error: err.Error()})
return mcp.NewToolResultText(string(errorJSON)), nil
}
response, err := handler(s, ctx, args)
if err != nil {
errorJSON, _ := json.Marshal(errorResponse{Error: err.Error()})
return mcp.NewToolResultText(string(errorJSON)), nil
}
resultJSON, err := json.MarshalIndent(response, "", " ")
if err != nil {
errorJSON, _ := json.Marshal(errorResponse{Error: err.Error()})
return mcp.NewToolResultText(string(errorJSON)), nil
}
return mcp.NewToolResultText(string(resultJSON)), nil
}
}
func (s *Service) RegisterMCPTools(mcpServer *server.MCPServer) error {
mcpServer.AddTool(
mcp.NewTool(
"set_task_success",
mcp.WithDescription("Mark a task as successfully completed"),
mcp.WithNumber("project_id",
mcp.Required(),
mcp.Description("The project ID"),
),
mcp.WithString("task_id",
mcp.Required(),
mcp.Description("The task ID to mark as successful"),
),
mcp.WithString("result",
mcp.Required(),
mcp.Description("The result of the task execution"),
),
mcp.WithString("notes",
mcp.Description("Additional notes about the task completion"),
),
),
wrapServiceHandler(s, handleSetTaskSuccess),
)
mcpServer.AddTool(
mcp.NewTool(
"set_task_failure",
mcp.WithDescription("Mark a task as failed"),
mcp.WithNumber("project_id",
mcp.Required(),
mcp.Description("The project ID"),
),
mcp.WithString("task_id",
mcp.Required(),
mcp.Description("The task ID to mark as failed"),
),
mcp.WithString("error",
mcp.Required(),
mcp.Description("The error message describing why the task failed"),
),
mcp.WithString("notes",
mcp.Description("Additional notes about the task failure"),
),
),
wrapServiceHandler(s, handleSetTaskFailure),
)
return nil
}
func handleSetTaskSuccess(s *Service, ctx context.Context, args setTaskSuccessArgs) (*taskResponse, error) {
project, err := s.GetProject(args.ProjectID)
if err != nil {
return nil, fmt.Errorf("failed to get project: %w", err)
}
task, err := project.GetRunningTask(args.TaskID)
if err != nil {
return nil, fmt.Errorf("failed to get task: %w", err)
}
nextTask, err := task.SetSuccess(args.Result, args.Notes)
if err != nil {
return nil, fmt.Errorf("completion callback error: %w", err)
}
response := &taskResponse{
NextTask: nextTask,
}
return response, nil
}
func handleSetTaskFailure(s *Service, ctx context.Context, args setTaskFailureArgs) (*taskResponse, error) {
project, err := s.GetProject(args.ProjectID)
if err != nil {
return nil, fmt.Errorf("failed to get project: %w", err)
}
task, err := project.GetRunningTask(args.TaskID)
if err != nil {
return nil, fmt.Errorf("failed to get task: %w", err)
}
nextTask, err := task.SetFailure(args.Error, args.Notes)
if err != nil {
return nil, fmt.Errorf("completion callback error: %w", err)
}
response := &taskResponse{
NextTask: nextTask,
}
return response, nil
}