# Text-to-SQL Agent Architecture & Extension Guide
This document explains how the LangGraph text-to-SQL agent works and how to extend it with new tools and nodes.
## Table of Contents
1. [Architecture Overview](#architecture-overview)
2. [Workflow Graph](#workflow-graph)
3. [State Management](#state-management)
4. [Helper Methods and Features](#helper-methods-and-features)
5. [Adding New Nodes](#adding-new-nodes)
6. [Adding New Tools](#adding-new-tools)
7. [Modifying the Workflow](#modifying-the-workflow)
8. [Examples](#examples)
## Architecture Overview
The text-to-SQL agent is built using **LangGraph**, which provides a state machine-based approach to building agents. The agent follows a graph-based workflow where each node performs a specific task, and edges control the flow between nodes.
### Key Design Principles
1. **Node-Edge Separation**:
- **Nodes** are pure state processors - they take state, process it, return updates
- **Edges** handle all orchestration logic - they make routing decisions based on state flags
- No conditional logic inside nodes (all moved to edges)
2. **Single Responsibility**: Each node has one clear purpose
3. **State-Driven Routing**: Edges check state flags to decide next step
4. **Extensibility**: Easy to add new nodes/paths without modifying existing code
### Key Components
1. **AgentState**: TypedDict that holds the agent's state throughout execution, including decision flags for edges
2. **Nodes**: Functions that perform specific tasks (explore schema, generate SQL, refine SQL, execute query, etc.)
3. **Edges**: Connections between nodes that control workflow
4. **Conditional Edges**: Routes to different nodes based on state conditions (orchestration logic)
5. **Tools**: MCP tools that interact with the database
## Workflow Graph
### Visual Workflow Diagram
```mermaid
graph TD
Start([User Question]) --> ExploreSchema[explore_schema<br/>- Check cache<br/>- List tables<br/>- Identify relevant tables<br/>- Describe structures<br/>- Fetch foreign keys<br/>- Set schema_complete flag]
ExploreSchema -->|schema_complete=true| GenerateSQL[generate_sql<br/>- Validate query<br/>- Build prompt with schema<br/>- Chain-of-thought reasoning<br/>- Call LLM<br/>- Extract & validate SQL<br/>- Execute test query<br/>- Calculate confidence & analysis<br/>- Set refinement flags]
ExploreSchema -->|schema_complete=false| ExploreSchema
GenerateSQL -->|should_refine=true| RefineSQL[refine_sql<br/>- Check max_refinements<br/>- Fetch missing schema if needed<br/>- Refine SQL using analysis<br/>- Re-execute test query<br/>- Recalculate confidence<br/>- Check if needs more refinement]
GenerateSQL -->|should_refine=false, has SQL| ExecuteQuery[execute_query<br/>- Use stored SQL<br/>- Reuse test results if possible<br/>- Run full query if needed]
GenerateSQL -->|Tool calls| Tools[tools<br/>- Process calls<br/>- Execute tools<br/>- Return results]
GenerateSQL -->|No SQL| End1([END])
RefineSQL -->|should_refine=true, attempts < max| RefineSQL
RefineSQL -->|should_refine=false OR max reached| ExecuteQuery
RefineSQL -->|Tool calls| Tools
RefineSQL -->|No SQL| End1
Tools --> GenerateSQL
ExecuteQuery --> CheckResult{Check Result}
CheckResult -->|Success| End2([END - Success])
CheckResult -->|Retry needed<br/>attempts < max| RefineQuery[refine_query<br/>- Parse error<br/>- Extract SQL from state<br/>- Pass error context<br/>- Update schema if needed]
CheckResult -->|Max attempts reached| End3([END - Error])
RefineQuery --> GenerateSQL
style Start fill:#e1f5ff
style End1 fill:#ffe1e1
style End2 fill:#e1ffe1
style End3 fill:#ffe1e1
style ExploreSchema fill:#fff4e1
style GenerateSQL fill:#fff4e1
style RefineSQL fill:#ffe1f4
style ExecuteQuery fill:#fff4e1
style RefineQuery fill:#ffe1f4
style Tools fill:#e1e1ff
style CheckResult fill:#e1e1ff
```
### Workflow Conditions Explained
#### 1. Schema Exploration Loop
**Edge Function**: `_should_continue_schema_exploration()` (lines ~2026-2032)
**Condition**: `schema_complete` flag
- **`schema_complete=true`** → Routes to `generate_sql`
- **`schema_complete=false`** → Routes back to `explore_schema` (loops until complete)
**When `schema_complete` is set**:
- Set to `true` when all relevant table descriptions and foreign keys are fetched
- Set to `false` if schema fetching fails or is incomplete
---
#### 2. After SQL Generation
**Edge Function**: `_should_refine_sql()` (lines ~2034-2061)
**Checks** (in order):
1. **`should_refine` flag** (set by `generate_sql` node):
- **`should_refine=true`** → Routes to `refine_sql`
- **`should_refine=false`** → Continue to next checks
2. **Tool calls** (if last message is AIMessage with tool_calls):
- Routes to `tools` node
3. **SQL presence** (if last message contains SQL):
- Routes to `execute_query`
4. **No SQL found**:
- Routes to `END`
**When `should_refine` is set to `true`**:
- Confidence < 0.6 (LOW_CONFIDENCE_THRESHOLD)
- Query execution error exists
- SQL syntax is invalid
- Critical issues detected in analysis
**When `should_refine` is set to `false`** (even if above conditions exist):
- Confidence ≥ 0.9 (VERY_HIGH_CONFIDENCE_THRESHOLD) with valid SQL and no error
- OR Confidence ≥ 0.8 (HIGH_CONFIDENCE_THRESHOLD) with valid SQL and no previous error
---
#### 3. After SQL Refinement
**Edge Function**: `_should_refine_sql()` (same as above, lines ~2034-2061)
**Checks** (in order):
1. **`should_refine` flag** (set by `refine_sql` node):
- **`should_refine=true` AND `refinement_attempts < max_refinements`** → Routes back to `refine_sql` (loops)
- **`should_refine=false` OR `refinement_attempts >= max_refinements`** → Continue to next checks
2. **Tool calls** (if last message is AIMessage with tool_calls):
- Routes to `tools` node
3. **SQL presence** (if last message contains SQL):
- Routes to `execute_query`
4. **No SQL found**:
- Routes to `END`
**When `should_refine` is set to `true` in `refine_sql`**:
- Confidence < 0.6 after refinement
- Query execution error exists after refinement
- SQL syntax is invalid after refinement
- Critical issues detected in analysis after refinement
**Max Refinements Limit**:
- Default: 3 refinements (`DEFAULT_MAX_REFINEMENTS = 3`)
- Configurable via `max_refinements` parameter in `__init__`
- Prevents infinite refinement loops
---
#### 4. After Query Execution
**Edge Function**: `_check_query_result()` (lines ~2222-2242)
**Checks**:
1. **Last message contains "successfully"**:
- Routes to `END` (success)
2. **Last message contains "error" or "failed"**:
- **`query_attempts < max_attempts`** → Routes to `refine_query` (retry)
- **`query_attempts >= max_attempts`** → Routes to `END` (error, give up)
3. **Default** (if unclear):
- Routes to `END` (assume success)
**Max Attempts Limit**:
- Default: 3 attempts (`DEFAULT_MAX_ATTEMPTS = 3`)
- Configurable via `max_query_attempts` parameter in `__init__`
- Counts the full cycle: `generate_sql` → `refine_sql` → `execute_query` → `refine_query` → `generate_sql`
---
### All Possible Workflow Flows
#### Flow 1: Simple Success (No Refinement Needed)
```
User Question
→ explore_schema (schema_complete=true)
→ generate_sql (confidence ≥ 0.8, should_refine=false)
→ execute_query (success)
→ END (Success)
```
#### Flow 2: Single Refinement Success
```
User Question
→ explore_schema (schema_complete=true)
→ generate_sql (confidence < 0.6, should_refine=true)
→ refine_sql (confidence ≥ 0.8, should_refine=false)
→ execute_query (success)
→ END (Success)
```
#### Flow 3: Multiple Refinements Success
```
User Question
→ explore_schema (schema_complete=true)
→ generate_sql (confidence < 0.6, should_refine=true)
→ refine_sql (attempt 1, confidence < 0.6, should_refine=true)
→ refine_sql (attempt 2, confidence < 0.6, should_refine=true)
→ refine_sql (attempt 3, confidence ≥ 0.8, should_refine=false)
→ execute_query (success)
→ END (Success)
```
#### Flow 4: Max Refinements Reached
```
User Question
→ explore_schema (schema_complete=true)
→ generate_sql (confidence < 0.6, should_refine=true)
→ refine_sql (attempt 1, confidence < 0.6, should_refine=true)
→ refine_sql (attempt 2, confidence < 0.6, should_refine=true)
→ refine_sql (attempt 3, confidence < 0.6, max_refinements reached)
→ execute_query (proceeds despite low confidence)
→ END (Success or Error)
```
#### Flow 5: Execution Error with Retry
```
User Question
→ explore_schema (schema_complete=true)
→ generate_sql (confidence ≥ 0.8, should_refine=false)
→ execute_query (error, query_attempts=1 < 3)
→ refine_query (parse error, update schema)
→ generate_sql (confidence ≥ 0.8, should_refine=false)
→ execute_query (success)
→ END (Success)
```
#### Flow 6: Execution Error - Max Attempts Reached
```
User Question
→ explore_schema (schema_complete=true)
→ generate_sql (confidence ≥ 0.8, should_refine=false)
→ execute_query (error, query_attempts=1 < 3)
→ refine_query
→ generate_sql (confidence ≥ 0.8, should_refine=false)
→ execute_query (error, query_attempts=2 < 3)
→ refine_query
→ generate_sql (confidence ≥ 0.8, should_refine=false)
→ execute_query (error, query_attempts=3 >= 3)
→ END (Error - Max attempts reached)
```
#### Flow 7: Tool Calls
```
User Question
→ explore_schema (schema_complete=true)
→ generate_sql (LLM requests tool call)
→ tools (execute tool)
→ generate_sql (continue with tool results)
→ execute_query (success)
→ END (Success)
```
#### Flow 8: Invalid Query (No SQL Generated)
```
User Question
→ explore_schema (schema_complete=true)
→ generate_sql (no SQL extracted, should_refine=false)
→ END (No SQL)
```
#### Flow 9: Schema Exploration Loop
```
User Question
→ explore_schema (schema_complete=false)
→ explore_schema (schema_complete=false)
→ explore_schema (schema_complete=true)
→ generate_sql
→ ...
```
#### Flow 10: Complex Flow (Refinement + Execution Error + Retry)
```
User Question
→ explore_schema (schema_complete=true)
→ generate_sql (confidence < 0.6, should_refine=true)
→ refine_sql (attempt 1, confidence < 0.6, should_refine=true)
→ refine_sql (attempt 2, confidence ≥ 0.8, should_refine=false)
→ execute_query (error, query_attempts=1 < 3)
→ refine_query (parse error, update schema)
→ generate_sql (confidence ≥ 0.8, should_refine=false)
→ execute_query (success)
→ END (Success)
```
---
### Edge Decision Functions
#### `_should_continue_schema_exploration(state: AgentState) -> Literal["complete", "continue"]`
**Code Location**: Lines ~2026-2032
**Logic**:
```python
return "complete" if state.get("schema_complete", False) else "continue"
```
**Routes**:
- `"complete"` → `generate_sql`
- `"continue"` → `explore_schema` (loops)
---
#### `_should_refine_sql(state: AgentState) -> Literal["refine", "execute", "use_tools", "end"]`
**Code Location**: Lines ~2034-2061
**Logic** (checks in order):
1. If `should_refine = True` → return `"refine"`
2. If last message is AIMessage with tool_calls → return `"use_tools"`
3. If last message contains SQL → return `"execute"`
4. Otherwise → return `"end"`
**Routes**:
- `"refine"` → `refine_sql`
- `"execute"` → `execute_query`
- `"use_tools"` → `tools`
- `"end"` → `END`
**Used by**: Both `generate_sql` and `refine_sql` nodes (after refinement)
**Important Notes**:
- The same edge function is used for both `generate_sql` and `refine_sql` nodes
- After `refine_sql`, if `should_refine=true` and `refinement_attempts < max_refinements`, it routes back to `refine_sql` (creates a loop)
- The loop continues until confidence is high enough OR max refinements is reached
---
#### `_check_query_result(state: AgentState) -> Literal["success", "retry", "error"]`
**Code Location**: Lines ~2222-2242
**Logic**:
1. Check last message (ToolMessage from `execute_query`)
2. If contains "successfully" → return `"success"`
3. If contains "error" or "failed":
- If `query_attempts < max_attempts` → return `"retry"`
- If `query_attempts >= max_attempts` → return `"error"`
4. Default → return `"success"`
**Routes**:
- `"success"` → `END`
- `"retry"` → `refine_query`
- `"error"` → `END`
---
### State Flags Used for Routing
| Flag | Set By | Used By | Purpose |
|------|--------|---------|---------|
| `schema_complete` | `explore_schema` | `_should_continue_schema_exploration` | Indicates if schema exploration is complete |
| `should_refine` | `generate_sql`, `refine_sql` | `_should_refine_sql` | Indicates if SQL needs refinement |
| `refinement_attempts` | `refine_sql` | `_refine_sql_node` | Counts refinement cycles (max: 3) |
| `query_attempts` | `generate_sql` | `_check_query_result` | Counts full query cycles (max: 3) |
| `sql_is_valid` | `generate_sql`, `refine_sql` | `generate_sql`, `refine_sql` | Indicates if SQL syntax is valid |
| `has_critical_issues` | `generate_sql`, `refine_sql` | `generate_sql`, `refine_sql` | Indicates if critical issues detected |
| `confidence` | `generate_sql`, `refine_sql` | `generate_sql`, `refine_sql` | Confidence score (0-1) for routing decisions |
---
### Complete Flow Examples
#### Example 1: Perfect First Attempt
**Query**: "Show me all employees"
**Flow**:
1. `explore_schema` → `schema_complete=true`
2. `generate_sql` → confidence=0.95, `should_refine=false`
3. `execute_query` → success
4. `END` (Success)
**Total Nodes**: 3 nodes executed
---
#### Example 2: Single Refinement Needed
**Query**: "Find top 2 employees per department"
**Flow**:
1. `explore_schema` → `schema_complete=true`
2. `generate_sql` → confidence=0.50, `should_refine=true` (low confidence)
3. `refine_sql` (attempt 1) → confidence=0.85, `should_refine=false` (improved)
4. `execute_query` → success
5. `END` (Success)
**Total Nodes**: 4 nodes executed
---
#### Example 3: Multiple Refinements
**Query**: "Complex multi-part query with window functions"
**Flow**:
1. `explore_schema` → `schema_complete=true`
2. `generate_sql` → confidence=0.40, `should_refine=true`
3. `refine_sql` (attempt 1) → confidence=0.55, `should_refine=true` (still low)
4. `refine_sql` (attempt 2) → confidence=0.70, `should_refine=true` (still low)
5. `refine_sql` (attempt 3) → confidence=0.82, `should_refine=false` (finally good)
6. `execute_query` → success
7. `END` (Success)
**Total Nodes**: 6 nodes executed (3 refinement cycles)
---
#### Example 4: Max Refinements Reached
**Query**: "Very complex query that's hard to get right"
**Flow**:
1. `explore_schema` → `schema_complete=true`
2. `generate_sql` → confidence=0.35, `should_refine=true`
3. `refine_sql` (attempt 1) → confidence=0.45, `should_refine=true`
4. `refine_sql` (attempt 2) → confidence=0.50, `should_refine=true`
5. `refine_sql` (attempt 3) → confidence=0.55, `should_refine=true` BUT `refinement_attempts=3 >= max_refinements`
6. `execute_query` → proceeds despite low confidence
7. `END` (Success or Error depending on execution)
**Total Nodes**: 6 nodes executed (max refinements reached)
---
#### Example 5: Execution Error with Retry
**Query**: "Show employees with salary > 50000"
**Flow**:
1. `explore_schema` → `schema_complete=true`
2. `generate_sql` → confidence=0.90, `should_refine=false`
3. `execute_query` → error: "Unknown column 'salary'" (query_attempts=1 < 3)
4. `refine_query` → parse error, update schema, pass error context
5. `generate_sql` → confidence=0.88, `should_refine=false` (uses correct column name)
6. `execute_query` → success
7. `END` (Success)
**Total Nodes**: 5 nodes executed (1 retry cycle)
---
#### Example 6: Execution Error - Max Attempts
**Query**: "Query that keeps failing"
**Flow**:
1. `explore_schema` → `schema_complete=true`
2. `generate_sql` → confidence=0.85, `should_refine=false`
3. `execute_query` → error (query_attempts=1 < 3)
4. `refine_query` → parse error, update schema
5. `generate_sql` → confidence=0.80, `should_refine=false`
6. `execute_query` → error (query_attempts=2 < 3)
7. `refine_query` → parse error, update schema
8. `generate_sql` → confidence=0.75, `should_refine=false`
9. `execute_query` → error (query_attempts=3 >= 3)
10. `END` (Error - Max attempts reached)
**Total Nodes**: 9 nodes executed (3 full retry cycles)
---
#### Example 7: Refinement + Execution Error
**Query**: "Complex query that needs refinement and has execution errors"
**Flow**:
1. `explore_schema` → `schema_complete=true`
2. `generate_sql` → confidence=0.45, `should_refine=true`
3. `refine_sql` (attempt 1) → confidence=0.65, `should_refine=true`
4. `refine_sql` (attempt 2) → confidence=0.82, `should_refine=false`
5. `execute_query` → error: "Unknown column" (query_attempts=1 < 3)
6. `refine_query` → parse error, update schema
7. `generate_sql` → confidence=0.88, `should_refine=false`
8. `execute_query` → success
9. `END` (Success)
**Total Nodes**: 8 nodes executed (2 refinements + 1 retry)
---
#### Example 8: Tool Calls
**Query**: "Query that requires additional tool usage"
**Flow**:
1. `explore_schema` → `schema_complete=true`
2. `generate_sql` → LLM requests tool call (e.g., `describe_table`)
3. `tools` → execute tool, return results
4. `generate_sql` → confidence=0.90, `should_refine=false` (now has tool results)
5. `execute_query` → success
6. `END` (Success)
**Total Nodes**: 5 nodes executed (includes tool usage)
---
#### Example 9: Invalid Query
**Query**: "sns" (nonsensical input)
**Flow**:
1. `explore_schema` → `schema_complete=true`
2. `generate_sql` → `_validate_query_is_database_question()` returns False
3. `END` (No SQL - query validation failed)
**Total Nodes**: 2 nodes executed
---
#### Example 10: Schema Exploration Loop
**Query**: "Query requiring schema that needs multiple fetches"
**Flow**:
1. `explore_schema` → `schema_complete=false` (some tables missing)
2. `explore_schema` → `schema_complete=false` (still missing)
3. `explore_schema` → `schema_complete=true` (all tables fetched)
4. `generate_sql` → confidence=0.88, `should_refine=false`
5. `execute_query` → success
6. `END` (Success)
**Total Nodes**: 5 nodes executed (3 schema exploration cycles)
---
### Flow Summary Statistics
| Flow Type | Min Nodes | Max Nodes | Typical Nodes |
|-----------|-----------|-----------|---------------|
| Simple Success | 3 | 3 | 3 |
| Single Refinement | 4 | 4 | 4 |
| Multiple Refinements | 5 | 6 | 5-6 |
| Execution Error (1 retry) | 5 | 5 | 5 |
| Execution Error (max retries) | 9 | 9 | 9 |
| Refinement + Error | 6 | 8 | 7 |
| Tool Calls | 5 | 7 | 5-6 |
| Invalid Query | 2 | 2 | 2 |
| Schema Loop | 4 | 6 | 4-5 |
**Note**: Node counts include `explore_schema` and `END` nodes.
### Detailed Node Descriptions
#### 1. `explore_schema` Node
**Purpose**: Discover and cache database schema information intelligently
**Code Location**: `_explore_schema_node()` method (lines ~1084-1395)
**What it does**:
1. **Schema Caching Check**: Uses cached schema data if available (lines ~185-197)
- Checks `self.schema_cache` for previously fetched table descriptions
- Returns early if schema info is already available
2. **Table List Retrieval**: Gets list of all tables (lines ~209-226)
- Uses cached table list if available, otherwise calls `list_tables` MCP tool
- Caches table list in `self.schema_cache["tables"]` for future use
3. **Table Resources**: Gets table resources to extract table names (lines ~228-243)
- Uses cached resources if available, otherwise calls `list_table_resources` MCP tool
- Extracts table names and schema name from resources
- Caches resources in `self.schema_cache["table_resources"]`
4. **Intelligent Table Selection**: Identifies relevant tables using LLM (lines ~245-250)
- Calls `_identify_relevant_tables()` (lines ~129-153) to analyze user query
- Uses LLM to determine which tables are needed for the query
- Falls back to all tables if identification fails
5. **Table Descriptions**: Fetches table structures in parallel (lines ~1250-1370)
- Validates user_query is not None/empty (returns early if missing)
- Checks cache first - only fetches descriptions for tables not in cache
- Uses `describe_table` MCP tool with timeout protection for each relevant table
- Executes all description fetches in parallel using `asyncio.gather()`
- Handles individual table fetch failures gracefully (continues with others)
- Merges cached and newly fetched descriptions
- Safe table name extraction from resources (handles malformed strings)
6. **Foreign Key Relationships**: Fetches FK info in parallel (lines ~319-341)
- Uses `get_foreign_keys` MCP tool for each relevant table
- Executes all FK fetches in parallel using `asyncio.gather()`
- Stores FK relationships in `schema_info["foreign_keys"]`
7. **Schema Caching**: Updates cache with all fetched data (lines ~343-354)
- Merges new descriptions with existing cache
- Stores table descriptions, table names, and all tables in cache
- Cache persists across queries in the same agent instance
**Key Features**:
- **Intelligent Table Selection**: Only explores tables relevant to the query (not all tables)
- **Foreign Key Relationships**: Fetches FK info to understand table relationships for JOINs
- **Sample Data Inspection**: Can fetch sample rows (though currently not used in this node)
- **Parallel Execution**: All fetches (descriptions, FKs) happen concurrently
- **Incremental Caching**: Only fetches data not already in cache
- **Cache Merging**: Preserves all cached data when adding new data
**Input State**: `messages`, `schema_info` (empty initially), `relevant_tables` (empty initially)
**Output State**:
- `schema_info` (populated with: `tables`, `table_descriptions`, `foreign_keys`, `table_names`, `all_tables`)
- `relevant_tables` (list of tables identified as relevant)
- `schema_complete` (flag indicating if schema exploration is complete - used by edge for routing)
**Edge Routing**:
- After this node, `_should_continue_schema_exploration()` edge function checks `schema_complete` flag
- If `true`: routes to `generate_sql`
- If `false`: routes back to `explore_schema` to fetch missing schema
**Performance Optimizations**:
- Early exit if all relevant tables are cached (uses set for O(1) lookups)
- Parallel execution of table descriptions and foreign key fetches
- Incremental caching (only fetches missing table descriptions)
- Skips LLM call for table identification if ≤3 tables
**Helper Methods Used**:
- `_identify_relevant_tables()` (lines ~197-221): Uses LLM to identify relevant tables (skips for ≤3 tables)
- `_get_sample_data()` (lines ~223-244): Fetches sample rows from a table (supports `row_limit` parameter)
#### 2. `generate_sql` Node
**Purpose**: Convert natural language to SQL using LLM with confidence scoring
**Code Location**: `_generate_sql_node()` method (lines ~1228-1609)
**What it does**:
1. **Query Validation** (first attempt only):
- Validates user_query is not None/empty before processing
- Calls `_validate_query_is_database_question()` to check if query is valid
- Uses fast heuristics (keyword matching) first, then LLM only for ambiguous cases
- Rejects gibberish, greetings, and non-database questions early
- Returns error message if query is invalid (skips SQL generation)
- Handles validation exceptions gracefully
2. **Build System Prompt**:
- Includes database schema information
- Lists available tables
- Shows relevant tables for the query
- Includes table structures with exact column names
- Includes foreign key relationships for JOIN understanding
- Extracts and lists column names explicitly
- Includes error context if retrying (from `previous_error`)
- **Window Function Guidance**: Provides comprehensive guidance on when and how to use window functions (ROW_NUMBER(), RANK(), DENSE_RANK(), LAG(), LEAD(), SUM() OVER, COUNT() OVER, etc.) for ranking, comparison, and aggregation within groups
- **Tie Handling**: Guides LLM to handle ties correctly for "most/least" queries (returns ALL entities with max/min value, not just one)
2. **Chain-of-Thought Reasoning**:
- Prompts LLM to think step-by-step:
1. Identify which tables are needed
2. Determine what columns to select
3. Figure out JOIN conditions
4. Add any filters or aggregations
5. Consider window functions for ranking/comparison/aggregation within groups
6. Write the final SQL query
3. **Generate SQL**:
- Calls LLM with system prompt and chain-of-thought prompt
- Gets SQL response from LLM
4. **SQL Extraction and Cleanup**:
- Removes markdown code blocks (```sql, ```)
- Extracts SQL from code blocks or text
- Auto-fixes incomplete queries (adds missing SELECT keyword)
- Removes parameterized query placeholders (?)
5. **Execute Test Query**:
- Executes query with LIMIT to get sample results
- Uses `run_query_json` MCP tool
- Captures query results or error messages
6. **Calculate Confidence Score and Analysis**:
- Calls `_score_and_analyze_query()` (single LLM call optimization)
- **Explicitly checks if SQL answers the question** (not just syntax correctness)
- Uses query structure, actual results, and error messages
- Returns confidence score (0-1) and analysis text
- Low confidence (< 0.6) if query doesn't answer the question
7. **Validate SQL Syntax**:
- Calls `_validate_sql_syntax()`
- Checks if query starts with SELECT and has proper structure
8. **Check for Critical Issues**:
- Calls `_has_critical_issues()`
- Detects missing keywords, syntax errors, incomplete queries
9. **Set Refinement Flags** (for edge routing):
- Determines if refinement is needed based on:
- Query error exists
- Confidence < 0.6
- SQL is invalid
- Critical issues detected
- Sets `should_refine`, `refine_reason`, and other flags in state
10. **Build Response Message**:
- Includes confidence score
- Includes SQL query
- Includes sample results (if available) or error message
- Always includes analysis/reasoning
**Key Features**:
- **Chain-of-Thought Reasoning**: Structured step-by-step query generation
- **Foreign Key Integration**: FK relationships in prompt for better JOINs
- **Column Names Summary**: Explicit list of exact column names
- **Window Function Guidance**: Comprehensive prompts guide LLM to use window functions (ROW_NUMBER(), RANK(), DENSE_RANK(), LAG(), LEAD(), SUM() OVER, etc.) for appropriate scenarios like "top N per group", rankings within groups, comparing rows, running totals, percentiles, moving averages
- **Tie Handling**: Prompts ensure "most/least" queries return ALL entities with max/min value, not just one
- **SQL Validation**: Validates syntax and auto-fixes simple issues
- **Confidence Scoring**: Based on actual query execution results
- **Edge-Based Refinement Decision**: Sets flags for edge to route to refinement node
**Important**: This node does NOT perform refinement. It only sets flags. The `refine_sql` node handles actual refinement.
**Input State**: `messages`, `schema_info`, `query_attempts`, `previous_error` (optional), `previous_sql` (optional)
**Output State**:
- `messages` (with SQL, confidence score, sample results, analysis)
- `query_attempts` (incremented)
- `final_sql` (stored SQL to avoid re-extraction)
- `test_query_results` (stored test results if successful)
- `should_refine` (flag for edge routing)
- `refine_reason` (reason for refinement if needed)
- `confidence`, `sql_is_valid`, `has_critical_issues`, `confidence_reasoning`, `query_error` (flags for edge)
**Edge Routing**:
- After this node, `_should_refine_sql()` edge function checks `should_refine` flag
- If `true`: routes to `refine_sql`
- If `false` and has SQL: routes to `execute_query`
- If tool calls: routes to `tools`
- If no SQL: routes to `END`
**Helper Methods Used**:
- `_validate_query_is_database_question()` (lines ~368-442): Validates if user query is a database question (with timeout protection, handles None/empty)
- `_call_llm_with_timeout()` (lines ~164-193): Calls LLM with timeout protection (validates timeout > 0)
- `_call_tool_with_timeout()` (lines ~195-224): Calls MCP tools with timeout protection (validates timeout > 0)
- `_score_and_analyze_query()` (lines ~444-603): Calculates confidence and analysis in single LLM call (checks if SQL answers question, handles timeouts, validates regex groups)
- `_validate_sql_syntax()` (lines ~583-602): Validates SQL syntax (handles None/empty)
- `_has_critical_issues()` (lines ~604-638): Detects critical issues (handles None/empty)
- `_manage_column_cache_size()` (lines ~156-162): Manages column cache size with LRU eviction
- `_extract_sql_from_messages()` (lines ~1988-2020): Extracts SQL from messages (validates content, handles empty results)
#### 3. `refine_sql` Node
**Purpose**: Refine SQL query based on confidence score and analysis. Can loop back to itself if confidence is still low (up to max_refinements).
**Code Location**: `_refine_sql_node()` method (lines ~1908-2070)
**What it does**:
1. **Get SQL and Analysis from State**:
- Retrieves `final_sql` from state (or extracts from messages)
- Gets `confidence_reasoning`, `query_error`, `test_query_results` from state
2. **Fetch Missing Schema** (if needed):
- Calls `_fetch_missing_table_info()` to get schema for new tables mentioned in errors/analysis
3. **Refine SQL**:
- Calls `_refine_sql_with_analysis()` to generate corrected SQL
- Uses analysis/reasoning and error context
4. **Re-execute Test Query**:
- Executes refined query with LIMIT to get new results
- Captures new query results or error messages
5. **Recalculate Confidence**:
- Calls `_score_and_analyze_query()` for refined query
- Gets updated confidence score and analysis
6. **Validate Refined SQL**:
- Validates syntax and checks for critical issues
7. **Update State**:
- Updates message with refined SQL and new confidence
- Sets `should_refine = False` to prevent further refinement loops
**Key Features**:
- **Separate Refinement Logic**: Handles all refinement in dedicated node
- **Schema Re-fetching**: Gets missing schema info automatically
- **Window Function Guidance**: Refinement prompts include guidance on using window functions when appropriate
- **Re-validation**: Re-executes and re-analyzes refined query
- **Refinement Loop**: Can loop back to itself if confidence is still low (up to max_refinements, default: 3)
- **Max Refinements Protection**: Prevents infinite loops by limiting refinement cycles
**Input State**: `messages`, `schema_info`, `final_sql`, `confidence_reasoning`, `query_error`, `test_query_results`, `refine_reason`, `refinement_attempts`
**Output State**:
- `messages` (with refined SQL, new confidence, analysis)
- `final_sql` (refined SQL)
- `test_query_results` (new test results)
- `confidence`, `sql_is_valid`, `has_critical_issues`, `confidence_reasoning`, `query_error` (updated flags)
- `should_refine` (True if confidence still low, False if confidence is high or max refinements reached)
- `refinement_attempts` (incremented by 1)
**Edge Routing**:
- After this node, `_should_refine_sql()` edge function checks `should_refine` flag and `refinement_attempts`
- If `should_refine=true` AND `refinement_attempts < max_refinements`: routes back to `refine_sql` (loops)
- If `should_refine=false` OR `refinement_attempts >= max_refinements`: routes to `execute_query`
- If tool calls: routes to `tools`
- If no SQL: routes to `END`
**Edge Routing**:
- Always routes to `execute_query` after refinement
**Helper Methods Used**:
- `_fetch_missing_table_info()` (lines ~670-775): Fetches schema for new tables (with timeout protection, safe regex extraction)
- `_refine_sql_with_analysis()` (lines ~777-948): Refines SQL using analysis (with timeout protection, handles empty responses, validates extracted SQL)
- `_call_llm_with_timeout()` (lines ~164-193): Calls LLM with timeout protection (validates timeout > 0)
- `_call_tool_with_timeout()` (lines ~195-224): Calls MCP tools with timeout protection (validates timeout > 0)
- `_score_and_analyze_query()` (lines ~444-603): Recalculates confidence after refinement (with timeout protection, validates regex groups)
#### 4. `execute_query` Node
**Purpose**: Execute the generated SQL query
**Code Location**: `_execute_query_node()` method (lines ~1780-1865)
**What it does**:
1. **Get SQL from State**:
- Uses stored SQL from `final_sql` in state (avoids re-extraction)
- Falls back to extracting SQL from messages if not in state
2. **Reuse Test Results** (optimization):
- Checks if test query results are available and contain all data
- If original query has LIMIT ≤ TEST_QUERY_LIMIT, reuses test results
- **Respects original LIMIT clause** - only returns as many rows as requested
- Example: If original SQL has `LIMIT 1`, returns only 1 row even if test query fetched 3 rows
- Avoids redundant query execution
3. **Execute Query**:
- Uses `run_query` MCP tool if full execution needed
- Formats result as markdown
- Handles errors gracefully
**Input State**: `messages` (with SQL), `final_sql` (optional), `test_query_results` (optional)
**Output State**: `messages` (with execution result)
**Performance Optimizations**:
- Uses stored SQL from state (`final_sql`) to avoid re-extraction
- Reuses test query results if they contain all data (avoids redundant execution)
- Respects original LIMIT when reusing results (prevents returning more rows than requested)
- Only executes full query when test results are insufficient
- Safe LIMIT parsing with try/except (handles non-numeric values)
**Error Handling**:
- Timeout protection for database queries (prevents hanging, validates timeout > 0)
- Handles tool unavailability gracefully
- Validates query execution result is not None
- Safe LIMIT value parsing (handles ValueError/TypeError)
- Comprehensive error messages for debugging
#### 5. `refine_query` Node
**Purpose**: Improve query based on error feedback with intelligent error parsing
**Code Location**: `_refine_query_node()` method (lines ~1889-1966)
**What it does**:
1. **Extract Error Message**:
- Finds error message from ToolMessage in message history
- Looks for messages containing "error" in content
2. **Parse SQL Error**:
- Calls `_parse_sql_error()` to extract:
- Error type (unknown_column, unknown_table, syntax_error, etc.)
- Column names (handles `column` and `table.column` formats)
- Table names
- Original error message
3. **Re-fetch Schema** (if column errors):
- If error type is "unknown_column", re-fetches table descriptions
- Uses `describe_table` MCP tool to get fresh schema info
- Updates `schema_info["table_descriptions"]` with fresh data
4. **Build Error Guidance**:
- Provides targeted guidance based on error type:
- Unknown column: Shows which column doesn't exist and suggests checking column names
- Unknown table: Shows which table doesn't exist
- Syntax error: Provides syntax error guidance
- Includes parsed error details in refinement message
5. **Prepare Refinement Message**:
- Creates HumanMessage with error context
- Includes parsed error details and guidance
- Updates schema_info if refreshed
**Key Features**:
- **Intelligent Error Parsing**: Extracts actionable information from error messages
- **Targeted Error Guidance**: Specific guidance based on error type
- **Schema Refresh**: Re-fetches schema when column errors detected
- **Error Type Detection**: Identifies unknown columns, unknown tables, syntax errors
**Input State**: `messages`, `schema_info`, `previous_sql` (optional)
**Output State**:
- `messages` (with retry message)
- `schema_info` (possibly updated with fresh descriptions)
- `previous_error` (error message for generate_sql context)
- `previous_sql` (previous SQL for generate_sql context)
**Edge Routing**:
- Always routes to `generate_sql` to regenerate with error context
**Helper Methods Used**:
- `_parse_sql_error()` (lines ~720-785): Parses SQL error messages to extract error type, column names, table names
- `_extract_sql_from_messages()` (lines ~1276-1295): Extracts SQL from messages
#### 6. `tools` Node
**Purpose**: Handle tool calls from LLM
**What it does**:
- Processes tool calls from AI message
- Executes tools via MCP client
- Returns tool results
**Input State**: `messages` (with tool calls)
**Output State**: `messages` (with tool results)
**Code Location**: `_tools_node()` method
## State Management
### AgentState Structure
```python
class AgentState(TypedDict):
messages: Annotated[list[BaseMessage], add_messages] # Conversation history
schema_info: dict # Cached schema information (includes table_descriptions, foreign_keys, sample_data, relevant_tables)
query_attempts: int # Number of query attempts
max_attempts: int # Maximum retry attempts
relevant_tables: list # Tables identified as relevant to the query (populated by explore_schema)
previous_error: Optional[str] # Optional: previous query error for refinement context
previous_sql: Optional[str] # Optional: previous SQL query for refinement context
final_sql: Optional[str] # Optional: final SQL query to avoid re-extraction
test_query_results: Optional[list] # Optional: test query results to avoid re-execution
```
**Schema Info Structure**:
- `table_descriptions`: Dict mapping table names to their DESCRIBE output
- `foreign_keys`: Dict mapping table names to their foreign key relationships
- `sample_data`: Dict mapping table names to sample row data (lists of dicts)
- `table_names`: List of all table names
- `all_tables`: List of all available tables
- `relevant_tables`: List of tables identified as relevant to current query
- `tables`: Raw table list from MCP tool
### State Flow
1. **Initial State**: User question → `messages` contains HumanMessage
2. **Schema Exploration**: `schema_info` gets populated, `schema_complete` flag set
3. **SQL Generation**: `messages` gets AIMessage with SQL, `query_attempts` increments, refinement flags set
4. **SQL Refinement** (if needed): `should_refine=true` → `refine_sql` node refines SQL, updates flags
5. **Execution**: `messages` gets ToolMessage with result
6. **Error Recovery** (if error): `messages` gets HumanMessage with error context, routes back to `generate_sql`
## Adding New Nodes
### Step-by-Step Guide
#### Step 1: Create the Node Function
```python
async def _your_new_node(self, state: AgentState) -> dict:
"""
Your node description
Args:
state: Current agent state
Returns:
dict: State updates to apply
"""
messages = state["messages"]
schema_info = state.get("schema_info", {})
# Your node logic here
result = await self._do_something()
# Return state updates
return {
"messages": [AIMessage(content=str(result))],
# Add other state updates as needed
}
```
#### Step 2: Add Node to Graph
In `_build_graph()` method:
```python
def _build_graph(self) -> StateGraph:
workflow = StateGraph(AgentState)
# Add your new node
workflow.add_node("your_node_name", self._your_new_node)
# ... rest of graph setup
```
#### Step 3: Connect Node with Edges
```python
# Add edge from another node to your node
workflow.add_edge("previous_node", "your_node_name")
# Or add conditional edge
workflow.add_conditional_edges(
"previous_node",
self._should_go_to_your_node, # Decision function
{
"yes": "your_node_name",
"no": "other_node"
}
)
# Add edge from your node to next node
workflow.add_edge("your_node_name", "next_node")
```
#### Step 4: Create Decision Function (if using conditional edge)
```python
def _should_go_to_your_node(self, state: AgentState) -> Literal["yes", "no"]:
"""Decide whether to route to your node"""
# Check state conditions
if some_condition:
return "yes"
return "no"
```
### Example: Adding a "Validate SQL" Node
```python
# Step 1: Create the node
async def _validate_sql_node(self, state: AgentState) -> dict:
"""Validate SQL syntax before execution"""
messages = state["messages"]
# Extract SQL from last AI message
last_msg = messages[-1] if messages else None
sql = self._extract_sql(last_msg)
# Validate using EXPLAIN
try:
run_query_tool = await self._get_tool("run_query")
result = await run_query_tool.ainvoke({
"input": {
"sql": f"EXPLAIN {sql}",
"format": "json"
}
})
return {
"messages": [ToolMessage(
content="SQL validation passed",
tool_call_id="validation_success"
)]
}
except Exception as e:
return {
"messages": [ToolMessage(
content=f"SQL validation failed: {str(e)}",
tool_call_id="validation_error"
)]
}
# Step 2: Add to graph
def _build_graph(self) -> StateGraph:
workflow = StateGraph(AgentState)
# ... existing nodes ...
workflow.add_node("validate_sql", self._validate_sql_node)
# Step 3: Connect it
workflow.add_edge("generate_sql", "validate_sql")
workflow.add_edge("validate_sql", "execute_query")
return workflow.compile()
```
## Adding New Tools
### Step-by-Step Guide
#### Step 1: Add Tool to MCP Server
In `mysql-db-server.py`:
```python
@mcp.tool()
def your_new_tool(param1: str, param2: Optional[int] = None) -> str:
"""Description of what your tool does"""
# Your tool implementation
result = do_something(param1, param2)
return result
```
#### Step 2: Use Tool in Agent
The tool is automatically available via `get_tools()`. Use it in any node:
```python
async def _your_node(self, state: AgentState) -> dict:
# Get the tool
your_tool = await self._get_tool("your_new_tool")
if your_tool:
# Call the tool
result = await your_tool.ainvoke({
"param1": "value",
"param2": 123
})
return {
"messages": [ToolMessage(
content=str(result),
tool_call_id="your_tool_call"
)]
}
```
### Example: Adding a "Get Table Statistics" Tool
```python
# In mysql-db-server.py
@mcp.tool()
def get_table_stats(table_name: str, db_schema: Optional[str] = None) -> str:
"""Get statistics about a table (row count, size, etc.)"""
db = get_connection()
schema = db_schema or db.allowed_schema
try:
with db.cursor() as cur:
cur.execute(f"SELECT COUNT(*) as row_count FROM `{schema}`.`{table_name}`")
count = cur.fetchone()
cur.execute(f"""
SELECT
table_rows,
data_length,
index_length
FROM information_schema.tables
WHERE table_schema = '{schema}' AND table_name = '{table_name}'
""")
stats = cur.fetchone()
return json.dumps({
"row_count": count["row_count"],
"estimated_rows": stats["table_rows"],
"data_size": stats["data_length"],
"index_size": stats["index_length"]
}, indent=2)
finally:
db.close()
# In text_to_sql_agent.py - use in a node
async def _analyze_tables_node(self, state: AgentState) -> dict:
"""Analyze table statistics"""
schema_info = state.get("schema_info", {})
table_names = schema_info.get("table_names", [])
get_stats_tool = await self._get_tool("get_table_stats")
stats = {}
if get_stats_tool:
for table_name in table_names:
try:
result = await get_stats_tool.ainvoke({
"table_name": table_name
})
stats[table_name] = result
except Exception:
continue
schema_info["table_stats"] = stats
return {"schema_info": schema_info}
```
## Modifying the Workflow
### Adding a New Path
To add a new execution path:
```python
def _build_graph(self) -> StateGraph:
workflow = StateGraph(AgentState)
# ... existing nodes ...
# Add new conditional routing
workflow.add_conditional_edges(
"generate_sql",
self._route_after_generation,
{
"execute": "execute_query",
"validate": "validate_sql", # New path
"use_tools": "tools",
"end": END
}
)
return workflow.compile()
def _route_after_generation(self, state: AgentState) -> Literal["execute", "validate", "use_tools", "end"]:
"""Route after SQL generation"""
# Add validation step for complex queries
messages = state["messages"]
last_msg = messages[-1] if messages else None
if last_msg and isinstance(last_msg, AIMessage):
sql = self._extract_sql(last_msg)
# Route to validation for complex queries
if "JOIN" in sql.upper() or "GROUP BY" in sql.upper():
return "validate"
# ... other routing logic
```
### Inserting a Node in the Middle
```python
# Original: generate_sql -> execute_query
# New: generate_sql -> validate_sql -> execute_query
def _build_graph(self) -> StateGraph:
workflow = StateGraph(AgentState)
# ... existing setup ...
# Remove old edge
# workflow.add_edge("generate_sql", "execute_query") # Remove this
# Add new path
workflow.add_edge("generate_sql", "validate_sql")
workflow.add_edge("validate_sql", "execute_query")
return workflow.compile()
```
## Examples
### Example 1: Add Query Explanation Node
```python
async def _explain_query_node(self, state: AgentState) -> dict:
"""Generate explanation of what the SQL does"""
messages = state["messages"]
# Get the SQL
sql = None
for msg in reversed(messages):
if isinstance(msg, AIMessage) and "sql" in msg.content.lower():
sql = self._extract_sql(msg)
break
if not sql:
return {}
# Generate explanation
explain_prompt = f"""Explain what this SQL query does in simple terms:
{sql}
Explanation:"""
response = await self.llm.ainvoke([HumanMessage(content=explain_prompt)])
return {
"messages": [AIMessage(content=f"Query Explanation:\n{response.content}")]
}
# Add to graph after successful execution
workflow.add_node("explain_query", self._explain_query_node)
workflow.add_conditional_edges(
"execute_query",
self._should_explain,
{
"explain": "explain_query",
"done": END
}
)
def _should_explain(self, state: AgentState) -> Literal["explain", "done"]:
"""Decide if we should explain the query"""
# Only explain on first successful query
query_attempts = state.get("query_attempts", 0)
if query_attempts == 1:
return "explain"
return "done"
```
### Example 2: Add Result Summarization Node
```python
async def _summarize_results_node(self, state: AgentState) -> dict:
"""Convert SQL results to natural language summary"""
messages = state["messages"]
# Get original question
user_query = None
for msg in messages:
if isinstance(msg, HumanMessage):
user_query = msg.content
break
# Get query results
results = None
for msg in reversed(messages):
if isinstance(msg, ToolMessage) and "successfully" in msg.content.lower():
# Extract JSON results
import json
try:
results = json.loads(msg.content.split("\n\n")[1])
except:
results = msg.content
break
if not results or not user_query:
return {}
# Generate summary
summary_prompt = f"""The user asked: "{user_query}"
The query returned these results:
{json.dumps(results[:5], indent=2) if isinstance(results, list) else str(results)}
Provide a natural language summary in 1-2 sentences."""
response = await self.llm.ainvoke([HumanMessage(content=summary_prompt)])
return {
"messages": [AIMessage(content=f"Summary: {response.content}")]
}
# Add to graph
workflow.add_node("summarize_results", self._summarize_results_node)
workflow.add_edge("execute_query", "summarize_results") # After successful execution
workflow.add_edge("summarize_results", END)
```
### Example 3: Schema Caching
Schema caching is implemented in the agent. See the `explore_schema` node description above for details. The `schema_cache` is initialized in `__init__()` (line ~47) and populated in `_initialize_tools()` (lines ~64-81) and `_explore_schema_node()` (lines ~343-354). It stores:
- Table lists and resources (static data)
- Table descriptions (incremental caching)
- Foreign key relationships
- Sample data
The cache persists across queries in the same agent instance, significantly improving performance for subsequent queries.
## Common Patterns
### Pattern 1: Pre-Processing Node
A node that runs before the main logic to prepare data.
```python
async def _preprocess_node(self, state: AgentState) -> dict:
"""Prepare data before main processing"""
messages = state["messages"]
# Extract and normalize user query
user_query = self._extract_user_query(messages)
normalized = self._normalize_query(user_query)
return {
"messages": [HumanMessage(content=normalized)]
}
```
### Pattern 2: Post-Processing Node
A node that runs after main logic to format results.
```python
async def _format_results_node(self, state: AgentState) -> dict:
"""Format results for better presentation"""
messages = state["messages"]
# Get results and format them
results = self._extract_results(messages)
formatted = self._format_as_table(results)
return {
"messages": [AIMessage(content=formatted)]
}
```
### Pattern 3: Validation Node
A node that validates data before proceeding.
```python
async def _validate_node(self, state: AgentState) -> dict:
"""Validate state before proceeding"""
# Check if required data exists
if not self._has_required_data(state):
return {
"messages": [AIMessage(content="Error: Missing required data")]
}
return {} # No changes, continue
```
## Debugging Tips
### 1. Add Logging to Nodes
```python
import logging
logger = logging.getLogger("text_to_sql_agent")
async def _your_node(self, state: AgentState) -> dict:
logger.info(f"Entering your_node with state: {state}")
# ... your logic ...
logger.info(f"Exiting your_node with updates: {updates}")
return updates
```
### 2. Inspect State Between Nodes
```python
def _debug_state(self, state: AgentState) -> dict:
"""Debug node to inspect state"""
print(f"Current state:")
print(f" Messages: {len(state.get('messages', []))}")
print(f" Schema info keys: {list(state.get('schema_info', {}).keys())}")
print(f" Query attempts: {state.get('query_attempts', 0)}")
return {} # Don't modify state
```
### 3. Test Individual Nodes
```python
# Test a node in isolation
async def test_node():
agent = TextToSQLAgent(...)
test_state = {
"messages": [HumanMessage(content="test question")],
"schema_info": {},
"query_attempts": 0,
"max_attempts": 3
}
result = await agent._your_node(test_state)
print(result)
```
## Best Practices
1. **Keep Nodes Focused**: Each node should do one thing well
2. **Return State Updates**: Always return a dict with state updates
3. **Handle Errors Gracefully**: Use try/except in nodes
4. **Use Type Hints**: Helps with debugging and IDE support
5. **Document Nodes**: Add docstrings explaining what each node does
6. **Test Incrementally**: Add one node at a time and test
7. **Use Conditional Edges Wisely**: Don't create overly complex routing
8. **Cache Expensive Operations**: Schema fetching, tool initialization, etc.
## Quick Reference
### Adding a Node Checklist
- [ ] Create node function with `async def _node_name(self, state: AgentState) -> dict`
- [ ] Add node to graph: `workflow.add_node("node_name", self._node_name)`
- [ ] Connect with edges: `workflow.add_edge("from", "node_name")`
- [ ] Connect to next node: `workflow.add_edge("node_name", "to")`
- [ ] Test the node in isolation
- [ ] Test the full workflow
### Adding a Tool Checklist
- [ ] Add tool to MCP server (`mysql-db-server.py`)
- [ ] Restart MCP server
- [ ] Use tool in node: `tool = await self._get_tool("tool_name")`
- [ ] Call tool: `result = await tool.ainvoke({...})`
- [ ] Handle tool results in state updates
## Quick Start: Adding Your First Node
Here's a minimal example to add a "validate SQL" node:
```python
# 1. Add the node function
async def _validate_sql_node(self, state: AgentState) -> dict:
"""Validate SQL before execution"""
messages = state["messages"]
sql = self._extract_sql(messages[-1])
# Simple validation - check for SELECT
if not sql.upper().startswith("SELECT"):
return {
"messages": [ToolMessage(
content="Error: Only SELECT queries allowed",
tool_call_id="validation_error"
)]
}
return {} # Validation passed, continue
# 2. Add to graph in _build_graph()
workflow.add_node("validate_sql", self._validate_sql_node)
# 3. Insert in workflow
workflow.add_edge("generate_sql", "validate_sql")
workflow.add_edge("validate_sql", "execute_query")
```
That's it! Your new node is now part of the workflow.
## Quick Reference Templates
### Template: Adding a New Node
```python
# Step 1: Create node function
async def _your_node_name(self, state: AgentState) -> dict:
"""What your node does"""
# Get state
messages = state.get("messages", [])
schema_info = state.get("schema_info", {})
# Your logic here
result = await self._do_something()
# Return state updates
return {
"messages": [AIMessage(content=str(result))],
# Add other updates as needed
}
# Step 2: Add to _build_graph()
def _build_graph(self) -> StateGraph:
workflow = StateGraph(AgentState)
# ... existing nodes ...
workflow.add_node("your_node_name", self._your_node_name)
# Step 3: Connect edges
workflow.add_edge("previous_node", "your_node_name")
workflow.add_edge("your_node_name", "next_node")
return workflow.compile()
```
### Template: Adding a New Tool
```python
# Step 1: Add tool to mysql-db-server.py
@mcp.tool()
def your_tool_name(param1: str, param2: int = 10) -> str:
"""What your tool does"""
# Your tool implementation
return result
# Step 2: Use in any node
async def _your_node(self, state: AgentState) -> dict:
tool = await self._get_tool("your_tool_name")
if tool:
result = await tool.ainvoke({
"param1": "value",
"param2": 20
})
return {"messages": [ToolMessage(content=str(result))]}
return {}
```
### Template: Conditional Routing
```python
# Decision function
def _should_route_to_node(self, state: AgentState) -> Literal["yes", "no"]:
"""Decide routing based on state"""
if some_condition(state):
return "yes"
return "no"
# Add conditional edge
workflow.add_conditional_edges(
"source_node",
self._should_route_to_node,
{
"yes": "target_node",
"no": "other_node"
}
)
```
## Helper Methods and Features
This section documents all helper methods and features used throughout the agent, including performance optimizations.
### Schema Exploration Helpers
#### `_identify_relevant_tables(user_query: str, all_tables: list) -> list`
**Code Location**: Lines ~140-195
**What it does**:
Uses LLM to analyze the user query and identify which tables are relevant, rather than exploring all tables. This significantly reduces schema exploration time for large databases.
**How it works**:
1. **Optimization**: Skips LLM call if ≤3 tables (not worth the overhead)
2. Sends a prompt to LLM with the user query and all available tables
3. LLM returns comma-separated list of relevant table names
4. Parses and validates table names (uses set for O(1) lookup)
5. Falls back to all tables if identification fails or returns empty
**Performance Optimizations**:
- Skips LLM call for small databases (≤3 tables)
- Uses set for O(1) table name validation
- Single strip chain for parsing (optimized string operations)
**Benefits**:
- Faster schema exploration (only explores 2-3 tables instead of 20+)
- Less token usage (smaller prompts)
- More focused schema context
- No unnecessary LLM calls for small databases
---
#### `_get_sample_data(table_name: str, schema_name: str = None, row_limit: int = 3) -> list`
**Code Location**: Lines ~155-176
**What it does**:
Fetches sample rows from a table to understand actual data patterns, formats, and relationships.
**How it works**:
1. Builds SQL query: `SELECT * FROM table_name LIMIT {row_limit}`
2. Uses `run_query_json` MCP tool to execute query
3. Returns list of sample rows (dictionaries)
**Parameters**:
- `table_name`: Name of the table
- `schema_name`: Optional schema name for schema-prefixed queries
- `row_limit`: Number of rows to fetch (default: 3, can be 1-5)
**Benefits**:
- Understands actual data types and formats
- Better value generation in queries (sees real examples)
- Sees actual relationships in data
---
### Query Validation
#### `_validate_query_is_database_question(query: str) -> tuple[bool, str]`
**Code Location**: Lines ~368-442
**What it does**:
Validates if the user query is actually a database question before generating SQL. Uses fast heuristics first, then LLM only for ambiguous cases.
**How it works**:
1. **Fast Heuristic Checks** (avoid LLM call when possible):
- Checks query length (minimum 3 characters)
- Checks for database-related keywords (show, find, list, count, etc.) using set-based O(1) lookups
- Checks if query looks like SQL
- Returns immediately if clearly valid or invalid
2. **LLM Validation** (only for ambiguous cases):
- Uses LLM with timeout protection to validate if query is a meaningful database question
- Falls back to heuristics if LLM call fails or times out
**Returns**: Tuple of (is_valid: bool, reason: str)
**Used in**: `_generate_sql_node()` before SQL generation (first attempt only)
**Optimization**: Uses set-based keyword matching for O(1) lookups, avoids LLM call for 90%+ of cases
**Error Handling**: Handles LLM timeouts and failures gracefully with fallback to heuristics
---
### Confidence Scoring and Validation
#### `_score_and_analyze_query(query: str, sql: str, schema_info: dict, query_results: list = None, query_error: str = None) -> tuple[float, str]`
**Code Location**: Lines ~444-603
**What it does**:
Calculates a confidence score (0-1) and analysis for a generated SQL query. **Explicitly checks if SQL answers the question**, not just syntax correctness.
**How it works**:
1. If query error exists: Scores based on error type and severity
2. If query results exist: **Explicitly checks if results answer the question correctly**
- Verifies if query retrieves the requested information
- Checks if multi-part questions handle ALL parts
- Validates result relevance to the question
- **Confidence MUST be low (< 0.6) if query doesn't answer the question**
3. If no results: Scores based on query structure and schema alignment
4. Uses LLM with timeout protection to evaluate confidence considering:
- Column name correctness
- JOIN correctness
- **Whether it answers the question (CRITICAL)**
- Query logic and structure
**Returns**: Tuple of (confidence: float, analysis: str)
**Used in**: `_generate_sql_node()` and `_refine_sql_node()` after executing test query
**Key Features**:
- Confidence scoring explicitly validates that SQL answers the question, not just that it's syntactically correct
- Handles None/empty LLM responses gracefully
- Timeout protection prevents hanging on slow LLM calls
- Fallback to default confidence (0.5) if parsing fails
---
#### `_analyze_query_confidence(query: str, sql: str, schema_info: dict, confidence: float, query_results: list = None, query_error: str = None) -> str`
**Code Location**: Lines ~511-583
**What it does**:
Analyzes the query and provides reasoning/analysis based on actual query execution results. Always included in output regardless of confidence score.
**How it works**:
1. If query error: Analyzes the error and explains what went wrong
2. If query results: Analyzes if results answer the question correctly
3. If no results: Analyzes query structure against schema
4. Provides 2-3 sentence analysis explaining query quality
**Returns**: String with analysis/reasoning
**Used in**: `_generate_sql_node()` after confidence scoring
---
#### `_validate_sql_syntax(sql: str) -> bool`
**Code Location**: Lines ~583-602
**What it does**:
Validates basic SQL syntax to ensure the query is a valid SELECT statement.
**How it works**:
1. Checks if SQL is None or empty (returns False)
2. Checks if SQL starts with SELECT
3. Validates that query has FROM clause (for most queries)
4. Returns True if valid, False otherwise
**Returns**: Boolean indicating if SQL is syntactically valid
**Used in**: `_generate_sql_node()` after SQL extraction
**Error Handling**: Handles None/empty SQL gracefully
---
#### `_has_critical_issues(analysis: str, query_error: str = None, sql_query: str = None) -> bool`
**Code Location**: Lines ~604-638
**What it does**:
Detects if there are critical issues that require query refinement, such as missing keywords, syntax errors, or incomplete queries.
**How it works**:
1. Checks if analysis is None or empty (returns False)
2. Checks analysis text for critical keywords (missing, incomplete, syntax error, etc.)
3. Checks query error for syntax errors or error codes
4. Validates SQL query itself for missing SELECT keyword
5. Returns True if critical issues detected
**Returns**: Boolean indicating if critical issues exist
**Used in**: `_generate_sql_node()` to determine if refinement is needed
**Error Handling**: Handles None/empty analysis gracefully
---
### SQL Refinement
#### `_refine_sql_with_analysis(query: str, sql: str, schema_info: dict, analysis: str, query_results: list = None, query_error: str = None) -> str`
**Code Location**: Lines ~737-908
**What it does**:
Generates a corrected SQL query using the analysis/reasoning from confidence scoring and actual query results.
**How it works**:
1. Builds refinement prompt with:
- Original question and SQL
- Analysis of issues
- Query results or error message
- Window function guidance for appropriate scenarios (ranking, comparison, aggregation within groups)
- Tie handling guidance for "most/least" queries
2. Includes full schema information in system prompt
3. Calls LLM with timeout protection to generate corrected SQL
4. Handles None/empty responses gracefully (returns original SQL)
5. Cleans up refined SQL (removes markdown, parameter placeholders)
6. Returns refined SQL or original SQL if refinement fails or times out
**Returns**: String with refined SQL query
**Used in**: `_refine_sql_node()` when confidence is low or issues detected
**Error Handling**:
- Handles LLM timeouts gracefully (returns original SQL)
- Handles empty/malformed LLM responses (returns original SQL)
- Comprehensive error logging for debugging
---
#### `_fetch_missing_table_info(schema_info: dict, query_error: str = None, analysis: str = None, sql_query: str = None)`
**Code Location**: Lines ~640-735
**What it does**:
Selectively fetches schema information for tables that are mentioned in errors/analysis/SQL but not in the cached schema info.
**How it works**:
1. Extracts table names from:
- SQL query (FROM/JOIN clauses) - uses compiled regex
- Error messages (e.g., "Table 'X' doesn't exist") - uses compiled regex
- Analysis text (mentioned table names) - uses compiled regex
2. Compares with cached table descriptions (uses set for O(1) lookup)
3. Identifies missing tables
4. Fetches table descriptions and foreign keys in parallel for missing tables only (with timeout protection)
5. Updates `schema_info` and cache with new data
6. Manages cache size to prevent unbounded growth
**Performance Optimizations**:
- Uses compiled regex patterns for table extraction
- Uses set for O(1) table name lookups
- Parallel execution of description and FK fetches
- Timeout protection prevents hanging on slow database calls
**Error Handling**:
- Handles tool call failures gracefully (continues with other tables)
- Logs warnings for failed fetches (if logging enabled)
- Continues processing even if some table fetches fail
- Safe regex extraction with try/except (handles extraction failures)
- Validates SQL query and error message are strings before regex operations
- Filters out empty regex matches before adding to mentioned_tables set
**Benefits**:
- Only fetches what's needed (performance)
- Ensures all required tables are available for refinement
- Updates cache for future queries
- Prevents cache from growing unbounded
**Used in**: `_refine_sql_node()` before refinement
---
### Cache Management
#### `_manage_schema_cache_size()`
**Code Location**: Lines ~120-133
**What it does**:
Manages schema cache size using LRU eviction to prevent unbounded growth.
**How it works**:
1. Checks if schema cache exists
2. Gets table descriptions dictionary
3. If cache exceeds `max_schema_cache_size`, removes oldest entries
4. Logs eviction events (if logging enabled)
**Used in**: `_explore_schema_node()` after updating cache
---
#### `_manage_column_cache_size()`
**Code Location**: Lines ~135-141
**What it does**:
Manages column cache size using LRU eviction to prevent unbounded growth.
**How it works**:
1. Uses OrderedDict for LRU behavior
2. Removes oldest entries (FIFO) when cache exceeds `max_column_cache_size`
3. Logs cache size warnings when approaching limit (if logging enabled)
**Used in**: `_generate_sql_node()` after extracting column names
---
### Timeout Protection
#### `_call_llm_with_timeout(messages: list, timeout: Optional[int] = None) -> BaseMessage`
**Code Location**: Lines ~143-165
**What it does**:
Calls LLM with timeout protection to prevent hanging on slow API calls.
**How it works**:
1. Uses `asyncio.wait_for()` to enforce timeout
2. Logs LLM call start/completion (if logging enabled)
3. Raises `TimeoutError` if call exceeds timeout
4. Handles other exceptions with comprehensive logging
**Returns**: LLM response message
**Used in**: All LLM calls throughout the agent
**Configuration**: Timeout defaults to `llm_timeout` (60s) but can be overridden per call
---
#### `_call_tool_with_timeout(tool: BaseTool, args: dict, timeout: Optional[int] = None) -> any`
**Code Location**: Lines ~167-189
**What it does**:
Calls MCP tool with timeout protection to prevent hanging on slow database operations.
**How it works**:
1. Uses `asyncio.wait_for()` to enforce timeout
2. Logs tool call start/completion (if logging enabled)
3. Raises `TimeoutError` if call exceeds timeout
4. Handles other exceptions with comprehensive logging
**Returns**: Tool execution result
**Used in**: All MCP tool calls throughout the agent
**Configuration**: Timeout defaults to `query_timeout` (30s) but can be overridden per call
---
### Error Handling
#### `_parse_sql_error(error_msg: str) -> dict`
**Code Location**: Lines ~1048-1080
**What it does**:
Parses SQL error messages to extract actionable information such as error type, column names, and table names.
**How it works**:
1. Handles None/empty error messages gracefully
2. Parses "Unknown column" errors (uses compiled regex):
- Extracts column name
- Validates regex group is not empty before use
- Handles `table.column` format safely (validates split result has 2 parts)
3. Parses "Table doesn't exist" errors (uses compiled regex):
- Extracts table name
- Validates regex group is not empty before use
4. Parses syntax errors:
- Detects MySQL error code 1064 or "syntax error" text
5. Returns structured error info dictionary
**Performance Optimizations**:
- Uses compiled regex patterns for faster parsing
- Single-pass parsing with optimized string operations
**Error Handling**:
- Returns default error info if parsing fails
- Validates regex group matches are not empty
- Safe string splitting with length validation
**Returns**: Dictionary with:
- `type`: Error type (unknown_column, unknown_table, syntax_error, unknown)
- `column`: Column name (if applicable)
- `table`: Table name (if applicable)
- `message`: Original error message
**Used in**: `_refine_query_node()` and `_refine_sql_with_analysis()` to provide targeted error guidance
---
### Helper Methods
#### `_extract_sql_from_messages(messages: list) -> Optional[str]`
**Code Location**: Lines ~2030-2069
**What it does**:
Extracts SQL query from AIMessage content using multiple extraction methods with validation.
**How it works**:
1. Validates messages list is not empty
2. Iterates through messages in reverse order (most recent first)
3. Tries multiple extraction methods:
- SQL code blocks (```sql ... ```)
- Generic code blocks (``` ... ```)
- Regex pattern matching for SELECT statements
4. Validates extracted SQL is not None/empty before returning
5. Strips whitespace from extracted SQL
**Error Handling**:
- Handles None/empty message content gracefully
- Validates regex group matches are not empty
- Returns None if no SQL found
**Returns**: Extracted SQL string or None if not found
**Used in**: `_execute_query_node()`, `_refine_query_node()` as fallback when SQL not in state
---
#### `get_final_answer(result: dict) -> str`
**Code Location**: Lines ~2400-2446
**What it does**:
Extracts the final answer from the agent result with optimized message access.
**How it works**:
1. Gets messages from result dictionary
2. Returns early if no messages found
3. Searches in reverse order for successful query results:
- Looks for ToolMessage with "successfully" in content
- Checks if last AIMessage is the final message (uses `is` comparison for O(1) check)
4. Falls back to last message content if no successful result found
5. Validates content is not None/empty before returning
**Performance Optimizations**:
- Uses `is` comparison (`msg is messages[-1]`) instead of O(n) `index()` call
- Early exit if no messages
**Error Handling**:
- Handles empty messages list
- Validates message content is not None/empty
- Safe attribute access with `hasattr()` checks
**Returns**: Final answer string or "No answer found" if unavailable
**Used in**: User-facing code to extract readable answer from agent result
---
## Architecture Principles
### Node-Edge Separation
The agent follows LangGraph best practices with clear separation:
- **Nodes**: Pure state processors - they take state, process it, return updates
- **Edges**: Orchestration logic - they make routing decisions based on state flags
- **No Conditionals in Nodes**: All conditional logic moved to edge functions
### Benefits
1. **Clear Workflow**: Graph structure shows the workflow visually
2. **Easy Extension**: Add new nodes/paths without modifying existing nodes
3. **Better Testing**: Test nodes and edges independently
4. **Maintainability**: Each node has single responsibility
## Performance Optimizations
The agent includes several performance optimizations:
1. **Compiled Regex Patterns**: All regex patterns are compiled once in `__init__` for faster text processing
2. **Schema Caching**: Schema information is cached across queries (session-based)
3. **Column Name Caching**: Extracted column names are cached to avoid redundant parsing
4. **Parallel Execution**: Schema exploration operations (descriptions, foreign keys) execute in parallel
5. **Intelligent Table Selection**: Skips LLM call for small databases (≤3 tables)
6. **Set-Based Lookups**: Uses sets instead of lists for O(1) table/column lookups
7. **State-Based Storage**: Stores SQL and test results in state to avoid re-extraction/re-execution
8. **Result Reuse**: Reuses test query results when they contain all data (avoids redundant execution)
9. **LIMIT Respect**: When reusing test results, respects original LIMIT clause (e.g., LIMIT 1 returns 1 row, not all test results)
10. **Early Exits**: Skips unnecessary operations when data is already cached
11. **Combined LLM Calls**: Confidence scoring and analysis combined into single LLM call
12. **Fast Query Validation**: Uses set-based keyword matching for O(1) lookups, avoids LLM call for 90%+ of cases
13. **Edge-Based Routing**: Decisions made in edge functions (no redundant checks in nodes)
14. **LRU Cache Management**: Schema and column caches use LRU eviction to prevent unbounded growth
15. **Timeout Protection**: All LLM and database calls have configurable timeouts to prevent hanging (validates timeout > 0)
16. **Input Validation**: Validates query input (None, empty, type checks) before processing
17. **Comprehensive Logging**: Optional logging for debugging, performance monitoring, and error tracking
18. **Graceful Error Handling**: Handles None/empty responses, timeouts, and malformed data with fallbacks
19. **Type Safety**: Safe type conversions with try/except, validates regex group matches before use
20. **Cache Structure Validation**: Validates cache structure (dict types) before operations to handle corruption
21. **Empty SQL Validation**: Validates SQL is not empty after extraction and cleaning
22. **Safe String Parsing**: Safe parsing for table resources, error messages, and analysis text with error handling
23. **Regex Group Validation**: Validates regex groups are not empty before using them
24. **Tool Call Validation**: Validates tool_call structure (dict, has name) before processing
25. **Optimized Message Access**: Uses efficient checks (is comparison) instead of O(n) index() calls
26. **Edge Case Handling**: Comprehensive validation for None/empty values, type safety, and malformed data
27. **Safe Parsing**: All string parsing operations have try/except blocks and validate results before use
28. **Code Cleanliness**: No duplicate imports, uses compiled regex patterns consistently, clean structure
## Next Steps
1. Review the current workflow graph
2. Identify where you want to add functionality
3. Follow the step-by-step guides above
4. Test incrementally
5. Refer to examples for common patterns
For improvement ideas that haven't been implemented yet, see [IMPROVEMENTS.md](IMPROVEMENTS.md).