diff --git a/main.go b/main.go index 4c10cbf..b91eb97 100644 --- a/main.go +++ b/main.go @@ -387,7 +387,7 @@ func main() { // Define the search_replace tool searchReplaceTool := mcp.NewTool("search_replace", - mcp.WithDescription("Search and optionally replace text in files. Supports context-aware replacements with capture groups."), + mcp.WithDescription("Search and optionally replace text in files. Supports context-aware replacements with before/after patterns."), mcp.WithString("paths", mcp.Required(), mcp.Description("File/directory path or comma-separated paths to search"), @@ -408,11 +408,11 @@ func main() { mcp.WithBoolean("include_context", mcp.Description("Include line context in search results (default: false)"), ), - mcp.WithString("context_pattern", - mcp.Description("Context pattern with capture groups, e.g., '(prefix)(target)(suffix)' to replace only the target"), + mcp.WithString("before_pattern", + mcp.Description("Pattern that must appear before the main pattern (for context-aware replacement)"), ), - mcp.WithNumber("target_group", - mcp.Description("Which capture group to replace when using context_pattern (1-based, default: 2 for 3 groups, otherwise last)"), + mcp.WithString("after_pattern", + mcp.Description("Pattern that must appear after the main pattern (for context-aware replacement)"), ), ) mcpServer.AddTool(searchReplaceTool, searchReplaceHandler) @@ -1035,30 +1035,17 @@ func searchReplaceHandler(ctx context.Context, request mcp.CallToolRequest) (*mc } var replacement *string - if r := request.GetString("replacement", ""); r != "" { + if r, exists := request.GetOptionalString("replacement"); exists { replacement = &r } useRegex := request.GetBool("regex", false) caseInsensitive := request.GetBool("case_insensitive", false) includeContext := request.GetBool("include_context", false) - contextPattern := request.GetString("context_pattern", "") - targetGroup := int(request.GetFloat("target_group", 0)) + beforePattern := request.GetString("before_pattern", "") + afterPattern := request.GetString("after_pattern", "") - // If context pattern is provided, use the context-aware replacement - if contextPattern != "" && replacement != nil && targetGroup > 0 { - result, err := searchReplaceWithGroups(paths, contextPattern, *replacement, targetGroup, caseInsensitive) - if err != nil { - return mcp.NewToolResultError(fmt.Sprintf("search/replace failed: %v", err)), nil - } - jsonData, err := json.Marshal(result) - if err != nil { - return mcp.NewToolResultError(fmt.Sprintf("failed to marshal result: %v", err)), nil - } - return mcp.NewToolResultText(string(jsonData)), nil - } - - result, err := searchReplace(paths, pattern, replacement, useRegex, caseInsensitive, includeContext, contextPattern) + result, err := searchReplace(paths, pattern, replacement, useRegex, caseInsensitive, includeContext, beforePattern, afterPattern) if err != nil { return mcp.NewToolResultError(fmt.Sprintf("search/replace failed: %v", err)), nil } diff --git a/tool_search_replace.go b/tool_search_replace.go index 5438c53..7ee49d6 100644 --- a/tool_search_replace.go +++ b/tool_search_replace.go @@ -29,7 +29,7 @@ type SearchMatch struct { Context string `json:"context,omitempty"` } -func searchReplace(paths []string, pattern string, replacement *string, useRegex, caseInsensitive bool, includeContext bool, contextPattern string) (*SearchReplaceResult, error) { +func searchReplace(paths []string, pattern string, replacement *string, useRegex, caseInsensitive bool, includeContext bool, beforePattern, afterPattern string) (*SearchReplaceResult, error) { result := &SearchReplaceResult{ Files: []FileSearchReplaceResult{}, } @@ -38,53 +38,54 @@ func searchReplace(paths []string, pattern string, replacement *string, useRegex var searchFunc func(string) [][]int var replaceFunc func(string) string - // Handle context-aware replacement - if contextPattern != "" && replacement != nil { - // Context pattern should use capture groups like "(prefix)(target)(suffix)" - // The replacement will apply to the target group + // Handle context-aware replacement with before/after patterns + if (beforePattern != "" || afterPattern != "") && replacement != nil { + // Build a pattern that captures before, target, and after parts + contextPattern := "" + if beforePattern != "" { + contextPattern += "(" + regexp.QuoteMeta(beforePattern) + ")" + } + if useRegex { + contextPattern += "(" + pattern + ")" + } else { + contextPattern += "(" + regexp.QuoteMeta(pattern) + ")" + } + if afterPattern != "" { + contextPattern += "(" + regexp.QuoteMeta(afterPattern) + ")" + } + flags := "" if caseInsensitive { flags = "(?i)" } - // First validate the context pattern has at least one capture group contextRe, err := regexp.Compile(flags + contextPattern) if err != nil { return nil, fmt.Errorf("invalid context pattern: %w", err) } - // Count capture groups - numGroups := contextRe.NumSubexp() - if numGroups == 0 { - return nil, fmt.Errorf("context pattern must have at least one capture group, e.g., '(prefix)(target)(suffix)'") - } - searchFunc = func(text string) [][]int { return contextRe.FindAllStringIndex(text, -1) } replaceFunc = func(text string) string { - // For context replacement, we need to specify which group to replace - // By default, replace the middle group if there are 3 groups, otherwise the last group - targetGroup := numGroups - if numGroups == 3 { - targetGroup = 2 - } - - // Use ReplaceAllStringFunc to handle complex replacements return contextRe.ReplaceAllStringFunc(text, func(match string) string { submatches := contextRe.FindStringSubmatch(match) - if len(submatches) <= targetGroup { - return match - } - // Rebuild the match with the target group replaced + // Rebuild the match with the target replaced result := "" - for i := 1; i <= numGroups; i++ { - if i == targetGroup { - result += *replacement - } else { - result += submatches[i] + if beforePattern != "" && len(submatches) > 1 { + result += submatches[1] // before part + } + result += *replacement // replacement for target + if afterPattern != "" { + // The after part is at index 3 if before exists, otherwise at index 2 + afterIndex := 2 + if beforePattern != "" { + afterIndex = 3 + } + if len(submatches) > afterIndex { + result += submatches[afterIndex] } } return result @@ -188,101 +189,6 @@ func searchReplace(paths []string, pattern string, replacement *string, useRegex return result, nil } -// searchReplaceWithGroups allows specifying which capture group to replace -func searchReplaceWithGroups(paths []string, contextPattern string, replacement string, targetGroup int, caseInsensitive bool) (*SearchReplaceResult, error) { - result := &SearchReplaceResult{ - Files: []FileSearchReplaceResult{}, - } - - flags := "" - if caseInsensitive { - flags = "(?i)" - } - - contextRe, err := regexp.Compile(flags + contextPattern) - if err != nil { - return nil, fmt.Errorf("invalid context pattern: %w", err) - } - - numGroups := contextRe.NumSubexp() - if numGroups == 0 { - return nil, fmt.Errorf("context pattern must have capture groups") - } - - if targetGroup < 1 || targetGroup > numGroups { - return nil, fmt.Errorf("target group %d is out of range (pattern has %d groups)", targetGroup, numGroups) - } - - searchFunc := func(text string) [][]int { - return contextRe.FindAllStringIndex(text, -1) - } - - replaceFunc := func(text string) string { - return contextRe.ReplaceAllStringFunc(text, func(match string) string { - submatches := contextRe.FindStringSubmatch(match) - if len(submatches) <= numGroups { - return match - } - - // Rebuild the match with the target group replaced - result := "" - for i := 1; i <= numGroups; i++ { - if i == targetGroup { - result += replacement - } else { - result += submatches[i] - } - } - return result - }) - } - - for _, path := range paths { - info, err := os.Stat(path) - if err != nil { - result.Files = append(result.Files, FileSearchReplaceResult{ - Path: path, - Error: fmt.Sprintf("stat error: %v", err), - }) - continue - } - - if info.IsDir() { - // Process directory tree - err := filepath.WalkDir(path, func(filePath string, d fs.DirEntry, err error) error { - if err != nil || d.IsDir() { - return nil - } - - // Skip non-text files - if !isTextFile(filePath) { - return nil - } - - fileResult := processFile(filePath, searchFunc, replaceFunc, false) - if len(fileResult.Matches) > 0 || fileResult.Replaced > 0 || fileResult.Error != "" { - result.Files = append(result.Files, fileResult) - result.TotalMatches += len(fileResult.Matches) - result.TotalReplaced += fileResult.Replaced - } - return nil - }) - if err != nil { - return nil, err - } - } else { - // Process single file - fileResult := processFile(path, searchFunc, replaceFunc, false) - if len(fileResult.Matches) > 0 || fileResult.Replaced > 0 || fileResult.Error != "" { - result.Files = append(result.Files, fileResult) - result.TotalMatches += len(fileResult.Matches) - result.TotalReplaced += fileResult.Replaced - } - } - } - - return result, nil -} func processFile(path string, searchFunc func(string) [][]int, replaceFunc func(string) string, includeContext bool) FileSearchReplaceResult { result := FileSearchReplaceResult{