import asyncio
import json
import logging
import re
from collections import OrderedDict
from typing import Annotated, Any, Literal, Optional, TypedDict
from langchain_core.messages import (
AIMessage,
BaseMessage,
HumanMessage,
SystemMessage,
ToolMessage,
)
from langchain_core.tools import BaseTool
from langchain_mcp_adapters.client import MultiServerMCPClient
from langchain_openai import ChatOpenAI
from langgraph.graph import END, StateGraph
from langgraph.graph.message import add_messages
from prompts import (
generate_sql_prompt,
get_table_identification_prompt,
get_validation_prompt,
get_combined_confidence_prompt,
get_refinement_prompt,
get_analysis_prompt,
get_sql_generation_prompt,
)
# ========================================================================
# Constants
# ========================================================================
DEFAULT_MAX_ATTEMPTS = 3
DEFAULT_MAX_REFINEMENTS = 3 # Maximum number of refinement cycles in refine_sql node
TEST_QUERY_LIMIT = 3
SMALL_DB_THRESHOLD = 3 # Skip LLM call for table selection if ≤3 tables
LOW_CONFIDENCE_THRESHOLD = 0.6
HIGH_CONFIDENCE_THRESHOLD = 0.8
VERY_HIGH_CONFIDENCE_THRESHOLD = 0.9
MAX_SCHEMA_CACHE_SIZE = 1000 # Maximum number of table descriptions to cache
MAX_COLUMN_CACHE_SIZE = 500 # Maximum number of column name extractions to cache
DEFAULT_LLM_TIMEOUT = 60 # Default timeout for LLM calls in seconds
DEFAULT_QUERY_TIMEOUT = 30 # Default timeout for database queries in seconds
# ========================================================================
# Logging Setup
# ========================================================================
logger = logging.getLogger(__name__)
class AgentState(TypedDict):
"""State for the text-to-SQL agent"""
messages: Annotated[list[BaseMessage], add_messages]
schema_info: dict # Cached schema information
query_attempts: int # Number of query attempts
max_attempts: int # Maximum query attempts before giving up
relevant_tables: list # Tables identified as relevant to the query
previous_error: Optional[
str
] # Optional: previous query error for refinement context
previous_sql: Optional[str] # Optional: previous SQL query for refinement context
final_sql: Optional[str] # Optional: final SQL query to avoid re-extraction
test_query_results: Optional[
list
] # Optional: test query results to avoid re-execution
# Refinement decision flags (set by nodes, used by edges)
should_refine: Optional[bool] # Whether SQL needs refinement
refine_reason: Optional[str] # Reason for refinement
confidence: Optional[float] # Confidence score
sql_is_valid: Optional[bool] # Whether SQL syntax is valid
has_critical_issues: Optional[bool] # Whether critical issues detected
confidence_reasoning: Optional[str] # Analysis/reasoning from confidence scoring
query_error: Optional[str] # Query execution error if any
# Schema exploration flags
schema_complete: Optional[bool] # Whether schema exploration is complete
class TextToSQLAgent:
"""LangGraph agent for converting natural language to SQL queries.
Architecture:
- Uses LangGraph for state machine orchestration
- All database operations via MCP tools (no direct DB connections)
- Nodes process state, edges make routing decisions
- Supports schema caching, intelligent table selection, and auto-refinement
"""
def __init__(
self,
mcp_client: MultiServerMCPClient,
llm: ChatOpenAI,
max_query_attempts: int = DEFAULT_MAX_ATTEMPTS,
max_refinements: int = DEFAULT_MAX_REFINEMENTS,
llm_timeout: int = DEFAULT_LLM_TIMEOUT,
query_timeout: int = DEFAULT_QUERY_TIMEOUT,
max_schema_cache_size: int = MAX_SCHEMA_CACHE_SIZE,
max_column_cache_size: int = MAX_COLUMN_CACHE_SIZE,
enable_logging: bool = True,
):
self.mcp_client = mcp_client
self.llm = llm
self.max_attempts = max_query_attempts
self.max_refinements = max_refinements
self.llm_timeout = llm_timeout
self.query_timeout = query_timeout
self.max_schema_cache_size = max_schema_cache_size
self.max_column_cache_size = max_column_cache_size
self.enable_logging = enable_logging
# Cache for schema information to avoid repeated expensive calls
# Use OrderedDict for LRU eviction
self.schema_cache: Optional[dict] = None
# Cache for extracted column names to avoid redundant parsing
# Use OrderedDict for LRU eviction
self.column_cache: OrderedDict[str, list[str]] = OrderedDict()
# Get tools from MCP client
self.tools = None
self.tool_dict = None # Dictionary for O(1) tool lookup
self._log("info", f"TextToSQLAgent initialized with max_attempts={max_query_attempts}, "
f"llm_timeout={llm_timeout}s, query_timeout={query_timeout}s, "
f"max_schema_cache={max_schema_cache_size}, max_column_cache={max_column_cache_size}")
# Compile regex patterns once for performance (reused many times)
self._compile_regex_patterns()
# Build the graph
self.graph = self._build_graph()
# ========================================================================
# Graph Building
# ========================================================================
def _build_graph(self) -> StateGraph:
"""Build the LangGraph state graph"""
workflow = StateGraph(AgentState)
# Add nodes
workflow.add_node("explore_schema", self._explore_schema_node)
workflow.add_node("generate_sql", self._generate_sql_node)
workflow.add_node("refine_sql", self._refine_sql_node)
workflow.add_node("execute_query", self._execute_query_node)
workflow.add_node("refine_query", self._refine_query_node)
workflow.add_node("tools", self._tools_node)
# Set entry point
workflow.set_entry_point("explore_schema")
# Add conditional edge: check if schema exploration is complete
workflow.add_conditional_edges(
"explore_schema",
self._should_continue_schema_exploration,
{
"complete": "generate_sql",
"continue": "explore_schema", # Will fetch missing schema
},
)
# Add conditional edge: after SQL generation, decide if refinement is needed
workflow.add_conditional_edges(
"generate_sql",
self._should_refine_sql,
{
"refine": "refine_sql",
"execute": "execute_query",
"use_tools": "tools",
"end": END,
},
)
# After refinement, check if we need to refine again or execute
workflow.add_conditional_edges(
"refine_sql",
self._should_refine_sql, # Reuse same edge function
{
"refine": "refine_sql", # Loop back if confidence still low
"execute": "execute_query", # Execute if confidence is good
"use_tools": "tools",
"end": END,
},
)
# Add conditional edge: after execution, decide next step
workflow.add_conditional_edges(
"execute_query",
self._check_query_result,
{
"success": END,
"retry": "refine_query",
"error": END, # Give up after max attempts
},
)
workflow.add_edge("refine_query", "generate_sql")
workflow.add_edge("tools", "generate_sql")
return workflow.compile()
# ========================================================================
# Helper Methods
# ========================================================================
def _log(self, level: str, message: str, exc_info: bool = False):
"""Helper method for conditional logging."""
if self.enable_logging:
getattr(logger, level)(message, exc_info=exc_info)
def _has_multi_part(self, query: str) -> bool:
"""Check if query has multiple parts (e.g., 'and then', 'also find')."""
if not query:
return False
indicators = ["and then", "then find", "also find", "also show", "additionally", "and also"]
return any(indicator in query.lower() for indicator in indicators)
def _clean_sql(self, sql: str, fallback: str = None) -> str:
"""Clean SQL by removing markdown code blocks and extracting SELECT statements."""
if not sql:
return fallback or ""
sql = sql.strip()
# Remove markdown code blocks
for prefix in ["```sql", "```"]:
if sql.startswith(prefix):
sql = sql[len(prefix):]
if sql.endswith("```"):
sql = sql[:-3]
sql = sql.strip()
# Extract SELECT statement if needed
if sql and not (sql.upper().startswith("SELECT") or sql.upper().startswith("WITH")):
match = self._regex_patterns["select_statement"].search(sql)
if match:
extracted = match.group(1)
if extracted:
sql = extracted.strip()
# Remove parameter placeholders
if "?" in sql:
for pattern_name in ["param_and_or", "param_where", "param_standalone", "param_end", "param_start"]:
sql = self._regex_patterns[pattern_name].sub("", sql)
sql = sql.replace("?", "")
sql = self._regex_patterns["trailing_clauses"].sub("", sql)
sql = sql.strip()
return sql if sql else (fallback or "")
def _get_user_query(self, messages: list) -> Optional[str]:
"""Extract user query from messages."""
for msg in reversed(messages):
if isinstance(msg, HumanMessage) and msg.content:
return msg.content
return None
def _build_system_prompt_with_schema(self, schema_info: dict, include_error_context: bool = False,
previous_error: str = None, previous_sql: str = None,
query_attempts: int = 0) -> str:
"""Build system prompt with schema information."""
system_prompt = generate_sql_prompt
# Add table names
table_names = schema_info.get("table_names", [])
if table_names:
system_prompt += f"\nAvailable Tables: {', '.join(table_names)}\n"
# Add relevant tables
relevant_tables = schema_info.get("relevant_tables", [])
if relevant_tables:
system_prompt += f"Relevant tables: {', '.join(relevant_tables)}\n\n"
# Add table structures
if schema_info.get("table_descriptions"):
for table_name, desc in schema_info["table_descriptions"].items():
if not relevant_tables or table_name in relevant_tables:
system_prompt += f"\n{table_name}:\n{desc}\n"
# Add error context if needed
if include_error_context and previous_error:
error_info = f"\n\nPrevious query error: {previous_error}"
col_match = self._regex_patterns["unknown_column"].search(previous_error)
problematic_column = col_match.group(1) if col_match else None
error_guidance = f"\n\nThis is attempt {query_attempts + 1}. Previous query failed.{error_info}"
if problematic_column:
error_guidance += f"\n\n⚠️ CRITICAL: The column '{problematic_column}' does NOT exist! Check the column names summary above for the correct column name."
if previous_sql:
error_guidance += f"\n\nPrevious SQL (for reference - fix the error but preserve the query structure):\n```sql\n{previous_sql}\n```"
error_guidance += "\n\nIMPORTANT REMINDERS: \n- Pay close attention to the EXACT column names from the table structures above. Do NOT guess column names.\n- Use the exact column names shown in the table structures - do not assume generic names like 'id' if the actual column has a different name.\n- Do NOT use parameterized queries with ? placeholders - embed all values directly in the SQL.\n- If the error mentions '?' or 'syntax error', make sure your SQL has no ? placeholders.\n- If fixing a column name error, keep the query structure the same - only change the column name."
system_prompt += error_guidance
return system_prompt
def _manage_schema_cache_size(self):
"""Manage schema cache size using LRU eviction."""
if not self.schema_cache:
return
# Validate cache structure
if not isinstance(self.schema_cache, dict):
self._log("warning", "Schema cache is not a dictionary, resetting")
self.schema_cache = {}
return
table_descriptions = self.schema_cache.get("table_descriptions", {})
if not isinstance(table_descriptions, dict):
self._log("warning", "table_descriptions is not a dictionary, resetting")
self.schema_cache["table_descriptions"] = {}
return
if len(table_descriptions) > self.max_schema_cache_size:
excess = len(table_descriptions) - self.max_schema_cache_size
for key in list(table_descriptions.keys())[:excess]:
del table_descriptions[key]
self._log("debug", f"Evicted {excess} table descriptions from schema cache")
def _manage_column_cache_size(self):
"""Manage column cache size using LRU eviction."""
while len(self.column_cache) > self.max_column_cache_size:
# Remove oldest entry (FIFO)
self.column_cache.popitem(last=False)
if (
self.enable_logging
and len(self.column_cache) > self.max_column_cache_size * 0.9
):
logger.debug(
f"Column cache size: {len(self.column_cache)}/{self.max_column_cache_size}"
)
async def _call_with_timeout(
self, coro, timeout: Optional[int], default_timeout: int,
name: str = "operation", timeout_type: str = "LLM"
) -> Any:
"""Generic timeout wrapper for async operations."""
timeout = timeout or default_timeout
if timeout <= 0:
timeout = default_timeout
self._log("warning", f"Invalid timeout, using default {default_timeout}s")
try:
self._log("debug", f"Calling {name} with timeout={timeout}s")
result = await asyncio.wait_for(coro, timeout=timeout)
self._log("debug", f"{name} completed successfully")
return result
except asyncio.TimeoutError:
error_msg = f"{name} timed out after {timeout} seconds"
self._log("error", error_msg)
raise TimeoutError(error_msg)
except Exception as e:
error_msg = f"{name} failed: {str(e)}"
self._log("error", error_msg, exc_info=True)
raise
async def _call_llm_with_timeout(
self, messages: list, timeout: Optional[int] = None
) -> BaseMessage:
"""Call LLM with timeout handling."""
return await self._call_with_timeout(
self.llm.ainvoke(messages), timeout, self.llm_timeout, "LLM", "LLM"
)
async def _call_tool_with_timeout(
self, tool: BaseTool, args: dict, timeout: Optional[int] = None
) -> Any:
"""Call MCP tool with timeout handling."""
return await self._call_with_timeout(
tool.ainvoke(args), timeout, self.query_timeout, f"Tool {tool.name}", "Tool"
)
def _compile_regex_patterns(self):
"""Compile regex patterns once for performance (reused many times)"""
self._regex_patterns = {
"sql_code_block": re.compile(
r"```sql\s*(.*?)\s*```", re.DOTALL | re.IGNORECASE
),
"sql_marker": re.compile(
r"SQL Query:\s*(.*?)(?:\n\n|$)", re.DOTALL | re.IGNORECASE
),
"select_statement": re.compile(
r"((?:WITH\s+\w+\s+AS\s*\([^)]*(?:\([^)]*\)[^)]*)*\)\s*,?\s*)*SELECT.*?)(?:\n\n|$)",
re.DOTALL | re.IGNORECASE,
),
"confidence_score": re.compile(r"CONFIDENCE:\s*(\d+\.?\d*)", re.IGNORECASE),
"analysis_section": re.compile(
r"ANALYSIS:\s*(.+?)(?:\n\n|$)", re.IGNORECASE | re.DOTALL
),
"unknown_column": re.compile(
r"Unknown column ['\"]?([^'\"]+)['\"]?", re.IGNORECASE
),
"unknown_table": re.compile(
r"Table ['\"]?([^'\"]+)['\"]? doesn't exist", re.IGNORECASE
),
"limit_replace": re.compile(r"\s+LIMIT\s+\d+", re.IGNORECASE),
"param_and_or": re.compile(
r"\s+(AND|OR)\s+[^\s]+\s*[=<>!]+\s*\?", re.IGNORECASE
),
"param_where": re.compile(
r"\s+WHERE\s+[^\s]+\s*[=<>!]+\s*\?", re.IGNORECASE
),
"param_standalone": re.compile(r"\s+\?\s+"),
"param_end": re.compile(r"\s+\?$"),
"param_start": re.compile(r"^\?\s+"),
"trailing_clauses": re.compile(r"\s+(WHERE|AND|OR)\s+$", re.IGNORECASE),
"from_table": re.compile(r"\bFROM\s+`?(\w+)`?", re.IGNORECASE),
"join_table": re.compile(r"\bJOIN\s+`?(\w+)`?", re.IGNORECASE),
"table_in_error": re.compile(r"Table\s+['\"]?(\w+)['\"]?", re.IGNORECASE),
"table_in_quotes": re.compile(r"['\"`](\w+)['\"`]"),
}
async def _initialize_tools(self):
"""Initialize tools from MCP client and cache static schema data"""
if self.tools is None:
tools_list = await self.mcp_client.get_tools()
self.tools = tools_list
# Create a dictionary for O(1) lookup instead of O(n) linear search
self.tool_dict = {t.name: t for t in tools_list}
# Pre‑fetch static schema information that never changes during a session
# This avoids repeated network calls for list_tables
list_tables_tool = await self._get_tool("list_tables_mysql")
if list_tables_tool:
try:
tables_result = await self._call_tool_with_timeout(
list_tables_tool, {}
)
# Extract table names immediately
table_names = self._extract_table_names_from_result(tables_result)
self.schema_cache = self.schema_cache or {}
self.schema_cache["table_names"] = table_names
except Exception as e:
self._log("warning", f"Failed to fetch tables list: {str(e)}")
return self.tools
def _extract_table_names_from_result(self, tables_result: str) -> list:
"""Extract table names from list_tables result (JSON string or markdown format)"""
table_names = []
if not tables_result:
return table_names
try:
# Parse JSON string to extract table names
import json
tables_data = json.loads(tables_result)
if isinstance(tables_data, list):
for row in tables_data:
if isinstance(row, dict):
# Get the first value from the dict (table name)
table_name = list(row.values())[0] if row.values() else None
if table_name:
table_names.append(table_name)
elif isinstance(row, list) and row:
table_name = row[0]
if table_name:
table_names.append(table_name)
except (json.JSONDecodeError, AttributeError, IndexError):
# Fallback: try to extract from string format
import re
# Try to match table names from markdown or text format
table_matches = re.findall(r"`?(\w+)`?", str(tables_result))
if table_matches:
table_names = list(set(table_matches)) # Remove duplicates
return table_names
async def _get_tool(self, tool_name: str):
"""Get a tool by name using O(1) dictionary lookup"""
if not hasattr(self, "tool_dict") or self.tool_dict is None:
await self._initialize_tools()
return self.tool_dict.get(tool_name)
def _extract_sql_from_messages(self, messages: list) -> Optional[str]:
"""Extract SQL from AIMessage content."""
if not messages:
return None
for msg in reversed(messages):
if isinstance(msg, AIMessage) and msg.content:
# Try SQL code block first
if "```sql" in msg.content:
sql = msg.content.split("```sql")[1].split("```")[0].strip()
if sql:
return sql
# Try generic code block
elif "```" in msg.content:
sql = msg.content.split("```")[1].strip()
if sql:
return sql
# Try regex extraction
else:
match = self._regex_patterns["select_statement"].search(msg.content)
if match and match.group(1):
return match.group(1).strip()
return None
async def _identify_relevant_tables(
self, user_query: str, all_tables: list
) -> list:
"""Use LLM to identify which tables are relevant to the query.
Args:
user_query: The natural language query
all_tables: List of all available table names
Returns:
List of relevant table names (or all tables if identification fails)
"""
if not all_tables:
return []
# Skip LLM call for small databases (optimization)
if len(all_tables) <= SMALL_DB_THRESHOLD:
return all_tables
prompt = get_table_identification_prompt(user_query, all_tables)
try:
response = await self._call_llm_with_timeout([HumanMessage(content=prompt)])
relevant_str = response.content
if not relevant_str:
self._log("warning", "LLM returned empty content for table identification, using all tables")
return all_tables
relevant = [t.strip(" `\"'") for t in relevant_str.strip().split(",") if t.strip()]
tables_set = set(all_tables)
relevant = [t for t in relevant if t in tables_set]
return relevant if relevant else all_tables
except Exception as e:
self._log("warning", f"Table identification failed: {str(e)}, using all tables")
return all_tables
async def _validate_query_is_database_question(
self, query: str
) -> tuple[bool, str]:
"""Validate if the user query is actually a database question.
Uses fast heuristics first, then LLM only if needed for ambiguous cases.
Returns (is_valid, reason) tuple.
"""
query_stripped = query.strip()
query_lower = query_stripped.lower()
# Fast heuristic checks first (avoid LLM call if possible)
# Check if query is too short
if len(query_stripped) < 3:
return False, "Query is too short to be a valid database question."
# Check for common database question patterns (optimized with set for O(1) lookup)
db_keywords = {
"show",
"find",
"get",
"list",
"count",
"select",
"what",
"which",
"who",
"how many",
"how much",
"where",
"when",
"display",
"retrieve",
"query",
"books",
"authors",
"members",
"loans",
"employees",
"customers",
"orders",
"table",
"tables",
"data",
"database",
"records",
"rows",
}
# Check if query contains any database keywords
query_words = set(query_lower.split())
has_db_keyword = bool(query_words & db_keywords) or any(
kw in query_lower for kw in ["how many", "how much"]
)
# Check if it looks like SQL (shouldn't be)
sql_keywords = {"select", "from", "where", "join", "insert", "update", "delete"}
looks_like_sql = bool(query_words & sql_keywords)
# Fast path: If clearly valid or clearly invalid, return immediately
if has_db_keyword or looks_like_sql:
return True, "Query contains database-related keywords."
# If query is very short and has no keywords, likely invalid
if len(query_stripped) < 5 and not has_db_keyword:
return False, "Query does not appear to be a database question."
# Ambiguous case: Use LLM for validation (only when heuristics are unclear)
validation_prompt = get_validation_prompt(query)
try:
response = await self._call_llm_with_timeout(
[HumanMessage(content=validation_prompt)]
)
response_text = response.content.strip()
# Parse response
is_valid = (
"valid: yes" in response_text.lower()
or "yes" in response_text.lower()[:50]
)
reason = response_text
return is_valid, reason
except Exception as e:
# Fallback to heuristics if LLM call fails
return False, "Query does not appear to be a database question."
async def _score_and_analyze_query(
self,
query: str,
sql: str,
schema_info: dict,
query_results: list = None,
query_error: str = None,
) -> tuple[float, str]:
"""Score confidence and analyze query in a single LLM call (optimization).
Confidence scoring now explicitly checks if the SQL answers the question.
"""
combined_prompt = get_combined_confidence_prompt(
query, sql, query_error, query_results
)
try:
response = await self._call_llm_with_timeout(
[HumanMessage(content=combined_prompt)]
)
content = response.content
if not content:
self._log("warning", "LLM returned empty content for confidence scoring")
return 0.5, "Unable to analyze query: LLM returned empty response."
content = content.strip()
# Extract confidence score (use compiled regex)
confidence_match = self._regex_patterns["confidence_score"].search(content)
if confidence_match:
try:
score_str = confidence_match.group(1)
if score_str: # Ensure group is not empty
score = float(score_str)
if score > 1.0:
score = score / 100.0
confidence = max(0.0, min(1.0, score))
else:
raise ValueError("Empty confidence score match")
except (ValueError, TypeError) as e:
self._log("warning", f"Could not parse confidence score: {str(e)}, trying fallback")
confidence = None
else:
confidence = None # Will trigger fallback
# Fallback: try to find any number if regex didn't work
if confidence is None:
match = re.search(r"(\d+\.?\d*)", content)
if match:
try:
score_str = match.group(1)
if score_str: # Ensure group is not empty
score = float(score_str)
if score > 1.0:
score = score / 100.0
confidence = max(0.0, min(1.0, score))
else:
raise ValueError("Empty score match")
except (ValueError, TypeError) as e:
self._log("warning", f"Could not parse fallback confidence score: {str(e)}, using default 0.5")
confidence = 0.5
else:
self._log("warning", "Could not extract confidence score from LLM response, using default 0.5")
confidence = 0.5
# Extract analysis (use compiled regex)
analysis_match = self._regex_patterns["analysis_section"].search(content)
if analysis_match:
analysis_group = analysis_match.group(1)
analysis = (
analysis_group.strip()
if analysis_group
else "Analysis unavailable."
)
else:
# Fallback: use everything after CONFIDENCE line
lines = content.split("\n")
analysis_lines = []
found_conf = False
for line in lines:
if "CONFIDENCE:" in line.upper():
found_conf = True
continue
if found_conf:
analysis_lines.append(line)
analysis = (
" ".join(analysis_lines).strip()
if analysis_lines
else "Analysis unavailable."
)
self._log("debug", f"Confidence score: {confidence:.2f}")
return confidence, analysis
except TimeoutError:
self._log("error", "Confidence scoring timed out")
return 0.3, "Confidence scoring timed out. Using default low confidence."
except Exception as e:
self._log("warning", f"Error in combined confidence scoring: {str(e)}", exc_info=True)
try:
confidence = await self._score_query_confidence(query, sql, schema_info, query_results, query_error)
analysis = await self._analyze_query_confidence(query, sql, schema_info, confidence, query_results, query_error)
return confidence, analysis
except Exception as fallback_error:
self._log("error",
f"Fallback confidence scoring also failed: {str(fallback_error)}",
exc_info=True,
)
return 0.3, f"Error analyzing query: {str(e)}"
async def _score_query_confidence(
self,
query: str,
sql: str,
schema_info: dict,
query_results: list = None,
query_error: str = None,
) -> float:
"""Score confidence in the generated query (0-1) - kept for backward compatibility"""
confidence, _ = await self._score_and_analyze_query(
query, sql, schema_info, query_results, query_error
)
return confidence
def _validate_sql_syntax(self, sql: str) -> bool:
"""Validate basic SQL syntax - check if it's a valid read-only query (SELECT, WITH, SHOW, EXPLAIN, DESCRIBE, etc.)"""
if not sql or not sql.strip():
return False
sql_upper = sql.upper().strip()
# Allow all read-only SQL query types
readonly_keywords = [
"SELECT",
"WITH",
"SHOW",
"EXPLAIN",
"DESCRIBE",
"DESC",
"VALUES",
]
# Check if query starts with any read-only keyword
starts_with_readonly = any(sql_upper.startswith(kw) for kw in readonly_keywords)
if not starts_with_readonly:
return False
# For SELECT and WITH queries, basic validation: should have SELECT and FROM (for most queries)
# Allow for subqueries and complex queries
# WITH queries will have SELECT after the CTE definitions
# if sql_upper.startswith("SELECT") or sql_upper.startswith("WITH"):
# if "FROM" not in sql_upper and "UNION" not in sql_upper:
# # Some queries might not have FROM (e.g., SELECT 1), but most should
# # For our use case, we expect FROM clauses
# # WITH queries will have SELECT somewhere in them
# if not any(
# keyword in sql_upper
# for keyword in ["SELECT 1", "SELECT NOW()", "SELECT"]
# ):
# return False
# For other read-only queries (SHOW, EXPLAIN, etc.), no additional validation needed
return True
def _has_critical_issues(
self, analysis: str, query_error: str = None, sql_query: str = None
) -> bool:
"""Check if analysis or error indicates critical syntax/structural issues that require refinement"""
if not analysis:
return False
analysis_lower = analysis.lower()
# Check for critical syntax/structural keywords only (not incomplete answers - those are handled by confidence score)
critical_keywords = [
"missing select",
"syntax error",
"missing keyword",
"incorrect syntax",
"malformed",
"invalid syntax",
"invalid query structure",
]
if any(keyword in analysis_lower for keyword in critical_keywords):
return True
# Check query error for critical syntax issues
if query_error:
error_lower = query_error.lower()
if any(
keyword in error_lower
for keyword in ["syntax error", "1064", "invalid syntax", "malformed"]
):
return True
# Check SQL query itself for obvious structural issues
if sql_query:
sql_upper = sql_query.upper().strip()
# Missing SELECT keyword
if not sql_upper.startswith("SELECT") and any(
kw in sql_upper for kw in ["FROM", "WHERE", "JOIN"]
):
return True
return False
async def _fetch_missing_table_info(
self,
schema_info: dict,
query_error: str = None,
analysis: str = None,
sql_query: str = None,
):
"""Selectively fetch schema info for tables that are mentioned but not in schema_info"""
# Get all available table names
all_tables = schema_info.get("table_names", [])
if not all_tables:
return # Can't fetch if we don't know available tables
# Get currently cached table descriptions
cached_tables = set(schema_info.get("table_descriptions", {}).keys())
# Extract table names from error, analysis, and SQL query
mentioned_tables = set()
# Extract from SQL query (FROM, JOIN clauses) - use compiled regex
if sql_query:
from_matches = self._regex_patterns["from_table"].findall(sql_query)
join_matches = self._regex_patterns["join_table"].findall(sql_query)
mentioned_tables.update(from_matches)
mentioned_tables.update(join_matches)
# Extract from error message (e.g., "Table 'X' doesn't exist") - use compiled regex
if query_error and isinstance(query_error, str):
try:
table_matches = self._regex_patterns["table_in_error"].findall(
query_error
)
# Filter out empty matches
mentioned_tables.update([t for t in table_matches if t])
except Exception as e:
self._log("debug", f"Could not extract tables from error message: {str(e)}")
# Extract from analysis text (might mention table names) - use compiled regex
if analysis and isinstance(analysis, str):
try:
table_matches = self._regex_patterns["table_in_quotes"].findall(
analysis
)
# Filter to only include actual table names (use set for O(1) lookup)
tables_set = set(all_tables)
mentioned_tables.update(
[t for t in table_matches if t and t in tables_set]
)
except Exception as e:
self._log("debug", f"Could not extract tables from analysis: {str(e)}")
# Find tables that are mentioned but not in cache
missing_tables = [
t for t in mentioned_tables if t in all_tables and t not in cached_tables
]
if not missing_tables:
return # No missing tables
# Fetch schema info for missing tables
describe_table_tool = await self._get_tool("describe_table")
# Fetch table descriptions in parallel
if describe_table_tool:
async def describe_table(table_name: str):
try:
desc = await self._call_tool_with_timeout(
describe_table_tool, {"table_name": table_name}
)
return table_name, desc, None
except Exception as e:
if self.enable_logging:
logger.debug(
f"Failed to fetch description for {table_name}: {str(e)}"
)
return table_name, None, str(e)
desc_tasks = [describe_table(t) for t in missing_tables]
desc_results = await asyncio.gather(*desc_tasks, return_exceptions=True)
table_descriptions = schema_info.get("table_descriptions", {})
for result in desc_results:
if not isinstance(result, Exception):
tbl, desc, _ = result
if desc:
table_descriptions[tbl] = desc
schema_info["table_descriptions"] = table_descriptions
# Update cache
if self.schema_cache:
existing_descriptions = self.schema_cache.get("table_descriptions", {})
existing_descriptions.update(table_descriptions)
self.schema_cache["table_descriptions"] = existing_descriptions
async def _refine_sql_with_analysis(
self,
query: str,
sql: str,
schema_info: dict,
analysis: str,
query_results: list = None,
query_error: str = None,
) -> str:
"""Refine SQL query using the analysis/reasoning from confidence scoring"""
# Check if error is a simple column name error
is_simple_column_error = False
wrong_column = None
if query_error:
parsed_error = self._parse_sql_error(query_error)
if parsed_error.get("type") == "unknown_column":
is_simple_column_error = True
wrong_column = parsed_error.get("column")
# Build refinement prompt with the analysis
has_multi_part = self._has_multi_part(query)
refinement_prompt = get_refinement_prompt(
query=query,
sql=sql,
analysis=analysis,
query_error=query_error,
query_results=query_results,
is_simple_column_error=is_simple_column_error,
wrong_column=wrong_column,
has_multi_part=has_multi_part,
)
# Build system prompt with schema info
system_prompt = generate_sql_prompt
if schema_info.get("table_descriptions"):
system_prompt += "\n=== TABLE STRUCTURES (USE EXACT COLUMN NAMES) ===\n"
for table_name, desc in schema_info["table_descriptions"].items():
system_prompt += f"\n--- Table: {table_name} ---\n{desc}\n"
try:
messages_list = [
SystemMessage(content=system_prompt),
HumanMessage(content=refinement_prompt),
]
response = await self._call_llm_with_timeout(messages_list)
refined_sql = self._clean_sql(response.content, fallback=sql)
return refined_sql if refined_sql else sql
except (TimeoutError, Exception) as e:
self._log("error", f"SQL refinement failed: {str(e)}", exc_info=isinstance(e, Exception))
return sql
async def _analyze_query_confidence(
self,
query: str,
sql: str,
schema_info: dict,
confidence: float,
query_results: list = None,
query_error: str = None,
) -> str:
"""Analyze the query and provide reasoning based on actual query results (always included)"""
analysis_prompt = get_analysis_prompt(
query=query,
sql=sql,
confidence=confidence,
query_error=query_error,
query_results=query_results,
schema_info=schema_info,
)
try:
response = await self._call_llm_with_timeout(
[HumanMessage(content=analysis_prompt)]
)
content = response.content
if not content:
self._log("warning", "LLM returned empty content for analysis")
content = "Analysis unavailable."
return content.strip()
except TimeoutError:
error_msg = "Analysis timed out"
self._log("error", error_msg)
if query_error:
return f"Query error: {query_error}. Confidence: {confidence:.2f}. {error_msg}."
elif query_results is not None:
return f"Query returned {len(query_results)} rows. Confidence: {confidence:.2f}. {error_msg}."
else:
return f"Confidence: {confidence:.2f}. {error_msg}."
except Exception as e:
self._log("error", f"Error generating analysis: {str(e)}", exc_info=True)
if query_error:
return f"Query error: {query_error}. Confidence: {confidence:.2f}. Please review the query carefully."
elif query_results is not None:
return f"Query returned {len(query_results)} rows. Confidence: {confidence:.2f}. Please review the results carefully."
else:
return f"Confidence: {confidence:.2f}. Please review the query carefully against the schema information."
def _parse_sql_error(self, error_msg: str) -> dict:
"""Parse SQL error to extract actionable information (uses compiled regex for performance)"""
if not error_msg:
return {"type": "unknown", "message": error_msg}
error_info = {"type": "unknown", "message": error_msg}
# Parse "Unknown column" errors (use compiled regex)
col_match = self._regex_patterns["unknown_column"].search(error_msg)
if col_match:
error_info["type"] = "unknown_column"
col_name = col_match.group(1)
if col_name: # Ensure group is not empty
# Handle table.column format
if "." in col_name:
parts = col_name.split(".", 1)
if len(parts) == 2:
error_info["table"] = parts[0].strip(" `\"'")
error_info["column"] = parts[1].strip(" `\"'")
else:
error_info["column"] = col_name.strip(" `\"'")
else:
error_info["column"] = col_name.strip(" `\"'")
else:
# Empty match, keep as unknown
error_info["type"] = "unknown"
# Parse "Table doesn't exist" errors (use compiled regex)
table_match = self._regex_patterns["unknown_table"].search(error_msg)
if table_match:
table_name = table_match.group(1)
if table_name: # Ensure group is not empty
error_info["type"] = "unknown_table"
error_info["table"] = table_name.strip(" `\"'")
else:
# Empty match, keep as unknown
error_info["type"] = "unknown"
# Parse syntax errors
if "syntax error" in error_msg.lower() or "1064" in error_msg:
error_info["type"] = "syntax_error"
return error_info
# ========================================================================
# Node Methods
# ========================================================================
async def _explore_schema_node(self, state: AgentState) -> dict:
"""Explore database schema to understand available tables and structure.
Fetches table descriptions and caches schema information.
Sets schema_complete flag for edge routing.
"""
messages = state["messages"]
schema_info = state.get("schema_info", {})
relevant_tables = state.get("relevant_tables", [])
# Use cached schema data if available (but still need to identify relevant tables for this query)
# IMPORTANT: Don't copy relevant_tables from cache - they're query-specific
if self.schema_cache:
# Use cached table descriptions, table names, etc.
if self.schema_cache.get("table_descriptions"):
schema_info["table_descriptions"] = self.schema_cache.get(
"table_descriptions", {}
).copy()
if self.schema_cache.get("table_names"):
schema_info["table_names"] = self.schema_cache.get("table_names", [])[
:
] # Copy list
# Always reset relevant_tables for new query (don't use cached ones)
relevant_tables = [] # Reset - will be identified for this specific query
# Retrieve the user's question (needed for identifying relevant tables)
user_query = None
for msg in reversed(messages):
if isinstance(msg, HumanMessage):
user_query = msg.content
break
if not user_query:
self._log("warning", "No user query found in messages for schema exploration")
return {"schema_info": schema_info}
# CRITICAL: Always identify relevant tables for THIS query (even if schema is cached)
# This ensures each query gets fresh relevant_tables based on the current question
# Get table names from cache first
table_names = schema_info.get("table_names", [])
if not table_names and self.schema_cache:
table_names = self.schema_cache.get("table_names", [])
# If we have table names, identify relevant ones for this query NOW
if table_names and not relevant_tables:
relevant_tables = await self._identify_relevant_tables(
user_query, table_names
)
schema_info["relevant_tables"] = relevant_tables
# Check if schema exploration is complete (decision made by edge)
# Schema is complete if we have descriptions for all relevant tables
schema_complete = False
if schema_info and schema_info.get("table_descriptions") and relevant_tables:
cached_descriptions = (
self.schema_cache.get("table_descriptions", {})
if self.schema_cache
else {}
)
# Use set for O(1) lookup
cached_tables_set = set(cached_descriptions.keys())
if all(tbl in cached_tables_set for tbl in relevant_tables):
schema_complete = True
elif (
schema_info
and schema_info.get("table_descriptions")
and not relevant_tables
):
# If no relevant tables identified but we have all descriptions, schema is complete
schema_complete = True
# If schema is complete AND we've identified relevant tables, return early
if schema_complete and relevant_tables:
return {
"schema_info": schema_info,
"relevant_tables": relevant_tables, # Include identified relevant tables
"schema_complete": True,
}
# If we don't have table names yet, we need to fetch them first
if not table_names:
# Continue to fetch below
pass
# Schema not complete - continue fetching (edge will route back to explore_schema)
# Initialize tools (cached where possible)
await self._initialize_tools()
list_tables_tool = await self._get_tool("list_tables_mysql")
describe_table_tool = await self._get_tool("describe_table")
# Get table names (use cache if present)
if self.schema_cache and self.schema_cache.get("table_names"):
table_names = self.schema_cache["table_names"]
elif list_tables_tool:
try:
tables_result = await self._call_tool_with_timeout(list_tables_tool, {})
table_names = self._extract_table_names_from_result(tables_result)
self.schema_cache = self.schema_cache or {}
self.schema_cache["table_names"] = table_names
except Exception as e:
self._log("warning", f"Failed to fetch tables list: {str(e)}")
table_names = []
else:
table_names = []
# Update schema_info with table names
schema_info["table_names"] = table_names
# Determine relevant tables via LLM (if not already identified above)
if not relevant_tables and table_names:
relevant_tables = await self._identify_relevant_tables(
user_query, table_names
)
schema_info["relevant_tables"] = relevant_tables
elif not relevant_tables:
relevant_tables = table_names
schema_info["relevant_tables"] = relevant_tables
# Describe relevant tables (parallel fetching)
# Check cache first - only fetch descriptions for tables not in cache
table_descriptions = {}
cached_descriptions = (
self.schema_cache.get("table_descriptions", {}) if self.schema_cache else {}
)
# Get descriptions from cache for tables that are already cached
tables_to_fetch = []
for table_name in relevant_tables:
if table_name in cached_descriptions:
table_descriptions[table_name] = cached_descriptions[table_name]
else:
tables_to_fetch.append(table_name)
# Only fetch descriptions for tables not in cache
if describe_table_tool and tables_to_fetch:
async def describe_table(table_name: str):
try:
desc = await self._call_tool_with_timeout(
describe_table_tool, {"table_name": table_name}
)
return table_name, desc, None
except Exception as e:
self._log("debug", f"Failed to fetch description for {table_name}: {str(e)}")
return table_name, None, str(e)
desc_tasks = [describe_table(t) for t in tables_to_fetch]
desc_results = await asyncio.gather(*desc_tasks, return_exceptions=True)
for result in desc_results:
if not isinstance(result, Exception):
tbl, desc, _ = result
if desc:
table_descriptions[tbl] = desc
schema_info["table_descriptions"] = table_descriptions
schema_info["table_names"] = table_names
# Cache heavy schema parts for future queries
# Merge with existing cache to preserve all table descriptions
self.schema_cache = self.schema_cache or {}
existing_descriptions = self.schema_cache.get("table_descriptions", {})
existing_descriptions.update(
table_descriptions
) # Merge new descriptions with existing
self.schema_cache.update(
{
"table_descriptions": existing_descriptions,
"table_names": table_names,
}
)
# Manage cache size
self._manage_schema_cache_size()
# Schema fetching complete - return with flag (edge will route to generate_sql)
return {
"schema_info": schema_info,
"relevant_tables": relevant_tables,
"schema_complete": True,
}
async def _generate_sql_node(self, state: AgentState) -> dict:
"""Generate SQL query from natural language using LLM.
Generates SQL, executes test query, calculates confidence, and sets
refinement flags. Does NOT perform refinement (handled by refine_sql node).
Validates that the user query is actually a database question before generating SQL.
"""
messages = state["messages"]
schema_info = state.get("schema_info", {})
query_attempts = state.get("query_attempts", 0)
previous_error = state.get("previous_error")
previous_sql = state.get("previous_sql")
# Retrieve the user's question
user_query = self._get_user_query(messages)
if not user_query:
return {
"messages": [AIMessage(content="Error: No user question found in messages.")],
"should_refine": False,
}
# Validate that this is actually a database question (skip validation on retries)
if previous_error is None and previous_sql is None:
try:
is_valid_query, validation_reason = await self._validate_query_is_database_question(user_query)
if not is_valid_query:
error_msg = f'I cannot generate a SQL query for this request. {validation_reason}\n\nPlease ask a question about the database, such as:\n- "Show me all authors"\n- "How many books are there?"\n- "Find books by a specific author"'
self._log("warning", f"Query validation failed: {validation_reason}")
return {
"messages": [AIMessage(content=error_msg)],
"should_refine": False,
"confidence": 0.0,
"sql_is_valid": False,
"has_critical_issues": True,
}
except Exception as e:
self._log("error", f"Error validating query: {str(e)}", exc_info=True)
return {
"messages": [AIMessage(content=f"Error: {str(e)}. Please try again.")],
"should_refine": False,
"confidence": 0.0,
"sql_is_valid": False,
"has_critical_issues": True,
}
# Build system prompt with schema information
system_prompt = self._build_system_prompt_with_schema(
schema_info,
include_error_context=bool(previous_error),
previous_error=previous_error,
previous_sql=previous_sql,
query_attempts=query_attempts
)
# Add column names summary
system_prompt += "\n=== COLUMN NAMES SUMMARY (USE THESE EXACT NAMES) ===\n"
for table_name in schema_info.get("table_descriptions", {}):
if table_name in self.column_cache:
column_descriptions = self.column_cache[table_name]
self.column_cache.move_to_end(table_name)
else:
fetch_tool = await self._get_tool("fetch_column_descriptions")
column_descriptions = await self._call_tool_with_timeout(
fetch_tool, {"table_name": "edl." + table_name}
) if fetch_tool else None
if column_descriptions:
self.column_cache[table_name] = column_descriptions
self._manage_column_cache_size()
if column_descriptions:
system_prompt += f"\n{table_name}: {', '.join(str(r) for r in column_descriptions)}\n"
else:
system_prompt += f"\n{table_name}: (see table structure above)\n"
# Re-get user query
user_query = self._get_user_query(messages)
if not user_query:
return {"messages": [AIMessage(content="I need a question to generate a SQL query.")]}
has_multi_part = self._has_multi_part(user_query)
# Generate SQL using LLM
prompt = get_sql_generation_prompt(user_query, has_multi_part)
messages_list = [
SystemMessage(content=system_prompt),
HumanMessage(content=prompt),
]
response = await self._call_llm_with_timeout(messages_list)
sql_query = response.content
# Handle None or empty content
if not sql_query:
self._log("warning", "LLM returned empty SQL query")
return {
"messages": [AIMessage(content="Error: LLM returned empty SQL query. Please try rephrasing your question.")],
"should_refine": False,
"sql_is_valid": False,
"has_critical_issues": True,
"confidence": 0.0,
}
sql_query = self._clean_sql(sql_query)
# Validate SQL is not empty after extraction
if not sql_query or not sql_query.strip():
if self.enable_logging:
logger.warning("SQL query is empty after extraction")
return {
"messages": [
AIMessage(
content="Error: Could not extract valid SQL query from LLM response. Please try rephrasing your question."
)
],
"should_refine": False,
"sql_is_valid": False,
"has_critical_issues": True,
"confidence": 0.0,
}
# Remove any parameterized query placeholders (?) - use compiled regex for performance
if "?" in sql_query:
# Remove ? placeholders that are used as parameter markers
sql_query = self._regex_patterns["param_and_or"].sub("", sql_query)
sql_query = self._regex_patterns["param_where"].sub("", sql_query)
sql_query = self._regex_patterns["param_standalone"].sub(" ", sql_query)
sql_query = self._regex_patterns["param_end"].sub("", sql_query)
sql_query = self._regex_patterns["param_start"].sub("", sql_query)
# Clean up any remaining ? characters (safety fallback)
sql_query = sql_query.replace("?", "")
# Clean up trailing WHERE/AND/OR
sql_query = self._regex_patterns["trailing_clauses"].sub("", sql_query)
sql_query = sql_query.strip()
# Execute test query with LIMIT to get sample results for confidence scoring
query_results = None
query_error = None
try:
# Add LIMIT if not already present
test_sql = sql_query
if "LIMIT" not in test_sql.upper():
test_sql = f"{sql_query.rstrip(';')} LIMIT {TEST_QUERY_LIMIT}"
else:
# Replace existing LIMIT
test_sql = self._regex_patterns["limit_replace"].sub(
f" LIMIT {TEST_QUERY_LIMIT}", test_sql
)
# Execute the query to get sample results
run_query_tool = await self._get_tool("run_query_json")
# run_query_tool = await self._get_tool("run_select_query")
if run_query_tool:
try:
result = await self._call_tool_with_timeout(
run_query_tool,
# {"query": test_sql},
{"input": {"sql": test_sql, "row_limit": TEST_QUERY_LIMIT}},
)
except TimeoutError:
query_error = (
f"Query execution timed out after {self.query_timeout} seconds"
)
if self.enable_logging:
logger.error(query_error)
except Exception as e:
query_error = f"Query execution failed: {str(e)}"
if self.enable_logging:
logger.error(query_error, exc_info=True)
else:
# Process result only if no exception occurred
if "result" in locals():
if isinstance(result, list):
query_results = result
elif isinstance(result, str) and (
"error" in result.lower() or "query error" in result.lower()
):
query_error = result
except Exception as e:
query_error = f"Unexpected error during test query execution: {str(e)}"
if self.enable_logging:
logger.error(query_error, exc_info=True)
# Calculate confidence score and analysis in a single LLM call (optimization)
confidence, confidence_reasoning = await self._score_and_analyze_query(
user_query, sql_query, schema_info, query_results, query_error
)
# Validate SQL syntax - check if it's a valid SELECT query
sql_is_valid = self._validate_sql_syntax(sql_query)
# Check if analysis indicates critical syntax/structural issues (incomplete answers handled by confidence score)
has_critical_issues = self._has_critical_issues(
confidence_reasoning, query_error, sql_query
)
# Determine if refinement is needed (decision made by edge)
# Refine if: error exists OR low confidence OR invalid SQL OR critical issues
# Skip if: very high confidence with valid SQL and no error
has_error = query_error is not None
has_low_confidence = confidence < LOW_CONFIDENCE_THRESHOLD
should_refine = (
has_error or has_low_confidence or not sql_is_valid or has_critical_issues
)
# Skip refinement if confidence is very high and SQL is valid
if should_refine:
is_very_high_confidence = confidence >= VERY_HIGH_CONFIDENCE_THRESHOLD
is_high_confidence_no_error = (
previous_error is None
and confidence >= HIGH_CONFIDENCE_THRESHOLD
and sql_is_valid
)
if (
is_very_high_confidence and sql_is_valid and not has_error
) or is_high_confidence_no_error:
should_refine = False
# Determine refine reason if needed
refine_reason = None
if should_refine:
if not sql_is_valid:
refine_reason = "SQL syntax is invalid"
elif has_critical_issues:
refine_reason = "Critical issues detected in analysis"
elif has_low_confidence:
refine_reason = f"Low confidence ({confidence:.2f})"
# Build message with confidence score, sample results, and reasoning
confidence_msg = f"Generated SQL query (confidence: {confidence:.2f}):\n```sql\n{sql_query}\n```"
if query_error:
confidence_msg += f"\n\nQuery Error:\n{query_error}"
elif query_results is not None:
confidence_msg += f"\n\nSample Results (first {len(query_results)} rows):\n{json.dumps(query_results, indent=2, default=str)}"
# Always include reasoning
confidence_msg += f"\n\nAnalysis:\n{confidence_reasoning}"
new_messages = [AIMessage(content=confidence_msg)]
return {
"messages": new_messages,
"query_attempts": query_attempts + 1,
"previous_error": None, # Clear previous error after generating new SQL
"previous_sql": None, # Clear previous SQL after generating new SQL
"final_sql": sql_query, # Store final SQL to avoid re-extraction
"test_query_results": (
query_results if query_error is None else None
), # Store test results if successful
# Set refinement decision flags for edge
"should_refine": should_refine,
"refine_reason": refine_reason,
"confidence": confidence,
"sql_is_valid": sql_is_valid,
"has_critical_issues": has_critical_issues,
"confidence_reasoning": confidence_reasoning,
"query_error": query_error,
"final_sql": sql_query, # Store SQL for refinement node
}
async def _refine_sql_node(self, state: AgentState) -> dict:
"""Refine SQL query based on confidence score and analysis.
Called when should_refine flag is True. Fetches missing schema if needed,
refines SQL, re-executes test query, and recalculates confidence.
Can loop back to itself if confidence is still low (up to max_refinements).
"""
messages = state["messages"]
schema_info = state.get("schema_info", {})
refinement_attempts = state.get("refinement_attempts", 0)
# Check if we've exceeded max refinements
if refinement_attempts >= self.max_refinements:
if self.enable_logging:
logger.warning(
f"Max refinements ({self.max_refinements}) reached, proceeding to execution"
)
return {
"should_refine": False, # Stop refining
"refinement_attempts": refinement_attempts,
}
user_query = None
for msg in reversed(messages):
if isinstance(msg, HumanMessage):
user_query = msg.content
break
if not user_query:
return {}
# Get SQL and analysis from state
sql_query = state.get("final_sql")
if not sql_query:
sql_query = self._extract_sql_from_messages(messages)
if not sql_query:
return {}
confidence_reasoning = state.get("confidence_reasoning", "")
query_error = state.get("query_error")
query_results = state.get("test_query_results")
refine_reason = state.get("refine_reason", "Low confidence or errors detected")
# Check if we need to fetch schema info for new tables mentioned in error/analysis
await self._fetch_missing_table_info(
schema_info, query_error, confidence_reasoning, sql_query
)
# Refine the SQL using the analysis/reasoning
refined_sql = await self._refine_sql_with_analysis(
user_query,
sql_query,
schema_info,
confidence_reasoning,
query_results,
query_error,
)
# Validate refined SQL is not empty
if not refined_sql or not refined_sql.strip():
self._log("warning", "Refined SQL is empty, using original SQL")
refined_sql = sql_query
if refined_sql == sql_query:
# Refinement didn't change SQL, return original
return {"final_sql": sql_query, "should_refine": False}
# Re-execute the refined query to get new results
query_results = None
query_error = None
try:
test_sql = refined_sql
if "LIMIT" not in test_sql.upper():
test_sql = f"{refined_sql.rstrip(';')} LIMIT {TEST_QUERY_LIMIT}"
else:
test_sql = self._regex_patterns["limit_replace"].sub(
f" LIMIT {TEST_QUERY_LIMIT}", test_sql
)
run_query_tool = await self._get_tool("run_query_json")
# run_query_tool = await self._get_tool("run_select_query")
if run_query_tool:
try:
result = await self._call_tool_with_timeout(
run_query_tool,
# {"query": test_sql},
{"input": {"sql": test_sql, "row_limit": TEST_QUERY_LIMIT}},
)
except TimeoutError:
query_error = (
f"Query execution timed out after {self.query_timeout} seconds"
)
if self.enable_logging:
logger.error(query_error)
except Exception as e:
query_error = f"Query execution failed: {str(e)}"
if self.enable_logging:
logger.error(query_error, exc_info=True)
else:
# Process result only if no exception occurred
if "result" in locals():
if isinstance(result, list):
query_results = result
elif isinstance(result, str) and (
"error" in result.lower() or "query error" in result.lower()
):
query_error = result
except Exception as e:
query_error = str(e)
# Recalculate confidence and analysis for refined query
confidence, confidence_reasoning = await self._score_and_analyze_query(
user_query, refined_sql, schema_info, query_results, query_error
)
# Validate refined SQL
sql_is_valid = self._validate_sql_syntax(refined_sql)
has_critical_issues = self._has_critical_issues(
confidence_reasoning, query_error, refined_sql
)
# Check if we need to refine again (same logic as generate_sql node)
has_error = query_error is not None
has_low_confidence = confidence < LOW_CONFIDENCE_THRESHOLD
should_refine_again = (
has_error or has_low_confidence or not sql_is_valid or has_critical_issues
)
# Skip further refinement if confidence is very high and SQL is valid
if should_refine_again:
is_very_high_confidence = confidence >= VERY_HIGH_CONFIDENCE_THRESHOLD
is_high_confidence_no_error = (
query_error is None
and confidence >= HIGH_CONFIDENCE_THRESHOLD
and sql_is_valid
)
if (
is_very_high_confidence and sql_is_valid and not has_error
) or is_high_confidence_no_error:
should_refine_again = False
# Determine refine reason if needed
refine_reason_again = None
if should_refine_again:
if not sql_is_valid:
refine_reason_again = "SQL syntax is invalid"
elif has_critical_issues:
refine_reason_again = "Critical issues detected in analysis"
elif has_low_confidence:
refine_reason_again = f"Low confidence ({confidence:.2f})"
elif has_error:
refine_reason_again = "Query execution error"
# Update message with refined SQL
confidence_msg = f"Generated SQL query (confidence: {confidence:.2f}):\n```sql\n{refined_sql}\n```"
if query_error:
confidence_msg += f"\n\nQuery Error:\n{query_error}"
elif query_results is not None:
confidence_msg += f"\n\nSample Results (first {len(query_results)} rows):\n{json.dumps(query_results, indent=2, default=str)}"
confidence_msg += f"\n\nAnalysis:\n{confidence_reasoning}"
confidence_msg += f"\n\n⚠️ Query was refined due to: {refine_reason}"
# Check if we've reached max refinements
if should_refine_again and refinement_attempts + 1 >= self.max_refinements:
should_refine_again = False
confidence_msg += f"\n\n⚠️ Max refinements ({self.max_refinements}) reached, proceeding to execution despite low confidence."
elif should_refine_again:
confidence_msg += f"\n\n⚠️ Confidence still low, will refine again... (attempt {refinement_attempts + 1}/{self.max_refinements})"
return {
"messages": [AIMessage(content=confidence_msg)],
"final_sql": refined_sql,
"test_query_results": query_results if query_error is None else None,
"confidence": confidence,
"sql_is_valid": sql_is_valid,
"has_critical_issues": has_critical_issues,
"confidence_reasoning": confidence_reasoning,
"query_error": query_error,
"should_refine": should_refine_again, # Keep refining if confidence still low
"refine_reason": refine_reason_again if should_refine_again else None,
"refinement_attempts": refinement_attempts
+ 1, # Increment refinement counter
}
async def _execute_query_node(self, state: AgentState) -> dict:
"""Execute the generated SQL query.
Uses stored SQL from state if available. Reuses test query results
when they contain all data to avoid redundant execution.
"""
messages = state["messages"]
# Use stored SQL from state if available (avoids re-extraction)
sql_query = state.get("final_sql")
test_results = state.get("test_query_results")
# If SQL not in state, extract from message (fallback)
if not sql_query:
sql_query = self._extract_sql_from_messages(messages)
if not sql_query:
return {
"messages": [
ToolMessage(
content="Error: Could not extract SQL query from the response.",
tool_call_id="sql_extraction_error",
)
]
}
# Reuse test query results if they contain all data (avoids redundant execution)
needs_full_execution = True
result = None
if test_results is not None and len(test_results) < TEST_QUERY_LIMIT:
sql_upper = sql_query.upper()
if "LIMIT" in sql_upper:
limit_match = re.search(r"LIMIT\s+(\d+)", sql_upper)
if limit_match:
try:
original_limit = int(limit_match.group(1))
if original_limit > 0 and original_limit <= TEST_QUERY_LIMIT:
# Test query already got all results, reuse them
# But respect the original LIMIT - only return as many rows as requested
results_to_return = (
test_results[:original_limit]
if original_limit <= len(test_results)
else test_results
)
needs_full_execution = False
result = json.dumps(
results_to_return, indent=2, default=str
)
except (ValueError, TypeError) as e:
if self.enable_logging:
logger.warning(
f"Could not parse LIMIT value: {str(e)}, executing full query"
)
# Continue with full execution
# Execute the query only if needed
if needs_full_execution:
try:
run_query_tool = await self._get_tool("run_query")
if not run_query_tool:
return {
"messages": [
ToolMessage(
content="Error: run_query tool not available.",
tool_call_id="run_query_error",
)
]
}
try:
result = await self._call_tool_with_timeout(
run_query_tool,
{
"input": {
"sql": sql_query,
"format": "markdown",
"row_limit": 100,
}
},
)
except TimeoutError:
error_msg = (
f"Query execution timed out after {self.query_timeout} seconds"
)
if self.enable_logging:
logger.error(error_msg)
return {
"messages": [
ToolMessage(
content=error_msg, tool_call_id="run_query_timeout"
)
]
}
except Exception as e:
error_msg = f"Error executing query: {str(e)}"
return {
"messages": [
ToolMessage(content=error_msg, tool_call_id="run_query_error")
]
}
# Check if there was an error
if result is None:
return {
"messages": [
ToolMessage(
content="Error: Query execution returned no result.",
tool_call_id="run_query_error",
)
]
}
if isinstance(result, str):
result_lower = result.lower()
if (
"error" in result_lower
or "query error" in result_lower
or "permission" in result_lower
):
return {
"messages": [
ToolMessage(
content=f"Query execution failed: {result}",
tool_call_id="run_query_error",
)
]
}
# Success!
return {
"messages": [
ToolMessage(
content=f"Query executed successfully:\n\n{result}",
tool_call_id="run_query_success",
)
]
}
async def _refine_query_node(self, state: AgentState) -> dict:
"""Refine the SQL query based on error feedback.
Called when query execution fails. Parses error, updates schema if needed,
and passes error context to generate_sql node for regeneration.
"""
messages = state["messages"]
schema_info = state.get("schema_info", {})
# Get the error message
error_msg = None
for msg in reversed(messages):
if isinstance(msg, ToolMessage):
content = msg.content
if content and isinstance(content, str) and "error" in content.lower():
error_msg = content
break
# Parse the error to extract actionable information
parsed_error = self._parse_sql_error(error_msg) if error_msg else None
# If error mentions unknown column, try to get fresh schema info for relevant tables
if parsed_error and parsed_error["type"] == "unknown_column":
# Extract table names that might be involved
table_names = schema_info.get("table_names", [])
describe_table_tool = await self._get_tool("describe_table")
if describe_table_tool and table_names:
# Re-fetch descriptions for all tables to ensure we have latest info
table_descriptions = {}
for table_name in table_names:
try:
desc = await self._call_tool_with_timeout(
describe_table_tool, {"table_name": table_name}
)
table_descriptions[table_name] = desc
except Exception as e:
if self.enable_logging:
logger.warning(
f"Failed to fetch description for table {table_name}: {str(e)}"
)
continue
# Update schema info with fresh descriptions
schema_info["table_descriptions"] = table_descriptions
# Build refined error message with parsed error details
if parsed_error:
if parsed_error["type"] == "unknown_column":
error_guidance = (
f"\n\n⚠️ ERROR: Column '{parsed_error['column']}' does not exist!"
)
if parsed_error.get("table"):
error_guidance += f" (in table '{parsed_error['table']}')"
error_guidance += "\n\nPlease check the column names summary in the schema information above and use the EXACT column names shown."
elif parsed_error["type"] == "unknown_table":
error_guidance = (
f"\n\n⚠️ ERROR: Table '{parsed_error['table']}' does not exist!"
)
error_guidance += "\n\nPlease check the available tables list and use the correct table name."
elif parsed_error["type"] == "syntax_error":
error_guidance = f"\n\n⚠️ ERROR: SQL syntax error detected!"
error_guidance += "\n\nPlease review the SQL syntax carefully. Make sure there are no ? placeholders, and all values are properly embedded in the query."
else:
error_guidance = f"\n\n⚠️ ERROR: {parsed_error['message']}"
else:
error_guidance = f"\n\nPrevious query failed. Error: {error_msg}\n\nPlease check the table structures carefully and use the EXACT column names shown."
# Get stored SQL from state or extract from messages
original_sql = state.get("previous_sql") or self._extract_sql_from_messages(
messages
)
# Get the original user query
user_query = None
for msg in messages:
if isinstance(msg, HumanMessage):
content = msg.content
if content: # Ensure content is not None or empty
user_query = content
break
# Pass error context to generate_sql node (it will optionally use it)
# Just add a message indicating we're retrying with error context
refine_msg = f"Previous query failed.{error_guidance}\n\nRegenerating query with error context..."
return {
"messages": [HumanMessage(content=refine_msg)],
"schema_info": schema_info, # Include updated schema info
"previous_error": error_msg, # Pass error to generate_sql as optional context
"previous_sql": original_sql, # Pass previous SQL to generate_sql as optional context
}
async def _tools_node(self, state: AgentState) -> dict:
"""Handle tool calls from the LLM.
Processes tool calls from AI messages and executes them via MCP tools.
"""
messages = state["messages"]
last_msg = messages[-1] if messages else None
if last_msg and hasattr(last_msg, "tool_calls") and last_msg.tool_calls:
await self._initialize_tools() # Ensure tools are initialized
tool_results = []
for tool_call in last_msg.tool_calls:
if not isinstance(tool_call, dict):
continue
tool_name = tool_call.get("name")
if not tool_name:
tool_results.append(
ToolMessage(
content="Error: Tool call missing name.",
tool_call_id=tool_call.get("id", "unknown"),
)
)
continue
tool_args = tool_call.get("args", {})
if not isinstance(tool_args, dict):
tool_args = {}
try:
tool = await self._get_tool(tool_name)
if tool:
try:
result = await self._call_tool_with_timeout(tool, tool_args)
tool_results.append(
ToolMessage(
content=str(result), tool_call_id=tool_call["id"]
)
)
except TimeoutError:
tool_results.append(
ToolMessage(
content=f"Error: Tool {tool_name} call timed out after {self.query_timeout} seconds",
tool_call_id=tool_call["id"],
)
)
except Exception as e:
tool_results.append(
ToolMessage(
content=f"Error calling {tool_name}: {str(e)}",
tool_call_id=tool_call["id"],
)
)
else:
tool_results.append(
ToolMessage(
content=f"Error: Tool {tool_name} not found.",
tool_call_id=tool_call["id"],
)
)
except Exception as e:
tool_results.append(
ToolMessage(
content=f"Error calling {tool_name}: {str(e)}",
tool_call_id=tool_call["id"],
)
)
return {"messages": tool_results}
return {}
# ========================================================================
# Edge Methods (Routing Decisions)
# ========================================================================
def _should_continue_schema_exploration(
self, state: AgentState
) -> Literal["complete", "continue"]:
"""Decide whether schema exploration is complete.
Returns:
"complete" if schema is ready, "continue" if more fetching needed
"""
return "complete" if state.get("schema_complete", False) else "continue"
def _should_refine_sql(
self, state: AgentState
) -> Literal["refine", "execute", "use_tools", "end"]:
"""Decide routing after SQL generation.
Returns:
"refine" if SQL needs refinement
"execute" if SQL is ready to execute
"use_tools" if LLM wants to use tools
"end" if no SQL found
"""
# Check if refinement is needed (set by generate_sql node)
if state.get("should_refine", False):
return "refine"
# Check for tool calls or SQL in message
messages = state["messages"]
last_msg = messages[-1] if messages else None
if last_msg and isinstance(last_msg, AIMessage):
if hasattr(last_msg, "tool_calls") and last_msg.tool_calls:
return "use_tools"
content = last_msg.content
if content: # Ensure content is not None or empty
has_sql = "SELECT" in content.upper() or (
"sql" in content.lower() and "```" in content
)
if has_sql:
return "execute"
return "end"
def _check_query_result(
self, state: AgentState
) -> Literal["success", "retry", "error"]:
"""Check query execution result and decide next step.
Returns:
"success" if query executed successfully
"retry" if error occurred and attempts remain
"error" if max attempts reached
"""
messages = state["messages"]
query_attempts = state.get("query_attempts", 0)
max_attempts = state.get("max_attempts", self.max_attempts)
last_msg = messages[-1] if messages else None
if last_msg and isinstance(last_msg, ToolMessage):
content = last_msg.content.lower()
if "successfully" in content:
return "success"
elif "error" in content or "failed" in content:
return "retry" if query_attempts < max_attempts else "error"
return "success" # Default to success if unclear
# ========================================================================
# Public Methods
# ========================================================================
async def query(self, question: str, config: Optional[dict] = None) -> dict:
"""Execute a text-to-SQL query.
Each query starts with a fresh state. Schema cache persists across queries
for performance, but query-specific state (relevant_tables, previous_error, etc.)
is reset for each new query.
Args:
question: The natural language question to convert to SQL
config: Optional LangGraph configuration
Returns:
AgentState dictionary with query results
Raises:
ValueError: If question is None, empty, or not a string
"""
# Input validation
if question is None:
raise ValueError(
"Question cannot be None. Please provide a valid database question."
)
if not isinstance(question, str):
raise ValueError(
f"Question must be a string, got {type(question).__name__}"
)
question = question.strip()
if not question:
raise ValueError(
"Question cannot be empty. Please provide a valid database question."
)
self._log("info", f"Starting query: {question[:100]}...")
initial_state: AgentState = {
"messages": [HumanMessage(content=question)],
"schema_info": {}, # Fresh schema_info for this query
"query_attempts": 0,
"max_attempts": self.max_attempts,
"relevant_tables": [], # Will be populated by explore_schema node
# Clear all query-specific flags for fresh query
"previous_error": None,
"previous_sql": None,
"final_sql": None,
"test_query_results": None,
"should_refine": None,
"refine_reason": None,
"confidence": None,
"sql_is_valid": None,
"has_critical_issues": None,
"confidence_reasoning": None,
"query_error": None,
"schema_complete": None,
"refinement_attempts": 0, # Reset refinement counter for new query
}
config = config or {}
try:
result = await self.graph.ainvoke(initial_state, config)
self._log("info", "Query completed successfully")
return result
except Exception as e:
self._log("error", f"Query execution failed: {str(e)}", exc_info=True)
raise
def get_final_answer(self, result: dict) -> str:
"""Extract the final answer from the agent result"""
messages = result.get("messages", [])
if not messages:
return "No answer found"
# Look for the last successful query result (check in reverse order)
for msg in reversed(messages):
if isinstance(msg, ToolMessage):
content = msg.content
if content and "successfully" in content.lower():
return content
elif isinstance(msg, AIMessage):
# Check if this is the last message (more efficient than index())
if msg is messages[-1]:
content = msg.content
return content if content else "No answer found"
# Fallback: return last message content
last_msg = messages[-1]
if hasattr(last_msg, "content"):
content = last_msg.content
return str(content) if content else "No answer found"
return "No answer found"