"""
LangGraph Agent for Text-to-SQL
This agent uses LangGraph to build a sophisticated text-to-SQL agent that:
1. Explores database schema when needed
2. Generates SQL queries from natural language
3. Executes queries and handles errors
4. Refines queries based on feedback
Architecture:
- All database operations (queries, schema exploration) are performed via MCP tools
- This agent contains orchestration logic, LLM interactions, and query analysis
- No direct database connections - all data access goes through the MCP server
"""
from typing import TypedDict, Annotated, Literal, Optional
from langchain_core.messages import BaseMessage, HumanMessage, AIMessage, SystemMessage, ToolMessage
from langchain_openai import ChatOpenAI
from langchain_mcp_adapters.client import MultiServerMCPClient
from langgraph.graph import StateGraph, END
from langgraph.graph.message import add_messages
from langchain_core.tools import BaseTool
import json
import asyncio
from datetime import date
import re
import logging
from collections import OrderedDict
# ========================================================================
# Constants
# ========================================================================
DEFAULT_MAX_ATTEMPTS = 3
TEST_QUERY_LIMIT = 3
SMALL_DB_THRESHOLD = 3 # Skip LLM call for table selection if ≤3 tables
LOW_CONFIDENCE_THRESHOLD = 0.6
HIGH_CONFIDENCE_THRESHOLD = 0.8
VERY_HIGH_CONFIDENCE_THRESHOLD = 0.9
MAX_SCHEMA_CACHE_SIZE = 1000 # Maximum number of table descriptions to cache
MAX_COLUMN_CACHE_SIZE = 500 # Maximum number of column name extractions to cache
DEFAULT_LLM_TIMEOUT = 60 # Default timeout for LLM calls in seconds
DEFAULT_QUERY_TIMEOUT = 30 # Default timeout for database queries in seconds
# ========================================================================
# Logging Setup
# ========================================================================
logger = logging.getLogger(__name__)
class AgentState(TypedDict):
"""State for the text-to-SQL agent"""
messages: Annotated[list[BaseMessage], add_messages]
schema_info: dict # Cached schema information
query_attempts: int # Number of query attempts
max_attempts: int # Maximum query attempts before giving up
relevant_tables: list # Tables identified as relevant to the query
previous_error: Optional[str] # Optional: previous query error for refinement context
previous_sql: Optional[str] # Optional: previous SQL query for refinement context
final_sql: Optional[str] # Optional: final SQL query to avoid re-extraction
test_query_results: Optional[list] # Optional: test query results to avoid re-execution
# Refinement decision flags (set by nodes, used by edges)
should_refine: Optional[bool] # Whether SQL needs refinement
refine_reason: Optional[str] # Reason for refinement
confidence: Optional[float] # Confidence score
sql_is_valid: Optional[bool] # Whether SQL syntax is valid
has_critical_issues: Optional[bool] # Whether critical issues detected
confidence_reasoning: Optional[str] # Analysis/reasoning from confidence scoring
query_error: Optional[str] # Query execution error if any
# Schema exploration flags
schema_complete: Optional[bool] # Whether schema exploration is complete
class TextToSQLAgent:
"""LangGraph agent for converting natural language to SQL queries.
Architecture:
- Uses LangGraph for state machine orchestration
- All database operations via MCP tools (no direct DB connections)
- Nodes process state, edges make routing decisions
- Supports schema caching, intelligent table selection, and auto-refinement
"""
def __init__(
self,
mcp_client: MultiServerMCPClient,
llm: ChatOpenAI,
max_query_attempts: int = DEFAULT_MAX_ATTEMPTS,
llm_timeout: int = DEFAULT_LLM_TIMEOUT,
query_timeout: int = DEFAULT_QUERY_TIMEOUT,
max_schema_cache_size: int = MAX_SCHEMA_CACHE_SIZE,
max_column_cache_size: int = MAX_COLUMN_CACHE_SIZE,
enable_logging: bool = True
):
self.mcp_client = mcp_client
self.llm = llm
self.max_attempts = max_query_attempts
self.llm_timeout = llm_timeout
self.query_timeout = query_timeout
self.max_schema_cache_size = max_schema_cache_size
self.max_column_cache_size = max_column_cache_size
self.enable_logging = enable_logging
# Cache for schema information to avoid repeated expensive calls
# Use OrderedDict for LRU eviction
self.schema_cache: Optional[dict] = None
# Cache for extracted column names to avoid redundant parsing
# Use OrderedDict for LRU eviction
self.column_cache: OrderedDict[str, list[str]] = OrderedDict()
# Get tools from MCP client
self.tools = None
self.tool_dict = None # Dictionary for O(1) tool lookup
if self.enable_logging:
logger.info(f"TextToSQLAgent initialized with max_attempts={max_query_attempts}, "
f"llm_timeout={llm_timeout}s, query_timeout={query_timeout}s, "
f"max_schema_cache={max_schema_cache_size}, max_column_cache={max_column_cache_size}")
# Compile regex patterns once for performance (reused many times)
self._compile_regex_patterns()
# Build the graph
self.graph = self._build_graph()
def _manage_schema_cache_size(self):
"""Manage schema cache size using LRU eviction."""
if not self.schema_cache:
return
# Validate cache structure
if not isinstance(self.schema_cache, dict):
if self.enable_logging:
logger.warning("Schema cache is not a dictionary, resetting")
self.schema_cache = {}
return
table_descriptions = self.schema_cache.get("table_descriptions", {})
# Validate table_descriptions is a dict
if not isinstance(table_descriptions, dict):
if self.enable_logging:
logger.warning("table_descriptions is not a dictionary, resetting")
self.schema_cache["table_descriptions"] = {}
return
if len(table_descriptions) > self.max_schema_cache_size:
# Remove oldest entries (simple approach: remove excess)
excess = len(table_descriptions) - self.max_schema_cache_size
keys_to_remove = list(table_descriptions.keys())[:excess]
for key in keys_to_remove:
del table_descriptions[key]
if self.enable_logging:
logger.debug(f"Evicted {excess} table descriptions from schema cache")
def _manage_column_cache_size(self):
"""Manage column cache size using LRU eviction."""
while len(self.column_cache) > self.max_column_cache_size:
# Remove oldest entry (FIFO)
self.column_cache.popitem(last=False)
if self.enable_logging and len(self.column_cache) > self.max_column_cache_size * 0.9:
logger.debug(f"Column cache size: {len(self.column_cache)}/{self.max_column_cache_size}")
async def _call_llm_with_timeout(self, messages: list, timeout: Optional[int] = None) -> BaseMessage:
"""Call LLM with timeout handling."""
timeout = timeout or self.llm_timeout
# Validate timeout is positive
if timeout <= 0:
timeout = DEFAULT_LLM_TIMEOUT
if self.enable_logging:
logger.warning(f"Invalid timeout {timeout}, using default {DEFAULT_LLM_TIMEOUT}s")
try:
if self.enable_logging:
logger.debug(f"Calling LLM with timeout={timeout}s")
response = await asyncio.wait_for(
self.llm.ainvoke(messages),
timeout=timeout
)
if self.enable_logging:
logger.debug("LLM call completed successfully")
return response
except asyncio.TimeoutError:
error_msg = f"LLM call timed out after {timeout} seconds"
if self.enable_logging:
logger.error(error_msg)
raise TimeoutError(error_msg)
except Exception as e:
error_msg = f"LLM call failed: {str(e)}"
if self.enable_logging:
logger.error(error_msg, exc_info=True)
raise
async def _call_tool_with_timeout(self, tool: BaseTool, args: dict, timeout: Optional[int] = None) -> any:
"""Call MCP tool with timeout handling."""
timeout = timeout or self.query_timeout
# Validate timeout is positive
if timeout <= 0:
timeout = DEFAULT_QUERY_TIMEOUT
if self.enable_logging:
logger.warning(f"Invalid timeout {timeout}, using default {DEFAULT_QUERY_TIMEOUT}s")
try:
if self.enable_logging:
logger.debug(f"Calling tool {tool.name} with timeout={timeout}s")
result = await asyncio.wait_for(
tool.ainvoke(args),
timeout=timeout
)
if self.enable_logging:
logger.debug(f"Tool {tool.name} completed successfully")
return result
except asyncio.TimeoutError:
error_msg = f"Tool {tool.name} call timed out after {timeout} seconds"
if self.enable_logging:
logger.error(error_msg)
raise TimeoutError(error_msg)
except Exception as e:
error_msg = f"Tool {tool.name} call failed: {str(e)}"
if self.enable_logging:
logger.error(error_msg, exc_info=True)
raise
def _compile_regex_patterns(self):
"""Compile regex patterns once for performance (reused many times)"""
self._regex_patterns = {
'sql_code_block': re.compile(r'```sql\s*(.*?)\s*```', re.DOTALL | re.IGNORECASE),
'sql_marker': re.compile(r'SQL Query:\s*(.*?)(?:\n\n|$)', re.DOTALL | re.IGNORECASE),
'select_statement': re.compile(r'(SELECT.*?)(?:\n\n|$)', re.DOTALL | re.IGNORECASE),
'confidence_score': re.compile(r'CONFIDENCE:\s*(\d+\.?\d*)', re.IGNORECASE),
'analysis_section': re.compile(r'ANALYSIS:\s*(.+?)(?:\n\n|$)', re.IGNORECASE | re.DOTALL),
'unknown_column': re.compile(r"Unknown column ['\"]?([^'\"]+)['\"]?", re.IGNORECASE),
'unknown_table': re.compile(r"Table ['\"]?([^'\"]+)['\"]? doesn't exist", re.IGNORECASE),
'limit_replace': re.compile(r'\s+LIMIT\s+\d+', re.IGNORECASE),
'param_and_or': re.compile(r'\s+(AND|OR)\s+[^\s]+\s*[=<>!]+\s*\?', re.IGNORECASE),
'param_where': re.compile(r'\s+WHERE\s+[^\s]+\s*[=<>!]+\s*\?', re.IGNORECASE),
'param_standalone': re.compile(r'\s+\?\s+'),
'param_end': re.compile(r'\s+\?$'),
'param_start': re.compile(r'^\?\s+'),
'trailing_clauses': re.compile(r'\s+(WHERE|AND|OR)\s+$', re.IGNORECASE),
'from_table': re.compile(r'\bFROM\s+`?(\w+)`?', re.IGNORECASE),
'join_table': re.compile(r'\bJOIN\s+`?(\w+)`?', re.IGNORECASE),
'table_in_error': re.compile(r"Table\s+['\"]?(\w+)['\"]?", re.IGNORECASE),
'table_in_quotes': re.compile(r"['\"`](\w+)['\"`]"),
}
async def _initialize_tools(self):
"""Initialize tools from MCP client and cache static schema data"""
if self.tools is None:
tools_list = await self.mcp_client.get_tools()
self.tools = tools_list
# Create a dictionary for O(1) lookup instead of O(n) linear search
self.tool_dict = {t.name: t for t in tools_list}
# Pre‑fetch static schema information that never changes during a session
# This avoids repeated network calls for list_tables / list_table_resources
list_tables_tool = await self._get_tool("list_tables")
list_table_resources_tool = await self._get_tool("list_table_resources")
if list_tables_tool:
try:
tables_result = await self._call_tool_with_timeout(list_tables_tool, {})
self.schema_cache = self.schema_cache or {}
self.schema_cache["tables"] = tables_result
except Exception as e:
if self.enable_logging:
logger.warning(f"Failed to fetch tables list: {str(e)}")
if list_table_resources_tool:
try:
resources = await self._call_tool_with_timeout(list_table_resources_tool, {})
self.schema_cache = self.schema_cache or {}
self.schema_cache["table_resources"] = resources
except Exception as e:
if self.enable_logging:
logger.warning(f"Failed to fetch table resources: {str(e)}")
return self.tools
async def _get_tool(self, tool_name: str):
"""Get a tool by name using O(1) dictionary lookup"""
if not hasattr(self, 'tool_dict') or self.tool_dict is None:
await self._initialize_tools()
return self.tool_dict.get(tool_name)
def _build_graph(self) -> StateGraph:
"""Build the LangGraph state graph"""
workflow = StateGraph(AgentState)
# Add nodes
workflow.add_node("explore_schema", self._explore_schema_node)
workflow.add_node("generate_sql", self._generate_sql_node)
workflow.add_node("refine_sql", self._refine_sql_node)
workflow.add_node("execute_query", self._execute_query_node)
workflow.add_node("refine_query", self._refine_query_node)
workflow.add_node("tools", self._tools_node)
# Set entry point
workflow.set_entry_point("explore_schema")
# Add conditional edge: check if schema exploration is complete
workflow.add_conditional_edges(
"explore_schema",
self._should_continue_schema_exploration,
{
"complete": "generate_sql",
"continue": "explore_schema" # Will fetch missing schema
}
)
# Add conditional edge: after SQL generation, decide if refinement is needed
workflow.add_conditional_edges(
"generate_sql",
self._should_refine_sql,
{
"refine": "refine_sql",
"execute": "execute_query",
"use_tools": "tools",
"end": END
}
)
# After refinement, always go to execute_query
workflow.add_edge("refine_sql", "execute_query")
# Add conditional edge: after execution, decide next step
workflow.add_conditional_edges(
"execute_query",
self._check_query_result,
{
"success": END,
"retry": "refine_query",
"error": END # Give up after max attempts
}
)
workflow.add_edge("refine_query", "generate_sql")
workflow.add_edge("tools", "generate_sql")
return workflow.compile()
async def _identify_relevant_tables(self, user_query: str, all_tables: list) -> list:
"""Use LLM to identify which tables are relevant to the query.
Args:
user_query: The natural language query
all_tables: List of all available table names
Returns:
List of relevant table names (or all tables if identification fails)
"""
if not all_tables:
return []
# Skip LLM call for small databases (optimization)
if len(all_tables) <= SMALL_DB_THRESHOLD:
return all_tables
prompt = f"""Given this database question: "{user_query}"
Available tables: {', '.join(all_tables)}
Which tables are likely needed to answer this question?
Return ONLY the table names, comma-separated. If unsure, include all potentially relevant tables.
Relevant tables:"""
try:
response = await self._call_llm_with_timeout([HumanMessage(content=prompt)])
relevant_str = response.content
# Handle None or empty content
if not relevant_str:
if self.enable_logging:
logger.warning("LLM returned empty content for table identification, using all tables")
return all_tables
relevant_str = relevant_str.strip()
# Parse comma-separated list (optimized: single strip chain)
relevant = [t.strip(' `"\'') for t in relevant_str.split(',') if t.strip()]
# Filter to only include actual table names (use set for O(1) lookup)
tables_set = set(all_tables)
relevant = [t for t in relevant if t in tables_set]
return relevant if relevant else all_tables # Fallback to all if none identified
except Exception as e:
# If LLM call fails, return all tables as fallback
if self.enable_logging:
logger.warning(f"Table identification failed: {str(e)}, using all tables")
return all_tables
async def _get_sample_data(self, table_name: str, schema_name: str = None, row_limit: int = 3) -> list:
"""Get sample data from a table to understand data patterns"""
try:
# Use run_query_json as it's more reliable and doesn't require schema parameter
run_query_tool = await self._get_tool("run_query_json")
if run_query_tool:
# Build SQL query - use schema prefix if available
if schema_name:
sql = f"SELECT * FROM `{schema_name}`.`{table_name}` LIMIT {row_limit}"
else:
sql = f"SELECT * FROM `{table_name}` LIMIT {row_limit}"
result = await self._call_tool_with_timeout(
run_query_tool,
{
"input": {
"sql": sql,
"row_limit": row_limit
}
}
)
return result if isinstance(result, list) else []
except Exception as e:
# If query fails, return empty list (sample data is optional)
return []
async def _validate_query_is_database_question(self, query: str) -> tuple[bool, str]:
"""Validate if the user query is actually a database question.
Uses fast heuristics first, then LLM only if needed for ambiguous cases.
Returns (is_valid, reason) tuple.
"""
query_stripped = query.strip()
query_lower = query_stripped.lower()
# Fast heuristic checks first (avoid LLM call if possible)
# Check if query is too short
if len(query_stripped) < 3:
return False, "Query is too short to be a valid database question."
# Check for common database question patterns (optimized with set for O(1) lookup)
db_keywords = {
'show', 'find', 'get', 'list', 'count', 'select', 'what', 'which', 'who',
'how many', 'how much', 'where', 'when', 'display', 'retrieve', 'query',
'books', 'authors', 'members', 'loans', 'employees', 'customers', 'orders',
'table', 'tables', 'data', 'database', 'records', 'rows'
}
# Check if query contains any database keywords
query_words = set(query_lower.split())
has_db_keyword = bool(query_words & db_keywords) or any(kw in query_lower for kw in ['how many', 'how much'])
# Check if it looks like SQL (shouldn't be)
sql_keywords = {'select', 'from', 'where', 'join', 'insert', 'update', 'delete'}
looks_like_sql = bool(query_words & sql_keywords)
# Fast path: If clearly valid or clearly invalid, return immediately
if has_db_keyword or looks_like_sql:
return True, "Query contains database-related keywords."
# If query is very short and has no keywords, likely invalid
if len(query_stripped) < 5 and not has_db_keyword:
return False, "Query does not appear to be a database question."
# Ambiguous case: Use LLM for validation (only when heuristics are unclear)
validation_prompt = f"""Is this a valid database query question?
Question: "{query}"
A valid database question should:
- Ask about data in the database (tables, records, relationships)
- Request information that can be retrieved via SQL
- Be clear and meaningful
Examples of VALID questions:
- "Show me all authors"
- "How many books are there?"
- "Find books by J.K. Rowling"
Examples of INVALID questions:
- "sns" (gibberish)
- "hello" (greeting, not a query)
- "test" (too vague)
- Random characters or single words
Respond with ONLY:
VALID: <yes or no>
REASON: <brief explanation>"""
try:
response = await self._call_llm_with_timeout([HumanMessage(content=validation_prompt)])
response_text = response.content.strip()
# Parse response
is_valid = "valid: yes" in response_text.lower() or "yes" in response_text.lower()[:50]
reason = response_text
return is_valid, reason
except Exception as e:
# Fallback to heuristics if LLM call fails
return False, "Query does not appear to be a database question."
async def _score_and_analyze_query(
self, query: str, sql: str, schema_info: dict,
query_results: list = None, query_error: str = None
) -> tuple[float, str]:
"""Score confidence and analyze query in a single LLM call (optimization).
Confidence scoring now explicitly checks if the SQL answers the question.
"""
if query_error:
combined_prompt = f"""Analyze this SQL query and provide both a confidence score (0-1) and brief analysis.
Question: {query}
SQL: {sql}
Query Error: {query_error}
Tasks:
1. Rate confidence (0-1) - consider error type and severity
2. Provide brief analysis (2-3 sentences) explaining the issue
Format your response as:
CONFIDENCE: <number between 0.0 and 1.0>
ANALYSIS: <2-3 sentence analysis>"""
elif query_results is not None:
results_str = json.dumps(query_results[:5], indent=2, default=str)
# Check if results are empty
is_empty = len(query_results) == 0
empty_note = " NOTE: Query returned NO results (empty array). This may indicate a logical error in the query." if is_empty else ""
combined_prompt = f"""Analyze this SQL query and provide both a confidence score (0-1) and brief analysis.
Question: {query}
SQL: {sql}
Query Results (sample, up to 5 rows):
{results_str}{empty_note}
CRITICAL: When scoring confidence, you MUST consider:
1. Does the SQL query actually answer the question asked?
- If the question asks for specific information, does the query retrieve it?
- If the question has multiple parts (e.g., "find X, then find Y"), does the query handle ALL parts?
- Are the results relevant to what was asked?
2. If the question asks for "most/least/maximum/minimum", does the query handle ties correctly?
- Questions like "who has the most books" should return ALL entities with the maximum count, not just one
- Using LIMIT 1 when there are ties is INCORRECT - confidence should be low
- Should use HAVING COUNT = (SELECT MAX(...)) pattern to handle ties
3. If empty results, is this expected (no matching data) or a logical error?
4. Are GROUP BY, HAVING, JOINs, and subqueries used correctly?
5. For multi-part questions (with "and then"), does the query handle ALL parts?
IMPORTANT: If the query does NOT answer the question, confidence MUST be low (< 0.6).
If the query asks for "most/least" and uses LIMIT 1 (which doesn't handle ties), confidence MUST be low (< 0.6).
If the query answers the question correctly, confidence should be high (≥ 0.8).
Tasks:
1. Rate confidence (0-1) - MUST consider if results answer the question
2. Provide brief analysis (2-3 sentences) on query quality and whether it answers the question
Format your response as:
CONFIDENCE: <number between 0.0 and 1.0>
ANALYSIS: <2-3 sentence analysis>"""
else:
combined_prompt = f"""Analyze this SQL query and provide both a confidence score (0-1) and brief analysis.
Question: {query}
SQL: {sql}
CRITICAL: When scoring confidence, consider:
1. Does the SQL query structure suggest it will answer the question?
2. Are the correct tables and columns selected?
3. Are JOINs, WHERE clauses, and filters appropriate for the question?
IMPORTANT: If the query structure does NOT match the question, confidence MUST be low (< 0.6).
Tasks:
1. Rate confidence (0-1) - consider if query structure matches the question
2. Provide brief analysis (2-3 sentences) on potential issues
Format your response as:
CONFIDENCE: <number between 0.0 and 1.0>
ANALYSIS: <2-3 sentence analysis>"""
try:
response = await self._call_llm_with_timeout([HumanMessage(content=combined_prompt)])
content = response.content
# Handle None or empty content
if not content:
if self.enable_logging:
logger.warning("LLM returned empty content for confidence scoring")
return 0.5, "Unable to analyze query: LLM returned empty response."
content = content.strip()
# Extract confidence score (use compiled regex)
confidence_match = self._regex_patterns['confidence_score'].search(content)
if confidence_match:
try:
score_str = confidence_match.group(1)
if score_str: # Ensure group is not empty
score = float(score_str)
if score > 1.0:
score = score / 100.0
confidence = max(0.0, min(1.0, score))
else:
raise ValueError("Empty confidence score match")
except (ValueError, TypeError) as e:
if self.enable_logging:
logger.warning(f"Could not parse confidence score: {str(e)}, trying fallback")
confidence = None # Will trigger fallback
else:
confidence = None # Will trigger fallback
# Fallback: try to find any number if regex didn't work
if confidence is None:
match = re.search(r'(\d+\.?\d*)', content)
if match:
try:
score_str = match.group(1)
if score_str: # Ensure group is not empty
score = float(score_str)
if score > 1.0:
score = score / 100.0
confidence = max(0.0, min(1.0, score))
else:
raise ValueError("Empty score match")
except (ValueError, TypeError) as e:
if self.enable_logging:
logger.warning(f"Could not parse fallback confidence score: {str(e)}, using default 0.5")
confidence = 0.5
else:
if self.enable_logging:
logger.warning("Could not extract confidence score from LLM response, using default 0.5")
confidence = 0.5
# Extract analysis (use compiled regex)
analysis_match = self._regex_patterns['analysis_section'].search(content)
if analysis_match:
analysis_group = analysis_match.group(1)
analysis = analysis_group.strip() if analysis_group else "Analysis unavailable."
else:
# Fallback: use everything after CONFIDENCE line
lines = content.split('\n')
analysis_lines = []
found_conf = False
for line in lines:
if 'CONFIDENCE:' in line.upper():
found_conf = True
continue
if found_conf:
analysis_lines.append(line)
analysis = ' '.join(analysis_lines).strip() if analysis_lines else "Analysis unavailable."
if self.enable_logging:
logger.debug(f"Confidence score: {confidence:.2f}")
return confidence, analysis
except TimeoutError:
error_msg = "Confidence scoring timed out"
if self.enable_logging:
logger.error(error_msg)
return 0.3, f"{error_msg}. Using default low confidence."
except Exception as e:
# Fallback to separate calls if combined fails
error_msg = f"Error in combined confidence scoring: {str(e)}"
if self.enable_logging:
logger.warning(error_msg, exc_info=True)
try:
confidence = await self._score_query_confidence(query, sql, schema_info, query_results, query_error)
analysis = await self._analyze_query_confidence(query, sql, schema_info, confidence, query_results, query_error)
return confidence, analysis
except Exception as fallback_error:
if self.enable_logging:
logger.error(f"Fallback confidence scoring also failed: {str(fallback_error)}", exc_info=True)
return 0.3, f"Error analyzing query: {str(e)}"
async def _score_query_confidence(
self, query: str, sql: str, schema_info: dict,
query_results: list = None, query_error: str = None
) -> float:
"""Score confidence in the generated query (0-1) - kept for backward compatibility"""
confidence, _ = await self._score_and_analyze_query(query, sql, schema_info, query_results, query_error)
return confidence
def _validate_sql_syntax(self, sql: str) -> bool:
"""Validate basic SQL syntax - check if it's a valid SELECT query"""
if not sql or not sql.strip():
return False
sql_upper = sql.upper().strip()
# Must start with SELECT
if not sql_upper.startswith('SELECT'):
return False
# Basic validation: should have SELECT and FROM (for most queries)
# Allow for subqueries and complex queries
if 'FROM' not in sql_upper and 'UNION' not in sql_upper:
# Some queries might not have FROM (e.g., SELECT 1), but most should
# For our use case, we expect FROM clauses
if not any(keyword in sql_upper for keyword in ['SELECT 1', 'SELECT NOW()', 'SELECT']):
return False
return True
def _has_critical_issues(self, analysis: str, query_error: str = None, sql_query: str = None) -> bool:
"""Check if analysis or error indicates critical syntax/structural issues that require refinement"""
if not analysis:
return False
analysis_lower = analysis.lower()
# Check for critical syntax/structural keywords only (not incomplete answers - those are handled by confidence score)
critical_keywords = [
'missing select',
'syntax error',
'missing keyword',
'incorrect syntax',
'malformed',
'invalid syntax',
'invalid query structure'
]
if any(keyword in analysis_lower for keyword in critical_keywords):
return True
# Check query error for critical syntax issues
if query_error:
error_lower = query_error.lower()
if any(keyword in error_lower for keyword in ['syntax error', '1064', 'invalid syntax', 'malformed']):
return True
# Check SQL query itself for obvious structural issues
if sql_query:
sql_upper = sql_query.upper().strip()
# Missing SELECT keyword
if not sql_upper.startswith('SELECT') and any(kw in sql_upper for kw in ['FROM', 'WHERE', 'JOIN']):
return True
return False
async def _fetch_missing_table_info(
self, schema_info: dict, query_error: str = None, analysis: str = None, sql_query: str = None
):
"""Selectively fetch schema info for tables that are mentioned but not in schema_info"""
# Get all available table names
all_tables = schema_info.get("all_tables", [])
if not all_tables:
return # Can't fetch if we don't know available tables
# Get currently cached table descriptions
cached_tables = set(schema_info.get("table_descriptions", {}).keys())
# Extract table names from error, analysis, and SQL query
mentioned_tables = set()
# Extract from SQL query (FROM, JOIN clauses) - use compiled regex
if sql_query:
from_matches = self._regex_patterns['from_table'].findall(sql_query)
join_matches = self._regex_patterns['join_table'].findall(sql_query)
mentioned_tables.update(from_matches)
mentioned_tables.update(join_matches)
# Extract from error message (e.g., "Table 'X' doesn't exist") - use compiled regex
if query_error and isinstance(query_error, str):
try:
table_matches = self._regex_patterns['table_in_error'].findall(query_error)
# Filter out empty matches
mentioned_tables.update([t for t in table_matches if t])
except Exception as e:
if self.enable_logging:
logger.debug(f"Could not extract tables from error message: {str(e)}")
# Extract from analysis text (might mention table names) - use compiled regex
if analysis and isinstance(analysis, str):
try:
table_matches = self._regex_patterns['table_in_quotes'].findall(analysis)
# Filter to only include actual table names (use set for O(1) lookup)
tables_set = set(all_tables)
mentioned_tables.update([t for t in table_matches if t and t in tables_set])
except Exception as e:
if self.enable_logging:
logger.debug(f"Could not extract tables from analysis: {str(e)}")
# Find tables that are mentioned but not in cache
missing_tables = [t for t in mentioned_tables if t in all_tables and t not in cached_tables]
if not missing_tables:
return # No missing tables
# Fetch schema info for missing tables
describe_table_tool = await self._get_tool("describe_table")
get_foreign_keys_tool = await self._get_tool("get_foreign_keys")
# Get schema name if available
schema_name = None
if schema_info.get("tables"):
table_resources = self.schema_cache.get("table_resources", []) if self.schema_cache else []
if table_resources and "://" in table_resources[0]:
schema_name = table_resources[0].split("://")[1].split("/")[0]
# Fetch table descriptions in parallel
if describe_table_tool:
async def describe_table(table_name: str):
try:
desc = await self._call_tool_with_timeout(describe_table_tool, {"table_name": table_name})
return table_name, desc, None
except Exception as e:
if self.enable_logging:
logger.debug(f"Failed to fetch description for {table_name}: {str(e)}")
return table_name, None, str(e)
desc_tasks = [describe_table(t) for t in missing_tables]
desc_results = await asyncio.gather(*desc_tasks, return_exceptions=True)
table_descriptions = schema_info.get("table_descriptions", {})
for result in desc_results:
if not isinstance(result, Exception):
tbl, desc, _ = result
if desc:
table_descriptions[tbl] = desc
schema_info["table_descriptions"] = table_descriptions
# Update cache
if self.schema_cache:
existing_descriptions = self.schema_cache.get("table_descriptions", {})
existing_descriptions.update(table_descriptions)
self.schema_cache["table_descriptions"] = existing_descriptions
# Fetch foreign keys for missing tables in parallel
if get_foreign_keys_tool:
async def get_foreign_keys_for_table(table_name: str):
try:
fk_info = await self._call_tool_with_timeout(get_foreign_keys_tool, {"table_name": table_name})
return table_name, fk_info, None
except Exception as e:
if self.enable_logging:
logger.debug(f"Failed to fetch foreign keys for {table_name}: {str(e)}")
return table_name, None, str(e)
fk_tasks = [get_foreign_keys_for_table(t) for t in missing_tables]
fk_results = await asyncio.gather(*fk_tasks, return_exceptions=True)
foreign_keys = schema_info.get("foreign_keys", {})
for result in fk_results:
if not isinstance(result, Exception):
tbl, fk_info, _ = result
if fk_info:
foreign_keys[tbl] = fk_info
schema_info["foreign_keys"] = foreign_keys
async def _refine_sql_with_analysis(
self, query: str, sql: str, schema_info: dict, analysis: str,
query_results: list = None, query_error: str = None
) -> str:
"""Refine SQL query using the analysis/reasoning from confidence scoring"""
# Check if error is a simple column name error
is_simple_column_error = False
wrong_column = None
if query_error:
parsed_error = self._parse_sql_error(query_error)
if parsed_error.get("type") == "unknown_column":
is_simple_column_error = True
wrong_column = parsed_error.get("column")
# Build refinement prompt with the analysis
if is_simple_column_error:
# For simple column errors, emphasize preserving structure
refinement_prompt = f"""The SQL query below has a column name error. Fix ONLY the column name, keep everything else the same.
Original Question: {query}
Original SQL: {sql}
Query Error: {query_error}
CRITICAL INSTRUCTIONS:
1. Keep the EXACT same query structure - only change the column name
2. The column '{wrong_column}' does not exist - find the correct column name from the schema below
3. Replace '{wrong_column}' with the correct column name
4. Do NOT change the query logic, JOINs, subqueries, WHERE clauses, or any other part
5. The query structure is correct - only the column name needs to be fixed
"""
else:
refinement_prompt = f"""The SQL query below has issues identified in the analysis. Generate a corrected version.
Original Question: {query}
Original SQL: {sql}
Analysis of Issues: {analysis}
"""
if query_error and not is_simple_column_error:
refinement_prompt += f"Query Error: {query_error}\n\n"
elif query_results is not None:
results_str = json.dumps(query_results[:5], indent=2, default=str)
refinement_prompt += f"Query Results (sample):\n{results_str}\n\n"
# Check if original question has multiple parts
has_multi_part = False
if query:
query_lower = query.lower()
multi_part_indicators = ["and then", "then find", "also find", "also show", "additionally", "and also"]
has_multi_part = any(indicator in query_lower for indicator in multi_part_indicators)
if is_simple_column_error:
# For simple column errors, just fix the column name
refinement_prompt += """Find the correct column name in the schema below and replace the wrong column name in the SQL.
Keep everything else exactly the same - only change the column name.
Return ONLY the corrected SQL query (no explanations, no markdown formatting):"""
elif has_multi_part:
refinement_prompt += """Based on the analysis above, generate a corrected SQL query that:
1. Fixes the issues identified in the analysis
2. Uses EXACT column names from the schema
3. PRESERVES the original query structure - only fix the specific issues
4. Correctly answers the ORIGINAL QUESTION IN FULL - this question has multiple parts (e.g., "and then")
5. For multi-part questions like "find X and then find Y":
- Use a subquery or CTE to first find X
- Then use that result to find Y in the main query
- Example pattern: "find entities with most related items, then find their items after a date":
* First: Identify ALL entities with max count (subquery) - use MAX() to find max count, then HAVING COUNT = (SELECT MAX(...)) to get ALL entities with that count (not LIMIT 1)
* Then: Find related items by those entities matching the second condition
* Use WHERE foreign_key IN (subquery) or JOIN with subquery
6. Has proper JOIN conditions if needed
7. Includes all necessary WHERE clauses and filters
8. Ensure GROUP BY is correct - don't group by columns that should be in SELECT for aggregation
9. Handle ties correctly - when finding "entities with most X", find ALL entities who have the maximum count, not just one (use HAVING COUNT = (SELECT MAX(...)), not LIMIT 1)
CRITICAL: The question has multiple parts. Your SQL MUST:
- First solve the first part (e.g., "entities with most related items") - find ALL entities with max count, not just one
- Then use that result to solve the second part (e.g., "related items by those entities after a date")
- Use subqueries, CTEs, or IN clauses to connect the parts
- Handle ties: if multiple entities have the same max count, include ALL of them
- PRESERVE the query structure - only fix the specific issues mentioned
Return ONLY the corrected SQL query (no explanations, no markdown formatting):"""
else:
refinement_prompt += """Based on the analysis above, generate a corrected SQL query that:
1. Fixes the issues identified in the analysis
2. Uses EXACT column names from the schema
3. PRESERVES the original query structure - only fix the specific issues
4. Correctly answers the ORIGINAL QUESTION IN FULL
5. If the question asks for entities with the "most", "least", "maximum", "minimum", "highest", "lowest", "top", or "bottom" of something:
- Handle ties correctly by returning ALL entities that match the maximum/minimum value
- Use a subquery to find the MAX or MIN value first
- Then use HAVING COUNT = (SELECT MAX(...)) or HAVING COUNT = (SELECT MIN(...)) to find ALL entities with that value
- DO NOT use LIMIT 1 - this would only return one entity even if there are ties
6. Has proper JOIN conditions if needed
7. Includes all necessary WHERE clauses and filters
8. Ensure GROUP BY is correct - don't group by columns that should be in SELECT for aggregation
Return ONLY the corrected SQL query (no explanations, no markdown formatting):"""
# Build system prompt with schema info
system_prompt = """You are an expert SQL query generator for MySQL databases.
Your task is to fix SQL queries based on error analysis and feedback.
IMPORTANT RULES:
1. Only generate SELECT queries (read-only operations)
2. Use MySQL syntax with backticks for identifiers when needed
3. Use proper JOIN syntax when querying multiple tables
4. Always include appropriate WHERE clauses when filtering
5. Use LIMIT to restrict result sets when appropriate
6. Return ONLY the SQL query, no explanations or markdown formatting
7. DO NOT use parameterized queries (no ? placeholders) - embed all values directly in the SQL
8. CRITICAL: Use EXACT column names from the table structures below - DO NOT guess or assume column names
Database Schema Information:
"""
if schema_info.get("table_descriptions"):
system_prompt += "\n=== TABLE STRUCTURES (USE EXACT COLUMN NAMES) ===\n"
for table_name, desc in schema_info["table_descriptions"].items():
system_prompt += f"\n--- Table: {table_name} ---\n{desc}\n"
if schema_info.get("foreign_keys"):
system_prompt += "\n=== TABLE RELATIONSHIPS (FOR JOINs) ===\n"
for table_name, fk_info in schema_info["foreign_keys"].items():
if fk_info:
system_prompt += f"\n{table_name} relationships:\n{fk_info}\n"
try:
messages_list = [
SystemMessage(content=system_prompt),
HumanMessage(content=refinement_prompt)
]
response = await self._call_llm_with_timeout(messages_list)
refined_sql = response.content
# Handle None or empty content
if not refined_sql:
if self.enable_logging:
logger.warning("LLM returned empty SQL during refinement")
refined_sql = sql # Fallback to original SQL
else:
refined_sql = refined_sql.strip()
# Clean up refined SQL
if refined_sql.startswith("```sql"):
refined_sql = refined_sql[6:]
if refined_sql.startswith("```"):
refined_sql = refined_sql[3:]
if refined_sql.endswith("```"):
refined_sql = refined_sql[:-3]
refined_sql = refined_sql.strip()
# Extract just the SQL if there's extra text (use compiled regex)
if not refined_sql.upper().startswith('SELECT'):
select_match = self._regex_patterns['select_statement'].search(refined_sql)
if select_match:
extracted = select_match.group(1)
if extracted:
refined_sql = extracted.strip()
# Validate extracted SQL is not empty
if not refined_sql:
if self.enable_logging:
logger.warning("Extracted SQL is empty, using original SQL")
refined_sql = sql # Fallback to original
# Remove parameterized query placeholders (use compiled regex)
if "?" in refined_sql:
refined_sql = self._regex_patterns['param_and_or'].sub('', refined_sql)
refined_sql = self._regex_patterns['param_where'].sub('', refined_sql)
refined_sql = self._regex_patterns['param_standalone'].sub(' ', refined_sql)
refined_sql = self._regex_patterns['param_end'].sub('', refined_sql)
refined_sql = self._regex_patterns['param_start'].sub('', refined_sql)
refined_sql = refined_sql.replace('?', '')
refined_sql = self._regex_patterns['trailing_clauses'].sub('', refined_sql)
refined_sql = refined_sql.strip()
return refined_sql if refined_sql else sql
except TimeoutError:
error_msg = "SQL refinement timed out"
if self.enable_logging:
logger.error(error_msg)
return sql # Return original SQL if refinement times out
except Exception as e:
# If refinement fails, return original SQL
error_msg = f"SQL refinement failed: {str(e)}"
if self.enable_logging:
logger.error(error_msg, exc_info=True)
return sql
async def _analyze_query_confidence(
self, query: str, sql: str, schema_info: dict, confidence: float,
query_results: list = None, query_error: str = None
) -> str:
"""Analyze the query and provide reasoning based on actual query results (always included)"""
if query_error:
# If there was an error, analyze it
analysis_prompt = f"""Analyze this SQL query and explain why it failed or succeeded.
Question: {query}
SQL: {sql}
Query Error: {query_error}
Confidence Score: {confidence:.2f}
Analyze the error and identify:
1. What specific issue caused the error?
2. Is it a column name problem, table name problem, or syntax issue?
3. What needs to be fixed in the query?
4. How does this error relate to the confidence score?
Provide a brief analysis (2-3 sentences) explaining the issue and how to fix it:"""
elif query_results is not None:
# If we have results, analyze if they make sense
results_str = json.dumps(query_results[:5], indent=2, default=str)
analysis_prompt = f"""Analyze this SQL query and explain why it might be correct or incorrect.
Question: {query}
SQL: {sql}
Query Results (sample, up to 5 rows):
{results_str}
Confidence Score: {confidence:.2f}
Analyze if the results make sense:
1. Do the results answer the question correctly?
2. Are the column names and data types correct?
3. Is the query logic correct (JOINs, filters, aggregations)?
4. Are there any issues with the data returned?
5. How does the confidence score relate to the actual results?
Provide a brief analysis (2-3 sentences) explaining the query quality and any potential issues:"""
else:
# No results available, analyze based on schema only
analysis_prompt = f"""Analyze this SQL query and explain why it might be correct or incorrect.
Question: {query}
SQL: {sql}
Confidence Score: {confidence:.2f}
Available schema information:
- Relevant tables: {', '.join(schema_info.get('relevant_tables', []))}
- Column names available in schema
Identify specific issues:
1. Are there any column names that might be incorrect?
2. Are JOIN conditions correct?
3. Does the query structure match the question?
4. Are there any missing WHERE clauses or filters?
5. Are there any syntax issues?
6. How does the confidence score relate to the query structure?
Provide a brief analysis (2-3 sentences) explaining the potential issues:"""
try:
response = await self._call_llm_with_timeout([HumanMessage(content=analysis_prompt)])
content = response.content
if not content:
if self.enable_logging:
logger.warning("LLM returned empty content for analysis")
content = "Analysis unavailable."
return content.strip()
except TimeoutError:
error_msg = "Analysis timed out"
if self.enable_logging:
logger.error(error_msg)
if query_error:
return f"Query error: {query_error}. Confidence: {confidence:.2f}. {error_msg}."
elif query_results is not None:
return f"Query returned {len(query_results)} rows. Confidence: {confidence:.2f}. {error_msg}."
else:
return f"Confidence: {confidence:.2f}. {error_msg}."
except Exception as e:
error_msg = f"Error generating analysis: {str(e)}"
if self.enable_logging:
logger.error(error_msg, exc_info=True)
if query_error:
return f"Query error: {query_error}. Confidence: {confidence:.2f}. Please review the query carefully."
elif query_results is not None:
return f"Query returned {len(query_results)} rows. Confidence: {confidence:.2f}. Please review the results carefully."
else:
return f"Confidence: {confidence:.2f}. Please review the query carefully against the schema information."
def _parse_sql_error(self, error_msg: str) -> dict:
"""Parse SQL error to extract actionable information (uses compiled regex for performance)"""
if not error_msg:
return {"type": "unknown", "message": error_msg}
error_info = {"type": "unknown", "message": error_msg}
# Parse "Unknown column" errors (use compiled regex)
col_match = self._regex_patterns['unknown_column'].search(error_msg)
if col_match:
error_info["type"] = "unknown_column"
col_name = col_match.group(1)
if col_name: # Ensure group is not empty
# Handle table.column format
if '.' in col_name:
parts = col_name.split('.', 1)
if len(parts) == 2:
error_info["table"] = parts[0].strip(' `"\'')
error_info["column"] = parts[1].strip(' `"\'')
else:
error_info["column"] = col_name.strip(' `"\'')
else:
error_info["column"] = col_name.strip(' `"\'')
else:
# Empty match, keep as unknown
error_info["type"] = "unknown"
# Parse "Table doesn't exist" errors (use compiled regex)
table_match = self._regex_patterns['unknown_table'].search(error_msg)
if table_match:
table_name = table_match.group(1)
if table_name: # Ensure group is not empty
error_info["type"] = "unknown_table"
error_info["table"] = table_name.strip(' `"\'')
else:
# Empty match, keep as unknown
error_info["type"] = "unknown"
# Parse syntax errors
if "syntax error" in error_msg.lower() or "1064" in error_msg:
error_info["type"] = "syntax_error"
return error_info
# ========================================================================
# Node Functions (State Processors)
# ========================================================================
async def _explore_schema_node(self, state: AgentState) -> dict:
"""Explore database schema to understand available tables and structure.
Fetches table descriptions, foreign keys, and caches schema information.
Sets schema_complete flag for edge routing.
"""
messages = state["messages"]
schema_info = state.get("schema_info", {})
relevant_tables = state.get("relevant_tables", [])
# Use cached schema data if available (but still need to identify relevant tables for this query)
# IMPORTANT: Don't copy relevant_tables from cache - they're query-specific
if self.schema_cache:
# Use cached table descriptions, table names, etc.
if self.schema_cache.get("table_descriptions"):
schema_info["table_descriptions"] = self.schema_cache.get("table_descriptions", {}).copy()
if self.schema_cache.get("table_names"):
schema_info["table_names"] = self.schema_cache.get("table_names", [])[:] # Copy list
if self.schema_cache.get("all_tables"):
schema_info["all_tables"] = self.schema_cache.get("all_tables", [])[:] # Copy list
if self.schema_cache.get("tables"):
schema_info["tables"] = self.schema_cache.get("tables")
# Copy foreign keys if cached
if self.schema_cache.get("foreign_keys"):
schema_info["foreign_keys"] = self.schema_cache.get("foreign_keys", {}).copy()
# Always reset relevant_tables for new query (don't use cached ones)
relevant_tables = [] # Reset - will be identified for this specific query
# Retrieve the user's question (needed for identifying relevant tables)
user_query = None
for msg in reversed(messages):
if isinstance(msg, HumanMessage):
user_query = msg.content
break
if not user_query:
if self.enable_logging:
logger.warning("No user query found in messages for schema exploration")
return {"schema_info": schema_info}
# CRITICAL: Always identify relevant tables for THIS query (even if schema is cached)
# This ensures each query gets fresh relevant_tables based on the current question
# Get table names from cache first
table_names = schema_info.get("table_names", [])
if not table_names and self.schema_cache:
table_names = self.schema_cache.get("table_names", [])
# If we have table names, identify relevant ones for this query NOW
if table_names and not relevant_tables:
relevant_tables = await self._identify_relevant_tables(user_query, table_names)
schema_info["relevant_tables"] = relevant_tables
# Check if schema exploration is complete (decision made by edge)
# Schema is complete if we have descriptions for all relevant tables
schema_complete = False
if schema_info and schema_info.get("table_descriptions") and relevant_tables:
cached_descriptions = self.schema_cache.get("table_descriptions", {}) if self.schema_cache else {}
# Use set for O(1) lookup
cached_tables_set = set(cached_descriptions.keys())
if all(tbl in cached_tables_set for tbl in relevant_tables):
# Also check if we have foreign keys for all relevant tables
cached_fks = schema_info.get("foreign_keys", {})
if cached_fks or not relevant_tables: # If no FKs needed or already have them
schema_complete = True
elif schema_info and schema_info.get("table_descriptions") and not relevant_tables:
# If no relevant tables identified but we have all descriptions, schema is complete
schema_complete = True
# If schema is complete AND we've identified relevant tables, return early
if schema_complete and relevant_tables:
return {
"schema_info": schema_info,
"relevant_tables": relevant_tables, # Include identified relevant tables
"schema_complete": True
}
# If we don't have table names yet, we need to fetch them first
if not table_names:
# Continue to fetch below
pass
# Schema not complete - continue fetching (edge will route back to explore_schema)
# Initialize tools (cached where possible)
await self._initialize_tools()
list_tables_tool = await self._get_tool("list_tables")
list_table_resources_tool = await self._get_tool("list_table_resources")
describe_table_tool = await self._get_tool("describe_table")
# Get tables list (use cache if present)
if self.schema_cache and self.schema_cache.get("tables"):
tables_result = self.schema_cache["tables"]
elif list_tables_tool:
try:
tables_result = await self._call_tool_with_timeout(list_tables_tool, {})
self.schema_cache = self.schema_cache or {}
self.schema_cache["tables"] = tables_result
except Exception as e:
if self.enable_logging:
logger.warning(f"Failed to fetch tables list: {str(e)}")
tables_result = None
else:
tables_result = None
if tables_result:
schema_info["tables"] = tables_result
# Get table resources to extract table names (use cache if present)
if self.schema_cache and self.schema_cache.get("table_resources"):
table_resources = self.schema_cache["table_resources"]
elif list_table_resources_tool:
try:
table_resources = await self._call_tool_with_timeout(list_table_resources_tool, {})
self.schema_cache = self.schema_cache or {}
self.schema_cache["table_resources"] = table_resources
except Exception as e:
if self.enable_logging:
logger.warning(f"Failed to fetch table resources: {str(e)}")
table_resources = []
else:
table_resources = []
# Extract table names safely
table_names = []
for r in table_resources:
if r and isinstance(r, str) and r.startswith("table://"):
try:
table_name = r.split("/")[-1]
if table_name: # Ensure we got a valid name
table_names.append(table_name)
except (IndexError, AttributeError) as e:
if self.enable_logging:
logger.debug(f"Could not parse table resource '{r}': {str(e)}")
continue
# Extract schema name safely
schema_name = None
if table_resources:
try:
first_resource = table_resources[0]
if isinstance(first_resource, str) and "://" in first_resource:
schema_name = first_resource.split("://")[1].split("/")[0]
except (IndexError, AttributeError) as e:
if self.enable_logging:
logger.debug(f"Could not extract schema name: {str(e)}")
schema_name = None
# Update schema_info with table names
schema_info["table_names"] = table_names
schema_info["all_tables"] = table_names
# Determine relevant tables via LLM (if not already identified above)
if not relevant_tables and table_names:
relevant_tables = await self._identify_relevant_tables(user_query, table_names)
schema_info["relevant_tables"] = relevant_tables
elif not relevant_tables:
relevant_tables = table_names
schema_info["relevant_tables"] = relevant_tables
# Describe relevant tables (parallel fetching)
# Check cache first - only fetch descriptions for tables not in cache
table_descriptions = {}
cached_descriptions = self.schema_cache.get("table_descriptions", {}) if self.schema_cache else {}
# Get descriptions from cache for tables that are already cached
tables_to_fetch = []
for table_name in relevant_tables:
if table_name in cached_descriptions:
table_descriptions[table_name] = cached_descriptions[table_name]
else:
tables_to_fetch.append(table_name)
# Only fetch descriptions for tables not in cache
if describe_table_tool and tables_to_fetch:
async def describe_table(table_name: str):
try:
desc = await self._call_tool_with_timeout(describe_table_tool, {"table_name": table_name})
return table_name, desc, None
except Exception as e:
if self.enable_logging:
logger.debug(f"Failed to fetch description for {table_name}: {str(e)}")
return table_name, None, str(e)
desc_tasks = [describe_table(t) for t in tables_to_fetch]
desc_results = await asyncio.gather(*desc_tasks, return_exceptions=True)
for result in desc_results:
if not isinstance(result, Exception):
tbl, desc, _ = result
if desc:
table_descriptions[tbl] = desc
schema_info["table_descriptions"] = table_descriptions
schema_info["table_names"] = table_names
schema_info["all_tables"] = table_names # Keep full list for reference
# Get foreign key relationships for relevant tables (parallel fetching)
get_foreign_keys_tool = await self._get_tool("get_foreign_keys")
foreign_keys = {}
if get_foreign_keys_tool and relevant_tables:
async def get_foreign_keys_for_table(table_name: str):
try:
fk_info = await self._call_tool_with_timeout(get_foreign_keys_tool, {"table_name": table_name})
return table_name, fk_info, None
except Exception as e:
if self.enable_logging:
logger.debug(f"Failed to fetch foreign keys for {table_name}: {str(e)}")
return table_name, None, str(e)
# Fetch foreign keys in parallel for performance
fk_tasks = [get_foreign_keys_for_table(t) for t in relevant_tables]
fk_results = await asyncio.gather(*fk_tasks, return_exceptions=True)
for result in fk_results:
if not isinstance(result, Exception):
tbl, fk_info, _ = result
if fk_info:
foreign_keys[tbl] = fk_info
schema_info["foreign_keys"] = foreign_keys
# Cache heavy schema parts for future queries
# Merge with existing cache to preserve all table descriptions
self.schema_cache = self.schema_cache or {}
existing_descriptions = self.schema_cache.get("table_descriptions", {})
existing_descriptions.update(table_descriptions) # Merge new descriptions with existing
self.schema_cache.update({
"table_descriptions": existing_descriptions,
"table_names": table_names,
"all_tables": table_names,
})
# Manage cache size
self._manage_schema_cache_size()
# Schema fetching complete - return with flag (edge will route to generate_sql)
return {
"schema_info": schema_info,
"relevant_tables": relevant_tables,
"schema_complete": True
}
async def _generate_sql_node(self, state: AgentState) -> dict:
"""Generate SQL query from natural language using LLM.
Generates SQL, executes test query, calculates confidence, and sets
refinement flags. Does NOT perform refinement (handled by refine_sql node).
Validates that the user query is actually a database question before generating SQL.
"""
messages = state["messages"]
schema_info = state.get("schema_info", {})
query_attempts = state.get("query_attempts", 0)
previous_error = state.get("previous_error") # Optional: error from previous attempt
previous_sql = state.get("previous_sql") # Optional: SQL from previous attempt
# Retrieve the user's question
user_query = None
for msg in reversed(messages):
if isinstance(msg, HumanMessage):
user_query = msg.content
break
if not user_query:
return {
"messages": [AIMessage(content="Error: No user question found in messages.")],
"should_refine": False
}
# Validate that this is actually a database question (skip validation on retries)
if previous_error is None and previous_sql is None: # Only validate on first attempt
try:
is_valid_query, validation_reason = await self._validate_query_is_database_question(user_query)
if not is_valid_query:
error_msg = f"I cannot generate a SQL query for this request. {validation_reason}\n\nPlease ask a question about the database, such as:\n- \"Show me all authors\"\n- \"How many books are there?\"\n- \"Find books by a specific author\""
if self.enable_logging:
logger.warning(f"Query validation failed: {validation_reason}")
return {
"messages": [AIMessage(content=error_msg)],
"should_refine": False,
"confidence": 0.0,
"sql_is_valid": False,
"has_critical_issues": True
}
except Exception as e:
error_msg = f"Error validating query: {str(e)}"
if self.enable_logging:
logger.error(error_msg, exc_info=True)
return {
"messages": [AIMessage(content=f"Error: {error_msg}. Please try again.")],
"should_refine": False,
"confidence": 0.0,
"sql_is_valid": False,
"has_critical_issues": True
}
# Build system prompt with schema information (optimized for size)
system_prompt = """You are an expert SQL query generator for MySQL databases.
RULES:
1. Only SELECT queries (read-only)
2. Use MySQL syntax with backticks when needed
3. Use proper JOIN syntax for multiple tables
4. Include WHERE clauses when filtering
5. Use LIMIT when appropriate (but see rule 11 for exceptions)
6. Return ONLY SQL, no explanations or markdown
7. NO parameterized queries (no ? placeholders) - embed values directly
8. CRITICAL: Use EXACT column names from structures below - DO NOT guess
9. Check column names summary for available columns
10. For employee questions, always select employee name to identify them
11. IMPORTANT: If the question asks for entities with the "most", "least", "maximum", "minimum", "highest", "lowest", "top", or "bottom" of something:
- You MUST handle ties correctly by returning ALL entities that match the maximum/minimum value
- Use a subquery to find the MAX or MIN value first
- Then use HAVING COUNT = (SELECT MAX(...)) or HAVING COUNT = (SELECT MIN(...)) to find ALL entities with that value
- DO NOT use LIMIT 1 for these queries - it would only return one entity even if there are ties
- Example: For "who has the most books", if two authors both have 2 books (the maximum), return BOTH authors
Notes:
- `emp_employee_id` = eecode (unique employee identifier)
- Today: {date.today().isoformat()}
Schema:
"""
if schema_info.get("tables"):
system_prompt += f"\nAvailable Tables:\n{schema_info['tables']}\n"
# Show relevant tables (concise)
relevant_tables = schema_info.get("relevant_tables", [])
if relevant_tables:
system_prompt += f"Relevant tables: {', '.join(relevant_tables)}\n\n"
# Add table structures (only for relevant tables to reduce prompt size)
if schema_info.get("table_descriptions"):
for table_name, desc in schema_info["table_descriptions"].items():
if not relevant_tables or table_name in relevant_tables:
system_prompt += f"\n{table_name}:\n{desc}\n"
# Add foreign key relationships (concise)
if schema_info.get("foreign_keys"):
system_prompt += "\nRelationships:\n"
for table_name, fk_info in schema_info["foreign_keys"].items():
if fk_info and (not relevant_tables or table_name in relevant_tables):
system_prompt += f"{table_name}: {fk_info}\n"
# Extract and list column names explicitly for better visibility (cached for performance)
system_prompt += "\n=== COLUMN NAMES SUMMARY (USE THESE EXACT NAMES) ===\n"
for table_name, desc in schema_info["table_descriptions"].items():
# Check cache first (move to end for LRU)
if table_name in self.column_cache:
columns = self.column_cache[table_name]
# Move to end (most recently used)
self.column_cache.move_to_end(table_name)
else:
# Extract column names from description
lines_desc = desc.split('\n')
columns = []
header_found = False
for line in lines_desc:
line = line.strip()
if not line or line.startswith('---'):
continue
if '|' in line:
parts = [p.strip() for p in line.split('|') if p.strip()]
if parts and not header_found:
if 'Field' in parts[0] or 'field' in parts[0].lower():
header_found = True
continue
elif header_found and parts:
col_name = parts[0].strip('`').strip()
if col_name and col_name not in ['Field', '---', '']:
columns.append(col_name)
# Cache the extracted columns (with LRU management)
self.column_cache[table_name] = columns
self._manage_column_cache_size()
if columns:
system_prompt += f"\n{table_name}: {', '.join(columns)}\n"
else:
# Fallback: show abbreviated description (not full)
system_prompt += f"\n{table_name}: (see table structure above)\n"
# Optional: Add error context from previous attempt if available
if previous_error:
error_info = f"\n\nPrevious query error: {previous_error}"
problematic_column = None
# Use compiled regex pattern for better performance
col_match = self._regex_patterns['unknown_column'].search(previous_error)
if col_match:
problematic_column = col_match.group(1)
error_guidance = f"\n\nThis is attempt {query_attempts + 1}. Previous query failed.{error_info}"
if problematic_column:
error_guidance += f"\n\n⚠️ CRITICAL: The column '{problematic_column}' does NOT exist! Check the column names summary above for the correct column name."
if previous_sql:
error_guidance += f"\n\nPrevious SQL (for reference - fix the error but preserve the query structure):\n```sql\n{previous_sql}\n```"
error_guidance += "\n\nIMPORTANT REMINDERS: \n- Pay close attention to the EXACT column names from the table structures above. Do NOT guess column names.\n- Use the exact column names shown in the table structures - do not assume generic names like 'id' if the actual column has a different name.\n- Do NOT use parameterized queries with ? placeholders - embed all values directly in the SQL.\n- If the error mentions '?' or 'syntax error', make sure your SQL has no ? placeholders.\n- If fixing a column name error, keep the query structure the same - only change the column name."
system_prompt += error_guidance
# Get the latest user message
user_query = None
for msg in reversed(messages):
if isinstance(msg, HumanMessage):
content = msg.content
if content: # Ensure content is not None or empty
user_query = content
break
if not user_query:
return {"messages": [AIMessage(content="I need a question to generate a SQL query.")]}
# Check if query has multiple parts (e.g., "and then", "also find")
has_multi_part = False
if user_query:
user_lower = user_query.lower()
multi_part_indicators = ["and then", "then find", "also find", "also show", "additionally", "and also"]
has_multi_part = any(indicator in user_lower for indicator in multi_part_indicators)
# Generate SQL using LLM (optimized prompt - shorter but still effective)
if has_multi_part:
prompt = f"""Convert to SQL query. IMPORTANT: This question has multiple parts separated by "and then" or similar - you must answer ALL parts.
For multi-part questions like "find X and then find Y":
- Use a subquery or CTE to first find X
- Then use that result to find Y
- Example pattern: "find entities with the most related items, then find their items after a date" should:
1. First identify ALL entities with max count (subquery/CTE) - use MAX() to find the max count, then find ALL entities with that count (not LIMIT 1)
2. Then find related items by those entities matching the second condition
3. Join or filter using the first result
IMPORTANT: When finding "entities with the most X" or similar maximum count queries:
- Find the MAX count first using a subquery
- Then find ALL entities who have that count (use HAVING COUNT = (SELECT MAX(...)), not LIMIT 1)
- This handles ties correctly (multiple entities with same max count)
Steps:
1. Break down the question into parts
2. Create subquery/CTE for the first part (handle ties - find ALL matches, not just one)
3. Use that result in the main query for the second part
4. Use EXACT column names from structures
5. Write complete SQL
Question: {user_query}
SQL (must answer ALL parts using subqueries/CTEs, and handle ties correctly):"""
else:
prompt = f"""Convert to SQL query.
IMPORTANT: Analyze the question carefully. If the question asks for entities with the "most", "least", "maximum", "minimum", "highest", "lowest", "top", or "bottom" of something:
- You MUST handle ties correctly by returning ALL entities that match the maximum/minimum value
- Use a subquery to find the MAX or MIN value first
- Then use HAVING COUNT = (SELECT MAX(...)) or HAVING COUNT = (SELECT MIN(...)) to find ALL entities with that value
- DO NOT use LIMIT 1 - this would only return one entity even if there are ties
- Return ALL entities that match the maximum/minimum value
Example: For "who has the most books", if two authors both have 2 books (the maximum), return BOTH authors, not just one.
Steps:
1. Identify needed tables
2. Select columns (use EXACT names from structures)
3. Add JOINs (use relationships if available)
4. If question asks for most/least/maximum/minimum, use subquery + HAVING pattern to handle ties
5. Add filters/aggregations as needed
6. Write SQL
Question: {user_query}
SQL:"""
messages_list = [
SystemMessage(content=system_prompt),
HumanMessage(content=prompt)
]
response = await self._call_llm_with_timeout(messages_list)
sql_query = response.content
# Handle None or empty content
if not sql_query:
if self.enable_logging:
logger.warning("LLM returned empty SQL query")
return {
"messages": [AIMessage(content="Error: LLM returned empty SQL query. Please try rephrasing your question.")],
"should_refine": False,
"sql_is_valid": False,
"has_critical_issues": True,
"confidence": 0.0
}
sql_query = sql_query.strip()
# Clean up SQL query (remove markdown code blocks if present)
if sql_query.startswith("```sql"):
sql_query = sql_query[6:]
if sql_query.startswith("```"):
sql_query = sql_query[3:]
if sql_query.endswith("```"):
sql_query = sql_query[:-3]
sql_query = sql_query.strip()
# Extract just the SQL if the response includes reasoning (use compiled regex)
# First, try to extract SQL from code blocks
code_block_match = self._regex_patterns['sql_code_block'].search(sql_query)
if code_block_match:
extracted = code_block_match.group(1)
if extracted:
sql_query = extracted.strip()
else:
# Try to find SQL after "SQL Query:" marker
sql_marker_match = self._regex_patterns['sql_marker'].search(sql_query)
if sql_marker_match:
extracted = sql_marker_match.group(1)
if extracted:
sql_query = extracted.strip()
else:
# Try to extract SELECT statement
select_match = self._regex_patterns['select_statement'].search(sql_query)
if select_match:
extracted = select_match.group(1)
if extracted:
sql_query = extracted.strip()
# Validate SQL is not empty after extraction
if not sql_query or not sql_query.strip():
if self.enable_logging:
logger.warning("SQL query is empty after extraction")
return {
"messages": [AIMessage(content="Error: Could not extract valid SQL query from LLM response. Please try rephrasing your question.")],
"should_refine": False,
"sql_is_valid": False,
"has_critical_issues": True,
"confidence": 0.0
}
# If SQL doesn't start with SELECT, try to fix it
if not sql_query.upper().strip().startswith('SELECT'):
# Check if it looks like a SELECT query without the keyword
sql_upper = sql_query.upper()
if any(keyword in sql_upper for keyword in ['FROM', 'WHERE', 'JOIN', 'GROUP BY', 'ORDER BY', 'HAVING']):
# Looks like a SELECT query missing the SELECT keyword
sql_query = f"SELECT {sql_query}"
elif sql_upper.startswith('COUNT') or sql_upper.startswith('SUM') or sql_upper.startswith('AVG') or sql_upper.startswith('MAX') or sql_upper.startswith('MIN'):
# Aggregation function without SELECT
sql_query = f"SELECT {sql_query}"
else:
# SQL doesn't look like a valid SELECT query
if self.enable_logging:
logger.warning(f"SQL doesn't start with SELECT and doesn't match known patterns: {sql_query[:100]}")
return {
"messages": [AIMessage(content="Error: Generated SQL does not appear to be a valid SELECT query. Please try rephrasing your question.")],
"should_refine": False,
"sql_is_valid": False,
"has_critical_issues": True,
"confidence": 0.0
}
# Remove any parameterized query placeholders (?) - use compiled regex for performance
if "?" in sql_query:
# Remove ? placeholders that are used as parameter markers
sql_query = self._regex_patterns['param_and_or'].sub('', sql_query)
sql_query = self._regex_patterns['param_where'].sub('', sql_query)
sql_query = self._regex_patterns['param_standalone'].sub(' ', sql_query)
sql_query = self._regex_patterns['param_end'].sub('', sql_query)
sql_query = self._regex_patterns['param_start'].sub('', sql_query)
# Clean up any remaining ? characters (safety fallback)
sql_query = sql_query.replace('?', '')
# Clean up trailing WHERE/AND/OR
sql_query = self._regex_patterns['trailing_clauses'].sub('', sql_query)
sql_query = sql_query.strip()
# Execute test query with LIMIT to get sample results for confidence scoring
query_results = None
query_error = None
try:
# Add LIMIT if not already present
test_sql = sql_query
if "LIMIT" not in test_sql.upper():
test_sql = f"{sql_query.rstrip(';')} LIMIT {TEST_QUERY_LIMIT}"
else:
# Replace existing LIMIT
test_sql = self._regex_patterns['limit_replace'].sub(f' LIMIT {TEST_QUERY_LIMIT}', test_sql)
# Execute the query to get sample results
run_query_tool = await self._get_tool("run_query_json")
if run_query_tool:
try:
result = await self._call_tool_with_timeout(
run_query_tool,
{
"input": {
"sql": test_sql,
"row_limit": TEST_QUERY_LIMIT
}
}
)
except TimeoutError:
query_error = f"Query execution timed out after {self.query_timeout} seconds"
if self.enable_logging:
logger.error(query_error)
except Exception as e:
query_error = f"Query execution failed: {str(e)}"
if self.enable_logging:
logger.error(query_error, exc_info=True)
else:
# Process result only if no exception occurred
if 'result' in locals():
if isinstance(result, list):
query_results = result
elif isinstance(result, str) and ("error" in result.lower() or "query error" in result.lower()):
query_error = result
except Exception as e:
query_error = f"Unexpected error during test query execution: {str(e)}"
if self.enable_logging:
logger.error(query_error, exc_info=True)
# Calculate confidence score and analysis in a single LLM call (optimization)
confidence, confidence_reasoning = await self._score_and_analyze_query(
user_query, sql_query, schema_info, query_results, query_error
)
# Validate SQL syntax - check if it's a valid SELECT query
sql_is_valid = self._validate_sql_syntax(sql_query)
# Check if analysis indicates critical syntax/structural issues (incomplete answers handled by confidence score)
has_critical_issues = self._has_critical_issues(confidence_reasoning, query_error, sql_query)
# Determine if refinement is needed (decision made by edge)
# Refine if: error exists OR low confidence OR invalid SQL OR critical issues
# Skip if: very high confidence with valid SQL and no error
has_error = query_error is not None
has_low_confidence = confidence < LOW_CONFIDENCE_THRESHOLD
should_refine = (has_error or has_low_confidence or not sql_is_valid or has_critical_issues)
# Skip refinement if confidence is very high and SQL is valid
if should_refine:
is_very_high_confidence = confidence >= VERY_HIGH_CONFIDENCE_THRESHOLD
is_high_confidence_no_error = (previous_error is None and
confidence >= HIGH_CONFIDENCE_THRESHOLD and
sql_is_valid)
if (is_very_high_confidence and sql_is_valid and not has_error) or is_high_confidence_no_error:
should_refine = False
# Determine refine reason if needed
refine_reason = None
if should_refine:
if not sql_is_valid:
refine_reason = "SQL syntax is invalid"
elif has_critical_issues:
refine_reason = "Critical issues detected in analysis"
elif has_low_confidence:
refine_reason = f"Low confidence ({confidence:.2f})"
# Build message with confidence score, sample results, and reasoning
confidence_msg = f"Generated SQL query (confidence: {confidence:.2f}):\n```sql\n{sql_query}\n```"
if query_error:
confidence_msg += f"\n\nQuery Error:\n{query_error}"
elif query_results is not None:
confidence_msg += f"\n\nSample Results (first {len(query_results)} rows):\n{json.dumps(query_results, indent=2, default=str)}"
# Always include reasoning
confidence_msg += f"\n\nAnalysis:\n{confidence_reasoning}"
new_messages = [AIMessage(content=confidence_msg)]
return {
"messages": new_messages,
"query_attempts": query_attempts + 1,
"previous_error": None, # Clear previous error after generating new SQL
"previous_sql": None, # Clear previous SQL after generating new SQL
"final_sql": sql_query, # Store final SQL to avoid re-extraction
"test_query_results": query_results if query_error is None else None, # Store test results if successful
# Set refinement decision flags for edge
"should_refine": should_refine,
"refine_reason": refine_reason,
"confidence": confidence,
"sql_is_valid": sql_is_valid,
"has_critical_issues": has_critical_issues,
"confidence_reasoning": confidence_reasoning,
"query_error": query_error,
"final_sql": sql_query # Store SQL for refinement node
}
async def _refine_sql_node(self, state: AgentState) -> dict:
"""Refine SQL query based on confidence score and analysis.
Called when should_refine flag is True. Fetches missing schema if needed,
refines SQL, re-executes test query, and recalculates confidence.
"""
messages = state["messages"]
schema_info = state.get("schema_info", {})
user_query = None
for msg in reversed(messages):
if isinstance(msg, HumanMessage):
user_query = msg.content
break
if not user_query:
return {}
# Get SQL and analysis from state
sql_query = state.get("final_sql")
if not sql_query:
sql_query = self._extract_sql_from_messages(messages)
if not sql_query:
return {}
confidence_reasoning = state.get("confidence_reasoning", "")
query_error = state.get("query_error")
query_results = state.get("test_query_results")
refine_reason = state.get("refine_reason", "Low confidence or errors detected")
# Check if we need to fetch schema info for new tables mentioned in error/analysis
await self._fetch_missing_table_info(schema_info, query_error, confidence_reasoning, sql_query)
# Refine the SQL using the analysis/reasoning
refined_sql = await self._refine_sql_with_analysis(
user_query, sql_query, schema_info, confidence_reasoning, query_results, query_error
)
# Validate refined SQL is not empty
if not refined_sql or not refined_sql.strip():
if self.enable_logging:
logger.warning("Refined SQL is empty, using original SQL")
refined_sql = sql_query
if refined_sql == sql_query:
# Refinement didn't change SQL, return original
return {
"final_sql": sql_query,
"should_refine": False
}
# Re-execute the refined query to get new results
query_results = None
query_error = None
try:
test_sql = refined_sql
if "LIMIT" not in test_sql.upper():
test_sql = f"{refined_sql.rstrip(';')} LIMIT {TEST_QUERY_LIMIT}"
else:
test_sql = self._regex_patterns['limit_replace'].sub(f' LIMIT {TEST_QUERY_LIMIT}', test_sql)
run_query_tool = await self._get_tool("run_query_json")
if run_query_tool:
try:
result = await self._call_tool_with_timeout(
run_query_tool,
{
"input": {
"sql": test_sql,
"row_limit": TEST_QUERY_LIMIT
}
}
)
except TimeoutError:
query_error = f"Query execution timed out after {self.query_timeout} seconds"
if self.enable_logging:
logger.error(query_error)
except Exception as e:
query_error = f"Query execution failed: {str(e)}"
if self.enable_logging:
logger.error(query_error, exc_info=True)
else:
# Process result only if no exception occurred
if 'result' in locals():
if isinstance(result, list):
query_results = result
elif isinstance(result, str) and ("error" in result.lower() or "query error" in result.lower()):
query_error = result
except Exception as e:
query_error = str(e)
# Recalculate confidence and analysis for refined query
confidence, confidence_reasoning = await self._score_and_analyze_query(
user_query, refined_sql, schema_info, query_results, query_error
)
# Validate refined SQL
sql_is_valid = self._validate_sql_syntax(refined_sql)
has_critical_issues = self._has_critical_issues(confidence_reasoning, query_error, refined_sql)
# Update message with refined SQL
confidence_msg = f"Generated SQL query (confidence: {confidence:.2f}):\n```sql\n{refined_sql}\n```"
if query_error:
confidence_msg += f"\n\nQuery Error:\n{query_error}"
elif query_results is not None:
confidence_msg += f"\n\nSample Results (first {len(query_results)} rows):\n{json.dumps(query_results, indent=2, default=str)}"
confidence_msg += f"\n\nAnalysis:\n{confidence_reasoning}"
confidence_msg += f"\n\n⚠️ Query was refined due to: {refine_reason}"
return {
"messages": [AIMessage(content=confidence_msg)],
"final_sql": refined_sql,
"test_query_results": query_results if query_error is None else None,
"confidence": confidence,
"sql_is_valid": sql_is_valid,
"has_critical_issues": has_critical_issues,
"confidence_reasoning": confidence_reasoning,
"query_error": query_error,
"should_refine": False # Clear refinement flag after refining
}
# ========================================================================
# Edge Decision Functions (Orchestration Logic)
# ========================================================================
def _should_continue_schema_exploration(self, state: AgentState) -> Literal["complete", "continue"]:
"""Decide whether schema exploration is complete.
Returns:
"complete" if schema is ready, "continue" if more fetching needed
"""
return "complete" if state.get("schema_complete", False) else "continue"
def _should_refine_sql(self, state: AgentState) -> Literal["refine", "execute", "use_tools", "end"]:
"""Decide routing after SQL generation.
Returns:
"refine" if SQL needs refinement
"execute" if SQL is ready to execute
"use_tools" if LLM wants to use tools
"end" if no SQL found
"""
# Check if refinement is needed (set by generate_sql node)
if state.get("should_refine", False):
return "refine"
# Check for tool calls or SQL in message
messages = state["messages"]
last_msg = messages[-1] if messages else None
if last_msg and isinstance(last_msg, AIMessage):
if hasattr(last_msg, 'tool_calls') and last_msg.tool_calls:
return "use_tools"
content = last_msg.content
if content: # Ensure content is not None or empty
has_sql = "SELECT" in content.upper() or ("sql" in content.lower() and "```" in content)
if has_sql:
return "execute"
return "end"
# ========================================================================
# Helper Methods
# ========================================================================
def _extract_sql_from_messages(self, messages: list) -> Optional[str]:
"""Extract SQL from AIMessage content.
Tries multiple extraction methods: code blocks, SQL markers, SELECT statements.
"""
if not messages:
return None
for msg in reversed(messages):
if isinstance(msg, AIMessage):
content = msg.content
if not content:
continue
# Try SQL code block first
if "```sql" in content:
parts = content.split("```sql")
if len(parts) > 1:
sql = parts[1].split("```")[0].strip()
if sql:
return sql
# Try generic code block
elif "```" in content:
parts = content.split("```")
if len(parts) > 1:
sql = parts[1].strip()
if sql:
return sql
# Try regex extraction
else:
match = self._regex_patterns['select_statement'].search(content)
if match:
sql = match.group(1)
if sql:
sql = sql.strip()
if sql:
return sql
return None
async def _execute_query_node(self, state: AgentState) -> dict:
"""Execute the generated SQL query.
Uses stored SQL from state if available. Reuses test query results
when they contain all data to avoid redundant execution.
"""
messages = state["messages"]
# Use stored SQL from state if available (avoids re-extraction)
sql_query = state.get("final_sql")
test_results = state.get("test_query_results")
# If SQL not in state, extract from message (fallback)
if not sql_query:
sql_query = self._extract_sql_from_messages(messages)
if not sql_query:
return {
"messages": [ToolMessage(
content="Error: Could not extract SQL query from the response.",
tool_call_id="sql_extraction_error"
)]
}
# Reuse test query results if they contain all data (avoids redundant execution)
needs_full_execution = True
result = None
if test_results is not None and len(test_results) < TEST_QUERY_LIMIT:
sql_upper = sql_query.upper()
if "LIMIT" in sql_upper:
limit_match = re.search(r'LIMIT\s+(\d+)', sql_upper)
if limit_match:
try:
original_limit = int(limit_match.group(1))
if original_limit > 0 and original_limit <= TEST_QUERY_LIMIT:
# Test query already got all results, reuse them
# But respect the original LIMIT - only return as many rows as requested
results_to_return = test_results[:original_limit] if original_limit <= len(test_results) else test_results
needs_full_execution = False
result = json.dumps(results_to_return, indent=2, default=str)
except (ValueError, TypeError) as e:
if self.enable_logging:
logger.warning(f"Could not parse LIMIT value: {str(e)}, executing full query")
# Continue with full execution
# Execute the query only if needed
if needs_full_execution:
try:
run_query_tool = await self._get_tool("run_query")
if not run_query_tool:
return {
"messages": [ToolMessage(
content="Error: run_query tool not available.",
tool_call_id="run_query_error"
)]
}
try:
result = await self._call_tool_with_timeout(
run_query_tool,
{
"input": {
"sql": sql_query,
"format": "markdown",
"row_limit": 100
}
}
)
except TimeoutError:
error_msg = f"Query execution timed out after {self.query_timeout} seconds"
if self.enable_logging:
logger.error(error_msg)
return {
"messages": [ToolMessage(
content=error_msg,
tool_call_id="run_query_timeout"
)]
}
except Exception as e:
error_msg = f"Error executing query: {str(e)}"
return {
"messages": [ToolMessage(
content=error_msg,
tool_call_id="run_query_error"
)]
}
# Check if there was an error
if result is None:
return {
"messages": [ToolMessage(
content="Error: Query execution returned no result.",
tool_call_id="run_query_error"
)]
}
if isinstance(result, str):
result_lower = result.lower()
if "error" in result_lower or "query error" in result_lower or "permission" in result_lower:
return {
"messages": [ToolMessage(
content=f"Query execution failed: {result}",
tool_call_id="run_query_error"
)]
}
# Success!
return {
"messages": [ToolMessage(
content=f"Query executed successfully:\n\n{result}",
tool_call_id="run_query_success"
)]
}
def _check_query_result(self, state: AgentState) -> Literal["success", "retry", "error"]:
"""Check query execution result and decide next step.
Returns:
"success" if query executed successfully
"retry" if error occurred and attempts remain
"error" if max attempts reached
"""
messages = state["messages"]
query_attempts = state.get("query_attempts", 0)
max_attempts = state.get("max_attempts", self.max_attempts)
last_msg = messages[-1] if messages else None
if last_msg and isinstance(last_msg, ToolMessage):
content = last_msg.content.lower()
if "successfully" in content:
return "success"
elif "error" in content or "failed" in content:
return "retry" if query_attempts < max_attempts else "error"
return "success" # Default to success if unclear
async def _refine_query_node(self, state: AgentState) -> dict:
"""Refine the SQL query based on error feedback.
Called when query execution fails. Parses error, updates schema if needed,
and passes error context to generate_sql node for regeneration.
"""
messages = state["messages"]
schema_info = state.get("schema_info", {})
# Get the error message
error_msg = None
for msg in reversed(messages):
if isinstance(msg, ToolMessage):
content = msg.content
if content and isinstance(content, str) and "error" in content.lower():
error_msg = content
break
# Parse the error to extract actionable information
parsed_error = self._parse_sql_error(error_msg) if error_msg else None
# If error mentions unknown column, try to get fresh schema info for relevant tables
if parsed_error and parsed_error["type"] == "unknown_column":
# Extract table names that might be involved
table_names = schema_info.get("table_names", [])
describe_table_tool = await self._get_tool("describe_table")
if describe_table_tool and table_names:
# Re-fetch descriptions for all tables to ensure we have latest info
table_descriptions = {}
for table_name in table_names:
try:
desc = await self._call_tool_with_timeout(
describe_table_tool,
{"table_name": table_name}
)
table_descriptions[table_name] = desc
except Exception as e:
if self.enable_logging:
logger.warning(f"Failed to fetch description for table {table_name}: {str(e)}")
continue
# Update schema info with fresh descriptions
schema_info["table_descriptions"] = table_descriptions
# Build refined error message with parsed error details
if parsed_error:
if parsed_error["type"] == "unknown_column":
error_guidance = f"\n\n⚠️ ERROR: Column '{parsed_error['column']}' does not exist!"
if parsed_error.get("table"):
error_guidance += f" (in table '{parsed_error['table']}')"
error_guidance += "\n\nPlease check the column names summary in the schema information above and use the EXACT column names shown."
elif parsed_error["type"] == "unknown_table":
error_guidance = f"\n\n⚠️ ERROR: Table '{parsed_error['table']}' does not exist!"
error_guidance += "\n\nPlease check the available tables list and use the correct table name."
elif parsed_error["type"] == "syntax_error":
error_guidance = f"\n\n⚠️ ERROR: SQL syntax error detected!"
error_guidance += "\n\nPlease review the SQL syntax carefully. Make sure there are no ? placeholders, and all values are properly embedded in the query."
else:
error_guidance = f"\n\n⚠️ ERROR: {parsed_error['message']}"
else:
error_guidance = f"\n\nPrevious query failed. Error: {error_msg}\n\nPlease check the table structures carefully and use the EXACT column names shown."
# Get stored SQL from state or extract from messages
original_sql = state.get("previous_sql") or self._extract_sql_from_messages(messages)
# Get the original user query
user_query = None
for msg in messages:
if isinstance(msg, HumanMessage):
content = msg.content
if content: # Ensure content is not None or empty
user_query = content
break
# Pass error context to generate_sql node (it will optionally use it)
# Just add a message indicating we're retrying with error context
refine_msg = f"Previous query failed.{error_guidance}\n\nRegenerating query with error context..."
return {
"messages": [HumanMessage(content=refine_msg)],
"schema_info": schema_info, # Include updated schema info
"previous_error": error_msg, # Pass error to generate_sql as optional context
"previous_sql": original_sql # Pass previous SQL to generate_sql as optional context
}
async def _tools_node(self, state: AgentState) -> dict:
"""Handle tool calls from the LLM.
Processes tool calls from AI messages and executes them via MCP tools.
"""
messages = state["messages"]
last_msg = messages[-1] if messages else None
if last_msg and hasattr(last_msg, 'tool_calls') and last_msg.tool_calls:
await self._initialize_tools() # Ensure tools are initialized
tool_results = []
for tool_call in last_msg.tool_calls:
if not isinstance(tool_call, dict):
continue
tool_name = tool_call.get("name")
if not tool_name:
tool_results.append(ToolMessage(
content="Error: Tool call missing name.",
tool_call_id=tool_call.get("id", "unknown")
))
continue
tool_args = tool_call.get("args", {})
if not isinstance(tool_args, dict):
tool_args = {}
try:
tool = await self._get_tool(tool_name)
if tool:
try:
result = await self._call_tool_with_timeout(tool, tool_args)
tool_results.append(ToolMessage(
content=str(result),
tool_call_id=tool_call["id"]
))
except TimeoutError:
tool_results.append(ToolMessage(
content=f"Error: Tool {tool_name} call timed out after {self.query_timeout} seconds",
tool_call_id=tool_call["id"]
))
except Exception as e:
tool_results.append(ToolMessage(
content=f"Error calling {tool_name}: {str(e)}",
tool_call_id=tool_call["id"]
))
else:
tool_results.append(ToolMessage(
content=f"Error: Tool {tool_name} not found.",
tool_call_id=tool_call["id"]
))
except Exception as e:
tool_results.append(ToolMessage(
content=f"Error calling {tool_name}: {str(e)}",
tool_call_id=tool_call["id"]
))
return {"messages": tool_results}
return {}
async def query(self, question: str, config: Optional[dict] = None) -> dict:
"""Execute a text-to-SQL query.
Each query starts with a fresh state. Schema cache persists across queries
for performance, but query-specific state (relevant_tables, previous_error, etc.)
is reset for each new query.
Args:
question: The natural language question to convert to SQL
config: Optional LangGraph configuration
Returns:
AgentState dictionary with query results
Raises:
ValueError: If question is None, empty, or not a string
"""
# Input validation
if question is None:
raise ValueError("Question cannot be None. Please provide a valid database question.")
if not isinstance(question, str):
raise ValueError(f"Question must be a string, got {type(question).__name__}")
question = question.strip()
if not question:
raise ValueError("Question cannot be empty. Please provide a valid database question.")
if self.enable_logging:
logger.info(f"Starting query: {question[:100]}...")
initial_state: AgentState = {
"messages": [HumanMessage(content=question)],
"schema_info": {}, # Fresh schema_info for this query
"query_attempts": 0,
"max_attempts": self.max_attempts,
"relevant_tables": [], # Will be populated by explore_schema node
# Clear all query-specific flags for fresh query
"previous_error": None,
"previous_sql": None,
"final_sql": None,
"test_query_results": None,
"should_refine": None,
"refine_reason": None,
"confidence": None,
"sql_is_valid": None,
"has_critical_issues": None,
"confidence_reasoning": None,
"query_error": None,
"schema_complete": None
}
config = config or {}
try:
result = await self.graph.ainvoke(initial_state, config)
if self.enable_logging:
logger.info("Query completed successfully")
return result
except Exception as e:
error_msg = f"Query execution failed: {str(e)}"
if self.enable_logging:
logger.error(error_msg, exc_info=True)
raise
def get_final_answer(self, result: dict) -> str:
"""Extract the final answer from the agent result"""
messages = result.get("messages", [])
if not messages:
return "No answer found"
# Look for the last successful query result (check in reverse order)
for msg in reversed(messages):
if isinstance(msg, ToolMessage):
content = msg.content
if content and "successfully" in content.lower():
return content
elif isinstance(msg, AIMessage):
# Check if this is the last message (more efficient than index())
if msg is messages[-1]:
content = msg.content
return content if content else "No answer found"
# Fallback: return last message content
last_msg = messages[-1]
if hasattr(last_msg, 'content'):
content = last_msg.content
return str(content) if content else "No answer found"
return "No answer found"