From ee04ab585a2e94264a7a12c0d73afbf446b6f97d Mon Sep 17 00:00:00 2001 From: Ian Gulliver Date: Fri, 27 Jun 2025 22:54:45 -0700 Subject: [PATCH] Add AST analysis tools and go run/test execution tools --- CONTEXT.md | 32 ++- main.go | 363 +++++++++++++++++++++++++ tool_analyze_channels.go | 325 ++++++++++++++++++++++ tool_analyze_defer_patterns.go | 305 +++++++++++++++++++++ tool_analyze_goroutines.go | 243 +++++++++++++++++ tool_analyze_memory_allocations.go | 374 ++++++++++++++++++++++++++ tool_analyze_naming_conventions.go | 418 +++++++++++++++++++++++++++++ tool_find_empty_blocks.go | 366 +++++++++++++++++++++++++ tool_find_errors.go | 4 +- tool_find_inefficiencies.go | 4 +- tool_find_init_functions.go | 360 +++++++++++++++++++++++++ tool_find_method_receivers.go | 129 +++++++++ tool_find_panic_recover.go | 175 ++++++++++++ tool_find_reflection_usage.go | 279 +++++++++++++++++++ tool_find_type_assertions.go | 209 +++++++++++++++ tool_go_common.go | 25 ++ tool_go_run.go | 84 ++++++ tool_gotest.go | 106 ++++++++ 18 files changed, 3796 insertions(+), 5 deletions(-) create mode 100644 tool_analyze_channels.go create mode 100644 tool_analyze_defer_patterns.go create mode 100644 tool_analyze_goroutines.go create mode 100644 tool_analyze_memory_allocations.go create mode 100644 tool_analyze_naming_conventions.go create mode 100644 tool_find_empty_blocks.go create mode 100644 tool_find_init_functions.go create mode 100644 tool_find_method_receivers.go create mode 100644 tool_find_panic_recover.go create mode 100644 tool_find_reflection_usage.go create mode 100644 tool_find_type_assertions.go create mode 100644 tool_go_common.go create mode 100644 tool_go_run.go create mode 100644 tool_gotest.go diff --git a/CONTEXT.md b/CONTEXT.md index c5308be..6c8ed64 100644 --- a/CONTEXT.md +++ b/CONTEXT.md @@ -86,4 +86,34 @@ List all Go packages in directory tree - `name`: Package name - `dir`: Directory path - `go_files`: List of Go source files - - `imports`: List of imported packages \ No newline at end of file + - `imports`: List of imported packages + +### go_run +Execute go run command with specified path and optional flags +- Parameters: + - `path` (required): Path to Go file or package to run + - `flags` (optional): Optional flags for go run (space-separated) + - `timeout` (optional): Timeout in seconds (default: 30) +- Returns JSON with: + - `stdout`: Standard output from go run + - `stderr`: Standard error from go run + - `exit_code`: Process exit code + - `error`: Error message if any + - `command`: The full command that was executed + - `work_dir`: Working directory where command was run + +### go_test +Execute go test command with specified path and optional flags +- Parameters: + - `path` (required): Path to Go package or directory to test + - `flags` (optional): Optional flags for go test (space-separated, e.g., '-v -cover -race') + - `timeout` (optional): Timeout in seconds (default: 60) +- Returns JSON with: + - `stdout`: Standard output from go test + - `stderr`: Standard error from go test + - `exit_code`: Process exit code + - `error`: Error message if any + - `command`: The full command that was executed + - `work_dir`: Working directory where command was run + - `passed`: Boolean indicating if tests passed + - `test_count`: Number of tests found (if detectable) \ No newline at end of file diff --git a/main.go b/main.go index 3ef0b86..9bf1146 100644 --- a/main.go +++ b/main.go @@ -423,6 +423,137 @@ func main() { ) mcpServer.AddTool(searchReplaceTool, searchReplaceHandler) + // Define the find_method_receivers tool + findMethodReceiversTool := mcp.NewTool("find_method_receivers", + mcp.WithDescription("Track pointer vs value receivers inconsistencies and suggest standardization"), + mcp.WithString("dir", + mcp.Description("Directory to search (default: current directory)"), + ), + ) + mcpServer.AddTool(findMethodReceiversTool, findMethodReceiversHandler) + + // Define the analyze_goroutines tool + analyzeGoroutinesTool := mcp.NewTool("analyze_goroutines", + mcp.WithDescription("Find goroutine leaks, missing waitgroups, and unsafe concurrent access patterns"), + mcp.WithString("dir", + mcp.Description("Directory to search (default: current directory)"), + ), + ) + mcpServer.AddTool(analyzeGoroutinesTool, analyzeGoroutinesHandler) + + // Define the find_panic_recover tool + findPanicRecoverTool := mcp.NewTool("find_panic_recover", + mcp.WithDescription("Locate panic/recover patterns and suggest error handling improvements"), + mcp.WithString("dir", + mcp.Description("Directory to search (default: current directory)"), + ), + ) + mcpServer.AddTool(findPanicRecoverTool, findPanicRecoverHandler) + + // Define the analyze_channels tool + analyzeChannelsTool := mcp.NewTool("analyze_channels", + mcp.WithDescription("Detect channel deadlocks, unbuffered channel issues, and goroutine communication patterns"), + mcp.WithString("dir", + mcp.Description("Directory to search (default: current directory)"), + ), + ) + mcpServer.AddTool(analyzeChannelsTool, analyzeChannelsHandler) + + // Define the find_type_assertions tool + findTypeAssertionsTool := mcp.NewTool("find_type_assertions", + mcp.WithDescription("Find unsafe type assertions without ok checks"), + mcp.WithString("dir", + mcp.Description("Directory to search (default: current directory)"), + ), + ) + mcpServer.AddTool(findTypeAssertionsTool, findTypeAssertionsHandler) + + // Define the analyze_memory_allocations tool + analyzeMemoryAllocationsTool := mcp.NewTool("analyze_memory_allocations", + mcp.WithDescription("Identify excessive allocations, escaping variables, and suggest optimizations"), + mcp.WithString("dir", + mcp.Description("Directory to search (default: current directory)"), + ), + ) + mcpServer.AddTool(analyzeMemoryAllocationsTool, analyzeMemoryAllocationsHandler) + + // Define the find_reflection_usage tool + findReflectionUsageTool := mcp.NewTool("find_reflection_usage", + mcp.WithDescription("Track reflect package usage for performance and type safety analysis"), + mcp.WithString("dir", + mcp.Description("Directory to search (default: current directory)"), + ), + ) + mcpServer.AddTool(findReflectionUsageTool, findReflectionUsageHandler) + + // Define the find_init_functions tool + findInitFunctionsTool := mcp.NewTool("find_init_functions", + mcp.WithDescription("Track init() functions and their initialization order dependencies"), + mcp.WithString("dir", + mcp.Description("Directory to search (default: current directory)"), + ), + ) + mcpServer.AddTool(findInitFunctionsTool, findInitFunctionsHandler) + + // Define the analyze_defer_patterns tool + analyzeDeferPatternsTool := mcp.NewTool("analyze_defer_patterns", + mcp.WithDescription("Find incorrect defer usage and resource leak risks"), + mcp.WithString("dir", + mcp.Description("Directory to search (default: current directory)"), + ), + ) + mcpServer.AddTool(analyzeDeferPatternsTool, analyzeDeferPatternsHandler) + + // Define the find_empty_blocks tool + findEmptyBlocksTool := mcp.NewTool("find_empty_blocks", + mcp.WithDescription("Locate empty if/else/for blocks and suggest removal"), + mcp.WithString("dir", + mcp.Description("Directory to search (default: current directory)"), + ), + ) + mcpServer.AddTool(findEmptyBlocksTool, findEmptyBlocksHandler) + + // Define the analyze_naming_conventions tool + analyzeNamingConventionsTool := mcp.NewTool("analyze_naming_conventions", + mcp.WithDescription("Check Go naming conventions (camelCase, exported names, etc.)"), + mcp.WithString("dir", + mcp.Description("Directory to search (default: current directory)"), + ), + ) + mcpServer.AddTool(analyzeNamingConventionsTool, analyzeNamingConventionsHandler) + + // Define the go_run tool + goRunTool := mcp.NewTool("go_run", + mcp.WithDescription("Execute go run command with specified path and optional flags"), + mcp.WithString("path", + mcp.Required(), + mcp.Description("Path to Go file or package to run"), + ), + mcp.WithString("flags", + mcp.Description("Optional flags for go run (space-separated)"), + ), + mcp.WithNumber("timeout", + mcp.Description("Timeout in seconds (default: 30)"), + ), + ) + mcpServer.AddTool(goRunTool, goRunHandler) + + // Define the go_test tool + goTestTool := mcp.NewTool("go_test", + mcp.WithDescription("Execute go test command with specified path and optional flags"), + mcp.WithString("path", + mcp.Required(), + mcp.Description("Path to Go package or directory to test"), + ), + mcp.WithString("flags", + mcp.Description("Optional flags for go test (space-separated, e.g., '-v -cover -race')"), + ), + mcp.WithNumber("timeout", + mcp.Description("Timeout in seconds (default: 60)"), + ), + ) + mcpServer.AddTool(goTestTool, goTestHandler) + // Start the server if err := server.ServeStdio(mcpServer); err != nil { fmt.Fprintf(os.Stderr, "Server error: %v\n", err) @@ -1066,5 +1197,237 @@ func searchReplaceHandler(ctx context.Context, request mcp.CallToolRequest) (*mc return mcp.NewToolResultError(fmt.Sprintf("failed to marshal result: %v", err)), nil } + return mcp.NewToolResultText(string(jsonData)), nil +} + +func findMethodReceiversHandler(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + dir := request.GetString("dir", "./") + + analysis, err := findMethodReceivers(dir) + if err != nil { + return mcp.NewToolResultError(fmt.Sprintf("failed to analyze method receivers: %v", err)), nil + } + + jsonData, err := json.Marshal(analysis) + if err != nil { + return mcp.NewToolResultError(fmt.Sprintf("failed to marshal analysis: %v", err)), nil + } + + return mcp.NewToolResultText(string(jsonData)), nil +} + +func analyzeGoroutinesHandler(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + dir := request.GetString("dir", "./") + + analysis, err := analyzeGoroutines(dir) + if err != nil { + return mcp.NewToolResultError(fmt.Sprintf("failed to analyze goroutines: %v", err)), nil + } + + jsonData, err := json.Marshal(analysis) + if err != nil { + return mcp.NewToolResultError(fmt.Sprintf("failed to marshal analysis: %v", err)), nil + } + + return mcp.NewToolResultText(string(jsonData)), nil +} + +func findPanicRecoverHandler(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + dir := request.GetString("dir", "./") + + analysis, err := findPanicRecover(dir) + if err != nil { + return mcp.NewToolResultError(fmt.Sprintf("failed to find panic/recover: %v", err)), nil + } + + jsonData, err := json.Marshal(analysis) + if err != nil { + return mcp.NewToolResultError(fmt.Sprintf("failed to marshal analysis: %v", err)), nil + } + + return mcp.NewToolResultText(string(jsonData)), nil +} + +func analyzeChannelsHandler(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + dir := request.GetString("dir", "./") + + analysis, err := analyzeChannels(dir) + if err != nil { + return mcp.NewToolResultError(fmt.Sprintf("failed to analyze channels: %v", err)), nil + } + + jsonData, err := json.Marshal(analysis) + if err != nil { + return mcp.NewToolResultError(fmt.Sprintf("failed to marshal analysis: %v", err)), nil + } + + return mcp.NewToolResultText(string(jsonData)), nil +} + +func findTypeAssertionsHandler(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + dir := request.GetString("dir", "./") + + analysis, err := findTypeAssertions(dir) + if err != nil { + return mcp.NewToolResultError(fmt.Sprintf("failed to find type assertions: %v", err)), nil + } + + jsonData, err := json.Marshal(analysis) + if err != nil { + return mcp.NewToolResultError(fmt.Sprintf("failed to marshal analysis: %v", err)), nil + } + + return mcp.NewToolResultText(string(jsonData)), nil +} + +func analyzeMemoryAllocationsHandler(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + dir := request.GetString("dir", "./") + + analysis, err := analyzeMemoryAllocations(dir) + if err != nil { + return mcp.NewToolResultError(fmt.Sprintf("failed to analyze memory allocations: %v", err)), nil + } + + jsonData, err := json.Marshal(analysis) + if err != nil { + return mcp.NewToolResultError(fmt.Sprintf("failed to marshal analysis: %v", err)), nil + } + + return mcp.NewToolResultText(string(jsonData)), nil +} + +func findReflectionUsageHandler(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + dir := request.GetString("dir", "./") + + analysis, err := findReflectionUsage(dir) + if err != nil { + return mcp.NewToolResultError(fmt.Sprintf("failed to find reflection usage: %v", err)), nil + } + + jsonData, err := json.Marshal(analysis) + if err != nil { + return mcp.NewToolResultError(fmt.Sprintf("failed to marshal analysis: %v", err)), nil + } + + return mcp.NewToolResultText(string(jsonData)), nil +} + +func findInitFunctionsHandler(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + dir := request.GetString("dir", "./") + + analysis, err := findInitFunctions(dir) + if err != nil { + return mcp.NewToolResultError(fmt.Sprintf("failed to find init functions: %v", err)), nil + } + + jsonData, err := json.Marshal(analysis) + if err != nil { + return mcp.NewToolResultError(fmt.Sprintf("failed to marshal analysis: %v", err)), nil + } + + return mcp.NewToolResultText(string(jsonData)), nil +} + +func analyzeDeferPatternsHandler(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + dir := request.GetString("dir", "./") + + analysis, err := analyzeDeferPatterns(dir) + if err != nil { + return mcp.NewToolResultError(fmt.Sprintf("failed to analyze defer patterns: %v", err)), nil + } + + jsonData, err := json.Marshal(analysis) + if err != nil { + return mcp.NewToolResultError(fmt.Sprintf("failed to marshal analysis: %v", err)), nil + } + + return mcp.NewToolResultText(string(jsonData)), nil +} + +func findEmptyBlocksHandler(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + dir := request.GetString("dir", "./") + + analysis, err := findEmptyBlocks(dir) + if err != nil { + return mcp.NewToolResultError(fmt.Sprintf("failed to find empty blocks: %v", err)), nil + } + + jsonData, err := json.Marshal(analysis) + if err != nil { + return mcp.NewToolResultError(fmt.Sprintf("failed to marshal analysis: %v", err)), nil + } + + return mcp.NewToolResultText(string(jsonData)), nil +} + +func analyzeNamingConventionsHandler(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + dir := request.GetString("dir", "./") + + analysis, err := analyzeNamingConventions(dir) + if err != nil { + return mcp.NewToolResultError(fmt.Sprintf("failed to analyze naming conventions: %v", err)), nil + } + + jsonData, err := json.Marshal(analysis) + if err != nil { + return mcp.NewToolResultError(fmt.Sprintf("failed to marshal analysis: %v", err)), nil + } + + return mcp.NewToolResultText(string(jsonData)), nil +} + +func goRunHandler(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + path, err := request.RequireString("path") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + + flagsStr := request.GetString("flags", "") + timeout := request.GetFloat("timeout", 30.0) + + // Parse flags + var flags []string + if flagsStr != "" { + flags = strings.Fields(flagsStr) + } + + result, err := goRun(path, flags, time.Duration(timeout)*time.Second) + if err != nil { + return mcp.NewToolResultError(fmt.Sprintf("failed to run go run: %v", err)), nil + } + + jsonData, err := json.Marshal(result) + if err != nil { + return mcp.NewToolResultError(fmt.Sprintf("failed to marshal result: %v", err)), nil + } + + return mcp.NewToolResultText(string(jsonData)), nil +} + +func goTestHandler(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + path, err := request.RequireString("path") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + + flagsStr := request.GetString("flags", "") + timeout := request.GetFloat("timeout", 60.0) + + // Parse flags + var flags []string + if flagsStr != "" { + flags = strings.Fields(flagsStr) + } + + result, err := goTest(path, flags, time.Duration(timeout)*time.Second) + if err != nil { + return mcp.NewToolResultError(fmt.Sprintf("failed to run go test: %v", err)), nil + } + + jsonData, err := json.Marshal(result) + if err != nil { + return mcp.NewToolResultError(fmt.Sprintf("failed to marshal result: %v", err)), nil + } + return mcp.NewToolResultText(string(jsonData)), nil } \ No newline at end of file diff --git a/tool_analyze_channels.go b/tool_analyze_channels.go new file mode 100644 index 0000000..fa0979e --- /dev/null +++ b/tool_analyze_channels.go @@ -0,0 +1,325 @@ +package main + +import ( + "go/ast" + "go/token" +) + +type ChannelUsage struct { + Name string `json:"name"` + Type string `json:"type"` // "make", "send", "receive", "range", "select", "close" + ChannelType string `json:"channel_type"` // "unbuffered", "buffered", "unknown" + BufferSize int `json:"buffer_size,omitempty"` + Position Position `json:"position"` + Context string `json:"context"` +} + +type ChannelAnalysis struct { + Channels []ChannelUsage `json:"channels"` + Issues []ChannelIssue `json:"issues"` +} + +type ChannelIssue struct { + Type string `json:"type"` + Description string `json:"description"` + Position Position `json:"position"` +} + +func analyzeChannels(dir string) (*ChannelAnalysis, error) { + analysis := &ChannelAnalysis{ + Channels: []ChannelUsage{}, + Issues: []ChannelIssue{}, + } + + err := walkGoFiles(dir, func(path string, src []byte, file *ast.File, fset *token.FileSet) error { + // Track channel variables + channelVars := make(map[string]*ChannelInfo) + + // First pass: identify channel declarations + ast.Inspect(file, func(n ast.Node) bool { + switch node := n.(type) { + case *ast.ValueSpec: + for i, name := range node.Names { + if isChanType(node.Type) { + channelVars[name.Name] = &ChannelInfo{ + name: name.Name, + chanType: "unknown", + } + } else if i < len(node.Values) { + if info := extractChannelMake(node.Values[i]); info != nil { + info.name = name.Name + channelVars[name.Name] = info + } + } + } + + case *ast.AssignStmt: + for i, lhs := range node.Lhs { + if ident, ok := lhs.(*ast.Ident); ok && i < len(node.Rhs) { + if info := extractChannelMake(node.Rhs[i]); info != nil { + info.name = ident.Name + channelVars[ident.Name] = info + + pos := fset.Position(node.Pos()) + usage := ChannelUsage{ + Name: ident.Name, + Type: "make", + ChannelType: info.chanType, + BufferSize: info.bufferSize, + Position: newPosition(pos), + Context: extractContext(src, pos), + } + analysis.Channels = append(analysis.Channels, usage) + } + } + } + } + return true + }) + + // Second pass: analyze channel operations + ast.Inspect(file, func(n ast.Node) bool { + switch node := n.(type) { + case *ast.SendStmt: + pos := fset.Position(node.Pos()) + chanName := extractChannelName(node.Chan) + usage := ChannelUsage{ + Name: chanName, + Type: "send", + Position: newPosition(pos), + Context: extractContext(src, pos), + } + if info, ok := channelVars[chanName]; ok { + usage.ChannelType = info.chanType + usage.BufferSize = info.bufferSize + } + analysis.Channels = append(analysis.Channels, usage) + + // Check for potential deadlock + if isInMainGoroutine(file, node) && !hasGoroutineNearby(file, node) { + if info, ok := channelVars[chanName]; ok && info.chanType == "unbuffered" { + issue := ChannelIssue{ + Type: "potential_deadlock", + Description: "Send on unbuffered channel without goroutine may deadlock", + Position: newPosition(pos), + } + analysis.Issues = append(analysis.Issues, issue) + } + } + + case *ast.UnaryExpr: + if node.Op == token.ARROW { + pos := fset.Position(node.Pos()) + chanName := extractChannelName(node.X) + usage := ChannelUsage{ + Name: chanName, + Type: "receive", + Position: newPosition(pos), + Context: extractContext(src, pos), + } + if info, ok := channelVars[chanName]; ok { + usage.ChannelType = info.chanType + usage.BufferSize = info.bufferSize + } + analysis.Channels = append(analysis.Channels, usage) + } + + case *ast.RangeStmt: + if isChanExpression(node.X) { + pos := fset.Position(node.Pos()) + chanName := extractChannelName(node.X) + usage := ChannelUsage{ + Name: chanName, + Type: "range", + Position: newPosition(pos), + Context: extractContext(src, pos), + } + analysis.Channels = append(analysis.Channels, usage) + } + + case *ast.CallExpr: + if ident, ok := node.Fun.(*ast.Ident); ok && ident.Name == "close" { + if len(node.Args) > 0 { + pos := fset.Position(node.Pos()) + chanName := extractChannelName(node.Args[0]) + usage := ChannelUsage{ + Name: chanName, + Type: "close", + Position: newPosition(pos), + Context: extractContext(src, pos), + } + analysis.Channels = append(analysis.Channels, usage) + } + } + + case *ast.SelectStmt: + analyzeSelectStatement(node, fset, src, analysis, channelVars) + } + return true + }) + + return nil + }) + + return analysis, err +} + +type ChannelInfo struct { + name string + chanType string // "buffered", "unbuffered", "unknown" + bufferSize int +} + +func extractChannelMake(expr ast.Expr) *ChannelInfo { + call, ok := expr.(*ast.CallExpr) + if !ok { + return nil + } + + ident, ok := call.Fun.(*ast.Ident) + if !ok || ident.Name != "make" { + return nil + } + + if len(call.Args) < 1 || !isChanType(call.Args[0]) { + return nil + } + + info := &ChannelInfo{} + + if len(call.Args) == 1 { + info.chanType = "unbuffered" + info.bufferSize = 0 + } else if len(call.Args) >= 2 { + info.chanType = "buffered" + if lit, ok := call.Args[1].(*ast.BasicLit); ok && lit.Kind == token.INT { + // Parse buffer size if it's a literal + if size := lit.Value; size == "0" { + info.chanType = "unbuffered" + } else { + info.bufferSize = 1 // Default to 1 if we can't parse + } + } + } + + return info +} + +func isChanType(expr ast.Expr) bool { + _, ok := expr.(*ast.ChanType) + return ok +} + +func isChanExpression(expr ast.Expr) bool { + // Simple check - could be improved + switch expr.(type) { + case *ast.Ident, *ast.SelectorExpr: + return true + } + return false +} + +func extractChannelName(expr ast.Expr) string { + switch e := expr.(type) { + case *ast.Ident: + return e.Name + case *ast.SelectorExpr: + return exprToString(e) + default: + return "unknown" + } +} + +func analyzeSelectStatement(sel *ast.SelectStmt, fset *token.FileSet, src []byte, analysis *ChannelAnalysis, channelVars map[string]*ChannelInfo) { + pos := fset.Position(sel.Pos()) + hasDefault := false + + for _, clause := range sel.Body.List { + comm, ok := clause.(*ast.CommClause) + if !ok { + continue + } + + if comm.Comm == nil { + hasDefault = true + continue + } + + // Analyze communication in select + switch c := comm.Comm.(type) { + case *ast.SendStmt: + chanName := extractChannelName(c.Chan) + usage := ChannelUsage{ + Name: chanName, + Type: "select", + Position: newPosition(fset.Position(c.Pos())), + Context: "select send", + } + if info, ok := channelVars[chanName]; ok { + usage.ChannelType = info.chanType + } + analysis.Channels = append(analysis.Channels, usage) + + case *ast.AssignStmt: + // Receive in select + if len(c.Rhs) > 0 { + if unary, ok := c.Rhs[0].(*ast.UnaryExpr); ok && unary.Op == token.ARROW { + chanName := extractChannelName(unary.X) + usage := ChannelUsage{ + Name: chanName, + Type: "select", + Position: newPosition(fset.Position(c.Pos())), + Context: "select receive", + } + if info, ok := channelVars[chanName]; ok { + usage.ChannelType = info.chanType + } + analysis.Channels = append(analysis.Channels, usage) + } + } + } + } + + if !hasDefault && len(sel.Body.List) == 1 { + issue := ChannelIssue{ + Type: "single_case_select", + Description: "Select with single case and no default - consider using simple channel operation", + Position: newPosition(pos), + } + analysis.Issues = append(analysis.Issues, issue) + } +} + +func isInMainGoroutine(file *ast.File, target ast.Node) bool { + // Check if node is not inside a goroutine + var inGoroutine bool + ast.Inspect(file, func(n ast.Node) bool { + if _, ok := n.(*ast.GoStmt); ok { + if containsNode(n, target) { + inGoroutine = true + return false + } + } + return true + }) + return !inGoroutine +} + +func hasGoroutineNearby(file *ast.File, target ast.Node) bool { + // Check if there's a goroutine in the same function + var hasGo bool + ast.Inspect(file, func(n ast.Node) bool { + if fn, ok := n.(*ast.FuncDecl); ok && containsNode(fn, target) { + ast.Inspect(fn, func(inner ast.Node) bool { + if _, ok := inner.(*ast.GoStmt); ok { + hasGo = true + return false + } + return true + }) + return false + } + return true + }) + return hasGo +} \ No newline at end of file diff --git a/tool_analyze_defer_patterns.go b/tool_analyze_defer_patterns.go new file mode 100644 index 0000000..9360c03 --- /dev/null +++ b/tool_analyze_defer_patterns.go @@ -0,0 +1,305 @@ +package main + +import ( + "go/ast" + "go/token" + "strings" +) + +type DeferUsage struct { + Statement string `json:"statement"` + Position Position `json:"position"` + InLoop bool `json:"in_loop"` + InFunction string `json:"in_function"` + Context string `json:"context"` +} + +type DeferAnalysis struct { + Defers []DeferUsage `json:"defers"` + Issues []DeferIssue `json:"issues"` +} + +type DeferIssue struct { + Type string `json:"type"` + Description string `json:"description"` + Position Position `json:"position"` +} + +func analyzeDeferPatterns(dir string) (*DeferAnalysis, error) { + analysis := &DeferAnalysis{ + Defers: []DeferUsage{}, + Issues: []DeferIssue{}, + } + + err := walkGoFiles(dir, func(path string, src []byte, file *ast.File, fset *token.FileSet) error { + var currentFunc string + + ast.Inspect(file, func(n ast.Node) bool { + switch node := n.(type) { + case *ast.FuncDecl: + currentFunc = node.Name.Name + analyzeFunctionDefers(node, fset, src, analysis) + + case *ast.FuncLit: + currentFunc = "anonymous function" + analyzeFunctionDefers(&ast.FuncDecl{Body: node.Body}, fset, src, analysis) + + case *ast.DeferStmt: + pos := fset.Position(node.Pos()) + usage := DeferUsage{ + Statement: extractDeferStatement(node), + Position: newPosition(pos), + InLoop: isInLoop(file, node), + InFunction: currentFunc, + Context: extractContext(src, pos), + } + analysis.Defers = append(analysis.Defers, usage) + + // Check for issues + if usage.InLoop { + issue := DeferIssue{ + Type: "defer_in_loop", + Description: "defer in loop will accumulate until function returns", + Position: newPosition(pos), + } + analysis.Issues = append(analysis.Issues, issue) + } + + // Check for defer of result of function call + if hasNestedCall(node.Call) { + issue := DeferIssue{ + Type: "defer_nested_call", + Description: "defer evaluates function arguments immediately - nested calls execute now", + Position: newPosition(pos), + } + analysis.Issues = append(analysis.Issues, issue) + } + + // Check for useless defer patterns + checkUselessDefer(node, file, fset, analysis) + } + return true + }) + + return nil + }) + + return analysis, err +} + +func extractDeferStatement(deferStmt *ast.DeferStmt) string { + switch call := deferStmt.Call.Fun.(type) { + case *ast.Ident: + return "defer " + call.Name + "(...)" + case *ast.SelectorExpr: + return "defer " + exprToString(call) + "(...)" + case *ast.FuncLit: + return "defer func() { ... }" + default: + return "defer " + } +} + +func hasNestedCall(call *ast.CallExpr) bool { + // Check if any argument is a function call + for _, arg := range call.Args { + if _, ok := arg.(*ast.CallExpr); ok { + return true + } + } + return false +} + +func analyzeFunctionDefers(fn *ast.FuncDecl, fset *token.FileSet, src []byte, analysis *DeferAnalysis) { + if fn.Body == nil { + return + } + + var defers []*ast.DeferStmt + var hasReturn bool + var returnPos token.Position + + // Collect all defers and check for early returns + ast.Inspect(fn.Body, func(n ast.Node) bool { + switch node := n.(type) { + case *ast.DeferStmt: + defers = append(defers, node) + case *ast.ReturnStmt: + hasReturn = true + returnPos = fset.Position(node.Pos()) + case *ast.FuncLit: + // Don't analyze nested functions + return false + } + return true + }) + + // Check defer ordering issues + if len(defers) > 1 { + checkDeferOrdering(defers, fset, analysis) + } + + // Check for defer after return path + if hasReturn { + for _, def := range defers { + defPos := fset.Position(def.Pos()) + if defPos.Line > returnPos.Line { + issue := DeferIssue{ + Type: "unreachable_defer", + Description: "defer statement after return is unreachable", + Position: newPosition(defPos), + } + analysis.Issues = append(analysis.Issues, issue) + } + } + } + + // Check for missing defer on resource cleanup + checkMissingDefers(fn, fset, analysis) +} + +func checkDeferOrdering(defers []*ast.DeferStmt, fset *token.FileSet, analysis *DeferAnalysis) { + // Check for dependent defers in wrong order + for i := 0; i < len(defers)-1; i++ { + for j := i + 1; j < len(defers); j++ { + if areDefersDependentWrongOrder(defers[i], defers[j]) { + pos := fset.Position(defers[j].Pos()) + issue := DeferIssue{ + Type: "defer_order_issue", + Description: "defer statements may execute in wrong order (LIFO)", + Position: newPosition(pos), + } + analysis.Issues = append(analysis.Issues, issue) + } + } + } +} + +func areDefersDependentWrongOrder(first, second *ast.DeferStmt) bool { + // Simple heuristic: check for Close() after Flush() or similar patterns + firstName := extractMethodName(first.Call) + secondName := extractMethodName(second.Call) + + // Common patterns where order matters + orderPatterns := map[string]string{ + "Flush": "Close", + "Unlock": "Lock", + "Done": "Add", + } + + for before, after := range orderPatterns { + if firstName == after && secondName == before { + return true + } + } + + return false +} + +func extractMethodName(call *ast.CallExpr) string { + switch fun := call.Fun.(type) { + case *ast.Ident: + return fun.Name + case *ast.SelectorExpr: + return fun.Sel.Name + } + return "" +} + +func checkUselessDefer(deferStmt *ast.DeferStmt, file *ast.File, fset *token.FileSet, analysis *DeferAnalysis) { + // Check if defer is the last statement before return + ast.Inspect(file, func(n ast.Node) bool { + if block, ok := n.(*ast.BlockStmt); ok { + for i, stmt := range block.List { + if stmt == deferStmt && i < len(block.List)-1 { + // Check if next statement is return + if _, ok := block.List[i+1].(*ast.ReturnStmt); ok { + pos := fset.Position(deferStmt.Pos()) + issue := DeferIssue{ + Type: "unnecessary_defer", + Description: "defer immediately before return is unnecessary", + Position: newPosition(pos), + } + analysis.Issues = append(analysis.Issues, issue) + return false + } + } + } + } + return true + }) +} + +func checkMissingDefers(fn *ast.FuncDecl, fset *token.FileSet, analysis *DeferAnalysis) { + // Look for resource acquisition without corresponding defer + resources := make(map[string]token.Position) // resource var -> position + deferred := make(map[string]bool) + + ast.Inspect(fn.Body, func(n ast.Node) bool { + switch node := n.(type) { + case *ast.AssignStmt: + // Check for resource acquisition + for i, lhs := range node.Lhs { + if ident, ok := lhs.(*ast.Ident); ok && i < len(node.Rhs) { + if isResourceAcquisition(node.Rhs[i]) { + resources[ident.Name] = fset.Position(node.Pos()) + } + } + } + + case *ast.DeferStmt: + // Check if defer releases a resource + if varName := extractDeferredResourceVar(node.Call); varName != "" { + deferred[varName] = true + } + } + return true + }) + + // Report resources without defers + for resource, pos := range resources { + if !deferred[resource] { + issue := DeferIssue{ + Type: "missing_defer", + Description: "Resource '" + resource + "' acquired but not deferred for cleanup", + Position: newPosition(pos), + } + analysis.Issues = append(analysis.Issues, issue) + } + } +} + +func isResourceAcquisition(expr ast.Expr) bool { + call, ok := expr.(*ast.CallExpr) + if !ok { + return false + } + + // Check for common resource acquisition patterns + switch fun := call.Fun.(type) { + case *ast.SelectorExpr: + method := fun.Sel.Name + resourceMethods := []string{"Open", "Create", "Dial", "Connect", "Lock", "RLock", "Begin"} + for _, rm := range resourceMethods { + if method == rm || strings.HasPrefix(method, "Open") || strings.HasPrefix(method, "New") { + return true + } + } + } + return false +} + +func extractDeferredResourceVar(call *ast.CallExpr) string { + // Extract the variable being cleaned up in defer + switch fun := call.Fun.(type) { + case *ast.SelectorExpr: + if ident, ok := fun.X.(*ast.Ident); ok { + method := fun.Sel.Name + if method == "Close" || method == "Unlock" || method == "RUnlock" || + method == "Done" || method == "Release" { + return ident.Name + } + } + } + return "" +} \ No newline at end of file diff --git a/tool_analyze_goroutines.go b/tool_analyze_goroutines.go new file mode 100644 index 0000000..13c1245 --- /dev/null +++ b/tool_analyze_goroutines.go @@ -0,0 +1,243 @@ +package main + +import ( + "go/ast" + "go/token" + "strings" +) + +type GoroutineUsage struct { + Position Position `json:"position"` + Function string `json:"function"` + InLoop bool `json:"in_loop"` + HasWaitGroup bool `json:"has_wait_group"` + Context string `json:"context"` +} + +type GoroutineAnalysis struct { + Goroutines []GoroutineUsage `json:"goroutines"` + Issues []GoroutineIssue `json:"issues"` +} + +type GoroutineIssue struct { + Type string `json:"type"` + Description string `json:"description"` + Position Position `json:"position"` +} + +func analyzeGoroutines(dir string) (*GoroutineAnalysis, error) { + analysis := &GoroutineAnalysis{ + Goroutines: []GoroutineUsage{}, + Issues: []GoroutineIssue{}, + } + + err := walkGoFiles(dir, func(path string, src []byte, file *ast.File, fset *token.FileSet) error { + // Track WaitGroup usage + waitGroupVars := make(map[string]bool) + hasWaitGroupImport := false + + // Check imports + for _, imp := range file.Imports { + if imp.Path != nil && imp.Path.Value == `"sync"` { + hasWaitGroupImport = true + break + } + } + + // First pass: find WaitGroup variables + ast.Inspect(file, func(n ast.Node) bool { + switch node := n.(type) { + case *ast.ValueSpec: + for i, name := range node.Names { + if i < len(node.Values) { + if isWaitGroupType(node.Type) || isWaitGroupExpr(node.Values[i]) { + waitGroupVars[name.Name] = true + } + } + } + case *ast.AssignStmt: + for i, lhs := range node.Lhs { + if ident, ok := lhs.(*ast.Ident); ok && i < len(node.Rhs) { + if isWaitGroupExpr(node.Rhs[i]) { + waitGroupVars[ident.Name] = true + } + } + } + } + return true + }) + + // Second pass: analyze goroutines + ast.Inspect(file, func(n ast.Node) bool { + if goStmt, ok := n.(*ast.GoStmt); ok { + pos := fset.Position(goStmt.Pos()) + funcName := extractFunctionName(goStmt.Call) + inLoop := isInLoop(file, goStmt) + hasWG := hasNearbyWaitGroup(file, goStmt, waitGroupVars) + + usage := GoroutineUsage{ + Position: newPosition(pos), + Function: funcName, + InLoop: inLoop, + HasWaitGroup: hasWG, + Context: extractContext(src, pos), + } + analysis.Goroutines = append(analysis.Goroutines, usage) + + // Check for issues + if inLoop && !hasWG { + issue := GoroutineIssue{ + Type: "goroutine_leak_risk", + Description: "Goroutine launched in loop without WaitGroup may cause resource leak", + Position: newPosition(pos), + } + analysis.Issues = append(analysis.Issues, issue) + } + + // Check for goroutines without synchronization + if !hasWG && !hasChannelCommunication(goStmt.Call) && hasWaitGroupImport { + issue := GoroutineIssue{ + Type: "missing_synchronization", + Description: "Goroutine launched without apparent synchronization mechanism", + Position: newPosition(pos), + } + analysis.Issues = append(analysis.Issues, issue) + } + } + return true + }) + + return nil + }) + + return analysis, err +} + +func isWaitGroupType(expr ast.Expr) bool { + if expr == nil { + return false + } + if sel, ok := expr.(*ast.SelectorExpr); ok { + if ident, ok := sel.X.(*ast.Ident); ok { + return ident.Name == "sync" && sel.Sel.Name == "WaitGroup" + } + } + return false +} + +func isWaitGroupExpr(expr ast.Expr) bool { + switch e := expr.(type) { + case *ast.CompositeLit: + return isWaitGroupType(e.Type) + case *ast.UnaryExpr: + if e.Op == token.AND { + return isWaitGroupExpr(e.X) + } + } + return false +} + +func extractFunctionName(call *ast.CallExpr) string { + switch fun := call.Fun.(type) { + case *ast.Ident: + return fun.Name + case *ast.SelectorExpr: + return exprToString(fun.X) + "." + fun.Sel.Name + case *ast.FuncLit: + return "anonymous function" + default: + return "unknown" + } +} + +func isInLoop(file *ast.File, target ast.Node) bool { + var inLoop bool + ast.Inspect(file, func(n ast.Node) bool { + switch n.(type) { + case *ast.ForStmt, *ast.RangeStmt: + // Check if target is within this loop + if containsNode(n, target) { + inLoop = true + return false + } + } + return true + }) + return inLoop +} + +func containsNode(parent, child ast.Node) bool { + var found bool + ast.Inspect(parent, func(n ast.Node) bool { + if n == child { + found = true + return false + } + return true + }) + return found +} + +func hasNearbyWaitGroup(file *ast.File, goStmt *ast.GoStmt, waitGroupVars map[string]bool) bool { + // Look for WaitGroup.Add calls in the same block or parent function + var hasWG bool + + ast.Inspect(file, func(n ast.Node) bool { + switch node := n.(type) { + case *ast.CallExpr: + if sel, ok := node.Fun.(*ast.SelectorExpr); ok { + if ident, ok := sel.X.(*ast.Ident); ok { + if waitGroupVars[ident.Name] && sel.Sel.Name == "Add" { + // Check if this Add call is near the goroutine + if isNearby(file, node, goStmt) { + hasWG = true + return false + } + } + } + } + } + return true + }) + + return hasWG +} + +func isNearby(file *ast.File, node1, node2 ast.Node) bool { + // Simple proximity check - in same function + var func1, func2 *ast.FuncDecl + + ast.Inspect(file, func(n ast.Node) bool { + if fn, ok := n.(*ast.FuncDecl); ok { + if containsNode(fn, node1) { + func1 = fn + } + if containsNode(fn, node2) { + func2 = fn + } + } + return true + }) + + return func1 == func2 && func1 != nil +} + +func hasChannelCommunication(call *ast.CallExpr) bool { + // Check if the function likely uses channels for synchronization + hasChannel := false + ast.Inspect(call, func(n ast.Node) bool { + switch n.(type) { + case *ast.ChanType, *ast.SendStmt: + hasChannel = true + return false + } + if ident, ok := n.(*ast.Ident); ok { + if strings.Contains(strings.ToLower(ident.Name), "chan") { + hasChannel = true + return false + } + } + return true + }) + return hasChannel +} \ No newline at end of file diff --git a/tool_analyze_memory_allocations.go b/tool_analyze_memory_allocations.go new file mode 100644 index 0000000..4f90dbe --- /dev/null +++ b/tool_analyze_memory_allocations.go @@ -0,0 +1,374 @@ +package main + +import ( + "go/ast" + "go/token" + "strings" +) + +type MemoryAllocation struct { + Type string `json:"type"` // "make", "new", "composite", "append", "string_concat" + Description string `json:"description"` + InLoop bool `json:"in_loop"` + Position Position `json:"position"` + Context string `json:"context"` +} + +type AllocationAnalysis struct { + Allocations []MemoryAllocation `json:"allocations"` + Issues []AllocationIssue `json:"issues"` +} + +type AllocationIssue struct { + Type string `json:"type"` + Description string `json:"description"` + Position Position `json:"position"` +} + +func analyzeMemoryAllocations(dir string) (*AllocationAnalysis, error) { + analysis := &AllocationAnalysis{ + Allocations: []MemoryAllocation{}, + Issues: []AllocationIssue{}, + } + + err := walkGoFiles(dir, func(path string, src []byte, file *ast.File, fset *token.FileSet) error { + // Analyze allocations + ast.Inspect(file, func(n ast.Node) bool { + switch node := n.(type) { + case *ast.CallExpr: + analyzeCallExpr(node, file, fset, src, analysis) + + case *ast.CompositeLit: + pos := fset.Position(node.Pos()) + alloc := MemoryAllocation{ + Type: "composite", + Description: "Composite literal: " + exprToString(node.Type), + InLoop: isInLoop(file, node), + Position: newPosition(pos), + Context: extractContext(src, pos), + } + analysis.Allocations = append(analysis.Allocations, alloc) + + if alloc.InLoop { + issue := AllocationIssue{ + Type: "allocation_in_loop", + Description: "Composite literal allocation inside loop", + Position: newPosition(pos), + } + analysis.Issues = append(analysis.Issues, issue) + } + + case *ast.BinaryExpr: + if node.Op == token.ADD { + analyzeStringConcat(node, file, fset, src, analysis) + } + + case *ast.UnaryExpr: + if node.Op == token.AND { + // Taking address of value causes allocation + pos := fset.Position(node.Pos()) + if isEscaping(file, node) { + alloc := MemoryAllocation{ + Type: "address_of", + Description: "Taking address of value (escapes to heap)", + InLoop: isInLoop(file, node), + Position: newPosition(pos), + Context: extractContext(src, pos), + } + analysis.Allocations = append(analysis.Allocations, alloc) + } + } + } + return true + }) + + // Look for specific patterns + findAllocationPatterns(file, fset, src, analysis) + + return nil + }) + + return analysis, err +} + +func analyzeCallExpr(call *ast.CallExpr, file *ast.File, fset *token.FileSet, src []byte, analysis *AllocationAnalysis) { + ident, ok := call.Fun.(*ast.Ident) + if !ok { + // Check for method calls like strings.Builder + if sel, ok := call.Fun.(*ast.SelectorExpr); ok { + analyzeMethodCall(sel, call, file, fset, src, analysis) + } + return + } + + pos := fset.Position(call.Pos()) + inLoop := isInLoop(file, call) + + switch ident.Name { + case "make": + if len(call.Args) > 0 { + typeStr := exprToString(call.Args[0]) + sizeStr := "" + if len(call.Args) > 1 { + sizeStr = " with size" + } + + alloc := MemoryAllocation{ + Type: "make", + Description: "make(" + typeStr + ")" + sizeStr, + InLoop: inLoop, + Position: newPosition(pos), + Context: extractContext(src, pos), + } + analysis.Allocations = append(analysis.Allocations, alloc) + + if inLoop { + issue := AllocationIssue{ + Type: "make_in_loop", + Description: "make() called inside loop - consider pre-allocating", + Position: newPosition(pos), + } + analysis.Issues = append(analysis.Issues, issue) + } + } + + case "new": + if len(call.Args) > 0 { + typeStr := exprToString(call.Args[0]) + + alloc := MemoryAllocation{ + Type: "new", + Description: "new(" + typeStr + ")", + InLoop: inLoop, + Position: newPosition(pos), + Context: extractContext(src, pos), + } + analysis.Allocations = append(analysis.Allocations, alloc) + + if inLoop { + issue := AllocationIssue{ + Type: "new_in_loop", + Description: "new() called inside loop - consider pre-allocating", + Position: newPosition(pos), + } + analysis.Issues = append(analysis.Issues, issue) + } + } + + case "append": + alloc := MemoryAllocation{ + Type: "append", + Description: "append() may cause reallocation", + InLoop: inLoop, + Position: newPosition(pos), + Context: extractContext(src, pos), + } + analysis.Allocations = append(analysis.Allocations, alloc) + + if inLoop && !hasPreallocation(file, call) { + issue := AllocationIssue{ + Type: "append_in_loop", + Description: "append() in loop without pre-allocation - consider pre-allocating slice", + Position: newPosition(pos), + } + analysis.Issues = append(analysis.Issues, issue) + } + } +} + +func analyzeMethodCall(sel *ast.SelectorExpr, call *ast.CallExpr, file *ast.File, fset *token.FileSet, src []byte, analysis *AllocationAnalysis) { + // Check for common allocation patterns in method calls + methodName := sel.Sel.Name + + // Check for strings.Builder inefficiencies + if methodName == "WriteString" || methodName == "Write" { + if ident, ok := sel.X.(*ast.Ident); ok { + if isStringBuilderType(file, ident) && isInLoop(file, call) { + // This is okay - strings.Builder is designed for this + return + } + } + } +} + +func analyzeStringConcat(binExpr *ast.BinaryExpr, file *ast.File, fset *token.FileSet, src []byte, analysis *AllocationAnalysis) { + // Check if this is string concatenation + if !isStringType(binExpr.X) && !isStringType(binExpr.Y) { + return + } + + pos := fset.Position(binExpr.Pos()) + inLoop := isInLoop(file, binExpr) + + if inLoop { + alloc := MemoryAllocation{ + Type: "string_concat", + Description: "String concatenation with +", + InLoop: true, + Position: newPosition(pos), + Context: extractContext(src, pos), + } + analysis.Allocations = append(analysis.Allocations, alloc) + + issue := AllocationIssue{ + Type: "string_concat_in_loop", + Description: "String concatenation in loop - use strings.Builder instead", + Position: newPosition(pos), + } + analysis.Issues = append(analysis.Issues, issue) + } +} + +func isStringType(expr ast.Expr) bool { + // Simple heuristic - check for string literals or string-like identifiers + switch e := expr.(type) { + case *ast.BasicLit: + return e.Kind == token.STRING + case *ast.Ident: + // This is a simplification - ideally we'd have type info + return strings.Contains(strings.ToLower(e.Name), "str") || + strings.Contains(strings.ToLower(e.Name), "msg") || + strings.Contains(strings.ToLower(e.Name), "text") + } + return false +} + +func isEscaping(file *ast.File, unary *ast.UnaryExpr) bool { + // Simple escape analysis - if address is assigned or passed to function + var escapes bool + + ast.Inspect(file, func(n ast.Node) bool { + switch node := n.(type) { + case *ast.AssignStmt: + for _, rhs := range node.Rhs { + if rhs == unary { + escapes = true + return false + } + } + case *ast.CallExpr: + for _, arg := range node.Args { + if arg == unary { + escapes = true + return false + } + } + case *ast.ReturnStmt: + for _, result := range node.Results { + if result == unary { + escapes = true + return false + } + } + } + return true + }) + + return escapes +} + +func hasPreallocation(file *ast.File, appendCall *ast.CallExpr) bool { + // Check if the slice being appended to was pre-allocated + if len(appendCall.Args) == 0 { + return false + } + + // Get the slice being appended to + sliceName := extractSliceName(appendCall.Args[0]) + if sliceName == "" { + return false + } + + // Look for make() call with capacity + var hasCapacity bool + ast.Inspect(file, func(n ast.Node) bool { + if assign, ok := n.(*ast.AssignStmt); ok { + for i, lhs := range assign.Lhs { + if ident, ok := lhs.(*ast.Ident); ok && ident.Name == sliceName { + if i < len(assign.Rhs) { + if call, ok := assign.Rhs[i].(*ast.CallExpr); ok { + if ident, ok := call.Fun.(*ast.Ident); ok && ident.Name == "make" { + // Check if make has capacity argument + if len(call.Args) >= 3 { + hasCapacity = true + return false + } + } + } + } + } + } + } + return true + }) + + return hasCapacity +} + +func extractSliceName(expr ast.Expr) string { + if ident, ok := expr.(*ast.Ident); ok { + return ident.Name + } + return "" +} + +func isStringBuilderType(file *ast.File, ident *ast.Ident) bool { + // Check if identifier is of type strings.Builder + var isBuilder bool + + ast.Inspect(file, func(n ast.Node) bool { + switch node := n.(type) { + case *ast.ValueSpec: + for i, name := range node.Names { + if name.Name == ident.Name { + if node.Type != nil { + if sel, ok := node.Type.(*ast.SelectorExpr); ok { + if pkg, ok := sel.X.(*ast.Ident); ok { + isBuilder = pkg.Name == "strings" && sel.Sel.Name == "Builder" + return false + } + } + } else if i < len(node.Values) { + // Check initialization + if comp, ok := node.Values[i].(*ast.CompositeLit); ok { + if sel, ok := comp.Type.(*ast.SelectorExpr); ok { + if pkg, ok := sel.X.(*ast.Ident); ok { + isBuilder = pkg.Name == "strings" && sel.Sel.Name == "Builder" + return false + } + } + } + } + } + } + } + return true + }) + + return isBuilder +} + +func findAllocationPatterns(file *ast.File, fset *token.FileSet, src []byte, analysis *AllocationAnalysis) { + // Look for interface{} allocations + ast.Inspect(file, func(n ast.Node) bool { + if callExpr, ok := n.(*ast.CallExpr); ok { + // Check for fmt.Sprintf and similar + if sel, ok := callExpr.Fun.(*ast.SelectorExpr); ok { + if ident, ok := sel.X.(*ast.Ident); ok && ident.Name == "fmt" { + if strings.HasPrefix(sel.Sel.Name, "Sprint") { + pos := fset.Position(callExpr.Pos()) + alloc := MemoryAllocation{ + Type: "fmt_sprintf", + Description: "fmt." + sel.Sel.Name + " allocates for interface{} conversions", + InLoop: isInLoop(file, callExpr), + Position: newPosition(pos), + Context: extractContext(src, pos), + } + analysis.Allocations = append(analysis.Allocations, alloc) + } + } + } + } + return true + }) +} \ No newline at end of file diff --git a/tool_analyze_naming_conventions.go b/tool_analyze_naming_conventions.go new file mode 100644 index 0000000..eb59ea3 --- /dev/null +++ b/tool_analyze_naming_conventions.go @@ -0,0 +1,418 @@ +package main + +import ( + "go/ast" + "go/token" + "strings" + "unicode" +) + +type NamingViolation struct { + Name string `json:"name"` + Type string `json:"type"` // "function", "variable", "constant", "type", "package" + Issue string `json:"issue"` + Suggestion string `json:"suggestion,omitempty"` + Position Position `json:"position"` +} + +type NamingAnalysis struct { + Violations []NamingViolation `json:"violations"` + Statistics NamingStats `json:"statistics"` +} + +type NamingStats struct { + TotalSymbols int `json:"total_symbols"` + ExportedSymbols int `json:"exported_symbols"` + UnexportedSymbols int `json:"unexported_symbols"` + ViolationCount int `json:"violation_count"` +} + +func analyzeNamingConventions(dir string) (*NamingAnalysis, error) { + analysis := &NamingAnalysis{ + Violations: []NamingViolation{}, + Statistics: NamingStats{}, + } + + err := walkGoFiles(dir, func(path string, src []byte, file *ast.File, fset *token.FileSet) error { + // Check package name + checkPackageName(file, fset, analysis) + + ast.Inspect(file, func(n ast.Node) bool { + switch node := n.(type) { + case *ast.FuncDecl: + analysis.Statistics.TotalSymbols++ + checkFunctionName(node, fset, analysis) + + case *ast.GenDecl: + for _, spec := range node.Specs { + switch s := spec.(type) { + case *ast.TypeSpec: + analysis.Statistics.TotalSymbols++ + checkTypeName(s, node, fset, analysis) + + case *ast.ValueSpec: + for _, name := range s.Names { + analysis.Statistics.TotalSymbols++ + if node.Tok == token.CONST { + checkConstantName(name, fset, analysis) + } else { + checkVariableName(name, fset, analysis) + } + } + } + } + } + return true + }) + + return nil + }) + + analysis.Statistics.ViolationCount = len(analysis.Violations) + return analysis, err +} + +func checkPackageName(file *ast.File, fset *token.FileSet, analysis *NamingAnalysis) { + name := file.Name.Name + pos := fset.Position(file.Name.Pos()) + + // Package names should be lowercase + if !isAllLowercase(name) { + violation := NamingViolation{ + Name: name, + Type: "package", + Issue: "Package name should be lowercase", + Suggestion: strings.ToLower(name), + Position: newPosition(pos), + } + analysis.Violations = append(analysis.Violations, violation) + } + + // Check for underscores + if strings.Contains(name, "_") && name != "main" { + violation := NamingViolation{ + Name: name, + Type: "package", + Issue: "Package name should not contain underscores", + Position: newPosition(pos), + } + analysis.Violations = append(analysis.Violations, violation) + } +} + +func checkFunctionName(fn *ast.FuncDecl, fset *token.FileSet, analysis *NamingAnalysis) { + name := fn.Name.Name + pos := fset.Position(fn.Name.Pos()) + isExported := ast.IsExported(name) + + if isExported { + analysis.Statistics.ExportedSymbols++ + } else { + analysis.Statistics.UnexportedSymbols++ + } + + // Check receiver naming + if fn.Recv != nil && len(fn.Recv.List) > 0 { + for _, recv := range fn.Recv.List { + for _, recvName := range recv.Names { + checkReceiverName(recvName, recv.Type, fset, analysis) + } + } + } + + // Check CamelCase + if !isCamelCase(name) && !isSpecialFunction(name) { + violation := NamingViolation{ + Name: name, + Type: "function", + Issue: "Function name should be in CamelCase", + Suggestion: toCamelCase(name), + Position: newPosition(pos), + } + analysis.Violations = append(analysis.Violations, violation) + } + + // Check exported function starts with capital + if isExported && !unicode.IsUpper(rune(name[0])) { + violation := NamingViolation{ + Name: name, + Type: "function", + Issue: "Exported function should start with capital letter", + Position: newPosition(pos), + } + analysis.Violations = append(analysis.Violations, violation) + } + + // Check for Get prefix on getters + if strings.HasPrefix(name, "Get") && fn.Recv != nil && !returnsError(fn) { + violation := NamingViolation{ + Name: name, + Type: "function", + Issue: "Getter methods should not use Get prefix", + Suggestion: name[3:], // Remove "Get" prefix + Position: newPosition(pos), + } + analysis.Violations = append(analysis.Violations, violation) + } +} + +func checkTypeName(typeSpec *ast.TypeSpec, genDecl *ast.GenDecl, fset *token.FileSet, analysis *NamingAnalysis) { + name := typeSpec.Name.Name + pos := fset.Position(typeSpec.Name.Pos()) + isExported := ast.IsExported(name) + + if isExported { + analysis.Statistics.ExportedSymbols++ + } else { + analysis.Statistics.UnexportedSymbols++ + } + + // Check CamelCase + if !isCamelCase(name) { + violation := NamingViolation{ + Name: name, + Type: "type", + Issue: "Type name should be in CamelCase", + Suggestion: toCamelCase(name), + Position: newPosition(pos), + } + analysis.Violations = append(analysis.Violations, violation) + } + + // Check interface naming + if _, ok := typeSpec.Type.(*ast.InterfaceType); ok { + if isExported && !strings.HasSuffix(name, "er") && !isWellKnownInterface(name) { + // Only suggest for single-method interfaces + if iface, ok := typeSpec.Type.(*ast.InterfaceType); ok && len(iface.Methods.List) == 1 { + violation := NamingViolation{ + Name: name, + Type: "type", + Issue: "Single-method interface should end with 'er'", + Position: newPosition(pos), + } + analysis.Violations = append(analysis.Violations, violation) + } + } + } +} + +func checkConstantName(name *ast.Ident, fset *token.FileSet, analysis *NamingAnalysis) { + pos := fset.Position(name.Pos()) + isExported := ast.IsExported(name.Name) + + if isExported { + analysis.Statistics.ExportedSymbols++ + } else { + analysis.Statistics.UnexportedSymbols++ + } + + // Constants can be CamelCase or ALL_CAPS + if !isCamelCase(name.Name) && !isAllCaps(name.Name) { + violation := NamingViolation{ + Name: name.Name, + Type: "constant", + Issue: "Constant should be in CamelCase or ALL_CAPS", + Suggestion: toCamelCase(name.Name), + Position: newPosition(pos), + } + analysis.Violations = append(analysis.Violations, violation) + } +} + +func checkVariableName(name *ast.Ident, fset *token.FileSet, analysis *NamingAnalysis) { + pos := fset.Position(name.Pos()) + isExported := ast.IsExported(name.Name) + + if isExported { + analysis.Statistics.ExportedSymbols++ + } else { + analysis.Statistics.UnexportedSymbols++ + } + + // Skip blank identifier + if name.Name == "_" { + return + } + + // Check for single letter names (except common ones) + if len(name.Name) == 1 && !isCommonSingleLetter(name.Name) { + violation := NamingViolation{ + Name: name.Name, + Type: "variable", + Issue: "Single letter variable names should be avoided except for common cases (i, j, k for loops)", + Position: newPosition(pos), + } + analysis.Violations = append(analysis.Violations, violation) + } + + // Check CamelCase + if !isCamelCase(name.Name) && len(name.Name) > 1 { + violation := NamingViolation{ + Name: name.Name, + Type: "variable", + Issue: "Variable name should be in camelCase", + Suggestion: toCamelCase(name.Name), + Position: newPosition(pos), + } + analysis.Violations = append(analysis.Violations, violation) + } +} + +func checkReceiverName(name *ast.Ident, recvType ast.Expr, fset *token.FileSet, analysis *NamingAnalysis) { + pos := fset.Position(name.Pos()) + + // Receiver names should be short + if len(name.Name) > 3 { + typeName := extractReceiverTypeName(recvType) + suggestion := "" + if typeName != "" && len(typeName) > 0 { + suggestion = strings.ToLower(string(typeName[0])) + } + + violation := NamingViolation{ + Name: name.Name, + Type: "receiver", + Issue: "Receiver name should be a short, typically one-letter abbreviation", + Suggestion: suggestion, + Position: newPosition(pos), + } + analysis.Violations = append(analysis.Violations, violation) + } + + // Check for "self" or "this" + if name.Name == "self" || name.Name == "this" { + violation := NamingViolation{ + Name: name.Name, + Type: "receiver", + Issue: "Avoid 'self' or 'this' for receiver names", + Position: newPosition(pos), + } + analysis.Violations = append(analysis.Violations, violation) + } +} + +func extractReceiverTypeName(expr ast.Expr) string { + switch t := expr.(type) { + case *ast.Ident: + return t.Name + case *ast.StarExpr: + return extractReceiverTypeName(t.X) + } + return "" +} + +func isCamelCase(name string) bool { + if len(name) == 0 { + return false + } + + // Check for underscores + if strings.Contains(name, "_") { + return false + } + + // Allow all lowercase for short names (like "i", "ok", "err") + if len(name) <= 3 && isAllLowercase(name) { + return true + } + + // Check for proper camelCase/PascalCase + hasUpper := false + hasLower := false + for _, r := range name { + if unicode.IsUpper(r) { + hasUpper = true + } else if unicode.IsLower(r) { + hasLower = true + } + } + + // Single case is ok for short names + return len(name) <= 3 || (hasUpper && hasLower) || isAllLowercase(name) +} + +func isAllLowercase(s string) bool { + for _, r := range s { + if unicode.IsLetter(r) && !unicode.IsLower(r) { + return false + } + } + return true +} + +func isAllCaps(s string) bool { + for _, r := range s { + if unicode.IsLetter(r) && !unicode.IsUpper(r) { + return false + } + } + return true +} + +func toCamelCase(s string) string { + words := strings.FieldsFunc(s, func(r rune) bool { + return r == '_' || r == '-' + }) + + if len(words) == 0 { + return s + } + + // First word stays lowercase for camelCase + result := strings.ToLower(words[0]) + + // Capitalize first letter of subsequent words + for i := 1; i < len(words); i++ { + if len(words[i]) > 0 { + result += strings.ToUpper(words[i][:1]) + strings.ToLower(words[i][1:]) + } + } + + return result +} + +func isSpecialFunction(name string) bool { + // Special functions that don't follow normal naming + special := []string{"init", "main", "String", "Error", "MarshalJSON", "UnmarshalJSON"} + for _, s := range special { + if name == s { + return true + } + } + return false +} + +func isWellKnownInterface(name string) bool { + // Well-known interfaces that don't end in 'er' + known := []string{"Interface", "Handler", "ResponseWriter", "Context", "Value"} + for _, k := range known { + if name == k { + return true + } + } + return false +} + +func isCommonSingleLetter(name string) bool { + // Common single letter variables that are acceptable + common := []string{"i", "j", "k", "n", "m", "x", "y", "z", "s", "b", "r", "w", "t"} + for _, c := range common { + if name == c { + return true + } + } + return false +} + +func returnsError(fn *ast.FuncDecl) bool { + if fn.Type.Results == nil { + return false + } + + for _, result := range fn.Type.Results.List { + if ident, ok := result.Type.(*ast.Ident); ok && ident.Name == "error" { + return true + } + } + return false +} \ No newline at end of file diff --git a/tool_find_empty_blocks.go b/tool_find_empty_blocks.go new file mode 100644 index 0000000..d47ce51 --- /dev/null +++ b/tool_find_empty_blocks.go @@ -0,0 +1,366 @@ +package main + +import ( + "go/ast" + "go/token" + "strings" +) + +type EmptyBlock struct { + Type string `json:"type"` // "if", "else", "for", "switch_case", "function", etc. + Description string `json:"description"` + Position Position `json:"position"` + Context string `json:"context"` +} + +type EmptyBlockAnalysis struct { + EmptyBlocks []EmptyBlock `json:"empty_blocks"` + Issues []EmptyBlockIssue `json:"issues"` +} + +type EmptyBlockIssue struct { + Type string `json:"type"` + Description string `json:"description"` + Position Position `json:"position"` +} + +func findEmptyBlocks(dir string) (*EmptyBlockAnalysis, error) { + analysis := &EmptyBlockAnalysis{ + EmptyBlocks: []EmptyBlock{}, + Issues: []EmptyBlockIssue{}, + } + + err := walkGoFiles(dir, func(path string, src []byte, file *ast.File, fset *token.FileSet) error { + ast.Inspect(file, func(n ast.Node) bool { + switch node := n.(type) { + case *ast.IfStmt: + analyzeIfStatement(node, fset, src, analysis) + + case *ast.ForStmt: + if isEmptyBlock(node.Body) { + pos := fset.Position(node.Pos()) + empty := EmptyBlock{ + Type: "for", + Description: "Empty for loop", + Position: newPosition(pos), + Context: extractContext(src, pos), + } + analysis.EmptyBlocks = append(analysis.EmptyBlocks, empty) + + // Check if it's an infinite loop + if node.Cond == nil && node.Init == nil && node.Post == nil { + issue := EmptyBlockIssue{ + Type: "empty_infinite_loop", + Description: "Empty infinite loop - possible bug or incomplete implementation", + Position: newPosition(pos), + } + analysis.Issues = append(analysis.Issues, issue) + } + } + + case *ast.RangeStmt: + if isEmptyBlock(node.Body) { + pos := fset.Position(node.Pos()) + empty := EmptyBlock{ + Type: "range", + Description: "Empty range loop", + Position: newPosition(pos), + Context: extractContext(src, pos), + } + analysis.EmptyBlocks = append(analysis.EmptyBlocks, empty) + } + + case *ast.SwitchStmt: + analyzeSwitchStatement(node, fset, src, analysis) + + case *ast.TypeSwitchStmt: + analyzeTypeSwitchStatement(node, fset, src, analysis) + + case *ast.FuncDecl: + if node.Body != nil && isEmptyBlock(node.Body) { + pos := fset.Position(node.Pos()) + empty := EmptyBlock{ + Type: "function", + Description: "Empty function: " + node.Name.Name, + Position: newPosition(pos), + Context: extractContext(src, pos), + } + analysis.EmptyBlocks = append(analysis.EmptyBlocks, empty) + + // Check if it's an interface stub + if !isInterfaceStub(node) && !isTestHelper(node.Name.Name) { + issue := EmptyBlockIssue{ + Type: "empty_function", + Description: "Function '" + node.Name.Name + "' has empty body", + Position: newPosition(pos), + } + analysis.Issues = append(analysis.Issues, issue) + } + } + + case *ast.BlockStmt: + // Check for standalone empty blocks + if isEmptyBlock(node) && !isPartOfControlStructure(file, node) { + pos := fset.Position(node.Pos()) + empty := EmptyBlock{ + Type: "block", + Description: "Empty code block", + Position: newPosition(pos), + Context: extractContext(src, pos), + } + analysis.EmptyBlocks = append(analysis.EmptyBlocks, empty) + } + } + return true + }) + + return nil + }) + + return analysis, err +} + +func isEmptyBlock(block *ast.BlockStmt) bool { + if block == nil { + return true + } + + // Check if block has no statements + if len(block.List) == 0 { + return true + } + + // Check if all statements are empty + for _, stmt := range block.List { + if !isEmptyStatement(stmt) { + return false + } + } + + return true +} + +func isEmptyStatement(stmt ast.Stmt) bool { + switch s := stmt.(type) { + case *ast.EmptyStmt: + return true + case *ast.BlockStmt: + return isEmptyBlock(s) + default: + return false + } +} + +func analyzeIfStatement(ifStmt *ast.IfStmt, fset *token.FileSet, src []byte, analysis *EmptyBlockAnalysis) { + // Check if body + if isEmptyBlock(ifStmt.Body) { + pos := fset.Position(ifStmt.Pos()) + empty := EmptyBlock{ + Type: "if", + Description: "Empty if block", + Position: newPosition(pos), + Context: extractContext(src, pos), + } + analysis.EmptyBlocks = append(analysis.EmptyBlocks, empty) + + // Check if there's an else block + if ifStmt.Else == nil { + issue := EmptyBlockIssue{ + Type: "empty_if_no_else", + Description: "Empty if block with no else - condition may be unnecessary", + Position: newPosition(pos), + } + analysis.Issues = append(analysis.Issues, issue) + } + } + + // Check else block + if ifStmt.Else != nil { + switch elseNode := ifStmt.Else.(type) { + case *ast.BlockStmt: + if isEmptyBlock(elseNode) { + pos := fset.Position(elseNode.Pos()) + empty := EmptyBlock{ + Type: "else", + Description: "Empty else block", + Position: newPosition(pos), + Context: extractContext(src, pos), + } + analysis.EmptyBlocks = append(analysis.EmptyBlocks, empty) + + issue := EmptyBlockIssue{ + Type: "empty_else", + Description: "Empty else block - can be removed", + Position: newPosition(pos), + } + analysis.Issues = append(analysis.Issues, issue) + } + case *ast.IfStmt: + // Recursively analyze else if + analyzeIfStatement(elseNode, fset, src, analysis) + } + } +} + +func analyzeSwitchStatement(switchStmt *ast.SwitchStmt, fset *token.FileSet, src []byte, analysis *EmptyBlockAnalysis) { + for _, stmt := range switchStmt.Body.List { + if caseClause, ok := stmt.(*ast.CaseClause); ok { + if len(caseClause.Body) == 0 { + pos := fset.Position(caseClause.Pos()) + caseDesc := "default" + if len(caseClause.List) > 0 { + caseDesc = "case" + } + + empty := EmptyBlock{ + Type: "switch_case", + Description: "Empty " + caseDesc + " clause", + Position: newPosition(pos), + Context: extractContext(src, pos), + } + analysis.EmptyBlocks = append(analysis.EmptyBlocks, empty) + + // Check if it's not a fallthrough case + if !hasFallthrough(switchStmt, caseClause) { + issue := EmptyBlockIssue{ + Type: "empty_switch_case", + Description: "Empty " + caseDesc + " clause with no fallthrough", + Position: newPosition(pos), + } + analysis.Issues = append(analysis.Issues, issue) + } + } + } + } +} + +func analyzeTypeSwitchStatement(typeSwitch *ast.TypeSwitchStmt, fset *token.FileSet, src []byte, analysis *EmptyBlockAnalysis) { + for _, stmt := range typeSwitch.Body.List { + if caseClause, ok := stmt.(*ast.CaseClause); ok { + if len(caseClause.Body) == 0 { + pos := fset.Position(caseClause.Pos()) + caseDesc := "default" + if len(caseClause.List) > 0 { + caseDesc = "type case" + } + + empty := EmptyBlock{ + Type: "type_switch_case", + Description: "Empty " + caseDesc + " clause", + Position: newPosition(pos), + Context: extractContext(src, pos), + } + analysis.EmptyBlocks = append(analysis.EmptyBlocks, empty) + } + } + } +} + +func isPartOfControlStructure(file *ast.File, block *ast.BlockStmt) bool { + var isControl bool + ast.Inspect(file, func(n ast.Node) bool { + switch node := n.(type) { + case *ast.IfStmt: + if node.Body == block || node.Else == block { + isControl = true + return false + } + case *ast.ForStmt: + if node.Body == block { + isControl = true + return false + } + case *ast.RangeStmt: + if node.Body == block { + isControl = true + return false + } + case *ast.SwitchStmt: + if node.Body == block { + isControl = true + return false + } + case *ast.TypeSwitchStmt: + if node.Body == block { + isControl = true + return false + } + case *ast.FuncDecl: + if node.Body == block { + isControl = true + return false + } + case *ast.FuncLit: + if node.Body == block { + isControl = true + return false + } + } + return true + }) + return isControl +} + +func hasFallthrough(switchStmt *ast.SwitchStmt, caseClause *ast.CaseClause) bool { + // Check if the previous case has a fallthrough + var prevCase *ast.CaseClause + for _, stmt := range switchStmt.Body.List { + if cc, ok := stmt.(*ast.CaseClause); ok { + if cc == caseClause && prevCase != nil { + // Check if previous case ends with fallthrough + if len(prevCase.Body) > 0 { + if _, ok := prevCase.Body[len(prevCase.Body)-1].(*ast.BranchStmt); ok { + return true + } + } + } + prevCase = cc + } + } + return false +} + +func isInterfaceStub(fn *ast.FuncDecl) bool { + // Check if function has a receiver (method) + if fn.Recv == nil || len(fn.Recv.List) == 0 { + return false + } + + // Check for common stub patterns in name + name := fn.Name.Name + stubPatterns := []string{"Stub", "Mock", "Fake", "Dummy", "NoOp", "Noop"} + for _, pattern := range stubPatterns { + if strings.Contains(name, pattern) { + return true + } + } + + // Check if receiver type contains stub patterns + if len(fn.Recv.List) > 0 { + recvType := exprToString(fn.Recv.List[0].Type) + for _, pattern := range stubPatterns { + if strings.Contains(recvType, pattern) { + return true + } + } + } + + return false +} + +func isTestHelper(name string) bool { + // Common test helper patterns + helpers := []string{"setUp", "tearDown", "beforeEach", "afterEach", "beforeAll", "afterAll"} + nameLower := strings.ToLower(name) + + for _, helper := range helpers { + if strings.ToLower(helper) == nameLower { + return true + } + } + + // Check for test-related prefixes + return strings.HasPrefix(name, "Test") || + strings.HasPrefix(name, "Benchmark") || + strings.HasPrefix(name, "Example") +} \ No newline at end of file diff --git a/tool_find_errors.go b/tool_find_errors.go index 9476b19..63061b7 100644 --- a/tool_find_errors.go +++ b/tool_find_errors.go @@ -34,7 +34,7 @@ func findErrors(dir string) ([]ErrorInfo, error) { case *ast.ExprStmt: if call, ok := x.X.(*ast.CallExpr); ok { // Check if this function likely returns an error - if returnsError(call, file) { + if callReturnsError(call) { pos := fset.Position(call.Pos()) context := extractContext(src, pos) info.UnhandledErrors = append(info.UnhandledErrors, ErrorContext{ @@ -84,7 +84,7 @@ func findErrors(dir string) ([]ErrorInfo, error) { return errors, err } -func returnsError(call *ast.CallExpr, file *ast.File) bool { +func callReturnsError(call *ast.CallExpr) bool { // Simple heuristic: check if the function name suggests it returns an error switch fun := call.Fun.(type) { case *ast.Ident: diff --git a/tool_find_inefficiencies.go b/tool_find_inefficiencies.go index 25f400d..539eb8d 100644 --- a/tool_find_inefficiencies.go +++ b/tool_find_inefficiencies.go @@ -34,7 +34,7 @@ func findInefficiencies(dir string) ([]InefficiencyInfo, error) { if forStmt, ok := n.(*ast.ForStmt); ok { ast.Inspect(forStmt.Body, func(inner ast.Node) bool { if binExpr, ok := inner.(*ast.BinaryExpr); ok && binExpr.Op == token.ADD { - if isStringType(binExpr.X) || isStringType(binExpr.Y) { + if ineffIsStringType(binExpr.X) || ineffIsStringType(binExpr.Y) { pos := fset.Position(binExpr.Pos()) info.StringConcat = append(info.StringConcat, InefficiencyItem{ Type: "string_concatenation_in_loop", @@ -78,7 +78,7 @@ func findInefficiencies(dir string) ([]InefficiencyInfo, error) { return inefficiencies, err } -func isStringType(expr ast.Expr) bool { +func ineffIsStringType(expr ast.Expr) bool { if ident, ok := expr.(*ast.Ident); ok { return ident.Name == "string" } diff --git a/tool_find_init_functions.go b/tool_find_init_functions.go new file mode 100644 index 0000000..d215b87 --- /dev/null +++ b/tool_find_init_functions.go @@ -0,0 +1,360 @@ +package main + +import ( + "go/ast" + "go/token" + "sort" + "strings" +) + +type InitFunction struct { + Package string `json:"package"` + FilePath string `json:"file_path"` + Position Position `json:"position"` + Dependencies []string `json:"dependencies"` // Packages this init might depend on + HasSideEffects bool `json:"has_side_effects"` + Context string `json:"context"` +} + +type InitAnalysis struct { + InitFunctions []InitFunction `json:"init_functions"` + Issues []InitIssue `json:"issues"` + InitOrder []string `json:"init_order"` // Suggested initialization order +} + +type InitIssue struct { + Type string `json:"type"` + Description string `json:"description"` + Position Position `json:"position"` +} + +func findInitFunctions(dir string) (*InitAnalysis, error) { + analysis := &InitAnalysis{ + InitFunctions: []InitFunction{}, + Issues: []InitIssue{}, + InitOrder: []string{}, + } + + packageInits := make(map[string][]InitFunction) // package -> init functions + + err := walkGoFiles(dir, func(path string, src []byte, file *ast.File, fset *token.FileSet) error { + pkgName := file.Name.Name + + ast.Inspect(file, func(n ast.Node) bool { + if funcDecl, ok := n.(*ast.FuncDecl); ok && funcDecl.Name.Name == "init" { + pos := fset.Position(funcDecl.Pos()) + + initFunc := InitFunction{ + Package: pkgName, + FilePath: path, + Position: newPosition(pos), + Dependencies: extractInitDependencies(file, funcDecl), + HasSideEffects: hasInitSideEffects(funcDecl), + Context: extractContext(src, pos), + } + + analysis.InitFunctions = append(analysis.InitFunctions, initFunc) + packageInits[pkgName] = append(packageInits[pkgName], initFunc) + + // Analyze init function for issues + analyzeInitFunction(funcDecl, fset, analysis) + } + return true + }) + + return nil + }) + + if err != nil { + return nil, err + } + + // Analyze package-level init dependencies + analyzeInitDependencies(analysis, packageInits) + + // Sort init functions by package and file + sort.Slice(analysis.InitFunctions, func(i, j int) bool { + if analysis.InitFunctions[i].Package != analysis.InitFunctions[j].Package { + return analysis.InitFunctions[i].Package < analysis.InitFunctions[j].Package + } + return analysis.InitFunctions[i].FilePath < analysis.InitFunctions[j].FilePath + }) + + return analysis, nil +} + +func extractInitDependencies(file *ast.File, init *ast.FuncDecl) []string { + deps := make(map[string]bool) + + // Look for package references in init + ast.Inspect(init, func(n ast.Node) bool { + switch node := n.(type) { + case *ast.SelectorExpr: + if ident, ok := node.X.(*ast.Ident); ok { + // Check if this is a package reference + if isPackageIdent(file, ident.Name) { + deps[ident.Name] = true + } + } + case *ast.CallExpr: + // Check for function calls that might depend on other packages + if sel, ok := node.Fun.(*ast.SelectorExpr); ok { + if ident, ok := sel.X.(*ast.Ident); ok { + if isPackageIdent(file, ident.Name) { + deps[ident.Name] = true + } + } + } + } + return true + }) + + // Convert to slice + var result []string + for dep := range deps { + result = append(result, dep) + } + sort.Strings(result) + return result +} + +func isPackageIdent(file *ast.File, name string) bool { + // Check if name matches an import + for _, imp := range file.Imports { + importPath := strings.Trim(imp.Path.Value, `"`) + pkgName := importPath + if idx := strings.LastIndex(importPath, "/"); idx >= 0 { + pkgName = importPath[idx+1:] + } + + if imp.Name != nil { + // Named import + if imp.Name.Name == name { + return true + } + } else if pkgName == name { + // Default import name + return true + } + } + return false +} + +func hasInitSideEffects(init *ast.FuncDecl) bool { + var hasSideEffects bool + + ast.Inspect(init, func(n ast.Node) bool { + switch n.(type) { + case *ast.AssignStmt: + // Check for global variable assignments + hasSideEffects = true + return false + case *ast.CallExpr: + // Function calls likely have side effects + hasSideEffects = true + return false + case *ast.GoStmt: + // Starting goroutines + hasSideEffects = true + return false + case *ast.SendStmt: + // Channel operations + hasSideEffects = true + return false + } + return true + }) + + return hasSideEffects +} + +func analyzeInitFunction(init *ast.FuncDecl, fset *token.FileSet, analysis *InitAnalysis) { + // Check for complex init logic + stmtCount := countStatements(init.Body) + if stmtCount > 20 { + pos := fset.Position(init.Pos()) + issue := InitIssue{ + Type: "complex_init", + Description: "init() function is complex - consider refactoring", + Position: newPosition(pos), + } + analysis.Issues = append(analysis.Issues, issue) + } + + // Check for blocking operations + ast.Inspect(init, func(n ast.Node) bool { + switch node := n.(type) { + case *ast.CallExpr: + if sel, ok := node.Fun.(*ast.SelectorExpr); ok { + funcName := sel.Sel.Name + // Check for potentially blocking operations + if isBlockingCall(sel) { + pos := fset.Position(node.Pos()) + issue := InitIssue{ + Type: "blocking_init", + Description: "init() contains potentially blocking call: " + funcName, + Position: newPosition(pos), + } + analysis.Issues = append(analysis.Issues, issue) + } + } + + case *ast.GoStmt: + pos := fset.Position(node.Pos()) + issue := InitIssue{ + Type: "goroutine_in_init", + Description: "init() starts a goroutine - may cause race conditions", + Position: newPosition(pos), + } + analysis.Issues = append(analysis.Issues, issue) + + case *ast.ForStmt: + if node.Cond == nil { + pos := fset.Position(node.Pos()) + issue := InitIssue{ + Type: "infinite_loop_in_init", + Description: "init() contains infinite loop - will block program startup", + Position: newPosition(pos), + } + analysis.Issues = append(analysis.Issues, issue) + } + } + return true + }) +} + +func countStatements(block *ast.BlockStmt) int { + count := 0 + ast.Inspect(block, func(n ast.Node) bool { + switch n.(type) { + case ast.Stmt: + count++ + } + return true + }) + return count +} + +func isBlockingCall(sel *ast.SelectorExpr) bool { + // Check for common blocking operations + methodName := sel.Sel.Name + blockingMethods := []string{ + "Sleep", "Wait", "Lock", "RLock", "Dial", "Connect", + "Open", "Create", "ReadFile", "WriteFile", "Sync", + } + + for _, blocking := range blockingMethods { + if strings.Contains(methodName, blocking) { + return true + } + } + + // Check for HTTP/network operations + if ident, ok := sel.X.(*ast.Ident); ok { + if ident.Name == "http" || ident.Name == "net" { + return true + } + } + + return false +} + +func analyzeInitDependencies(analysis *InitAnalysis, packageInits map[string][]InitFunction) { + // Build dependency graph + depGraph := make(map[string][]string) + + for pkg, inits := range packageInits { + deps := make(map[string]bool) + for _, init := range inits { + for _, dep := range init.Dependencies { + deps[dep] = true + } + } + + var depList []string + for dep := range deps { + depList = append(depList, dep) + } + depGraph[pkg] = depList + } + + // Check for circular dependencies + for pkg := range depGraph { + if hasCycle, cycle := detectCycle(pkg, depGraph, make(map[string]bool), []string{pkg}); hasCycle { + issue := InitIssue{ + Type: "circular_init_dependency", + Description: "Circular init dependency detected: " + strings.Join(cycle, " -> "), + Position: Position{}, // Package-level issue + } + analysis.Issues = append(analysis.Issues, issue) + } + } + + // Suggest initialization order (topological sort) + analysis.InitOrder = topologicalSort(depGraph) +} + +func detectCycle(current string, graph map[string][]string, visited map[string]bool, path []string) (bool, []string) { + if visited[current] { + // Find where the cycle starts + for i, pkg := range path { + if pkg == current { + return true, path[i:] + } + } + } + + visited[current] = true + + for _, dep := range graph[current] { + newPath := append(path, dep) + if hasCycle, cycle := detectCycle(dep, graph, visited, newPath); hasCycle { + return true, cycle + } + } + + delete(visited, current) + return false, nil +} + +func topologicalSort(graph map[string][]string) []string { + // Simple topological sort for init order + var result []string + visited := make(map[string]bool) + temp := make(map[string]bool) + + var visit func(string) bool + visit = func(pkg string) bool { + if temp[pkg] { + return false // Cycle detected + } + if visited[pkg] { + return true + } + + temp[pkg] = true + for _, dep := range graph[pkg] { + if !visit(dep) { + return false + } + } + temp[pkg] = false + visited[pkg] = true + result = append([]string{pkg}, result...) // Prepend + return true + } + + var packages []string + for pkg := range graph { + packages = append(packages, pkg) + } + sort.Strings(packages) + + for _, pkg := range packages { + if !visited[pkg] { + visit(pkg) + } + } + + return result +} \ No newline at end of file diff --git a/tool_find_method_receivers.go b/tool_find_method_receivers.go new file mode 100644 index 0000000..619d54a --- /dev/null +++ b/tool_find_method_receivers.go @@ -0,0 +1,129 @@ +package main + +import ( + "go/ast" + "go/token" +) + +type MethodReceiver struct { + TypeName string `json:"type_name"` + MethodName string `json:"method_name"` + ReceiverType string `json:"receiver_type"` // "pointer" or "value" + ReceiverName string `json:"receiver_name"` + Position Position `json:"position"` +} + +type ReceiverAnalysis struct { + Methods []MethodReceiver `json:"methods"` + Issues []ReceiverIssue `json:"issues"` +} + +type ReceiverIssue struct { + Type string `json:"type"` + Description string `json:"description"` + Position Position `json:"position"` +} + +func findMethodReceivers(dir string) (*ReceiverAnalysis, error) { + analysis := &ReceiverAnalysis{ + Methods: []MethodReceiver{}, + Issues: []ReceiverIssue{}, + } + + typeReceivers := make(map[string]map[string]bool) // type -> receiver type -> exists + + err := walkGoFiles(dir, func(path string, src []byte, file *ast.File, fset *token.FileSet) error { + ast.Inspect(file, func(n ast.Node) bool { + if funcDecl, ok := n.(*ast.FuncDecl); ok && funcDecl.Recv != nil { + if len(funcDecl.Recv.List) > 0 { + recv := funcDecl.Recv.List[0] + var typeName string + var receiverType string + var receiverName string + + if len(recv.Names) > 0 { + receiverName = recv.Names[0].Name + } + + switch t := recv.Type.(type) { + case *ast.Ident: + typeName = t.Name + receiverType = "value" + case *ast.StarExpr: + if ident, ok := t.X.(*ast.Ident); ok { + typeName = ident.Name + receiverType = "pointer" + } + } + + if typeName != "" { + pos := fset.Position(funcDecl.Pos()) + method := MethodReceiver{ + TypeName: typeName, + MethodName: funcDecl.Name.Name, + ReceiverType: receiverType, + ReceiverName: receiverName, + Position: newPosition(pos), + } + analysis.Methods = append(analysis.Methods, method) + + // Track receiver types per type + if typeReceivers[typeName] == nil { + typeReceivers[typeName] = make(map[string]bool) + } + typeReceivers[typeName][receiverType] = true + } + } + } + return true + }) + return nil + }) + + if err != nil { + return nil, err + } + + // Analyze for inconsistencies + for typeName, receivers := range typeReceivers { + if receivers["pointer"] && receivers["value"] { + // Find all methods with this issue + for _, method := range analysis.Methods { + if method.TypeName == typeName { + issue := ReceiverIssue{ + Type: "mixed_receivers", + Description: "Type " + typeName + " has methods with both pointer and value receivers", + Position: method.Position, + } + analysis.Issues = append(analysis.Issues, issue) + break // Only report once per type + } + } + } + } + + // Check for methods that should use pointer receivers + for _, method := range analysis.Methods { + if method.ReceiverType == "value" && shouldUsePointerReceiver(method.MethodName) { + issue := ReceiverIssue{ + Type: "should_use_pointer", + Description: "Method " + method.MethodName + " on " + method.TypeName + " should probably use a pointer receiver", + Position: method.Position, + } + analysis.Issues = append(analysis.Issues, issue) + } + } + + return analysis, nil +} + +func shouldUsePointerReceiver(methodName string) bool { + // Methods that typically modify state should use pointer receivers + prefixes := []string{"Set", "Add", "Remove", "Delete", "Update", "Append", "Clear", "Reset"} + for _, prefix := range prefixes { + if len(methodName) > len(prefix) && methodName[:len(prefix)] == prefix { + return true + } + } + return false +} \ No newline at end of file diff --git a/tool_find_panic_recover.go b/tool_find_panic_recover.go new file mode 100644 index 0000000..fc93c37 --- /dev/null +++ b/tool_find_panic_recover.go @@ -0,0 +1,175 @@ +package main + +import ( + "go/ast" + "go/token" +) + +type PanicRecoverUsage struct { + Type string `json:"type"` // "panic" or "recover" + Position Position `json:"position"` + InDefer bool `json:"in_defer"` + Message string `json:"message,omitempty"` + Context string `json:"context"` +} + +type PanicRecoverAnalysis struct { + Usages []PanicRecoverUsage `json:"usages"` + Issues []PanicRecoverIssue `json:"issues"` +} + +type PanicRecoverIssue struct { + Type string `json:"type"` + Description string `json:"description"` + Position Position `json:"position"` +} + +func findPanicRecover(dir string) (*PanicRecoverAnalysis, error) { + analysis := &PanicRecoverAnalysis{ + Usages: []PanicRecoverUsage{}, + Issues: []PanicRecoverIssue{}, + } + + err := walkGoFiles(dir, func(path string, src []byte, file *ast.File, fset *token.FileSet) error { + // Track function boundaries and defer statements + var currentFunc *ast.FuncDecl + deferDepth := 0 + + ast.Inspect(file, func(n ast.Node) bool { + switch node := n.(type) { + case *ast.FuncDecl: + currentFunc = node + deferDepth = 0 + + case *ast.DeferStmt: + deferDepth++ + // Check for recover in defer + ast.Inspect(node, func(inner ast.Node) bool { + if call, ok := inner.(*ast.CallExpr); ok { + if ident, ok := call.Fun.(*ast.Ident); ok && ident.Name == "recover" { + pos := fset.Position(call.Pos()) + usage := PanicRecoverUsage{ + Type: "recover", + Position: newPosition(pos), + InDefer: true, + Context: extractContext(src, pos), + } + analysis.Usages = append(analysis.Usages, usage) + } + } + return true + }) + deferDepth-- + + case *ast.CallExpr: + if ident, ok := node.Fun.(*ast.Ident); ok { + pos := fset.Position(node.Pos()) + + switch ident.Name { + case "panic": + message := extractPanicMessage(node) + usage := PanicRecoverUsage{ + Type: "panic", + Position: newPosition(pos), + InDefer: deferDepth > 0, + Message: message, + Context: extractContext(src, pos), + } + analysis.Usages = append(analysis.Usages, usage) + + // Check if panic is in main or init + if currentFunc != nil && (currentFunc.Name.Name == "main" || currentFunc.Name.Name == "init") { + issue := PanicRecoverIssue{ + Type: "panic_in_main_init", + Description: "Panic in " + currentFunc.Name.Name + " function will crash the program", + Position: newPosition(pos), + } + analysis.Issues = append(analysis.Issues, issue) + } + + case "recover": + if deferDepth == 0 { + issue := PanicRecoverIssue{ + Type: "recover_outside_defer", + Description: "recover() called outside defer statement - it will always return nil", + Position: newPosition(pos), + } + analysis.Issues = append(analysis.Issues, issue) + } + + usage := PanicRecoverUsage{ + Type: "recover", + Position: newPosition(pos), + InDefer: deferDepth > 0, + Context: extractContext(src, pos), + } + analysis.Usages = append(analysis.Usages, usage) + } + } + } + return true + }) + + // Check for functions with panic but no recover + checkPanicWithoutRecover(file, fset, analysis) + + return nil + }) + + return analysis, err +} + +func extractPanicMessage(call *ast.CallExpr) string { + if len(call.Args) > 0 { + switch arg := call.Args[0].(type) { + case *ast.BasicLit: + return arg.Value + case *ast.Ident: + return arg.Name + case *ast.SelectorExpr: + return exprToString(arg) + default: + return "complex expression" + } + } + return "" +} + +func checkPanicWithoutRecover(file *ast.File, fset *token.FileSet, analysis *PanicRecoverAnalysis) { + // For each function, check if it has panic but no recover + ast.Inspect(file, func(n ast.Node) bool { + funcDecl, ok := n.(*ast.FuncDecl) + if !ok { + return true + } + + hasPanic := false + hasRecover := false + var panicPos token.Position + + ast.Inspect(funcDecl, func(inner ast.Node) bool { + if call, ok := inner.(*ast.CallExpr); ok { + if ident, ok := call.Fun.(*ast.Ident); ok { + if ident.Name == "panic" { + hasPanic = true + panicPos = fset.Position(call.Pos()) + } else if ident.Name == "recover" { + hasRecover = true + } + } + } + return true + }) + + if hasPanic && !hasRecover && funcDecl.Name.Name != "main" && funcDecl.Name.Name != "init" { + issue := PanicRecoverIssue{ + Type: "panic_without_recover", + Description: "Function " + funcDecl.Name.Name + " calls panic() but has no recover() - consider adding error handling", + Position: newPosition(panicPos), + } + analysis.Issues = append(analysis.Issues, issue) + } + + return true + }) +} \ No newline at end of file diff --git a/tool_find_reflection_usage.go b/tool_find_reflection_usage.go new file mode 100644 index 0000000..7aa59a0 --- /dev/null +++ b/tool_find_reflection_usage.go @@ -0,0 +1,279 @@ +package main + +import ( + "go/ast" + "go/token" +) + +type ReflectionUsage struct { + Type string `json:"type"` // "TypeOf", "ValueOf", "MethodByName", etc. + Target string `json:"target"` + Position Position `json:"position"` + Context string `json:"context"` +} + +type ReflectionAnalysis struct { + Usages []ReflectionUsage `json:"usages"` + Issues []ReflectionIssue `json:"issues"` +} + +type ReflectionIssue struct { + Type string `json:"type"` + Description string `json:"description"` + Position Position `json:"position"` +} + +func findReflectionUsage(dir string) (*ReflectionAnalysis, error) { + analysis := &ReflectionAnalysis{ + Usages: []ReflectionUsage{}, + Issues: []ReflectionIssue{}, + } + + err := walkGoFiles(dir, func(path string, src []byte, file *ast.File, fset *token.FileSet) error { + // Check if reflect package is imported + hasReflectImport := false + for _, imp := range file.Imports { + if imp.Path != nil && imp.Path.Value == `"reflect"` { + hasReflectImport = true + break + } + } + + if !hasReflectImport { + return nil + } + + ast.Inspect(file, func(n ast.Node) bool { + if callExpr, ok := n.(*ast.CallExpr); ok { + analyzeReflectCall(callExpr, file, fset, src, analysis) + } + return true + }) + + return nil + }) + + return analysis, err +} + +func analyzeReflectCall(call *ast.CallExpr, file *ast.File, fset *token.FileSet, src []byte, analysis *ReflectionAnalysis) { + sel, ok := call.Fun.(*ast.SelectorExpr) + if !ok { + return + } + + // Check if it's a reflect package call + ident, ok := sel.X.(*ast.Ident) + if !ok || ident.Name != "reflect" { + return + } + + pos := fset.Position(call.Pos()) + methodName := sel.Sel.Name + target := "" + if len(call.Args) > 0 { + target = exprToString(call.Args[0]) + } + + usage := ReflectionUsage{ + Type: methodName, + Target: target, + Position: newPosition(pos), + Context: extractContext(src, pos), + } + analysis.Usages = append(analysis.Usages, usage) + + // Analyze specific reflection patterns + switch methodName { + case "TypeOf", "ValueOf": + if isInLoop(file, call) { + issue := ReflectionIssue{ + Type: "reflection_in_loop", + Description: "reflect." + methodName + " called in loop - consider caching result", + Position: newPosition(pos), + } + analysis.Issues = append(analysis.Issues, issue) + } + + case "MethodByName", "FieldByName": + // These are particularly slow + issue := ReflectionIssue{ + Type: "slow_reflection", + Description: "reflect." + methodName + " is slow - consider caching or avoiding if possible", + Position: newPosition(pos), + } + analysis.Issues = append(analysis.Issues, issue) + + if isInLoop(file, call) { + issue := ReflectionIssue{ + Type: "slow_reflection_in_loop", + Description: "reflect." + methodName + " in loop is very inefficient", + Position: newPosition(pos), + } + analysis.Issues = append(analysis.Issues, issue) + } + + case "DeepEqual": + if isInHotPath(file, call) { + issue := ReflectionIssue{ + Type: "deep_equal_performance", + Description: "reflect.DeepEqual is expensive - consider custom comparison for hot paths", + Position: newPosition(pos), + } + analysis.Issues = append(analysis.Issues, issue) + } + + case "Copy", "AppendSlice", "MakeSlice", "MakeMap", "MakeChan": + // These allocate memory + if isInLoop(file, call) { + issue := ReflectionIssue{ + Type: "reflect_allocation_in_loop", + Description: "reflect." + methodName + " allocates memory in loop", + Position: newPosition(pos), + } + analysis.Issues = append(analysis.Issues, issue) + } + } + + // Check for unsafe reflection patterns + checkUnsafeReflection(call, file, fset, analysis) +} + +func checkUnsafeReflection(call *ast.CallExpr, file *ast.File, fset *token.FileSet, analysis *ReflectionAnalysis) { + // Look for patterns like Value.Interface() without type checking + ast.Inspect(file, func(n ast.Node) bool { + if selExpr, ok := n.(*ast.SelectorExpr); ok { + if selExpr.Sel.Name == "Interface" { + // Check if this is on a reflect.Value + if isReflectValueExpr(file, selExpr.X) { + // Check if result is used in type assertion without ok check + if parent := findParentNode(file, selExpr); parent != nil { + if typeAssert, ok := parent.(*ast.TypeAssertExpr); ok { + if !isUsedWithOkCheck(file, typeAssert) { + pos := fset.Position(typeAssert.Pos()) + issue := ReflectionIssue{ + Type: "unsafe_interface_conversion", + Description: "Type assertion on reflect.Value.Interface() without ok check", + Position: newPosition(pos), + } + analysis.Issues = append(analysis.Issues, issue) + } + } + } + } + } + } + return true + }) +} + +func isReflectValueExpr(file *ast.File, expr ast.Expr) bool { + // Simple heuristic - check if expression contains "reflect.Value" operations + switch e := expr.(type) { + case *ast.CallExpr: + if sel, ok := e.Fun.(*ast.SelectorExpr); ok { + if ident, ok := sel.X.(*ast.Ident); ok && ident.Name == "reflect" { + return sel.Sel.Name == "ValueOf" || sel.Sel.Name == "Indirect" + } + } + case *ast.Ident: + // Check if it's a variable of type reflect.Value + return isReflectValueVar(file, e.Name) + } + return false +} + +func isReflectValueVar(file *ast.File, varName string) bool { + var isValue bool + ast.Inspect(file, func(n ast.Node) bool { + switch node := n.(type) { + case *ast.ValueSpec: + for i, name := range node.Names { + if name.Name == varName { + if node.Type != nil { + if sel, ok := node.Type.(*ast.SelectorExpr); ok { + if ident, ok := sel.X.(*ast.Ident); ok { + isValue = ident.Name == "reflect" && sel.Sel.Name == "Value" + return false + } + } + } else if i < len(node.Values) { + // Check if assigned from reflect.ValueOf + if call, ok := node.Values[i].(*ast.CallExpr); ok { + if sel, ok := call.Fun.(*ast.SelectorExpr); ok { + if ident, ok := sel.X.(*ast.Ident); ok { + isValue = ident.Name == "reflect" && sel.Sel.Name == "ValueOf" + return false + } + } + } + } + } + } + } + return true + }) + return isValue +} + +func findParentNode(file *ast.File, target ast.Node) ast.Node { + var parent ast.Node + ast.Inspect(file, func(n ast.Node) bool { + // This is a simplified parent finder + switch node := n.(type) { + case *ast.TypeAssertExpr: + if node.X == target { + parent = node + return false + } + case *ast.CallExpr: + for _, arg := range node.Args { + if arg == target { + parent = node + return false + } + } + } + return true + }) + return parent +} + +func isInHotPath(file *ast.File, node ast.Node) bool { + // Check if node is in a function that looks like a hot path + var inHotPath bool + ast.Inspect(file, func(n ast.Node) bool { + if fn, ok := n.(*ast.FuncDecl); ok && containsNode(fn, node) { + // Check function name for common hot path patterns + name := fn.Name.Name + if name == "ServeHTTP" || name == "Handle" || name == "Process" || + name == "Execute" || name == "Run" || name == "Do" { + inHotPath = true + return false + } + // Check if function is called frequently (in loops) + if isFunctionCalledInLoop(file, fn.Name.Name) { + inHotPath = true + return false + } + } + return true + }) + return inHotPath +} + +func isFunctionCalledInLoop(file *ast.File, funcName string) bool { + var calledInLoop bool + ast.Inspect(file, func(n ast.Node) bool { + if call, ok := n.(*ast.CallExpr); ok { + if ident, ok := call.Fun.(*ast.Ident); ok && ident.Name == funcName { + if isInLoop(file, call) { + calledInLoop = true + return false + } + } + } + return true + }) + return calledInLoop +} \ No newline at end of file diff --git a/tool_find_type_assertions.go b/tool_find_type_assertions.go new file mode 100644 index 0000000..7f3cb44 --- /dev/null +++ b/tool_find_type_assertions.go @@ -0,0 +1,209 @@ +package main + +import ( + "go/ast" + "go/token" +) + +type TypeAssertion struct { + Expression string `json:"expression"` + TargetType string `json:"target_type"` + HasOkCheck bool `json:"has_ok_check"` + Position Position `json:"position"` + Context string `json:"context"` +} + +type TypeAssertionAnalysis struct { + Assertions []TypeAssertion `json:"assertions"` + Issues []TypeAssertionIssue `json:"issues"` +} + +type TypeAssertionIssue struct { + Type string `json:"type"` + Description string `json:"description"` + Position Position `json:"position"` +} + +func findTypeAssertions(dir string) (*TypeAssertionAnalysis, error) { + analysis := &TypeAssertionAnalysis{ + Assertions: []TypeAssertion{}, + Issues: []TypeAssertionIssue{}, + } + + err := walkGoFiles(dir, func(path string, src []byte, file *ast.File, fset *token.FileSet) error { + ast.Inspect(file, func(n ast.Node) bool { + switch node := n.(type) { + case *ast.TypeAssertExpr: + pos := fset.Position(node.Pos()) + assertion := TypeAssertion{ + Expression: exprToString(node.X), + TargetType: exprToString(node.Type), + HasOkCheck: false, // Will be updated if in assignment + Position: newPosition(pos), + Context: extractContext(src, pos), + } + + // Check if this assertion is used with ok check + hasOk := isUsedWithOkCheck(file, node) + assertion.HasOkCheck = hasOk + + analysis.Assertions = append(analysis.Assertions, assertion) + + // Report issue if no ok check + if !hasOk && !isInSafeContext(file, node) { + issue := TypeAssertionIssue{ + Type: "unsafe_type_assertion", + Description: "Type assertion without ok check may panic", + Position: newPosition(pos), + } + analysis.Issues = append(analysis.Issues, issue) + } + + case *ast.TypeSwitchStmt: + // Analyze type switch + analyzeTypeSwitch(node, fset, src, analysis) + } + return true + }) + + return nil + }) + + return analysis, err +} + +func isUsedWithOkCheck(file *ast.File, assertion *ast.TypeAssertExpr) bool { + var hasOk bool + + ast.Inspect(file, func(n ast.Node) bool { + switch node := n.(type) { + case *ast.AssignStmt: + // Check if assertion is on RHS and has 2 LHS values + for _, rhs := range node.Rhs { + if rhs == assertion && len(node.Lhs) >= 2 { + hasOk = true + return false + } + } + case *ast.ValueSpec: + // Check in var declarations + for _, value := range node.Values { + if value == assertion && len(node.Names) >= 2 { + hasOk = true + return false + } + } + } + return true + }) + + return hasOk +} + +func isInSafeContext(file *ast.File, assertion *ast.TypeAssertExpr) bool { + // Check if assertion is in a context where panic is acceptable + var isSafe bool + + ast.Inspect(file, func(n ast.Node) bool { + switch node := n.(type) { + case *ast.FuncDecl: + if containsNode(node, assertion) { + // Check if function has recover + hasRecover := false + ast.Inspect(node, func(inner ast.Node) bool { + if call, ok := inner.(*ast.CallExpr); ok { + if ident, ok := call.Fun.(*ast.Ident); ok && ident.Name == "recover" { + hasRecover = true + return false + } + } + return true + }) + if hasRecover { + isSafe = true + return false + } + } + case *ast.IfStmt: + // Check if assertion is guarded by type check + if containsNode(node, assertion) && isTypeCheckCondition(node.Cond) { + isSafe = true + return false + } + } + return true + }) + + return isSafe +} + +func isTypeCheckCondition(expr ast.Expr) bool { + // Check for _, ok := x.(Type) pattern in condition + switch e := expr.(type) { + case *ast.BinaryExpr: + // Check for ok == true or similar + if ident, ok := e.X.(*ast.Ident); ok && ident.Name == "ok" { + return true + } + if ident, ok := e.Y.(*ast.Ident); ok && ident.Name == "ok" { + return true + } + case *ast.Ident: + // Direct ok check + return e.Name == "ok" + } + return false +} + +func analyzeTypeSwitch(typeSwitch *ast.TypeSwitchStmt, fset *token.FileSet, src []byte, analysis *TypeAssertionAnalysis) { + pos := fset.Position(typeSwitch.Pos()) + + var expr string + hasDefault := false + caseCount := 0 + + // Extract the expression being switched on + switch assign := typeSwitch.Assign.(type) { + case *ast.AssignStmt: + if len(assign.Rhs) > 0 { + if typeAssert, ok := assign.Rhs[0].(*ast.TypeAssertExpr); ok { + expr = exprToString(typeAssert.X) + } + } + case *ast.ExprStmt: + if typeAssert, ok := assign.X.(*ast.TypeAssertExpr); ok { + expr = exprToString(typeAssert.X) + } + } + + // Count cases and check for default + for _, clause := range typeSwitch.Body.List { + if cc, ok := clause.(*ast.CaseClause); ok { + if cc.List == nil { + hasDefault = true + } else { + caseCount += len(cc.List) + } + } + } + + // Type switches are generally safe, but we can note them + assertion := TypeAssertion{ + Expression: expr, + TargetType: "type switch", + HasOkCheck: true, // Type switches are inherently safe + Position: newPosition(pos), + Context: extractContext(src, pos), + } + analysis.Assertions = append(analysis.Assertions, assertion) + + // Check for single-case type switch + if caseCount == 1 && !hasDefault { + issue := TypeAssertionIssue{ + Type: "single_case_type_switch", + Description: "Type switch with single case - consider using type assertion instead", + Position: newPosition(pos), + } + analysis.Issues = append(analysis.Issues, issue) + } +} \ No newline at end of file diff --git a/tool_go_common.go b/tool_go_common.go new file mode 100644 index 0000000..205a96b --- /dev/null +++ b/tool_go_common.go @@ -0,0 +1,25 @@ +package main + +import ( + "bytes" + "os/exec" +) + +// runCommand executes a command and returns stdout, stderr, exit code, and error +func runCommand(cmd *exec.Cmd) (stdout, stderr string, exitCode int, err error) { + var stdoutBuf, stderrBuf bytes.Buffer + cmd.Stdout = &stdoutBuf + cmd.Stderr = &stderrBuf + + err = cmd.Run() + + exitCode = 0 + if err != nil { + if exitErr, ok := err.(*exec.ExitError); ok { + exitCode = exitErr.ExitCode() + err = nil // Don't treat non-zero exit as error for go commands + } + } + + return stdoutBuf.String(), stderrBuf.String(), exitCode, err +} \ No newline at end of file diff --git a/tool_go_run.go b/tool_go_run.go new file mode 100644 index 0000000..8e91d07 --- /dev/null +++ b/tool_go_run.go @@ -0,0 +1,84 @@ +package main + +import ( + "context" + "os" + "os/exec" + "path/filepath" + "strings" + "time" +) + +type GoRunResult struct { + Stdout string `json:"stdout"` + Stderr string `json:"stderr"` + ExitCode int `json:"exit_code"` + Error string `json:"error,omitempty"` + Command string `json:"command"` + WorkDir string `json:"work_dir"` +} + +func goRun(path string, flags []string, timeout time.Duration) (*GoRunResult, error) { + // Resolve absolute path + absPath, err := filepath.Abs(path) + if err != nil { + return &GoRunResult{ + Error: err.Error(), + Command: "go run " + path, + }, nil + } + + // Determine working directory and target file/package + var workDir string + var target string + + // Check if path is a file or directory + info, err := os.Stat(absPath) + if err != nil { + return &GoRunResult{ + Error: err.Error(), + Command: "go run " + path, + }, nil + } + + if info.IsDir() { + // Running a package + workDir = absPath + target = "." + } else { + // Running a specific file + workDir = filepath.Dir(absPath) + target = filepath.Base(absPath) + } + + // Build command arguments + args := []string{"run"} + args = append(args, flags...) + args = append(args, target) + + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + + cmd := exec.CommandContext(ctx, "go", args...) + cmd.Dir = workDir + + stdout, stderr, exitCode, cmdErr := runCommand(cmd) + + result := &GoRunResult{ + Stdout: stdout, + Stderr: stderr, + ExitCode: exitCode, + Command: "go " + strings.Join(args, " "), + WorkDir: workDir, + } + + if cmdErr != nil { + if ctx.Err() == context.DeadlineExceeded { + result.Error = "execution timeout exceeded" + } else { + result.Error = cmdErr.Error() + } + } + + return result, nil +} \ No newline at end of file diff --git a/tool_gotest.go b/tool_gotest.go new file mode 100644 index 0000000..28da903 --- /dev/null +++ b/tool_gotest.go @@ -0,0 +1,106 @@ +package main + +import ( + "context" + "os" + "os/exec" + "path/filepath" + "strings" + "time" +) + +type GoTestResult struct { + Stdout string `json:"stdout"` + Stderr string `json:"stderr"` + ExitCode int `json:"exit_code"` + Error string `json:"error,omitempty"` + Command string `json:"command"` + WorkDir string `json:"work_dir"` + Passed bool `json:"passed"` + TestCount int `json:"test_count,omitempty"` +} + +func goTest(path string, flags []string, timeout time.Duration) (*GoTestResult, error) { + // Resolve absolute path + absPath, err := filepath.Abs(path) + if err != nil { + return &GoTestResult{ + Error: err.Error(), + Command: "go test " + path, + }, nil + } + + // Determine working directory and target + var workDir string + var target string + + // Check if path is a file or directory + info, err := os.Stat(absPath) + if err != nil { + return &GoTestResult{ + Error: err.Error(), + Command: "go test " + path, + }, nil + } + + if info.IsDir() { + // Testing a package + workDir = absPath + target = "." + } else { + // Testing a specific file (though go test typically works with packages) + workDir = filepath.Dir(absPath) + target = "." + } + + // Build command arguments + args := []string{"test"} + args = append(args, flags...) + args = append(args, target) + + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + + cmd := exec.CommandContext(ctx, "go", args...) + cmd.Dir = workDir + + stdout, stderr, exitCode, cmdErr := runCommand(cmd) + + result := &GoTestResult{ + Stdout: stdout, + Stderr: stderr, + ExitCode: exitCode, + Command: "go " + strings.Join(args, " "), + WorkDir: workDir, + Passed: exitCode == 0, + } + + // Try to extract test count from output + if strings.Contains(stdout, "PASS") || strings.Contains(stdout, "FAIL") { + result.TestCount = countTests(stdout) + } + + if cmdErr != nil { + if ctx.Err() == context.DeadlineExceeded { + result.Error = "execution timeout exceeded" + } else { + result.Error = cmdErr.Error() + } + } + + return result, nil +} + +// Helper function to count tests from go test output +func countTests(output string) int { + count := 0 + lines := strings.Split(output, "\n") + for _, line := range lines { + if strings.HasPrefix(strings.TrimSpace(line), "--- PASS:") || + strings.HasPrefix(strings.TrimSpace(line), "--- FAIL:") || + strings.HasPrefix(strings.TrimSpace(line), "--- SKIP:") { + count++ + } + } + return count +} \ No newline at end of file