"""
MCP Database Tools Service
Provides PostgreSQL functions as MCP tools for LLM interaction
"""
import json
import logging
from typing import Dict, Any, List, Optional
from pathlib import Path
from sqlalchemy import create_engine
from postgres_integration import PostgreSQLIntegration
logger = logging.getLogger(__name__)
class MCPDatabaseToolsService:
"""Service that exposes PostgreSQL integration functions as MCP tools"""
def __init__(self, config: Dict[str, Any]):
self.config = config
self.db_integrations: Dict[str, PostgreSQLIntegration] = {}
self._load_tool_definitions()
def _load_tool_definitions(self):
"""Load MCP tool definitions from resources"""
try:
tools_file = Path(__file__).parent.parent / "resources" / "lists" / "mcp_database_tools.json"
with open(tools_file, 'r') as f:
self.tool_definitions = json.load(f)
except Exception as e:
logger.warning(f"Could not load tool definitions: {e}")
self.tool_definitions = {"database_tools": {}}
def get_database_integration(self, database: str = "db3") -> PostgreSQLIntegration:
"""Get or create PostgreSQL integration for a database"""
if database not in self.db_integrations:
# Get database config
from config import Config
connection_string = Config.SQLALCHEMY_BINDS.get(database)
if not connection_string:
raise ValueError(f"Database '{database}' not configured")
engine = create_engine(connection_string)
self.db_integrations[database] = PostgreSQLIntegration(engine, database)
return self.db_integrations[database]
async def get_all_table_names(self, database: str = "db3") -> Dict[str, Any]:
"""MCP Tool: Get all table names from the database"""
try:
integration = self.get_database_integration(database)
tables = integration.get_all_table_names()
return {
"success": True,
"database": database,
"tables": tables,
"count": len(tables)
}
except Exception as e:
logger.error(f"Error getting table names: {e}")
return {
"success": False,
"error": str(e),
"database": database
}
async def get_table_schema(self, table_name: str, database: str = "db3") -> Dict[str, Any]:
"""MCP Tool: Get detailed schema for a specific table"""
try:
integration = self.get_database_integration(database)
schema = integration.get_table_schema(table_name)
return {
"success": True,
"database": database,
"table_name": table_name,
"schema": schema
}
except Exception as e:
logger.error(f"Error getting table schema for {table_name}: {e}")
return {
"success": False,
"error": str(e),
"database": database,
"table_name": table_name
}
async def get_database_schema(self, database: str = "db3") -> Dict[str, Any]:
"""MCP Tool: Get complete database schema"""
try:
integration = self.get_database_integration(database)
schema = integration.get_database_schema()
return {
"success": True,
"database": database,
"schema": schema
}
except Exception as e:
logger.error(f"Error getting database schema: {e}")
return {
"success": False,
"error": str(e),
"database": database
}
async def discover_table_semantics(self, database: str = "db3") -> Dict[str, Any]:
"""MCP Tool: Discover table relationships and semantics"""
try:
integration = self.get_database_integration(database)
semantics = integration._discover_table_semantics()
return {
"success": True,
"database": database,
"semantics": semantics
}
except Exception as e:
logger.error(f"Error discovering table semantics: {e}")
return {
"success": False,
"error": str(e),
"database": database
}
async def get_table_row_count(self, table_name: str, database: str = "db3") -> Dict[str, Any]:
"""MCP Tool: Get row count for a specific table"""
try:
integration = self.get_database_integration(database)
count = integration.get_table_row_count(table_name)
return {
"success": True,
"database": database,
"table_name": table_name,
"row_count": count
}
except Exception as e:
logger.error(f"Error getting row count for {table_name}: {e}")
return {
"success": False,
"error": str(e),
"database": database,
"table_name": table_name
}
async def get_table_size(self, table_name: str, database: str = "db3") -> Dict[str, Any]:
"""MCP Tool: Get storage size information for a table"""
try:
integration = self.get_database_integration(database)
size_info = integration.get_table_size(table_name)
return {
"success": True,
"database": database,
"table_name": table_name,
"size_info": size_info
}
except Exception as e:
logger.error(f"Error getting table size for {table_name}: {e}")
return {
"success": False,
"error": str(e),
"database": database,
"table_name": table_name
}
async def get_database_stats(self, database: str = "db3") -> Dict[str, Any]:
"""MCP Tool: Get overall database statistics"""
try:
integration = self.get_database_integration(database)
stats = integration.get_database_stats()
return {
"success": True,
"database": database,
"stats": stats
}
except Exception as e:
logger.error(f"Error getting database stats: {e}")
return {
"success": False,
"error": str(e),
"database": database
}
async def execute_safe_sql(self, sql_query: str, database: str = "db3", limit: int = 100) -> Dict[str, Any]:
"""MCP Tool: Execute SQL query safely with validation"""
try:
integration = self.get_database_integration(database)
# Add LIMIT if not present and it's a SELECT
query = sql_query.strip()
if query.upper().startswith('SELECT') and 'LIMIT' not in query.upper():
query += f" LIMIT {limit}"
result = integration.execute_safe_sql(query)
return {
"success": True,
"database": database,
"query": query,
"result": result
}
except Exception as e:
logger.error(f"Error executing SQL: {e}")
return {
"success": False,
"error": str(e),
"database": database,
"query": sql_query
}
async def search_tables_for_concept(self, concept: str, database: str = "db3", table_names: Optional[List[str]] = None) -> Dict[str, Any]:
"""MCP Tool: Search tables for a business concept"""
try:
integration = self.get_database_integration(database)
results = integration.search_tables_for_concept(concept, table_names)
return {
"success": True,
"database": database,
"concept": concept,
"search_results": results
}
except Exception as e:
logger.error(f"Error searching for concept '{concept}': {e}")
return {
"success": False,
"error": str(e),
"database": database,
"concept": concept
}
def get_available_tools(self) -> Dict[str, Any]:
"""Get list of all available MCP tools"""
tools = []
for category, tool_list in self.tool_definitions.get("database_tools", {}).items():
for tool in tool_list:
tools.append({
"name": tool["name"],
"description": tool["description"],
"category": category,
"parameters": tool["parameters"],
"returns": tool["returns"]
})
return {
"tools": tools,
"categories": self.tool_definitions.get("tool_categories", {}),
"safety_features": self.tool_definitions.get("safety_features", {})
}
async def handle_tool_call(self, tool_name: str, arguments: Dict[str, Any]) -> Dict[str, Any]:
"""Handle MCP tool calls by routing to appropriate methods"""
# Map tool names to methods
tool_handlers = {
"get_all_table_names": self.get_all_table_names,
"get_table_schema": self.get_table_schema,
"get_database_schema": self.get_database_schema,
"discover_table_semantics": self.discover_table_semantics,
"get_table_row_count": self.get_table_row_count,
"get_table_size": self.get_table_size,
"get_database_stats": self.get_database_stats,
"execute_safe_sql": self.execute_safe_sql,
"search_tables_for_concept": self.search_tables_for_concept
}
handler = tool_handlers.get(tool_name)
if not handler:
return {
"success": False,
"error": f"Unknown tool: {tool_name}",
"available_tools": list(tool_handlers.keys())
}
try:
return await handler(**arguments)
except Exception as e:
logger.error(f"Error handling tool call {tool_name}: {e}")
return {
"success": False,
"error": str(e),
"tool_name": tool_name,
"arguments": arguments
}