Skip to main content
Glama
FreePeak

Multi Database MCP Server

querybuilder.go23 kB
package dbtools import ( "context" "fmt" "regexp" "strconv" "strings" "time" "github.com/FreePeak/db-mcp-server/pkg/db" "github.com/FreePeak/db-mcp-server/pkg/logger" "github.com/FreePeak/db-mcp-server/pkg/tools" ) // QueryComponents represents the components of a SQL query type QueryComponents struct { Select []string `json:"select"` From string `json:"from"` Joins []JoinClause `json:"joins"` Where []Condition `json:"where"` GroupBy []string `json:"groupBy"` Having []string `json:"having"` OrderBy []OrderBy `json:"orderBy"` Limit int `json:"limit"` Offset int `json:"offset"` } // JoinClause represents a SQL JOIN clause type JoinClause struct { Type string `json:"type"` Table string `json:"table"` On string `json:"on"` } // Condition represents a WHERE condition type Condition struct { Column string `json:"column"` Operator string `json:"operator"` Value string `json:"value"` Connector string `json:"connector"` } // OrderBy represents an ORDER BY clause type OrderBy struct { Column string `json:"column"` Direction string `json:"direction"` } // createQueryBuilderTool creates a tool for building and validating SQL queries func createQueryBuilderTool() *tools.Tool { return &tools.Tool{ Name: "dbQueryBuilder", Description: "Visual SQL query construction with syntax validation", Category: "database", InputSchema: tools.ToolInputSchema{ Type: "object", Properties: map[string]interface{}{ "action": map[string]interface{}{ "type": "string", "description": "Action to perform (validate, build, analyze)", "enum": []string{"validate", "build", "analyze"}, }, "query": map[string]interface{}{ "type": "string", "description": "SQL query to validate or analyze", }, "components": map[string]interface{}{ "type": "object", "description": "Query components for building a query", "properties": map[string]interface{}{ "select": map[string]interface{}{ "type": "array", "description": "Columns to select", "items": map[string]interface{}{ "type": "string", }, }, "from": map[string]interface{}{ "type": "string", "description": "Table to select from", }, "joins": map[string]interface{}{ "type": "array", "description": "Joins to include", "items": map[string]interface{}{ "type": "object", "properties": map[string]interface{}{ "type": map[string]interface{}{ "type": "string", "enum": []string{"inner", "left", "right", "full"}, }, "table": map[string]interface{}{ "type": "string", }, "on": map[string]interface{}{ "type": "string", }, }, }, }, "where": map[string]interface{}{ "type": "array", "description": "Where conditions", "items": map[string]interface{}{ "type": "object", "properties": map[string]interface{}{ "column": map[string]interface{}{ "type": "string", }, "operator": map[string]interface{}{ "type": "string", "enum": []string{"=", "!=", "<", ">", "<=", ">=", "LIKE", "IN", "NOT IN", "IS NULL", "IS NOT NULL"}, }, "value": map[string]interface{}{ "type": "string", }, "connector": map[string]interface{}{ "type": "string", "enum": []string{"AND", "OR"}, }, }, }, }, "groupBy": map[string]interface{}{ "type": "array", "description": "Columns to group by", "items": map[string]interface{}{ "type": "string", }, }, "having": map[string]interface{}{ "type": "array", "description": "Having conditions", "items": map[string]interface{}{ "type": "string", }, }, "orderBy": map[string]interface{}{ "type": "array", "description": "Columns to order by", "items": map[string]interface{}{ "type": "object", "properties": map[string]interface{}{ "column": map[string]interface{}{ "type": "string", }, "direction": map[string]interface{}{ "type": "string", "enum": []string{"ASC", "DESC"}, }, }, }, }, "limit": map[string]interface{}{ "type": "integer", "description": "Limit results", }, "offset": map[string]interface{}{ "type": "integer", "description": "Offset results", }, }, }, "timeout": map[string]interface{}{ "type": "integer", "description": "Execution timeout in milliseconds (default: 5000)", }, "database": map[string]interface{}{ "type": "string", "description": "Database ID to use (optional if only one database is configured)", }, }, Required: []string{"action", "database"}, }, Handler: handleQueryBuilder, } } // handleQueryBuilder handles the query builder tool execution func handleQueryBuilder(ctx context.Context, params map[string]interface{}) (interface{}, error) { // Check if database manager is initialized if dbManager == nil { return nil, fmt.Errorf("database manager not initialized") } // Extract parameters action, ok := getStringParam(params, "action") if !ok { return nil, fmt.Errorf("action parameter is required") } // Get database ID databaseID, ok := getStringParam(params, "database") if !ok { return nil, fmt.Errorf("database parameter is required") } // Get database instance db, err := dbManager.GetDatabase(databaseID) if err != nil { return nil, fmt.Errorf("failed to get database: %w", err) } // Extract query parameter query, _ := getStringParam(params, "query") // Extract components parameter var components QueryComponents if componentsMap, ok := params["components"].(map[string]interface{}); ok { // Parse components from map if err := parseQueryComponents(&components, componentsMap); err != nil { return nil, fmt.Errorf("failed to parse query components: %w", err) } } // Create context with timeout dbTimeout := db.QueryTimeout() * 1000 // Convert from seconds to milliseconds timeout := dbTimeout // Default to the database's query timeout if timeoutParam, ok := getIntParam(params, "timeout"); ok { timeout = timeoutParam } timeoutCtx, cancel := context.WithTimeout(ctx, time.Duration(timeout)*time.Millisecond) defer cancel() // Execute requested action switch action { case "validate": if query == "" { return nil, fmt.Errorf("query parameter is required for validate action") } return validateQuery(timeoutCtx, db, query) case "build": if err := validateQueryComponents(&components); err != nil { return nil, fmt.Errorf("invalid query components: %w", err) } builtQuery, err := buildQueryFromComponents(&components) if err != nil { return nil, fmt.Errorf("failed to build query: %w", err) } return validateQuery(timeoutCtx, db, builtQuery) case "analyze": if query == "" { return nil, fmt.Errorf("query parameter is required for analyze action") } return analyzeQueryPlan(timeoutCtx, db, query) default: return nil, fmt.Errorf("invalid action: %s", action) } } // parseQueryComponents parses query components from a map func parseQueryComponents(components *QueryComponents, data map[string]interface{}) error { // Parse SELECT columns if selectArr, ok := data["select"].([]interface{}); ok { components.Select = make([]string, len(selectArr)) for i, col := range selectArr { if str, ok := col.(string); ok { components.Select[i] = str } } } // Parse FROM table if from, ok := data["from"].(string); ok { components.From = from } // Parse JOINs if joinsArr, ok := data["joins"].([]interface{}); ok { components.Joins = make([]JoinClause, len(joinsArr)) for i, join := range joinsArr { if joinMap, ok := join.(map[string]interface{}); ok { if joinType, ok := joinMap["type"].(string); ok { components.Joins[i].Type = joinType } if table, ok := joinMap["table"].(string); ok { components.Joins[i].Table = table } if on, ok := joinMap["on"].(string); ok { components.Joins[i].On = on } } } } // Parse WHERE conditions if whereArr, ok := data["where"].([]interface{}); ok { components.Where = make([]Condition, len(whereArr)) for i, cond := range whereArr { if condMap, ok := cond.(map[string]interface{}); ok { if col, ok := condMap["column"].(string); ok { components.Where[i].Column = col } if op, ok := condMap["operator"].(string); ok { components.Where[i].Operator = op } if val, ok := condMap["value"].(string); ok { components.Where[i].Value = val } if conn, ok := condMap["connector"].(string); ok { components.Where[i].Connector = conn } } } } // Parse GROUP BY columns if groupByArr, ok := data["groupBy"].([]interface{}); ok { components.GroupBy = make([]string, len(groupByArr)) for i, col := range groupByArr { if str, ok := col.(string); ok { components.GroupBy[i] = str } } } // Parse HAVING conditions if havingArr, ok := data["having"].([]interface{}); ok { components.Having = make([]string, len(havingArr)) for i, cond := range havingArr { if str, ok := cond.(string); ok { components.Having[i] = str } } } // Parse ORDER BY clauses if orderByArr, ok := data["orderBy"].([]interface{}); ok { components.OrderBy = make([]OrderBy, len(orderByArr)) for i, order := range orderByArr { if orderMap, ok := order.(map[string]interface{}); ok { if col, ok := orderMap["column"].(string); ok { components.OrderBy[i].Column = col } if dir, ok := orderMap["direction"].(string); ok { components.OrderBy[i].Direction = dir } } } } // Parse LIMIT if limit, ok := data["limit"].(float64); ok { components.Limit = int(limit) } // Parse OFFSET if offset, ok := data["offset"].(float64); ok { components.Offset = int(offset) } return nil } // validateQueryComponents validates query components func validateQueryComponents(components *QueryComponents) error { if components.From == "" { return fmt.Errorf("FROM clause is required") } if len(components.Select) == 0 { return fmt.Errorf("SELECT clause must have at least one column") } for _, join := range components.Joins { if join.Table == "" { return fmt.Errorf("JOIN clause must have a table") } if join.On == "" { return fmt.Errorf("JOIN clause must have an ON condition") } } for _, where := range components.Where { if where.Column == "" { return fmt.Errorf("WHERE condition must have a column") } if where.Operator == "" { return fmt.Errorf("WHERE condition must have an operator") } } for _, order := range components.OrderBy { if order.Column == "" { return fmt.Errorf("ORDER BY clause must have a column") } if order.Direction != "ASC" && order.Direction != "DESC" { return fmt.Errorf("ORDER BY direction must be ASC or DESC") } } return nil } // buildQueryFromComponents builds a SQL query from components func buildQueryFromComponents(components *QueryComponents) (string, error) { var query strings.Builder // Build SELECT clause query.WriteString("SELECT ") query.WriteString(strings.Join(components.Select, ", ")) // Build FROM clause query.WriteString(" FROM ") query.WriteString(components.From) // Build JOIN clauses for _, join := range components.Joins { query.WriteString(" ") query.WriteString(strings.ToUpper(join.Type)) query.WriteString(" JOIN ") query.WriteString(join.Table) query.WriteString(" ON ") query.WriteString(join.On) } // Build WHERE clause if len(components.Where) > 0 { query.WriteString(" WHERE ") for i, cond := range components.Where { if i > 0 { query.WriteString(" ") query.WriteString(cond.Connector) query.WriteString(" ") } query.WriteString(cond.Column) query.WriteString(" ") query.WriteString(cond.Operator) if cond.Value != "" { query.WriteString(" ") query.WriteString(cond.Value) } } } // Build GROUP BY clause if len(components.GroupBy) > 0 { query.WriteString(" GROUP BY ") query.WriteString(strings.Join(components.GroupBy, ", ")) } // Build HAVING clause if len(components.Having) > 0 { query.WriteString(" HAVING ") query.WriteString(strings.Join(components.Having, " AND ")) } // Build ORDER BY clause if len(components.OrderBy) > 0 { query.WriteString(" ORDER BY ") var orders []string for _, order := range components.OrderBy { orders = append(orders, order.Column+" "+order.Direction) } query.WriteString(strings.Join(orders, ", ")) } // Build LIMIT clause if components.Limit > 0 { query.WriteString(fmt.Sprintf(" LIMIT %d", components.Limit)) } // Build OFFSET clause if components.Offset > 0 { query.WriteString(fmt.Sprintf(" OFFSET %d", components.Offset)) } return query.String(), nil } // validateQuery validates a SQL query for syntax errors func validateQuery(ctx context.Context, db db.Database, query string) (interface{}, error) { // Validate query by attempting to execute it with EXPLAIN explainQuery := "EXPLAIN " + query _, err := db.Query(ctx, explainQuery) if err != nil { return map[string]interface{}{ "valid": false, "error": err.Error(), "query": query, }, nil } return map[string]interface{}{ "valid": true, "query": query, }, nil } // analyzeQueryPlan analyzes a specific query for performance func analyzeQueryPlan(ctx context.Context, db db.Database, query string) (interface{}, error) { explainQuery := "EXPLAIN (FORMAT JSON, ANALYZE, BUFFERS) " + query rows, err := db.Query(ctx, explainQuery) if err != nil { return nil, fmt.Errorf("failed to analyze query: %w", err) } defer func() { if err := rows.Close(); err != nil { logger.Error("error closing rows: %v", err) } }() var plan []byte if !rows.Next() { return nil, fmt.Errorf("no explain plan returned") } if err := rows.Scan(&plan); err != nil { return nil, fmt.Errorf("failed to scan explain plan: %w", err) } return map[string]interface{}{ "query": query, "plan": string(plan), }, nil } // Helper function to calculate query complexity func calculateQueryComplexity(query string) string { query = strings.ToUpper(query) // Count common complexity factors joins := strings.Count(query, " JOIN ") subqueries := strings.Count(query, "SELECT") - 1 // Subtract the main query if subqueries < 0 { subqueries = 0 } aggregations := strings.Count(query, " SUM(") + strings.Count(query, " COUNT(") + strings.Count(query, " AVG(") + strings.Count(query, " MIN(") + strings.Count(query, " MAX(") groupBy := strings.Count(query, " GROUP BY ") orderBy := strings.Count(query, " ORDER BY ") having := strings.Count(query, " HAVING ") distinct := strings.Count(query, " DISTINCT ") unions := strings.Count(query, " UNION ") // Calculate complexity score - adjusted to match test expectations score := joins*2 + (subqueries * 3) + aggregations + groupBy + orderBy + having*2 + distinct + unions*3 // Check special cases that should be complex if joins >= 3 || (joins >= 2 && subqueries >= 1) || (subqueries >= 1 && aggregations >= 1) { return "Complex" } // Determine complexity level if score <= 2 { return "Simple" } else if score <= 6 { return "Moderate" } else { return "Complex" } } // Helper functions to extract error information from error messages func getSuggestionForError(errorMsg string) string { errorMsg = strings.ToLower(errorMsg) if strings.Contains(errorMsg, "syntax error") { return "Check SQL syntax for errors such as missing keywords, incorrect operators, or unmatched parentheses" } else if strings.Contains(errorMsg, "unknown column") { return "Column name is incorrect or doesn't exist in the specified table" } else if strings.Contains(errorMsg, "unknown table") { return "Table name is incorrect or doesn't exist in the database" } else if strings.Contains(errorMsg, "ambiguous") { return "Column name is ambiguous. Qualify it with the table name" } else if strings.Contains(errorMsg, "missing") && strings.Contains(errorMsg, "from") { return "FROM clause is missing or incorrectly formatted" } else if strings.Contains(errorMsg, "no such table") { return "Table specified does not exist in the database" } return "Review the query syntax and structure" } // extractLineNumberFromError extracts line number from a database error message // //nolint:unused // Used in future implementation func extractLineNumberFromError(errMsg string) int { // Check for line number patterns like "at line 42" or "line 42" linePatterns := []string{ "at line ([0-9]+)", "line ([0-9]+)", "LINE ([0-9]+)", } for _, pattern := range linePatterns { lineMatch := regexp.MustCompile(pattern).FindStringSubmatch(errMsg) if len(lineMatch) > 1 { lineNum, scanErr := strconv.Atoi(lineMatch[1]) if scanErr != nil { logger.Warn("Failed to parse line number: %v", scanErr) continue } return lineNum } } return 0 } // extractPositionFromError extracts position from a database error message // //nolint:unused // Used in future implementation func extractPositionFromError(errMsg string) int { // Check for position patterns posPatterns := []string{ "at character ([0-9]+)", "position ([0-9]+)", "at or near \"([^\"]+)\"", } for _, pattern := range posPatterns { posMatch := regexp.MustCompile(pattern).FindStringSubmatch(errMsg) if len(posMatch) > 1 { // For "at or near X" patterns, need to find X in the query if strings.Contains(pattern, "at or near") { return 0 // Just return 0 for now } // For numeric positions pos, scanErr := strconv.Atoi(posMatch[1]) if scanErr != nil { logger.Warn("Failed to parse position: %v", scanErr) continue } return pos } } return 0 } // Mock functions for use when database is not available // mockValidateQuery provides mock validation of SQL queries func mockValidateQuery(query string) (interface{}, error) { query = strings.TrimSpace(query) // Basic syntax checks for demonstration purposes if !strings.HasPrefix(strings.ToUpper(query), "SELECT") { return map[string]interface{}{ "valid": false, "query": query, "error": "Query must start with SELECT", "suggestion": "Begin your query with the SELECT keyword", "errorLine": 1, "errorColumn": 1, }, nil } if !strings.Contains(strings.ToUpper(query), " FROM ") { return map[string]interface{}{ "valid": false, "query": query, "error": "Missing FROM clause", "suggestion": "Add a FROM clause to specify the table or view to query", "errorLine": 1, "errorColumn": len("SELECT"), }, nil } // Check for unbalanced parentheses if strings.Count(query, "(") != strings.Count(query, ")") { return map[string]interface{}{ "valid": false, "query": query, "error": "Unbalanced parentheses", "suggestion": "Ensure all opening parentheses have matching closing parentheses", "errorLine": 1, "errorColumn": 0, }, nil } // Check for unclosed quotes if strings.Count(query, "'")%2 != 0 { return map[string]interface{}{ "valid": false, "query": query, "error": "Unclosed string literal", "suggestion": "Ensure all string literals are properly closed with matching quotes", "errorLine": 1, "errorColumn": 0, }, nil } // Query appears valid return map[string]interface{}{ "valid": true, "query": query, }, nil } // mockAnalyzeQuery provides mock analysis of SQL queries func mockAnalyzeQuery(query string) (interface{}, error) { query = strings.ToUpper(query) // Mock analysis results var issues []string var suggestions []string // Check for potential performance issues if !strings.Contains(query, " WHERE ") { issues = append(issues, "Query has no WHERE clause") suggestions = append(suggestions, "Add a WHERE clause to filter results and improve performance") } // Check for multiple joins joinCount := strings.Count(query, " JOIN ") if joinCount > 1 { issues = append(issues, "Query contains multiple joins") suggestions = append(suggestions, "Multiple joins can impact performance. Consider denormalizing or using indexed columns") } if strings.Contains(query, " LIKE '%") || strings.Contains(query, "% LIKE") { issues = append(issues, "Query uses LIKE with leading wildcard") suggestions = append(suggestions, "Leading wildcards in LIKE conditions cannot use indexes. Consider alternative approaches") } if strings.Contains(query, " ORDER BY ") && !strings.Contains(query, " LIMIT ") { issues = append(issues, "ORDER BY without LIMIT") suggestions = append(suggestions, "Consider adding a LIMIT clause to prevent sorting large result sets") } // Create a mock explain plan mockExplainPlan := []map[string]interface{}{ { "id": 1, "select_type": "SIMPLE", "table": getTableFromQuery(query), "type": "ALL", "possible_keys": nil, "key": nil, "key_len": nil, "ref": nil, "rows": 1000, "Extra": "", }, } // If the query has a WHERE clause, assume it might use an index if strings.Contains(query, " WHERE ") { mockExplainPlan[0]["type"] = "range" mockExplainPlan[0]["possible_keys"] = "PRIMARY" mockExplainPlan[0]["key"] = "PRIMARY" mockExplainPlan[0]["key_len"] = 4 mockExplainPlan[0]["rows"] = 100 } return map[string]interface{}{ "query": query, "explainPlan": mockExplainPlan, "issues": issues, "suggestions": suggestions, "complexity": calculateQueryComplexity(query), "is_mock": true, }, nil } // Helper function to extract table name from a query func getTableFromQuery(query string) string { queryUpper := strings.ToUpper(query) // Try to find the table name after FROM fromIndex := strings.Index(queryUpper, " FROM ") if fromIndex == -1 { return "unknown_table" } // Get the text after FROM afterFrom := query[fromIndex+6:] afterFromUpper := queryUpper[fromIndex+6:] // Find the end of the table name (next space, comma, or parenthesis) endIndex := len(afterFrom) for i, char := range afterFromUpper { if char == ' ' || char == ',' || char == '(' || char == ')' { endIndex = i break } } tableName := strings.TrimSpace(afterFrom[:endIndex]) // If there's an alias, remove it tableNameParts := strings.Split(tableName, " AS ") if len(tableNameParts) > 1 { return tableNameParts[0] } return tableName }

Latest Blog Posts

MCP directory API

We provide all the information about MCP servers via our MCP API.

curl -X GET 'https://glama.ai/api/mcp/v1/servers/FreePeak/db-mcp-server'

If you have feedback or need assistance with the MCP directory API, please join our Discord server