"""
MCP Tools Service
Provides PostgreSQL active operations as MCP Tools (actions that DO things)
"""
import json
import logging
from typing import Dict, Any, List, Optional
from pathlib import Path
from sqlalchemy import create_engine, text
from postgres_integration import PostgreSQLIntegration
logger = logging.getLogger(__name__)
class MCPToolsService:
"""Service that exposes PostgreSQL active operations as MCP Tools"""
def __init__(self, config: Dict[str, Any]):
self.config = config
self.db_integrations: Dict[str, PostgreSQLIntegration] = {}
self._load_tool_definitions()
def _load_tool_definitions(self):
"""Load MCP tool definitions from configuration"""
try:
tools_file = Path(__file__).parent.parent / "resources" / "lists" / "mcp_resources_and_tools.json"
with open(tools_file, 'r') as f:
definitions = json.load(f)
self.tool_definitions = definitions.get("mcp_tools", {})
except Exception as e:
logger.warning(f"Could not load tool definitions: {e}")
self.tool_definitions = {}
def get_database_integration(self, database: str = "db3") -> PostgreSQLIntegration:
"""Get or create PostgreSQL integration for a database"""
if database not in self.db_integrations:
# Try to import config, fall back to default if not available
try:
import sys
from pathlib import Path
# Add the parent directory to sys.path temporarily
parent_dir = str(Path(__file__).parent.parent)
if parent_dir not in sys.path:
sys.path.insert(0, parent_dir)
from config import Config
connection_string = Config.SQLALCHEMY_BINDS.get(database)
except ImportError:
# Fallback to default connection strings
connection_strings = {
'db1': 'postgresql://admin:password@192.168.230.101/defaultdb?connect_timeout=1',
'db2': 'postgresql://admin:password@192.168.230.102/defaultdb?connect_timeout=1',
'db3': 'postgresql://postgres:postgres@localhost/postgres?connect_timeout=1'
}
connection_string = connection_strings.get(database)
if not connection_string:
raise ValueError(f"Database '{database}' not configured")
engine = create_engine(connection_string)
self.db_integrations[database] = PostgreSQLIntegration(engine, database)
return self.db_integrations[database]
async def execute_safe_sql(self, sql_query: str, database: str = "db3", limit: int = 100) -> Dict[str, Any]:
"""TOOL: Execute SQL query safely with validation"""
try:
integration = self.get_database_integration(database)
# Validate SQL safety
query = sql_query.strip().upper()
if not query.startswith('SELECT'):
return {
"success": False,
"error": "Only SELECT statements are allowed for safety",
"query": sql_query
}
# Add LIMIT if not present
if 'LIMIT' not in sql_query.upper():
sql_query += f" LIMIT {limit}"
# Execute the query using safe_run_sql
result = integration.safe_run_sql(sql_query)
return {
"success": True,
"database": database,
"query": sql_query,
"result": result,
"operation": "sql_execution"
}
except Exception as e:
logger.error(f"Error executing SQL: {e}")
return {
"success": False,
"error": str(e),
"database": database,
"query": sql_query,
"operation": "sql_execution"
}
async def validate_sql_syntax(self, sql_query: str) -> Dict[str, Any]:
"""TOOL: Validate SQL syntax without executing"""
try:
# Basic SQL validation
query = sql_query.strip()
# Check for dangerous operations
dangerous_keywords = ['DROP', 'DELETE', 'UPDATE', 'INSERT', 'ALTER', 'CREATE', 'TRUNCATE']
query_upper = query.upper()
for keyword in dangerous_keywords:
if keyword in query_upper:
return {
"success": False,
"valid": False,
"error": f"Dangerous operation detected: {keyword}",
"query": sql_query,
"operation": "sql_validation"
}
# Check basic SQL structure
if not query_upper.startswith('SELECT'):
return {
"success": False,
"valid": False,
"error": "Only SELECT statements are supported",
"query": sql_query,
"operation": "sql_validation"
}
# Perform syntax check using SQLAlchemy
try:
text(sql_query)
return {
"success": True,
"valid": True,
"query": sql_query,
"operation": "sql_validation"
}
except Exception as syntax_error:
return {
"success": False,
"valid": False,
"error": f"Syntax error: {str(syntax_error)}",
"query": sql_query,
"operation": "sql_validation"
}
except Exception as e:
logger.error(f"Error validating SQL: {e}")
return {
"success": False,
"error": str(e),
"query": sql_query,
"operation": "sql_validation"
}
async def search_tables_for_concept(self, concept: str, database: str = "db3", table_names: Optional[List[str]] = None) -> Dict[str, Any]:
"""TOOL: Actively search across tables for a business concept"""
try:
integration = self.get_database_integration(database)
# Implement search functionality using available methods
all_tables = integration.get_all_table_names()
if table_names:
search_tables = [t for t in all_tables if t in table_names]
else:
search_tables = all_tables
results = []
concept_lower = concept.lower()
# Search table names
for table in search_tables:
if concept_lower in table.lower():
results.append({
"table": table,
"match_type": "table_name",
"relevance": "high"
})
continue
# Search column names in table schema
try:
schema = integration.get_table_schema(table)
columns = schema.get('columns', [])
for col in columns:
col_name = col.get('column_name', '')
if concept_lower in col_name.lower():
results.append({
"table": table,
"match_type": "column_name",
"column": col_name,
"relevance": "medium"
})
break
except Exception:
# Skip tables that can't be accessed
continue
return {
"success": True,
"database": database,
"concept": concept,
"search_results": results,
"tables_searched": table_names or "all",
"operation": "concept_search"
}
except Exception as e:
logger.error(f"Error searching for concept '{concept}': {e}")
return {
"success": False,
"error": str(e),
"database": database,
"concept": concept,
"operation": "concept_search"
}
async def find_related_data(self, base_table: str, relationship_type: str = "all", database: str = "db3") -> Dict[str, Any]:
"""TOOL: Actively find and analyze relationships between data"""
try:
integration = self.get_database_integration(database)
# Get table schema to understand relationships
base_schema = integration.get_table_schema(base_table)
all_tables = integration.get_all_table_names()
relationships = {
"foreign_key_relationships": [],
"semantic_relationships": [],
"data_pattern_relationships": []
}
# Analyze foreign key relationships
if relationship_type in ["all", "foreign_key"]:
for constraint in base_schema.get("constraints", []):
if constraint.get("type") == "FOREIGN KEY":
relationships["foreign_key_relationships"].append({
"related_table": constraint.get("referenced_table"),
"local_column": constraint.get("column_name"),
"foreign_column": constraint.get("referenced_column"),
"relationship_strength": "strong"
})
# Analyze semantic relationships (tables with similar column names)
if relationship_type in ["all", "semantic"]:
base_columns = [col["column_name"] for col in base_schema.get("columns", [])]
for table in all_tables:
if table != base_table:
try:
table_schema = integration.get_table_schema(table)
table_columns = [col["column_name"] for col in table_schema.get("columns", [])]
# Find common columns
common_columns = set(base_columns) & set(table_columns)
if common_columns:
relationships["semantic_relationships"].append({
"related_table": table,
"common_columns": list(common_columns),
"relationship_strength": "medium" if len(common_columns) > 1 else "weak"
})
except:
continue
return {
"success": True,
"database": database,
"base_table": base_table,
"relationship_type": relationship_type,
"relationships": relationships,
"operation": "relationship_analysis"
}
except Exception as e:
logger.error(f"Error finding related data for '{base_table}': {e}")
return {
"success": False,
"error": str(e),
"database": database,
"base_table": base_table,
"operation": "relationship_analysis"
}
async def analyze_query_performance(self, sql_query: str, database: str = "db3") -> Dict[str, Any]:
"""TOOL: Analyze SQL query performance characteristics"""
try:
integration = self.get_database_integration(database)
# Get query execution plan
explain_query = f"EXPLAIN (ANALYZE false, FORMAT JSON) {sql_query}"
with integration.engine.connect() as connection:
result = connection.execute(text(explain_query))
execution_plan = result.fetchone()[0]
# Analyze the plan for performance insights
analysis = {
"estimated_cost": None,
"estimated_rows": None,
"scan_types": [],
"join_types": [],
"performance_warnings": []
}
def analyze_plan_node(node):
if isinstance(node, dict):
analysis["estimated_cost"] = node.get("Total Cost", 0)
analysis["estimated_rows"] = node.get("Plan Rows", 0)
node_type = node.get("Node Type", "")
if "Scan" in node_type:
analysis["scan_types"].append(node_type)
if "Join" in node_type:
analysis["join_types"].append(node_type)
# Performance warnings
if "Seq Scan" in node_type and node.get("Plan Rows", 0) > 10000:
analysis["performance_warnings"].append("Large sequential scan detected")
# Recursively analyze child plans
for child in node.get("Plans", []):
analyze_plan_node(child)
if execution_plan and len(execution_plan) > 0:
analyze_plan_node(execution_plan[0].get("Plan", {}))
return {
"success": True,
"database": database,
"query": sql_query,
"execution_plan": execution_plan,
"analysis": analysis,
"operation": "performance_analysis"
}
except Exception as e:
logger.error(f"Error analyzing query performance: {e}")
return {
"success": False,
"error": str(e),
"database": database,
"query": sql_query,
"operation": "performance_analysis"
}
async def analyze_data_patterns(self, table_name: str, sample_size: int = 1000, database: str = "db3") -> Dict[str, Any]:
"""TOOL: Actively analyze data patterns and distributions"""
try:
integration = self.get_database_integration(database)
# Get table schema first
schema = integration.get_table_schema(table_name)
columns = schema.get("columns", [])
# Sample data for analysis
sample_query = f"SELECT * FROM {table_name} TABLESAMPLE SYSTEM(10) LIMIT {sample_size}"
with integration.engine.connect() as connection:
result = connection.execute(text(sample_query))
sample_data = result.fetchall()
# Analyze patterns
patterns = {
"column_analysis": {},
"data_quality": {
"null_rates": {},
"unique_rates": {},
"data_types_consistent": True
},
"sample_size": len(sample_data),
"patterns_detected": []
}
if sample_data:
column_names = list(sample_data[0]._mapping.keys())
for col_name in column_names:
values = [row._mapping[col_name] for row in sample_data]
# Calculate null rate
null_count = sum(1 for v in values if v is None)
null_rate = null_count / len(values) if values else 0
# Calculate unique rate
unique_count = len(set(str(v) for v in values if v is not None))
unique_rate = unique_count / len(values) if values else 0
patterns["data_quality"]["null_rates"][col_name] = null_rate
patterns["data_quality"]["unique_rates"][col_name] = unique_rate
# Detect patterns
if null_rate > 0.5:
patterns["patterns_detected"].append(f"High null rate in {col_name} ({null_rate:.1%})")
if unique_rate == 1.0 and len(values) > 10:
patterns["patterns_detected"].append(f"All unique values in {col_name} (possible ID column)")
return {
"success": True,
"database": database,
"table_name": table_name,
"patterns": patterns,
"operation": "data_pattern_analysis"
}
except Exception as e:
logger.error(f"Error analyzing data patterns for '{table_name}': {e}")
return {
"success": False,
"error": str(e),
"database": database,
"table_name": table_name,
"operation": "data_pattern_analysis"
}
async def get_postgresql_logs(self, log_type: str = "recent", lines: int = 100, database: str = "db3") -> Dict[str, Any]:
"""TOOL: Safely retrieve PostgreSQL log files for troubleshooting"""
try:
integration = self.get_database_integration(database)
# Use PostgreSQL's built-in logging views and functions rather than file system access
with integration.engine.connect() as connection:
log_data = {
"log_entries": [],
"summary": {},
"log_type": log_type,
"lines_requested": lines
}
if log_type == "recent":
# Get recent activity from pg_stat_activity
result = connection.execute(text("""
SELECT
pid,
usename,
application_name,
client_addr,
backend_start,
state,
query_start,
state_change,
query
FROM pg_stat_activity
WHERE state IS NOT NULL
ORDER BY backend_start DESC
LIMIT :lines
"""), {"lines": lines})
for row in result:
log_data["log_entries"].append({
"timestamp": row.backend_start.isoformat() if row.backend_start else None,
"pid": row.pid,
"user": row.usename,
"application": row.application_name,
"client_addr": str(row.client_addr) if row.client_addr else None,
"state": row.state,
"query_start": row.query_start.isoformat() if row.query_start else None,
"state_change": row.state_change.isoformat() if row.state_change else None,
"query": row.query[:200] + "..." if row.query and len(row.query) > 200 else row.query
})
elif log_type == "errors":
# Check for recent errors in pg_stat_database_conflicts
result = connection.execute(text("""
SELECT
datname,
confl_tablespace,
confl_lock,
confl_snapshot,
confl_bufferpin,
confl_deadlock
FROM pg_stat_database_conflicts
WHERE datname = current_database()
"""))
for row in result:
if any([row.confl_tablespace, row.confl_lock, row.confl_snapshot, row.confl_bufferpin, row.confl_deadlock]):
log_data["log_entries"].append({
"database": row.datname,
"tablespace_conflicts": row.confl_tablespace,
"lock_conflicts": row.confl_lock,
"snapshot_conflicts": row.confl_snapshot,
"buffer_pin_conflicts": row.confl_bufferpin,
"deadlock_conflicts": row.confl_deadlock
})
elif log_type == "slow_queries":
# Get slow queries from pg_stat_statements if available
try:
result = connection.execute(text("""
SELECT
query,
calls,
total_exec_time,
mean_exec_time,
max_exec_time,
rows
FROM pg_stat_statements
WHERE mean_exec_time > 1000 -- Queries taking more than 1 second on average
ORDER BY mean_exec_time DESC
LIMIT :lines
"""), {"lines": lines})
for row in result:
log_data["log_entries"].append({
"query": row.query[:300] + "..." if len(row.query) > 300 else row.query,
"calls": row.calls,
"total_time_ms": float(row.total_exec_time),
"mean_time_ms": float(row.mean_exec_time),
"max_time_ms": float(row.max_exec_time),
"rows_affected": row.rows
})
except Exception:
# pg_stat_statements extension might not be installed
log_data["log_entries"].append({
"note": "pg_stat_statements extension not available - cannot retrieve slow query data"
})
elif log_type == "connections":
# Get connection statistics
result = connection.execute(text("""
SELECT
datname,
numbackends,
xact_commit,
xact_rollback,
blks_read,
blks_hit,
tup_returned,
tup_fetched,
tup_inserted,
tup_updated,
tup_deleted,
conflicts,
temp_files,
temp_bytes,
deadlocks,
stats_reset
FROM pg_stat_database
WHERE datname = current_database()
"""))
for row in result:
log_data["log_entries"].append({
"database": row.datname,
"active_connections": row.numbackends,
"committed_transactions": row.xact_commit,
"rolled_back_transactions": row.xact_rollback,
"blocks_read": row.blks_read,
"blocks_hit": row.blks_hit,
"tuples_returned": row.tup_returned,
"tuples_fetched": row.tup_fetched,
"tuples_inserted": row.tup_inserted,
"tuples_updated": row.tup_updated,
"tuples_deleted": row.tup_deleted,
"conflicts": row.conflicts,
"temp_files": row.temp_files,
"temp_bytes": row.temp_bytes,
"deadlocks": row.deadlocks,
"stats_reset": row.stats_reset.isoformat() if row.stats_reset else None
})
# Get summary statistics
summary_result = connection.execute(text("""
SELECT
(SELECT setting FROM pg_settings WHERE name = 'log_destination') as log_destination,
(SELECT setting FROM pg_settings WHERE name = 'logging_collector') as logging_collector,
(SELECT setting FROM pg_settings WHERE name = 'log_directory') as log_directory,
(SELECT setting FROM pg_settings WHERE name = 'log_filename') as log_filename,
(SELECT COUNT(*) FROM pg_stat_activity WHERE state = 'active') as active_connections,
(SELECT COUNT(*) FROM pg_stat_activity) as total_connections
"""))
summary_row = summary_result.fetchone()
log_data["summary"] = {
"log_destination": summary_row.log_destination,
"logging_collector": summary_row.logging_collector,
"log_directory": summary_row.log_directory,
"log_filename": summary_row.log_filename,
"active_connections": summary_row.active_connections,
"total_connections": summary_row.total_connections,
"entries_returned": len(log_data["log_entries"])
}
return {
"success": True,
"database": database,
"log_data": log_data,
"operation": "postgresql_log_retrieval"
}
except Exception as e:
logger.error(f"Error retrieving PostgreSQL logs: {e}")
return {
"success": False,
"error": str(e),
"database": database,
"log_type": log_type,
"operation": "postgresql_log_retrieval"
}
async def handle_tool_call(self, tool_name: str, arguments: Dict[str, Any]) -> Dict[str, Any]:
"""Handle MCP tool calls by routing to appropriate methods"""
tool_handlers = {
"execute_safe_sql": self.execute_safe_sql,
"validate_sql_syntax": self.validate_sql_syntax,
"search_tables_for_concept": self.search_tables_for_concept,
"find_related_data": self.find_related_data,
"analyze_query_performance": self.analyze_query_performance,
"analyze_data_patterns": self.analyze_data_patterns,
"get_postgresql_logs": self.get_postgresql_logs,
"read_resource": self.read_resource
}
handler = tool_handlers.get(tool_name)
if not handler:
return {
"success": False,
"error": f"Unknown tool: {tool_name}",
"available_tools": list(tool_handlers.keys())
}
try:
return await handler(**arguments)
except Exception as e:
logger.error(f"Error handling tool call {tool_name}: {e}")
return {
"success": False,
"error": str(e),
"tool_name": tool_name,
"arguments": arguments
}
def get_available_tools(self) -> Dict[str, Any]:
"""Get list of all available MCP tools"""
tools = []
for category, tool_list in self.tool_definitions.items():
for tool in tool_list:
tools.append({
"name": tool["name"],
"description": tool["description"],
"category": category,
"parameters": tool["parameters"]
})
return {
"tools": tools,
"categories": self.tool_definitions
}
async def call_tool(self, tool_name: str, arguments: Dict[str, Any]) -> Dict[str, Any]:
"""Alias for handle_tool_call to maintain compatibility with MCP SDK"""
return await self.handle_tool_call(tool_name, arguments)
async def read_resource(self, uri: str, **kwargs) -> Dict[str, Any]:
"""Read a resource by URI - bridge between MCP tools and resources"""
try:
# Import the resources service
from services.mcp_resources_service import MCPResourcesService
# Create a resources service instance
resources_service = MCPResourcesService(self.config)
# Call the read_resource method
result = await resources_service.read_resource(uri)
return {
"success": True,
"uri": uri,
"resource_data": result,
"operation": "resource_read"
}
except Exception as e:
logger.error(f"Error reading resource {uri}: {e}")
return {
"success": False,
"error": str(e),
"uri": uri,
"operation": "resource_read"
}