From d16a911a5c858d3439974e667574b824ac97aa78 Mon Sep 17 00:00:00 2001 From: Ian Gulliver Date: Fri, 27 Jun 2025 21:35:37 -0700 Subject: [PATCH] Add read_range, write_range, and search_replace tools --- main.go | 188 ++++++++++++++++++++++++++ tool_read_write_range.go | 191 ++++++++++++++++++++++++++ tool_search_replace.go | 281 +++++++++++++++++++++++++++++++++++++++ 3 files changed, 660 insertions(+) create mode 100644 tool_read_write_range.go create mode 100644 tool_search_replace.go diff --git a/main.go b/main.go index 7f655ed..534a24b 100644 --- a/main.go +++ b/main.go @@ -8,6 +8,7 @@ import ( "os" "os/exec" "path/filepath" + "strings" "time" "github.com/mark3labs/mcp-go/mcp" @@ -318,6 +319,95 @@ func main() { ) mcpServer.AddTool(findMissingTestsTool, findMissingTestsHandler) + // Define the read_range tool + readRangeTool := mcp.NewTool("read_range", + mcp.WithDescription("Read file content by line/column or byte range"), + mcp.WithString("file", + mcp.Required(), + mcp.Description("File path to read from"), + ), + mcp.WithNumber("start_line", + mcp.Description("Start line (1-based, use with end_line)"), + ), + mcp.WithNumber("end_line", + mcp.Description("End line (1-based, inclusive)"), + ), + mcp.WithNumber("start_col", + mcp.Description("Start column (1-based, optional)"), + ), + mcp.WithNumber("end_col", + mcp.Description("End column (1-based, optional)"), + ), + mcp.WithNumber("start_byte", + mcp.Description("Start byte offset (0-based, use with end_byte)"), + ), + mcp.WithNumber("end_byte", + mcp.Description("End byte offset (0-based, exclusive)"), + ), + ) + mcpServer.AddTool(readRangeTool, readRangeHandler) + + // Define the write_range tool + writeRangeTool := mcp.NewTool("write_range", + mcp.WithDescription("Write content to file at specific line/column or byte range"), + mcp.WithString("file", + mcp.Required(), + mcp.Description("File path to write to"), + ), + mcp.WithString("content", + mcp.Required(), + mcp.Description("Content to write"), + ), + mcp.WithNumber("start_line", + mcp.Description("Start line (1-based, use with end_line)"), + ), + mcp.WithNumber("end_line", + mcp.Description("End line (1-based, inclusive)"), + ), + mcp.WithNumber("start_col", + mcp.Description("Start column (1-based, optional)"), + ), + mcp.WithNumber("end_col", + mcp.Description("End column (1-based, optional)"), + ), + mcp.WithNumber("start_byte", + mcp.Description("Start byte offset (0-based, use with end_byte)"), + ), + mcp.WithNumber("end_byte", + mcp.Description("End byte offset (0-based, exclusive)"), + ), + mcp.WithString("confirm_old", + mcp.Description("Expected old content for confirmation before replacing"), + ), + ) + mcpServer.AddTool(writeRangeTool, writeRangeHandler) + + // Define the search_replace tool + searchReplaceTool := mcp.NewTool("search_replace", + mcp.WithDescription("Search and optionally replace text in files using string or regex patterns"), + mcp.WithString("paths", + mcp.Required(), + mcp.Description("File/directory path or comma-separated paths to search"), + ), + mcp.WithString("pattern", + mcp.Required(), + mcp.Description("Search pattern (string or regex)"), + ), + mcp.WithString("replacement", + mcp.Description("Replacement text (omit for search-only)"), + ), + mcp.WithBoolean("regex", + mcp.Description("Use regex pattern matching (default: false)"), + ), + mcp.WithBoolean("case_insensitive", + mcp.Description("Case-insensitive matching (default: false)"), + ), + mcp.WithBoolean("include_context", + mcp.Description("Include line context in search results (default: false)"), + ), + ) + mcpServer.AddTool(searchReplaceTool, searchReplaceHandler) + // Start the server if err := server.ServeStdio(mcpServer); err != nil { fmt.Fprintf(os.Stderr, "Server error: %v\n", err) @@ -855,5 +945,103 @@ func findMissingTestsHandler(ctx context.Context, request mcp.CallToolRequest) ( return mcp.NewToolResultError(fmt.Sprintf("failed to marshal missing tests: %v", err)), nil } + return mcp.NewToolResultText(string(jsonData)), nil +} + +func readRangeHandler(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + file, err := request.RequireString("file") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + + startLine := int(request.GetFloat("start_line", -1)) + endLine := int(request.GetFloat("end_line", -1)) + startCol := int(request.GetFloat("start_col", -1)) + endCol := int(request.GetFloat("end_col", -1)) + startByte := int(request.GetFloat("start_byte", -1)) + endByte := int(request.GetFloat("end_byte", -1)) + + result, err := readRange(file, startLine, endLine, startCol, endCol, startByte, endByte) + if err != nil { + return mcp.NewToolResultError(fmt.Sprintf("failed to read range: %v", err)), nil + } + + jsonData, err := json.Marshal(result) + if err != nil { + return mcp.NewToolResultError(fmt.Sprintf("failed to marshal result: %v", err)), nil + } + + return mcp.NewToolResultText(string(jsonData)), nil +} + +func writeRangeHandler(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + file, err := request.RequireString("file") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + + content, err := request.RequireString("content") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + + startLine := int(request.GetFloat("start_line", -1)) + endLine := int(request.GetFloat("end_line", -1)) + startCol := int(request.GetFloat("start_col", -1)) + endCol := int(request.GetFloat("end_col", -1)) + startByte := int(request.GetFloat("start_byte", -1)) + endByte := int(request.GetFloat("end_byte", -1)) + confirmOld := request.GetString("confirm_old", "") + + result, err := writeRange(file, content, startLine, endLine, startCol, endCol, startByte, endByte, confirmOld) + if err != nil { + return mcp.NewToolResultError(fmt.Sprintf("failed to write range: %v", err)), nil + } + + jsonData, err := json.Marshal(result) + if err != nil { + return mcp.NewToolResultError(fmt.Sprintf("failed to marshal result: %v", err)), nil + } + + return mcp.NewToolResultText(string(jsonData)), nil +} + +func searchReplaceHandler(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + // For now, paths will be a single string + pathStr, err := request.RequireString("paths") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + + // Split paths by comma if multiple + paths := strings.Split(pathStr, ",") + for i := range paths { + paths[i] = strings.TrimSpace(paths[i]) + } + + pattern, err := request.RequireString("pattern") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + + var replacement *string + if r := request.GetString("replacement", ""); r != "" { + replacement = &r + } + + useRegex := request.GetBool("regex", false) + caseInsensitive := request.GetBool("case_insensitive", false) + includeContext := request.GetBool("include_context", false) + + result, err := searchReplace(paths, pattern, replacement, useRegex, caseInsensitive, includeContext) + if err != nil { + return mcp.NewToolResultError(fmt.Sprintf("search/replace failed: %v", err)), nil + } + + jsonData, err := json.Marshal(result) + if err != nil { + return mcp.NewToolResultError(fmt.Sprintf("failed to marshal result: %v", err)), nil + } + return mcp.NewToolResultText(string(jsonData)), nil } \ No newline at end of file diff --git a/tool_read_write_range.go b/tool_read_write_range.go new file mode 100644 index 0000000..9256192 --- /dev/null +++ b/tool_read_write_range.go @@ -0,0 +1,191 @@ +package main + +import ( + "fmt" + "os" + "strings" +) + +type ReadRangeResult struct { + Content string `json:"content"` + Lines int `json:"lines"` + Bytes int `json:"bytes"` +} + +type WriteRangeResult struct { + Success bool `json:"success"` + LinesWritten int `json:"lines_written"` + BytesWritten int `json:"bytes_written"` + Message string `json:"message,omitempty"` +} + +// Helper function to convert line/column positions to byte offsets +func lineColToByteRange(data []byte, startLine, endLine, startCol, endCol int) (startByte, endByte int, err error) { + if len(data) == 0 { + return 0, 0, nil + } + + // Convert to 0-based indexing + if startLine > 0 { + startLine-- + } + if endLine > 0 { + endLine-- + } + if startCol > 0 { + startCol-- + } + if endCol > 0 { + endCol-- + } + + currentLine := 0 + currentCol := 0 + startByte = -1 + endByte = -1 + + for i := 0; i < len(data); i++ { + // Check if we're at the start position + if currentLine == startLine && currentCol == startCol && startByte == -1 { + startByte = i + } + + // Check if we're at the end position + if currentLine == endLine { + if endCol < 0 { + // No end column specified, go to end of line + for j := i; j < len(data) && data[j] != '\n'; j++ { + i = j + } + endByte = i + 1 + if endByte > len(data) { + endByte = len(data) + } + break + } else if currentCol == endCol { + endByte = i + break + } + } + + // Move to next character + if data[i] == '\n' { + // End of line reached + if currentLine == endLine && endByte == -1 { + endByte = i + break + } + currentLine++ + currentCol = 0 + } else { + currentCol++ + } + + // If we've passed the end line, set end byte + if currentLine > endLine && endByte == -1 { + endByte = i + break + } + } + + // Handle end of file cases + if startByte == -1 { + return 0, 0, fmt.Errorf("start position (line %d, col %d) not found", startLine+1, startCol+1) + } + if endByte == -1 { + endByte = len(data) + } + + return startByte, endByte, nil +} + +func readRange(file string, startLine, endLine, startCol, endCol, startByte, endByte int) (*ReadRangeResult, error) { + data, err := os.ReadFile(file) + if err != nil { + return nil, fmt.Errorf("failed to read file: %w", err) + } + + // Convert line/column to byte range if needed + if startByte < 0 || endByte < 0 { + startByte, endByte, err = lineColToByteRange(data, startLine, endLine, startCol, endCol) + if err != nil { + return nil, err + } + } + + // Validate byte range + if startByte < 0 || startByte > len(data) { + return nil, fmt.Errorf("start byte %d out of range (file size: %d)", startByte, len(data)) + } + if endByte < startByte { + return nil, fmt.Errorf("end byte %d is before start byte %d", endByte, startByte) + } + if endByte > len(data) { + endByte = len(data) + } + + // Extract content + content := string(data[startByte:endByte]) + + return &ReadRangeResult{ + Content: content, + Lines: strings.Count(content, "\n") + 1, + Bytes: len(content), + }, nil +} + +func writeRange(file string, content string, startLine, endLine, startCol, endCol, startByte, endByte int, confirmOld string) (*WriteRangeResult, error) { + data, err := os.ReadFile(file) + if err != nil { + return nil, fmt.Errorf("failed to read file: %w", err) + } + + // Convert line/column to byte range if needed + if startByte < 0 || endByte < 0 { + startByte, endByte, err = lineColToByteRange(data, startLine, endLine, startCol, endCol) + if err != nil { + return nil, err + } + } + + // Validate byte range + if startByte < 0 || startByte > len(data) { + return nil, fmt.Errorf("start byte %d out of range (file size: %d)", startByte, len(data)) + } + if endByte < startByte { + return nil, fmt.Errorf("end byte %d is before start byte %d", endByte, startByte) + } + if endByte > len(data) { + endByte = len(data) + } + + // Extract old content + oldContent := string(data[startByte:endByte]) + + // Check confirmation if provided + if confirmOld != "" && oldContent != confirmOld { + return &WriteRangeResult{ + Success: false, + Message: fmt.Sprintf("content mismatch: expected %q but found %q", confirmOld, oldContent), + }, nil + } + + // Build new content + newData := make([]byte, 0, len(data)-len(oldContent)+len(content)) + newData = append(newData, data[:startByte]...) + newData = append(newData, []byte(content)...) + newData = append(newData, data[endByte:]...) + + // Write the file + err = os.WriteFile(file, newData, 0644) + if err != nil { + return nil, fmt.Errorf("failed to write file: %w", err) + } + + return &WriteRangeResult{ + Success: true, + LinesWritten: strings.Count(content, "\n") + 1, + BytesWritten: len(content), + Message: "Successfully written", + }, nil +} \ No newline at end of file diff --git a/tool_search_replace.go b/tool_search_replace.go new file mode 100644 index 0000000..3f95571 --- /dev/null +++ b/tool_search_replace.go @@ -0,0 +1,281 @@ +package main + +import ( + "fmt" + "io/fs" + "os" + "path/filepath" + "regexp" + "strings" +) + +type SearchReplaceResult struct { + Files []FileSearchReplaceResult `json:"files"` + TotalMatches int `json:"total_matches"` + TotalReplaced int `json:"total_replaced,omitempty"` +} + +type FileSearchReplaceResult struct { + Path string `json:"path"` + Matches []SearchMatch `json:"matches,omitempty"` + Replaced int `json:"replaced,omitempty"` + Error string `json:"error,omitempty"` +} + +type SearchMatch struct { + Line int `json:"line"` + Column int `json:"column"` + Text string `json:"text"` + Context string `json:"context,omitempty"` +} + +func searchReplace(paths []string, pattern string, replacement *string, useRegex, caseInsensitive bool, includeContext bool) (*SearchReplaceResult, error) { + result := &SearchReplaceResult{ + Files: []FileSearchReplaceResult{}, + } + + // Prepare search/replace function + var searchFunc func(string) [][]int + var replaceFunc func(string) string + + if useRegex { + flags := "" + if caseInsensitive { + flags = "(?i)" + } + re, err := regexp.Compile(flags + pattern) + if err != nil { + return nil, fmt.Errorf("invalid regex pattern: %w", err) + } + searchFunc = func(text string) [][]int { + return re.FindAllStringIndex(text, -1) + } + if replacement != nil { + replaceFunc = func(text string) string { + return re.ReplaceAllString(text, *replacement) + } + } + } else { + searchPattern := pattern + if caseInsensitive { + searchPattern = strings.ToLower(pattern) + } + searchFunc = func(text string) [][]int { + searchText := text + if caseInsensitive { + searchText = strings.ToLower(text) + } + var matches [][]int + start := 0 + for { + idx := strings.Index(searchText[start:], searchPattern) + if idx < 0 { + break + } + realIdx := start + idx + matches = append(matches, []int{realIdx, realIdx + len(pattern)}) + start = realIdx + len(pattern) + } + return matches + } + if replacement != nil { + replaceFunc = func(text string) string { + if caseInsensitive { + // Case-insensitive string replacement + return caseInsensitiveReplace(text, pattern, *replacement) + } + return strings.ReplaceAll(text, pattern, *replacement) + } + } + } + + for _, path := range paths { + info, err := os.Stat(path) + if err != nil { + result.Files = append(result.Files, FileSearchReplaceResult{ + Path: path, + Error: fmt.Sprintf("stat error: %v", err), + }) + continue + } + + if info.IsDir() { + // Process directory tree + err := filepath.WalkDir(path, func(filePath string, d fs.DirEntry, err error) error { + if err != nil || d.IsDir() { + return nil + } + + // Skip non-text files + if !isTextFile(filePath) { + return nil + } + + fileResult := processFile(filePath, searchFunc, replaceFunc, includeContext) + if len(fileResult.Matches) > 0 || fileResult.Replaced > 0 || fileResult.Error != "" { + result.Files = append(result.Files, fileResult) + result.TotalMatches += len(fileResult.Matches) + result.TotalReplaced += fileResult.Replaced + } + return nil + }) + if err != nil { + return nil, err + } + } else { + // Process single file + fileResult := processFile(path, searchFunc, replaceFunc, includeContext) + if len(fileResult.Matches) > 0 || fileResult.Replaced > 0 || fileResult.Error != "" { + result.Files = append(result.Files, fileResult) + result.TotalMatches += len(fileResult.Matches) + result.TotalReplaced += fileResult.Replaced + } + } + } + + return result, nil +} + +func processFile(path string, searchFunc func(string) [][]int, replaceFunc func(string) string, includeContext bool) FileSearchReplaceResult { + result := FileSearchReplaceResult{ + Path: path, + } + + // Read file + data, err := os.ReadFile(path) + if err != nil { + result.Error = fmt.Sprintf("read error: %v", err) + return result + } + + content := string(data) + + // If replacement is requested, do it + if replaceFunc != nil { + matches := searchFunc(content) + result.Replaced = len(matches) + + if result.Replaced > 0 { + newContent := replaceFunc(content) + err = os.WriteFile(path, []byte(newContent), 0644) + if err != nil { + result.Error = fmt.Sprintf("write error: %v", err) + result.Replaced = 0 + } + } + return result + } + + // Otherwise, just search + lines := strings.Split(content, "\n") + lineStarts := make([]int, len(lines)) + pos := 0 + for i, line := range lines { + lineStarts[i] = pos + pos += len(line) + 1 // +1 for newline + } + + matches := searchFunc(content) + for _, match := range matches { + // Find line number + lineNum := 0 + for i, start := range lineStarts { + if match[0] >= start && (i == len(lineStarts)-1 || match[0] < lineStarts[i+1]) { + lineNum = i + 1 + break + } + } + + // Calculate column + lineStart := 0 + if lineNum > 0 { + lineStart = lineStarts[lineNum-1] + } + column := match[0] - lineStart + 1 + + searchMatch := SearchMatch{ + Line: lineNum, + Column: column, + Text: content[match[0]:match[1]], + } + + if includeContext && lineNum > 0 && lineNum <= len(lines) { + searchMatch.Context = strings.TrimSpace(lines[lineNum-1]) + } + + result.Matches = append(result.Matches, searchMatch) + } + + return result +} + +func caseInsensitiveReplace(text, old, new string) string { + // Simple case-insensitive replacement + var result strings.Builder + lowerText := strings.ToLower(text) + lowerOld := strings.ToLower(old) + + start := 0 + for { + idx := strings.Index(lowerText[start:], lowerOld) + if idx < 0 { + result.WriteString(text[start:]) + break + } + realIdx := start + idx + result.WriteString(text[start:realIdx]) + result.WriteString(new) + start = realIdx + len(old) + } + + return result.String() +} + +func isTextFile(path string) bool { + ext := strings.ToLower(filepath.Ext(path)) + textExts := map[string]bool{ + ".go": true, + ".txt": true, + ".md": true, + ".json": true, + ".yaml": true, + ".yml": true, + ".toml": true, + ".xml": true, + ".html": true, + ".css": true, + ".js": true, + ".ts": true, + ".py": true, + ".rb": true, + ".java": true, + ".c": true, + ".cpp": true, + ".h": true, + ".hpp": true, + ".rs": true, + ".sh": true, + ".bash": true, + ".zsh": true, + ".fish": true, + ".sql": true, + ".proto": true, + ".mod": true, + ".sum": true, + } + + // Check extension + if textExts[ext] { + return true + } + + // Check for files without extension that might be text + base := filepath.Base(path) + if base == "Makefile" || base == "Dockerfile" || base == "README" || + base == "LICENSE" || base == "CHANGELOG" || base == "TODO" || + strings.HasPrefix(base, ".") { // dotfiles are often text + return true + } + + return false +} \ No newline at end of file