import sqlite3
from loguru import logger
#from mcp.server.fastmcp import FastMCP
from fastmcp import FastMCP
from typing import List, Dict, Any
import re
import warnings
warnings.filterwarnings("ignore", category=RuntimeWarning, message="async generator ignored GeneratorExit")
# Initialize MCP server
mcp = FastMCP("SQLite MCP Server")
# SQLite database configuration
DB_PATH = "../data/chinook.db"
# Security helpers
def validate_table_name(table_name: str) -> bool:
"""Validate table name to prevent SQL injection.
Args:
table_name (str): The name of the table to validate.
Returns:
bool: True if the table name is valid (alphanumeric with underscores, starting with a letter or underscore), False otherwise.
"""
return bool(re.match(r"^[a-zA-Z_][a-zA-Z0-9_]*$", table_name))
def is_select_query(query: str) -> bool:
"""Ensure query is a SELECT statement to prevent destructive operations.
Args:
query (str): The SQL query to validate.
Returns:
bool: True if the query starts with 'SELECT' (case-insensitive), False otherwise.
"""
return query.strip().lower().startswith("select")
# MCP Tools
@mcp.tool()
def list_tables() -> List[str]:
"""Retrieve a list of all table names in the SQLite database.
This tool queries the SQLite database to return the names of all tables stored in the 'sqlite_master' table.
Args:
None
Returns:
List[str]: A list of table names as strings.
Example: ["users", "orders", "products"]
Errors:
If the database is inaccessible or an error occurs, returns a string starting with "Error: " followed by the error message.
Example: "Error: Database connection failed"
"""
logger.info("Listing tables in the database")
try:
with sqlite3.connect(DB_PATH) as conn:
cursor = conn.cursor()
cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
tables = [row[0] for row in cursor.fetchall()]
print(tables)
return tables
except Exception as e:
logger.error(f"Error listing tables: {str(e)}")
return f"Error: {str(e)}"
@mcp.tool()
def get_table_schema(table_name: str) -> List[Dict[str, str]]:
"""Retrieve the schema of a specified table, including column names and data types.
This tool uses SQLite's PRAGMA table_info to fetch metadata about the columns in the specified table.
Args:
table_name (str): The name of the table to retrieve the schema for.
Example: "users"
Returns:
List[Dict[str, str]]: A list of dictionaries, each containing:
- "column_name": The name of the column (string).
- "data_type": The SQLite data type of the column (string, e.g., "INTEGER", "TEXT").
Example: [
{"column_name": "id", "data_type": "INTEGER"},
{"column_name": "name", "data_type": "TEXT"}
]
Errors:
- If the table name is invalid (contains unsafe characters), returns "Error: Invalid table name".
- If the table does not exist or an error occurs, returns a string starting with "Error: " followed by the error message.
Example: "Error: Table 'users' does not exist"
"""
if not validate_table_name(table_name):
logger.error(f"Invalid table name: {table_name}")
return f"Error: Invalid table name"
logger.info(f"Retrieving schema for table: {table_name}")
try:
with sqlite3.connect(DB_PATH) as conn:
cursor = conn.cursor()
cursor.execute(f"PRAGMA table_info({table_name});")
schema = [{"column_name": row[1], "data_type": row[2]} for row in cursor.fetchall()]
return schema
except Exception as e:
logger.error(f"Error retrieving schema for {table_name}: {str(e)}")
return f"Error: {str(e)}"
@mcp.tool()
def count_rows(table_name: str) -> int:
"""Count the number of rows in a specified table.
This tool executes a SELECT COUNT(*) query to determine the number of rows in the given table.
Args:
table_name (str): The name of the table to count rows for.
Example: "orders"
Returns:
int: The number of rows in the table.
Example: 42
Errors:
- If the table name is invalid (contains unsafe characters), returns "Error: Invalid table name".
- If the table does not exist or an error occurs, returns a string starting with "Error: " followed by the error message.
Example: "Error: Table 'orders' does not exist"
"""
if not validate_table_name(table_name):
logger.error(f"Invalid table name: {table_name}")
return f"Error: Invalid table name"
logger.info(f"Counting rows in table: {table_name}")
try:
with sqlite3.connect(DB_PATH) as conn:
cursor = conn.cursor()
cursor.execute(f"SELECT COUNT(*) FROM {table_name}")
count = cursor.fetchone()[0]
return count
except Exception as e:
logger.error(f"Error counting rows in {table_name}: {str(e)}")
return f"Error: {str(e)}"
@mcp.tool()
def execute_query(query: str) -> Dict[str, Any]:
"""Execute a read-only SQL SELECT query and return the results.
This tool executes a user-provided SQL query, restricted to SELECT statements for safety.
It returns the column names and rows resulting from the query.
Args:
query (str): The SQL SELECT query to execute.
Example: "SELECT name, email FROM users WHERE id = 1"
Returns:
Dict[str, Any]: A dictionary containing:
- "columns": List of column names (List[str]).
- "rows": List of rows, where each row is a list of values (List[List[Any]]).
Example: {
"columns": ["name", "email"],
"rows": [["Alice", "alice@example.com"]]
}
Errors:
- If the query is not a SELECT statement, returns {"error": "Only SELECT queries are allowed"}.
- If the query is invalid or an error occurs, returns {"error": <error message>}.
Example: {"error": "no such table: users"}
"""
if not is_select_query(query):
logger.error("Only SELECT queries are allowed")
return {"error": "Only SELECT queries are allowed"}
logger.info(f"Executing SQL query: {query}")
try:
with sqlite3.connect(DB_PATH) as conn:
cursor = conn.cursor()
cursor.execute(query)
columns = [desc[0] for desc in cursor.description]
rows = cursor.fetchall()
return {"columns": columns, "rows": rows}
except Exception as e:
logger.error(f"Error executing query: {str(e)}")
return {"error": str(e)}
if __name__ == "__main__":
logger.info("Starting MCP server with streamable-http transport...")
mcp.run(transport="streamable-http", host="localhost", port=8080)