diff --git a/main.go b/main.go index 534a24b..f9054a6 100644 --- a/main.go +++ b/main.go @@ -384,14 +384,14 @@ func main() { // Define the search_replace tool searchReplaceTool := mcp.NewTool("search_replace", - mcp.WithDescription("Search and optionally replace text in files using string or regex patterns"), + mcp.WithDescription("Search and optionally replace text in files. Supports context-aware replacements with capture groups."), mcp.WithString("paths", mcp.Required(), mcp.Description("File/directory path or comma-separated paths to search"), ), mcp.WithString("pattern", mcp.Required(), - mcp.Description("Search pattern (string or regex)"), + mcp.Description("Search pattern (string or regex). Used for simple search unless context_pattern is provided."), ), mcp.WithString("replacement", mcp.Description("Replacement text (omit for search-only)"), @@ -405,6 +405,12 @@ func main() { mcp.WithBoolean("include_context", mcp.Description("Include line context in search results (default: false)"), ), + mcp.WithString("context_pattern", + mcp.Description("Context pattern with capture groups, e.g., '(prefix)(target)(suffix)' to replace only the target"), + ), + mcp.WithNumber("target_group", + mcp.Description("Which capture group to replace when using context_pattern (1-based, default: 2 for 3 groups, otherwise last)"), + ), ) mcpServer.AddTool(searchReplaceTool, searchReplaceHandler) @@ -1032,8 +1038,23 @@ func searchReplaceHandler(ctx context.Context, request mcp.CallToolRequest) (*mc useRegex := request.GetBool("regex", false) caseInsensitive := request.GetBool("case_insensitive", false) includeContext := request.GetBool("include_context", false) + contextPattern := request.GetString("context_pattern", "") + targetGroup := int(request.GetFloat("target_group", 0)) - result, err := searchReplace(paths, pattern, replacement, useRegex, caseInsensitive, includeContext) + // If context pattern is provided, use the context-aware replacement + if contextPattern != "" && replacement != nil && targetGroup > 0 { + result, err := searchReplaceWithGroups(paths, contextPattern, *replacement, targetGroup, caseInsensitive) + if err != nil { + return mcp.NewToolResultError(fmt.Sprintf("search/replace failed: %v", err)), nil + } + jsonData, err := json.Marshal(result) + if err != nil { + return mcp.NewToolResultError(fmt.Sprintf("failed to marshal result: %v", err)), nil + } + return mcp.NewToolResultText(string(jsonData)), nil + } + + result, err := searchReplace(paths, pattern, replacement, useRegex, caseInsensitive, includeContext, contextPattern) if err != nil { return mcp.NewToolResultError(fmt.Sprintf("search/replace failed: %v", err)), nil } diff --git a/tool_search_replace.go b/tool_search_replace.go index 3f95571..5438c53 100644 --- a/tool_search_replace.go +++ b/tool_search_replace.go @@ -29,7 +29,7 @@ type SearchMatch struct { Context string `json:"context,omitempty"` } -func searchReplace(paths []string, pattern string, replacement *string, useRegex, caseInsensitive bool, includeContext bool) (*SearchReplaceResult, error) { +func searchReplace(paths []string, pattern string, replacement *string, useRegex, caseInsensitive bool, includeContext bool, contextPattern string) (*SearchReplaceResult, error) { result := &SearchReplaceResult{ Files: []FileSearchReplaceResult{}, } @@ -38,7 +38,59 @@ func searchReplace(paths []string, pattern string, replacement *string, useRegex var searchFunc func(string) [][]int var replaceFunc func(string) string - if useRegex { + // Handle context-aware replacement + if contextPattern != "" && replacement != nil { + // Context pattern should use capture groups like "(prefix)(target)(suffix)" + // The replacement will apply to the target group + flags := "" + if caseInsensitive { + flags = "(?i)" + } + + // First validate the context pattern has at least one capture group + contextRe, err := regexp.Compile(flags + contextPattern) + if err != nil { + return nil, fmt.Errorf("invalid context pattern: %w", err) + } + + // Count capture groups + numGroups := contextRe.NumSubexp() + if numGroups == 0 { + return nil, fmt.Errorf("context pattern must have at least one capture group, e.g., '(prefix)(target)(suffix)'") + } + + searchFunc = func(text string) [][]int { + return contextRe.FindAllStringIndex(text, -1) + } + + replaceFunc = func(text string) string { + // For context replacement, we need to specify which group to replace + // By default, replace the middle group if there are 3 groups, otherwise the last group + targetGroup := numGroups + if numGroups == 3 { + targetGroup = 2 + } + + // Use ReplaceAllStringFunc to handle complex replacements + return contextRe.ReplaceAllStringFunc(text, func(match string) string { + submatches := contextRe.FindStringSubmatch(match) + if len(submatches) <= targetGroup { + return match + } + + // Rebuild the match with the target group replaced + result := "" + for i := 1; i <= numGroups; i++ { + if i == targetGroup { + result += *replacement + } else { + result += submatches[i] + } + } + return result + }) + } + } else if useRegex { flags := "" if caseInsensitive { flags = "(?i)" @@ -136,6 +188,102 @@ func searchReplace(paths []string, pattern string, replacement *string, useRegex return result, nil } +// searchReplaceWithGroups allows specifying which capture group to replace +func searchReplaceWithGroups(paths []string, contextPattern string, replacement string, targetGroup int, caseInsensitive bool) (*SearchReplaceResult, error) { + result := &SearchReplaceResult{ + Files: []FileSearchReplaceResult{}, + } + + flags := "" + if caseInsensitive { + flags = "(?i)" + } + + contextRe, err := regexp.Compile(flags + contextPattern) + if err != nil { + return nil, fmt.Errorf("invalid context pattern: %w", err) + } + + numGroups := contextRe.NumSubexp() + if numGroups == 0 { + return nil, fmt.Errorf("context pattern must have capture groups") + } + + if targetGroup < 1 || targetGroup > numGroups { + return nil, fmt.Errorf("target group %d is out of range (pattern has %d groups)", targetGroup, numGroups) + } + + searchFunc := func(text string) [][]int { + return contextRe.FindAllStringIndex(text, -1) + } + + replaceFunc := func(text string) string { + return contextRe.ReplaceAllStringFunc(text, func(match string) string { + submatches := contextRe.FindStringSubmatch(match) + if len(submatches) <= numGroups { + return match + } + + // Rebuild the match with the target group replaced + result := "" + for i := 1; i <= numGroups; i++ { + if i == targetGroup { + result += replacement + } else { + result += submatches[i] + } + } + return result + }) + } + + for _, path := range paths { + info, err := os.Stat(path) + if err != nil { + result.Files = append(result.Files, FileSearchReplaceResult{ + Path: path, + Error: fmt.Sprintf("stat error: %v", err), + }) + continue + } + + if info.IsDir() { + // Process directory tree + err := filepath.WalkDir(path, func(filePath string, d fs.DirEntry, err error) error { + if err != nil || d.IsDir() { + return nil + } + + // Skip non-text files + if !isTextFile(filePath) { + return nil + } + + fileResult := processFile(filePath, searchFunc, replaceFunc, false) + if len(fileResult.Matches) > 0 || fileResult.Replaced > 0 || fileResult.Error != "" { + result.Files = append(result.Files, fileResult) + result.TotalMatches += len(fileResult.Matches) + result.TotalReplaced += fileResult.Replaced + } + return nil + }) + if err != nil { + return nil, err + } + } else { + // Process single file + fileResult := processFile(path, searchFunc, replaceFunc, false) + if len(fileResult.Matches) > 0 || fileResult.Replaced > 0 || fileResult.Error != "" { + result.Files = append(result.Files, fileResult) + result.TotalMatches += len(fileResult.Matches) + result.TotalReplaced += fileResult.Replaced + } + } + } + + return result, nil +} + func processFile(path string, searchFunc func(string) [][]int, replaceFunc func(string) string, includeContext bool) FileSearchReplaceResult { result := FileSearchReplaceResult{ Path: path,