use anyhow::Result;
use clap::Parser;
use rmcp::{
model::{ErrorData, Content, CallToolResult},
handler::server::{tool::ToolRouter, tool::Parameters, ServerHandler},
transport::stdio,
ServiceExt,
tool_router, tool, tool_handler,
model::*,
};
use tracing_subscriber::{self, EnvFilter};
use serde::Deserialize;
use schemars::JsonSchema;
use std::future::Future;
mod database;
use database::DatabaseManager;
#[derive(Parser)]
#[command(name = "sqlx-mcp")]
#[command(about = "A Model Context Protocol server for SQL databases")]
struct Args {
#[arg(short, long, help = "Database URL")]
database_url: Option<String>,
}
#[derive(Debug, Deserialize, JsonSchema)]
pub struct DatabaseInfoArgs {
/// Database URL to connect to (optional, falls back to command line or environment variable)
pub database_url: Option<String>,
}
#[derive(Debug, Deserialize, JsonSchema)]
pub struct QueryArgs {
/// Database URL to connect to (optional, falls back to command line or environment variable)
pub database_url: Option<String>,
/// SQL query to execute
pub query: String,
}
#[derive(Debug, Deserialize, JsonSchema)]
pub struct TableStructureArgs {
/// Database URL to connect to (optional, falls back to command line or environment variable)
pub database_url: Option<String>,
/// Table name to get structure for
pub table_name: String,
}
#[derive(Debug, Deserialize, JsonSchema)]
pub struct TableDDLArgs {
/// Database URL to connect to (optional, falls back to command line or environment variable)
pub database_url: Option<String>,
/// Table name to get DDL for
pub table_name: String,
}
#[derive(Clone)]
pub struct SqlxMcpServer {
tool_router: ToolRouter<SqlxMcpServer>,
default_database_url: Option<String>,
}
fn create_error_response(msg: &str) -> CallToolResult {
CallToolResult::success(vec![Content::text(format!("Error: {}", msg))])
}
impl SqlxMcpServer {
/// Resolve database URL with priority: user input > command line > environment variable
fn resolve_database_url(&self, user_input: Option<String>) -> Result<String, String> {
if let Some(url) = user_input {
return Ok(url);
}
if let Some(url) = &self.default_database_url {
return Ok(url.clone());
}
if let Ok(url) = std::env::var("DATABASE_URL") {
return Ok(url);
}
Err("No database URL provided. Please provide via parameter, command line --database-url, or DATABASE_URL environment variable.".to_string())
}
/// Mask sensitive information in database URL for display purposes
fn mask_database_url(url: &str) -> String {
if let Ok(parsed) = url::Url::parse(url) {
let mut masked = parsed.clone();
if parsed.password().is_some() {
let _ = masked.set_password(Some("***"));
}
masked.to_string()
} else {
// If URL parsing fails, just show the scheme and host if possible
if let Some(scheme_end) = url.find("://") {
let scheme = &url[..scheme_end + 3];
if let Some(at_pos) = url.find('@') {
if let Some(host_start) = url[at_pos..].find('/') {
format!("{}***@{}", scheme, &url[at_pos + host_start..])
} else {
format!("{}***@{}", scheme, &url[at_pos + 1..])
}
} else {
url.to_string()
}
} else {
"***".to_string()
}
}
}
/// Check if a query is read-only based on SQL keywords
fn is_readonly_query(query: &str) -> bool {
let query_trimmed = query.trim().to_lowercase();
// Allow SELECT, WITH (CTE), SHOW, DESCRIBE, EXPLAIN statements
let readonly_keywords = [
"select", "with", "show", "describe", "explain", "desc"
];
// Check if query starts with any readonly keyword
for keyword in &readonly_keywords {
if query_trimmed.starts_with(keyword) {
return true;
}
}
false
}
}
#[tool_router]
impl SqlxMcpServer {
pub fn new(default_database_url: Option<String>) -> Self {
Self {
tool_router: Self::tool_router(),
default_database_url,
}
}
#[tool(description = "Get database basic information")]
async fn get_database_info(
&self,
Parameters(args): Parameters<DatabaseInfoArgs>,
) -> Result<CallToolResult, ErrorData> {
let database_url = match self.resolve_database_url(args.database_url) {
Ok(url) => url,
Err(e) => return Ok(create_error_response(&e)),
};
match DatabaseManager::new(&database_url).await {
Ok(manager) => {
match manager.get_database_info().await {
Ok(info) => {
let content = serde_json::to_string_pretty(&info)
.unwrap_or_else(|e| format!("Serialization error: {}", e));
Ok(CallToolResult::success(vec![Content::text(content)]))
},
Err(e) => {
Ok(create_error_response(&format!("Failed to get database info: {}", e)))
}
}
},
Err(e) => {
Ok(create_error_response(&format!("Failed to connect to database: {}", e)))
}
}
}
#[tool(description = "List all tables with metadata (name, comment, row count, etc.)")]
async fn list_tables(
&self,
Parameters(args): Parameters<DatabaseInfoArgs>,
) -> Result<CallToolResult, ErrorData> {
let database_url = match self.resolve_database_url(args.database_url) {
Ok(url) => url,
Err(e) => return Ok(create_error_response(&e)),
};
match DatabaseManager::new(&database_url).await {
Ok(manager) => {
match manager.list_tables().await {
Ok(tables) => {
let content = serde_json::to_string_pretty(&tables)
.unwrap_or_else(|e| format!("Serialization error: {}", e));
Ok(CallToolResult::success(vec![Content::text(content)]))
},
Err(e) => {
Ok(create_error_response(&format!("Failed to list tables: {}", e)))
}
}
},
Err(e) => {
Ok(create_error_response(&format!("Failed to connect to database: {}", e)))
}
}
}
#[tool(description = "Get table structure information")]
async fn get_table_structure(
&self,
Parameters(args): Parameters<TableStructureArgs>,
) -> Result<CallToolResult, ErrorData> {
let database_url = match self.resolve_database_url(args.database_url) {
Ok(url) => url,
Err(e) => return Ok(create_error_response(&e)),
};
match DatabaseManager::new(&database_url).await {
Ok(manager) => {
match manager.get_table_structure(&args.table_name).await {
Ok(tables) => {
let content = serde_json::to_string_pretty(&tables)
.unwrap_or_else(|e| format!("Serialization error: {}", e));
Ok(CallToolResult::success(vec![Content::text(content)]))
},
Err(e) => {
Ok(create_error_response(&format!("Failed to get table structure: {}", e)))
}
}
},
Err(e) => {
Ok(create_error_response(&format!("Failed to connect to database: {}", e)))
}
}
}
#[tool(description = "Execute read-only SQL query (SELECT, WITH, SHOW, DESCRIBE, EXPLAIN)")]
async fn execute_readonly_query(
&self,
Parameters(args): Parameters<QueryArgs>,
) -> Result<CallToolResult, ErrorData> {
let database_url = match self.resolve_database_url(args.database_url) {
Ok(url) => url,
Err(e) => return Ok(create_error_response(&e)),
};
// Validate that the query is read-only
if !Self::is_readonly_query(&args.query) {
return Ok(create_error_response(
"Query is not read-only. Only SELECT, WITH, SHOW, DESCRIBE, EXPLAIN statements are allowed."
));
}
match DatabaseManager::new(&database_url).await {
Ok(manager) => {
match manager.execute_readonly_query(&args.query).await {
Ok(result) => {
let content = serde_json::to_string_pretty(&result)
.unwrap_or_else(|e| format!("Serialization error: {}", e));
Ok(CallToolResult::success(vec![Content::text(content)]))
},
Err(e) => {
Ok(create_error_response(&format!("Failed to execute query: {}", e)))
}
}
},
Err(e) => {
Ok(create_error_response(&format!("Failed to connect to database: {}", e)))
}
}
}
#[tool(description = "Execute write SQL query (INSERT, UPDATE, DELETE, CREATE, DROP, ALTER)")]
async fn execute_write_query(
&self,
Parameters(args): Parameters<QueryArgs>,
) -> Result<CallToolResult, ErrorData> {
let database_url = match self.resolve_database_url(args.database_url) {
Ok(url) => url,
Err(e) => return Ok(create_error_response(&e)),
};
match DatabaseManager::new(&database_url).await {
Ok(manager) => {
match manager.execute_write_query(&args.query).await {
Ok(result) => {
let content = serde_json::to_string_pretty(&result)
.unwrap_or_else(|e| format!("Serialization error: {}", e));
Ok(CallToolResult::success(vec![Content::text(content)]))
},
Err(e) => {
Ok(create_error_response(&format!("Failed to execute query: {}", e)))
}
}
},
Err(e) => {
Ok(create_error_response(&format!("Failed to connect to database: {}", e)))
}
}
}
#[tool(description = "Get table DDL (CREATE TABLE statement) for the specified table")]
async fn get_table_ddl(
&self,
Parameters(args): Parameters<TableDDLArgs>,
) -> Result<CallToolResult, ErrorData> {
let database_url = match self.resolve_database_url(args.database_url) {
Ok(url) => url,
Err(e) => return Ok(create_error_response(&e)),
};
match DatabaseManager::new(&database_url).await {
Ok(manager) => {
match manager.get_table_ddl(&args.table_name).await {
Ok(ddl) => {
let content = serde_json::to_string_pretty(&ddl)
.unwrap_or_else(|e| format!("Serialization error: {}", e));
Ok(CallToolResult::success(vec![Content::text(content)]))
},
Err(e) => {
Ok(create_error_response(&format!("Failed to get table DDL: {}", e)))
}
}
},
Err(e) => {
Ok(create_error_response(&format!("Failed to connect to database: {}", e)))
}
}
}
}
#[tool_handler]
impl ServerHandler for SqlxMcpServer {
fn get_info(&self) -> ServerInfo {
// Check current database configuration
let db_config_info = if let Some(url) = &self.default_database_url {
format!("Database configured via command line: {}", Self::mask_database_url(url))
} else if let Ok(url) = std::env::var("DATABASE_URL") {
format!("Database configured via environment variable: {}", Self::mask_database_url(&url))
} else {
"No database configured. Please provide database_url parameter in tool calls or set DATABASE_URL environment variable.".to_string()
};
let instructions = format!(
"This server provides SQL database tools. Tools: get_database_info, list_tables, get_table_structure, get_table_ddl, execute_readonly_query, execute_write_query.\n\nCurrent database configuration: {}\n\nIf a database is already configured, you can omit the database_url parameter in tool calls.",
db_config_info
);
ServerInfo {
protocol_version: ProtocolVersion::V_2024_11_05,
capabilities: ServerCapabilities::builder()
.enable_tools()
.build(),
server_info: Implementation::from_build_env(),
instructions: Some(instructions),
}
}
}
#[tokio::main]
async fn main() -> Result<()> {
// Parse arguments first
let args = Args::parse();
// Initialize tracing
tracing_subscriber::fmt()
.with_env_filter(EnvFilter::from_default_env().add_directive(tracing::Level::DEBUG.into()))
.with_writer(std::io::stderr)
.with_ansi(false)
.init();
tracing::info!("Starting SQLx MCP server");
// Create server instance with default database URL from command line
let server = SqlxMcpServer::new(args.database_url.clone());
// Output database configuration info for verification
let db_config_info = if let Some(url) = &args.database_url {
format!("Database configured via command line: {}", SqlxMcpServer::mask_database_url(url))
} else if let Ok(url) = std::env::var("DATABASE_URL") {
format!("Database configured via environment variable: {}", SqlxMcpServer::mask_database_url(&url))
} else {
"No database configured. Will require database_url parameter in tool calls.".to_string()
};
tracing::info!("Database configuration: {}", db_config_info);
// Serve using stdio transport
let service = server.serve(stdio()).await.inspect_err(|e| {
tracing::error!("serving error: {:?}", e);
})?;
service.waiting().await?;
Ok(())
}