From 50058e877f8798325ee9d1afaf45a67ce069955e Mon Sep 17 00:00:00 2001 From: Ian Gulliver Date: Fri, 27 Jun 2025 20:29:53 -0700 Subject: [PATCH] Add Go AST analysis tools and refactor to use common walk code --- CONTEXT.md | 79 ++++++-- ast.go | 543 +++++++++++++++++++++++++++++++++++++++++++++++++++++ go.mod | 12 +- go.sum | 26 +++ main.go | 190 +++++++++++++++---- 5 files changed, 796 insertions(+), 54 deletions(-) create mode 100644 ast.go create mode 100644 go.sum diff --git a/CONTEXT.md b/CONTEXT.md index 330ddf7..c5308be 100644 --- a/CONTEXT.md +++ b/CONTEXT.md @@ -16,17 +16,74 @@ gocp is a Go MCP (Model Context Protocol) server that provides tools for building and executing Go code. It uses the go-mcp library to implement the MCP protocol. ## Key Files -- `main.go`: MCP server implementation with build_and_run_go tool +- `main.go`: MCP server implementation with tools and handlers +- `ast.go`: Go AST parsing and code analysis functionality - `go.mod`: Module definition with go-mcp dependency ## Tool Details -- **build_and_run_go**: Executes Go code using `go run` - - Parameters: - - `code` (required): Go source code to execute - - `timeout` (optional): Timeout in seconds (default: 30) - - Returns JSON with: - - `stdout`: Standard output - - `stderr`: Standard error - - `exit_code`: Process exit code - - `error`: Error message if any - - Creates temporary directories with `gocp-*` prefix \ No newline at end of file + +### build_and_run_go +Executes Go code using `go run` +- Parameters: + - `code` (required): Go source code to execute + - `timeout` (optional): Timeout in seconds (default: 30) +- Returns JSON with: + - `stdout`: Standard output + - `stderr`: Standard error + - `exit_code`: Process exit code + - `error`: Error message if any +- Creates temporary directories with `gocp-*` prefix + +### find_symbols +Find all functions, types, interfaces, constants, and variables by name/pattern +- Parameters: + - `dir` (optional): Directory to search (default: current directory) + - `pattern` (optional): Symbol name pattern to search for (case-insensitive substring match) +- Returns JSON array of symbols with: + - `name`: Symbol name + - `type`: Symbol type (function, struct, interface, constant, variable) + - `package`: Package name + - `file`: File path + - `line`: Line number + - `column`: Column number + - `exported`: Whether the symbol is exported + +### get_type_info +Get detailed information about a type including fields, methods, and embedded types +- Parameters: + - `dir` (optional): Directory to search (default: current directory) + - `type` (required): Type name to get information for +- Returns JSON with: + - `name`: Type name + - `package`: Package name + - `file`: File path + - `line`: Line number + - `kind`: Type kind (struct, interface, alias, other) + - `fields`: Array of field information (for structs) + - `methods`: Array of methods + - `embedded`: Array of embedded type names + - `interface`: Array of interface methods (for interfaces) + - `underlying`: Underlying type (for aliases) + +### find_references +Find all references to a symbol (function calls, type usage, etc.) +- Parameters: + - `dir` (optional): Directory to search (default: current directory) + - `symbol` (required): Symbol name to find references for +- Returns JSON array of references with: + - `file`: File path + - `line`: Line number + - `column`: Column number + - `context`: Code context around the reference + - `kind`: Reference kind (identifier, selector) + +### list_packages +List all Go packages in directory tree +- Parameters: + - `dir` (optional): Directory to search (default: current directory) +- Returns JSON array of packages with: + - `import_path`: Import path relative to search directory + - `name`: Package name + - `dir`: Directory path + - `go_files`: List of Go source files + - `imports`: List of imported packages \ No newline at end of file diff --git a/ast.go b/ast.go new file mode 100644 index 0000000..b454718 --- /dev/null +++ b/ast.go @@ -0,0 +1,543 @@ +package main + +import ( + "fmt" + "go/ast" + "go/parser" + "go/token" + "io/fs" + "os" + "path/filepath" + "strings" +) + +type Symbol struct { + Name string `json:"name"` + Type string `json:"type"` + Package string `json:"package"` + File string `json:"file"` + Line int `json:"line"` + Column int `json:"column"` + Exported bool `json:"exported"` +} + +type TypeInfo struct { + Name string `json:"name"` + Package string `json:"package"` + File string `json:"file"` + Line int `json:"line"` + Kind string `json:"kind"` + Fields []FieldInfo `json:"fields,omitempty"` + Methods []MethodInfo `json:"methods,omitempty"` + Embedded []string `json:"embedded,omitempty"` + Interface []MethodInfo `json:"interface,omitempty"` + Underlying string `json:"underlying,omitempty"` +} + +type FieldInfo struct { + Name string `json:"name"` + Type string `json:"type"` + Tag string `json:"tag,omitempty"` + Exported bool `json:"exported"` +} + +type MethodInfo struct { + Name string `json:"name"` + Signature string `json:"signature"` + Receiver string `json:"receiver,omitempty"` + Exported bool `json:"exported"` +} + +type Reference struct { + File string `json:"file"` + Line int `json:"line"` + Column int `json:"column"` + Context string `json:"context"` + Kind string `json:"kind"` +} + +type Package struct { + ImportPath string `json:"import_path"` + Name string `json:"name"` + Dir string `json:"dir"` + GoFiles []string `json:"go_files"` + Imports []string `json:"imports"` +} + +type fileVisitor func(path string, src []byte, file *ast.File, fset *token.FileSet) error + +func walkGoFiles(dir string, visitor fileVisitor) error { + fset := token.NewFileSet() + + return filepath.WalkDir(dir, func(path string, d fs.DirEntry, err error) error { + if err != nil { + return err + } + + if d.IsDir() || !strings.HasSuffix(path, ".go") || strings.Contains(path, "vendor/") { + return nil + } + + src, err := os.ReadFile(path) + if err != nil { + return nil + } + + file, err := parser.ParseFile(fset, path, src, parser.ParseComments) + if err != nil { + return nil + } + + return visitor(path, src, file, fset) + }) +} + +func findSymbols(dir string, pattern string) ([]Symbol, error) { + var symbols []Symbol + + err := walkGoFiles(dir, func(path string, src []byte, file *ast.File, fset *token.FileSet) error { + if strings.HasSuffix(path, "_test.go") && !strings.Contains(pattern, "Test") { + return nil + } + + pkgName := file.Name.Name + + ast.Inspect(file, func(n ast.Node) bool { + switch decl := n.(type) { + case *ast.FuncDecl: + name := decl.Name.Name + if matchesPattern(name, pattern) { + pos := fset.Position(decl.Pos()) + symbols = append(symbols, Symbol{ + Name: name, + Type: "function", + Package: pkgName, + File: path, + Line: pos.Line, + Column: pos.Column, + Exported: ast.IsExported(name), + }) + } + + case *ast.GenDecl: + for _, spec := range decl.Specs { + switch s := spec.(type) { + case *ast.TypeSpec: + name := s.Name.Name + if matchesPattern(name, pattern) { + pos := fset.Position(s.Pos()) + kind := "type" + switch s.Type.(type) { + case *ast.InterfaceType: + kind = "interface" + case *ast.StructType: + kind = "struct" + } + symbols = append(symbols, Symbol{ + Name: name, + Type: kind, + Package: pkgName, + File: path, + Line: pos.Line, + Column: pos.Column, + Exported: ast.IsExported(name), + }) + } + + case *ast.ValueSpec: + for _, name := range s.Names { + if matchesPattern(name.Name, pattern) { + pos := fset.Position(name.Pos()) + kind := "variable" + if decl.Tok == token.CONST { + kind = "constant" + } + symbols = append(symbols, Symbol{ + Name: name.Name, + Type: kind, + Package: pkgName, + File: path, + Line: pos.Line, + Column: pos.Column, + Exported: ast.IsExported(name.Name), + }) + } + } + } + } + } + return true + }) + + return nil + }) + + return symbols, err +} + +func getTypeInfo(dir string, typeName string) (*TypeInfo, error) { + var result *TypeInfo + + err := walkGoFiles(dir, func(path string, src []byte, file *ast.File, fset *token.FileSet) error { + if result != nil { + return nil + } + + ast.Inspect(file, func(n ast.Node) bool { + if result != nil { + return false + } + + switch decl := n.(type) { + case *ast.GenDecl: + for _, spec := range decl.Specs { + if ts, ok := spec.(*ast.TypeSpec); ok && ts.Name.Name == typeName { + pos := fset.Position(ts.Pos()) + info := &TypeInfo{ + Name: typeName, + Package: file.Name.Name, + File: path, + Line: pos.Line, + } + + switch t := ts.Type.(type) { + case *ast.StructType: + info.Kind = "struct" + info.Fields = extractFields(t) + info.Embedded = extractEmbedded(t) + + case *ast.InterfaceType: + info.Kind = "interface" + info.Interface = extractInterfaceMethods(t) + + case *ast.Ident: + info.Kind = "alias" + info.Underlying = t.Name + + case *ast.SelectorExpr: + info.Kind = "alias" + if x, ok := t.X.(*ast.Ident); ok { + info.Underlying = x.Name + "." + t.Sel.Name + } + + default: + info.Kind = "other" + } + + info.Methods = extractMethods(file, typeName) + result = info + return false + } + } + } + return true + }) + + return nil + }) + + if result == nil && err == nil { + return nil, fmt.Errorf("type %s not found", typeName) + } + + return result, err +} + +func findReferences(dir string, symbol string) ([]Reference, error) { + var refs []Reference + + err := walkGoFiles(dir, func(path string, src []byte, file *ast.File, fset *token.FileSet) error { + + ast.Inspect(file, func(n ast.Node) bool { + switch node := n.(type) { + case *ast.Ident: + if node.Name == symbol { + pos := fset.Position(node.Pos()) + kind := identifyReferenceKind(node) + context := extractContext(src, pos) + + refs = append(refs, Reference{ + File: path, + Line: pos.Line, + Column: pos.Column, + Context: context, + Kind: kind, + }) + } + + case *ast.SelectorExpr: + if node.Sel.Name == symbol { + pos := fset.Position(node.Sel.Pos()) + context := extractContext(src, pos) + + refs = append(refs, Reference{ + File: path, + Line: pos.Line, + Column: pos.Column, + Context: context, + Kind: "selector", + }) + } + } + return true + }) + + return nil + }) + + return refs, err +} + +func listPackages(dir string, includeTests bool) ([]Package, error) { + packages := make(map[string]*Package) + + err := walkGoFiles(dir, func(path string, src []byte, file *ast.File, fset *token.FileSet) error { + // Skip test files if not requested + if !includeTests && strings.HasSuffix(path, "_test.go") { + return nil + } + + pkgDir := filepath.Dir(path) + + // Initialize package if not seen before + if _, exists := packages[pkgDir]; !exists { + importPath := strings.TrimPrefix(pkgDir, dir) + importPath = strings.TrimPrefix(importPath, "/") + if importPath == "" { + importPath = "." + } + + packages[pkgDir] = &Package{ + ImportPath: importPath, + Name: file.Name.Name, + Dir: pkgDir, + GoFiles: []string{}, + Imports: []string{}, + } + } + + // Add file to package + fileName := filepath.Base(path) + packages[pkgDir].GoFiles = append(packages[pkgDir].GoFiles, fileName) + + // Collect unique imports + imports := make(map[string]bool) + for _, imp := range file.Imports { + importPath := strings.Trim(imp.Path.Value, `"`) + imports[importPath] = true + } + + // Merge imports into package + existingImports := make(map[string]bool) + for _, imp := range packages[pkgDir].Imports { + existingImports[imp] = true + } + for imp := range imports { + if !existingImports[imp] { + packages[pkgDir].Imports = append(packages[pkgDir].Imports, imp) + } + } + + return nil + }) + + if err != nil { + return nil, err + } + + var result []Package + for _, pkg := range packages { + result = append(result, *pkg) + } + + return result, nil +} + +func matchesPattern(name, pattern string) bool { + if pattern == "" { + return true + } + pattern = strings.ToLower(pattern) + name = strings.ToLower(name) + return strings.Contains(name, pattern) +} + +func extractFields(st *ast.StructType) []FieldInfo { + var fields []FieldInfo + + for _, field := range st.Fields.List { + fieldType := exprToString(field.Type) + tag := "" + if field.Tag != nil { + tag = field.Tag.Value + } + + if len(field.Names) == 0 { + fields = append(fields, FieldInfo{ + Name: "", + Type: fieldType, + Tag: tag, + Exported: true, + }) + } else { + for _, name := range field.Names { + fields = append(fields, FieldInfo{ + Name: name.Name, + Type: fieldType, + Tag: tag, + Exported: ast.IsExported(name.Name), + }) + } + } + } + + return fields +} + +func extractEmbedded(st *ast.StructType) []string { + var embedded []string + + for _, field := range st.Fields.List { + if len(field.Names) == 0 { + embedded = append(embedded, exprToString(field.Type)) + } + } + + return embedded +} + +func extractInterfaceMethods(it *ast.InterfaceType) []MethodInfo { + var methods []MethodInfo + + for _, method := range it.Methods.List { + if len(method.Names) > 0 { + for _, name := range method.Names { + sig := exprToString(method.Type) + methods = append(methods, MethodInfo{ + Name: name.Name, + Signature: sig, + Exported: ast.IsExported(name.Name), + }) + } + } + } + + return methods +} + +func extractMethods(file *ast.File, typeName string) []MethodInfo { + var methods []MethodInfo + + for _, decl := range file.Decls { + if fn, ok := decl.(*ast.FuncDecl); ok && fn.Recv != nil { + for _, recv := range fn.Recv.List { + recvType := exprToString(recv.Type) + if strings.Contains(recvType, typeName) { + sig := funcSignature(fn.Type) + methods = append(methods, MethodInfo{ + Name: fn.Name.Name, + Signature: sig, + Receiver: recvType, + Exported: ast.IsExported(fn.Name.Name), + }) + } + } + } + } + + return methods +} + +func exprToString(expr ast.Expr) string { + switch e := expr.(type) { + case *ast.Ident: + return e.Name + case *ast.StarExpr: + return "*" + exprToString(e.X) + case *ast.SelectorExpr: + return exprToString(e.X) + "." + e.Sel.Name + case *ast.ArrayType: + if e.Len == nil { + return "[]" + exprToString(e.Elt) + } + return "[" + exprToString(e.Len) + "]" + exprToString(e.Elt) + case *ast.MapType: + return "map[" + exprToString(e.Key) + "]" + exprToString(e.Value) + case *ast.InterfaceType: + if len(e.Methods.List) == 0 { + return "interface{}" + } + return "interface{...}" + case *ast.FuncType: + return funcSignature(e) + case *ast.ChanType: + switch e.Dir { + case ast.SEND: + return "chan<- " + exprToString(e.Value) + case ast.RECV: + return "<-chan " + exprToString(e.Value) + default: + return "chan " + exprToString(e.Value) + } + case *ast.BasicLit: + return e.Value + default: + return fmt.Sprintf("%T", expr) + } +} + +func funcSignature(fn *ast.FuncType) string { + params := fieldListToString(fn.Params) + results := fieldListToString(fn.Results) + + if results == "" { + return fmt.Sprintf("func(%s)", params) + } + return fmt.Sprintf("func(%s) %s", params, results) +} + +func fieldListToString(fl *ast.FieldList) string { + if fl == nil || len(fl.List) == 0 { + return "" + } + + var parts []string + for _, field := range fl.List { + fieldType := exprToString(field.Type) + if len(field.Names) == 0 { + parts = append(parts, fieldType) + } else { + for _, name := range field.Names { + parts = append(parts, name.Name+" "+fieldType) + } + } + } + + if len(parts) == 1 && !strings.Contains(parts[0], " ") { + return parts[0] + } + return "(" + strings.Join(parts, ", ") + ")" +} + +func identifyReferenceKind(ident *ast.Ident) string { + return "identifier" +} + +func extractContext(src []byte, pos token.Position) string { + lines := strings.Split(string(src), "\n") + if pos.Line <= 0 || pos.Line > len(lines) { + return "" + } + + start := pos.Line - 2 + if start < 0 { + start = 0 + } + end := pos.Line + 1 + if end > len(lines) { + end = len(lines) + } + + context := strings.Join(lines[start:end], "\n") + return strings.TrimSpace(context) +} \ No newline at end of file diff --git a/go.mod b/go.mod index 31aa5fe..e06f7a5 100644 --- a/go.mod +++ b/go.mod @@ -1,5 +1,13 @@ module github.com/flamingcow/gocp -go 1.21 +go 1.23.2 -require github.com/mark3labs/mcp-go v0.1.0 \ No newline at end of file +toolchain go1.24.4 + +require github.com/mark3labs/mcp-go v0.32.0 + +require ( + github.com/google/uuid v1.6.0 // indirect + github.com/spf13/cast v1.7.1 // indirect + github.com/yosida95/uritemplate/v3 v3.0.2 // indirect +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..a735303 --- /dev/null +++ b/go.sum @@ -0,0 +1,26 @@ +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8= +github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0= +github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= +github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= +github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= +github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= +github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= +github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= +github.com/mark3labs/mcp-go v0.32.0 h1:fgwmbfL2gbd67obg57OfV2Dnrhs1HtSdlY/i5fn7MU8= +github.com/mark3labs/mcp-go v0.32.0/go.mod h1:rXqOudj/djTORU/ThxYx8fqEVj/5pvTuuebQ2RC7uk4= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/rogpeppe/go-internal v1.9.0 h1:73kH8U+JUqXU8lRuOHeVHaa/SZPifC7BkcraZVejAe8= +github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs= +github.com/spf13/cast v1.7.1 h1:cuNEagBQEHWN1FnbGEjCXL2szYEXqfJPbP2HNUaca9Y= +github.com/spf13/cast v1.7.1/go.mod h1:ancEpBxwJDODSW/UG4rDrAqiKolqNNh2DX3mk86cAdo= +github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= +github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4= +github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/main.go b/main.go index 068b14d..a6537e5 100644 --- a/main.go +++ b/main.go @@ -22,119 +22,227 @@ type RunResult struct { } func main() { - // Create MCP server - s := server.NewMCPServer( - "go-executor", + mcpServer := server.NewMCPServer( + "gocp", "1.0.0", - server.WithToolCapabilities(true), + server.WithToolCapabilities(false), ) // Define the build_and_run_go tool - buildAndRunTool := mcp.NewTool( - "build_and_run_go", + buildAndRunTool := mcp.NewTool("build_and_run_go", mcp.WithDescription("Build and execute Go code"), - mcp.WithString("code", mcp.Required(), mcp.Description("The Go source code to build and run")), - mcp.WithNumber("timeout", mcp.Description("Timeout in seconds (default: 30)")), + mcp.WithString("code", + mcp.Required(), + mcp.Description("The Go source code to build and run"), + ), + mcp.WithNumber("timeout", + mcp.Description("Timeout in seconds (default: 30)"), + ), ) + mcpServer.AddTool(buildAndRunTool, buildAndRunHandler) - // Add tool handler - s.AddTool(buildAndRunTool, buildAndRunHandler) + // Define the find_symbols tool + findSymbolsTool := mcp.NewTool("find_symbols", + mcp.WithDescription("Find all functions, types, interfaces, constants, and variables by name/pattern"), + mcp.WithString("dir", + mcp.Description("Directory to search (default: current directory)"), + ), + mcp.WithString("pattern", + mcp.Description("Symbol name pattern to search for (case-insensitive substring match)"), + ), + ) + mcpServer.AddTool(findSymbolsTool, findSymbolsHandler) + + // Define the get_type_info tool + getTypeInfoTool := mcp.NewTool("get_type_info", + mcp.WithDescription("Get detailed information about a type including fields, methods, and embedded types"), + mcp.WithString("dir", + mcp.Description("Directory to search (default: current directory)"), + ), + mcp.WithString("type", + mcp.Required(), + mcp.Description("Type name to get information for"), + ), + ) + mcpServer.AddTool(getTypeInfoTool, getTypeInfoHandler) + + // Define the find_references tool + findReferencesTool := mcp.NewTool("find_references", + mcp.WithDescription("Find all references to a symbol (function calls, type usage, etc.)"), + mcp.WithString("dir", + mcp.Description("Directory to search (default: current directory)"), + ), + mcp.WithString("symbol", + mcp.Required(), + mcp.Description("Symbol name to find references for"), + ), + ) + mcpServer.AddTool(findReferencesTool, findReferencesHandler) + + // Define the list_packages tool + listPackagesTool := mcp.NewTool("list_packages", + mcp.WithDescription("List all Go packages in directory tree"), + mcp.WithString("dir", + mcp.Description("Directory to search (default: current directory)"), + ), + mcp.WithBoolean("include_tests", + mcp.Description("Include test files in package listings (default: false)"), + ), + ) + mcpServer.AddTool(listPackagesTool, listPackagesHandler) // Start the server - if err := s.Serve(); err != nil { + if err := server.ServeStdio(mcpServer); err != nil { fmt.Fprintf(os.Stderr, "Server error: %v\n", err) os.Exit(1) } } -func buildAndRunHandler(ctx context.Context, args map[string]interface{}) (*mcp.CallToolResult, error) { - // Extract code parameter - code, ok := args["code"].(string) - if !ok { - return nil, fmt.Errorf("code parameter is required and must be a string") +func buildAndRunHandler(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + code, err := request.RequireString("code") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil } - // Extract timeout parameter (optional) - timeout := 30.0 - if t, ok := args["timeout"].(float64); ok { - timeout = t - } + timeout := request.GetFloat("timeout", 30.0) - // Build and run the code - stdout, stderr, exitCode, err := buildAndRunGo(code, time.Duration(timeout)*time.Second) + stdout, stderr, exitCode, runErr := buildAndRunGo(code, time.Duration(timeout)*time.Second) - // Create structured result result := RunResult{ Stdout: stdout, Stderr: stderr, ExitCode: exitCode, } - if err != nil { - result.Error = err.Error() + if runErr != nil { + result.Error = runErr.Error() } - // Convert to JSON jsonData, err := json.Marshal(result) if err != nil { - return nil, fmt.Errorf("failed to marshal result: %w", err) + return mcp.NewToolResultError(fmt.Sprintf("failed to marshal result: %v", err)), nil } - return mcp.NewCallToolResult( - mcp.NewTextContent(string(jsonData)), - ), nil + return mcp.NewToolResultText(string(jsonData)), nil } func buildAndRunGo(code string, timeout time.Duration) (stdout, stderr string, exitCode int, err error) { - // Create temporary directory tmpDir, err := os.MkdirTemp("", "gocp-*") if err != nil { return "", "", -1, fmt.Errorf("failed to create temp dir: %w", err) } defer os.RemoveAll(tmpDir) - // Write code to temporary file tmpFile := filepath.Join(tmpDir, "main.go") if err := os.WriteFile(tmpFile, []byte(code), 0644); err != nil { return "", "", -1, fmt.Errorf("failed to write code: %w", err) } - // Create context with timeout ctx, cancel := context.WithTimeout(context.Background(), timeout) defer cancel() - // Initialize go.mod in temp directory modCmd := exec.CommandContext(ctx, "go", "mod", "init", "temp") modCmd.Dir = tmpDir if err := modCmd.Run(); err != nil { return "", "", -1, fmt.Errorf("failed to initialize go.mod: %w", err) } - // Run the code directly with go run runCmd := exec.CommandContext(ctx, "go", "run", tmpFile) runCmd.Dir = tmpDir - // Capture stdout and stderr separately var stdoutBuf, stderrBuf bytes.Buffer runCmd.Stdout = &stdoutBuf runCmd.Stderr = &stderrBuf - // Run the command err = runCmd.Run() - // Get exit code exitCode = 0 if err != nil { if exitErr, ok := err.(*exec.ExitError); ok { exitCode = exitErr.ExitCode() - err = nil // Clear error since we got the exit code + err = nil } else if ctx.Err() == context.DeadlineExceeded { return stdoutBuf.String(), stderrBuf.String(), -1, fmt.Errorf("execution timeout exceeded") } else { - // Some other error occurred return stdoutBuf.String(), stderrBuf.String(), -1, err } } return stdoutBuf.String(), stderrBuf.String(), exitCode, nil +} + +func findSymbolsHandler(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + dir := request.GetString("dir", "./") + pattern := request.GetString("pattern", "") + + symbols, err := findSymbols(dir, pattern) + if err != nil { + return mcp.NewToolResultError(fmt.Sprintf("failed to find symbols: %v", err)), nil + } + + jsonData, err := json.Marshal(symbols) + if err != nil { + return mcp.NewToolResultError(fmt.Sprintf("failed to marshal symbols: %v", err)), nil + } + + return mcp.NewToolResultText(string(jsonData)), nil +} + +func getTypeInfoHandler(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + dir := request.GetString("dir", "./") + + typeName, err := request.RequireString("type") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + + info, err := getTypeInfo(dir, typeName) + if err != nil { + return mcp.NewToolResultError(fmt.Sprintf("failed to get type info: %v", err)), nil + } + + jsonData, err := json.Marshal(info) + if err != nil { + return mcp.NewToolResultError(fmt.Sprintf("failed to marshal type info: %v", err)), nil + } + + return mcp.NewToolResultText(string(jsonData)), nil +} + +func findReferencesHandler(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + dir := request.GetString("dir", "./") + + symbol, err := request.RequireString("symbol") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + + refs, err := findReferences(dir, symbol) + if err != nil { + return mcp.NewToolResultError(fmt.Sprintf("failed to find references: %v", err)), nil + } + + jsonData, err := json.Marshal(refs) + if err != nil { + return mcp.NewToolResultError(fmt.Sprintf("failed to marshal references: %v", err)), nil + } + + return mcp.NewToolResultText(string(jsonData)), nil +} + +func listPackagesHandler(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + dir := request.GetString("dir", "./") + includeTests := request.GetBool("include_tests", false) + + packages, err := listPackages(dir, includeTests) + if err != nil { + return mcp.NewToolResultError(fmt.Sprintf("failed to list packages: %v", err)), nil + } + + jsonData, err := json.Marshal(packages) + if err != nil { + return mcp.NewToolResultError(fmt.Sprintf("failed to marshal packages: %v", err)), nil + } + + return mcp.NewToolResultText(string(jsonData)), nil } \ No newline at end of file