use anyhow::{anyhow, Result};
use serde::{Deserialize, Serialize};
use sqlx::{Column, Row, TypeInfo, ValueRef};
use std::collections::HashMap;
use std::time::Instant;
// 数据库类型枚举
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub enum DatabaseType {
PostgreSQL,
MySQL,
SQLite,
}
// 数据库连接池枚举
#[derive(Debug)]
pub enum DatabasePool {
Postgres(sqlx::PgPool),
MySQL(sqlx::MySqlPool),
SQLite(sqlx::SqlitePool),
}
// 统一的行数据结构
#[derive(Debug)]
pub enum DatabaseRow {
Postgres(sqlx::postgres::PgRow),
MySQL(sqlx::mysql::MySqlRow),
SQLite(sqlx::sqlite::SqliteRow),
}
impl DatabaseRow {
pub fn columns(&self) -> &[Box<dyn Column<Database = ()>>] {
match self {
DatabaseRow::Postgres(row) => row.columns(),
DatabaseRow::MySQL(row) => row.columns(),
DatabaseRow::SQLite(row) => row.columns(),
}
}
pub fn try_get<T>(&self, index: &str) -> Result<T, sqlx::Error>
where
T: for<'r> sqlx::Decode<'r, sqlx::Postgres> + sqlx::Type<sqlx::Postgres>,
T: for<'r> sqlx::Decode<'r, sqlx::MySql> + sqlx::Type<sqlx::MySql>,
T: for<'r> sqlx::Decode<'r, sqlx::Sqlite> + sqlx::Type<sqlx::Sqlite>,
{
match self {
DatabaseRow::Postgres(row) => row.try_get(index),
DatabaseRow::MySQL(row) => row.try_get(index),
DatabaseRow::SQLite(row) => row.try_get(index),
}
}
pub fn len(&self) -> usize {
match self {
DatabaseRow::Postgres(row) => row.len(),
DatabaseRow::MySQL(row) => row.len(),
DatabaseRow::SQLite(row) => row.len(),
}
}
}
// 表结构信息
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TableInfo {
pub name: String,
pub schema: Option<String>,
pub columns: Vec<ColumnInfo>,
pub row_count: Option<i64>,
}
// 列结构信息
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ColumnInfo {
pub name: String,
pub data_type: String,
pub is_nullable: bool,
pub default_value: Option<String>,
pub is_primary_key: bool,
pub is_foreign_key: bool,
pub character_maximum_length: Option<i32>,
pub ordinal_position: i32,
}
// 数据库模式对象信息
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SchemaObject {
pub name: String,
pub object_type: String, // TABLE, VIEW, FUNCTION, PROCEDURE, etc.
pub schema: Option<String>,
pub description: Option<String>,
pub dependencies: Vec<String>, // 依赖的其他表或对象
}
pub struct DatabaseManager {
pool: DatabasePool,
database_type: DatabaseType,
}
impl DatabaseManager {
pub async fn new(connection_url: &str) -> Result<Self> {
let database_type = Self::detect_database_type(connection_url)?;
let pool = match database_type {
DatabaseType::PostgreSQL => {
let pg_pool = sqlx::PgPool::connect(connection_url).await?;
DatabasePool::Postgres(pg_pool)
}
DatabaseType::MySQL => {
let mysql_pool = sqlx::MySqlPool::connect(connection_url).await?;
DatabasePool::MySQL(mysql_pool)
}
DatabaseType::SQLite => {
let sqlite_pool = sqlx::SqlitePool::connect(connection_url).await?;
DatabasePool::SQLite(sqlite_pool)
}
};
Ok(Self {
pool,
database_type,
})
}
fn detect_database_type(url: &str) -> Result<DatabaseType> {
if url.starts_with("postgres://") || url.starts_with("postgresql://") {
Ok(DatabaseType::PostgreSQL)
} else if url.starts_with("mysql://") {
Ok(DatabaseType::MySQL)
} else if url.starts_with("sqlite://") || url.starts_with("sqlite:") {
Ok(DatabaseType::SQLite)
} else {
Err(anyhow!("Unsupported database URL: {}", url))
}
}
// 执行查询的辅助方法 - 返回多行
async fn execute_query(&self, query: &str) -> Result<Vec<DatabaseRow>> {
match &self.pool {
DatabasePool::Postgres(pool) => {
let rows = sqlx::query(query).fetch_all(pool).await?;
Ok(rows.into_iter().map(DatabaseRow::Postgres).collect())
}
DatabasePool::MySQL(pool) => {
let rows = sqlx::query(query).fetch_all(pool).await?;
Ok(rows.into_iter().map(DatabaseRow::MySQL).collect())
}
DatabasePool::SQLite(pool) => {
let rows = sqlx::query(query).fetch_all(pool).await?;
Ok(rows.into_iter().map(DatabaseRow::SQLite).collect())
}
}
}
// 执行单行查询的辅助方法
async fn execute_query_one(&self, query: &str) -> Result<DatabaseRow> {
match &self.pool {
DatabasePool::Postgres(pool) => {
let row = sqlx::query(query).fetch_one(pool).await?;
Ok(DatabaseRow::Postgres(row))
}
DatabasePool::MySQL(pool) => {
let row = sqlx::query(query).fetch_one(pool).await?;
Ok(DatabaseRow::MySQL(row))
}
DatabasePool::SQLite(pool) => {
let row = sqlx::query(query).fetch_one(pool).await?;
Ok(DatabaseRow::SQLite(row))
}
}
}
pub async fn get_database_info(&self) -> Result<HashMap<String, serde_json::Value>> {
let mut info = HashMap::new();
match self.database_type {
DatabaseType::PostgreSQL => {
let row = self.execute_query_one("SELECT version()").await?;
let version: String = row.try_get("version")?;
info.insert("database_type".to_string(), serde_json::Value::String("PostgreSQL".to_string()));
info.insert("version".to_string(), serde_json::Value::String(version));
}
DatabaseType::MySQL => {
let row = self.execute_query_one("SELECT VERSION() as version").await?;
let version: String = row.try_get("version")?;
info.insert("database_type".to_string(), serde_json::Value::String("MySQL".to_string()));
info.insert("version".to_string(), serde_json::Value::String(version));
}
DatabaseType::SQLite => {
let row = self.execute_query_one("SELECT sqlite_version() as version").await?;
let version: String = row.try_get("version")?;
info.insert("database_type".to_string(), serde_json::Value::String("SQLite".to_string()));
info.insert("version".to_string(), serde_json::Value::String(version));
}
}
Ok(info)
}
pub async fn execute_sql(&self, query: &str) -> Result<serde_json::Value> {
let start_time = Instant::now();
// 修剪查询字符串
let query = query.trim();
if query.is_empty() {
return Err(anyhow!("Empty query"));
}
// 检查是否是查询语句
let is_select = query.to_lowercase().trim_start().starts_with("select")
|| query.to_lowercase().trim_start().starts_with("with")
|| query.to_lowercase().trim_start().starts_with("show")
|| query.to_lowercase().trim_start().starts_with("describe")
|| query.to_lowercase().trim_start().starts_with("explain");
if is_select {
// 执行查询
let rows = self.execute_query(query).await?;
let execution_time = start_time.elapsed();
// 转换为 JSON
let mut result_data = Vec::new();
for row in &rows {
let mut row_data = HashMap::new();
match row {
DatabaseRow::Postgres(pg_row) => {
for (i, column) in pg_row.columns().iter().enumerate() {
let column_name = column.name();
let value = self.convert_postgres_value_to_json(pg_row, i)?;
row_data.insert(column_name.to_string(), value);
}
}
DatabaseRow::MySQL(mysql_row) => {
for (i, column) in mysql_row.columns().iter().enumerate() {
let column_name = column.name();
let value = self.convert_mysql_value_to_json(mysql_row, i)?;
row_data.insert(column_name.to_string(), value);
}
}
DatabaseRow::SQLite(sqlite_row) => {
for (i, column) in sqlite_row.columns().iter().enumerate() {
let column_name = column.name();
let value = self.convert_sqlite_value_to_json(sqlite_row, i)?;
row_data.insert(column_name.to_string(), value);
}
}
}
result_data.push(serde_json::Value::Object(row_data.into_iter().collect()));
}
Ok(serde_json::json!({
"type": "query",
"rows": result_data,
"row_count": rows.len(),
"execution_time_ms": execution_time.as_millis()
}))
} else {
// 执行非查询语句 (INSERT, UPDATE, DELETE, etc.)
let affected_rows = match &self.pool {
DatabasePool::Postgres(pool) => {
let result = sqlx::query(query).execute(pool).await?;
result.rows_affected()
}
DatabasePool::MySQL(pool) => {
let result = sqlx::query(query).execute(pool).await?;
result.rows_affected()
}
DatabasePool::SQLite(pool) => {
let result = sqlx::query(query).execute(pool).await?;
result.rows_affected()
}
};
let execution_time = start_time.elapsed();
Ok(serde_json::json!({
"type": "execute",
"affected_rows": affected_rows,
"execution_time_ms": execution_time.as_millis()
}))
}
}
// PostgreSQL 值转换
fn convert_postgres_value_to_json(&self, row: &sqlx::postgres::PgRow, index: usize) -> Result<serde_json::Value> {
let column = &row.columns()[index];
let type_info = column.type_info();
if row.try_get_raw(index)?.is_null() {
return Ok(serde_json::Value::Null);
}
// 根据类型转换
match type_info.name() {
"BOOL" => Ok(serde_json::Value::Bool(row.try_get(index)?)),
"INT2" | "SMALLINT" => Ok(serde_json::Value::Number(serde_json::Number::from(row.try_get::<i16, _>(index)?))),
"INT4" | "INTEGER" => Ok(serde_json::Value::Number(serde_json::Number::from(row.try_get::<i32, _>(index)?))),
"INT8" | "BIGINT" => Ok(serde_json::Value::Number(serde_json::Number::from(row.try_get::<i64, _>(index)?))),
"FLOAT4" | "REAL" => {
let val: f32 = row.try_get(index)?;
Ok(serde_json::Value::Number(serde_json::Number::from_f64(val as f64).unwrap_or(serde_json::Number::from(0))))
},
"FLOAT8" | "DOUBLE PRECISION" => {
let val: f64 = row.try_get(index)?;
Ok(serde_json::Value::Number(serde_json::Number::from_f64(val).unwrap_or(serde_json::Number::from(0))))
},
_ => {
// 默认转换为字符串
let val: String = row.try_get(index)?;
Ok(serde_json::Value::String(val))
}
}
}
// MySQL 值转换
fn convert_mysql_value_to_json(&self, row: &sqlx::mysql::MySqlRow, index: usize) -> Result<serde_json::Value> {
let column = &row.columns()[index];
let type_info = column.type_info();
if row.try_get_raw(index)?.is_null() {
return Ok(serde_json::Value::Null);
}
// 根据类型转换
match type_info.name() {
"TINYINT(1)" => Ok(serde_json::Value::Bool(row.try_get::<bool, _>(index)?)),
"TINYINT" => Ok(serde_json::Value::Number(serde_json::Number::from(row.try_get::<i8, _>(index)?))),
"SMALLINT" => Ok(serde_json::Value::Number(serde_json::Number::from(row.try_get::<i16, _>(index)?))),
"INT" | "INTEGER" => Ok(serde_json::Value::Number(serde_json::Number::from(row.try_get::<i32, _>(index)?))),
"BIGINT" => Ok(serde_json::Value::Number(serde_json::Number::from(row.try_get::<i64, _>(index)?))),
"FLOAT" => {
let val: f32 = row.try_get(index)?;
Ok(serde_json::Value::Number(serde_json::Number::from_f64(val as f64).unwrap_or(serde_json::Number::from(0))))
},
"DOUBLE" => {
let val: f64 = row.try_get(index)?;
Ok(serde_json::Value::Number(serde_json::Number::from_f64(val).unwrap_or(serde_json::Number::from(0))))
},
_ => {
// 默认转换为字符串
let val: String = row.try_get(index)?;
Ok(serde_json::Value::String(val))
}
}
}
// SQLite 值转换
fn convert_sqlite_value_to_json(&self, row: &sqlx::sqlite::SqliteRow, index: usize) -> Result<serde_json::Value> {
let column = &row.columns()[index];
let type_info = column.type_info();
if row.try_get_raw(index)?.is_null() {
return Ok(serde_json::Value::Null);
}
// SQLite 的类型系统比较简单
match type_info.name() {
"BOOLEAN" => Ok(serde_json::Value::Bool(row.try_get(index)?)),
"INTEGER" => Ok(serde_json::Value::Number(serde_json::Number::from(row.try_get::<i64, _>(index)?))),
"REAL" => {
let val: f64 = row.try_get(index)?;
Ok(serde_json::Value::Number(serde_json::Number::from_f64(val).unwrap_or(serde_json::Number::from(0))))
},
_ => {
// 默认转换为字符串
let val: String = row.try_get(index)?;
Ok(serde_json::Value::String(val))
}
}
}
pub async fn list_tables(&self) -> Result<Vec<TableInfo>> {
let query = match self.database_type {
DatabaseType::PostgreSQL => {
r#"
SELECT
t.table_name as name,
t.table_schema as schema
FROM information_schema.tables t
WHERE t.table_type = 'BASE TABLE'
AND t.table_schema NOT IN ('information_schema', 'pg_catalog', 'pg_toast')
ORDER BY t.table_schema, t.table_name
"#
}
DatabaseType::MySQL => {
r#"
SELECT
table_name as name,
table_schema as schema
FROM information_schema.tables
WHERE table_type = 'BASE TABLE'
AND table_schema NOT IN ('information_schema', 'mysql', 'performance_schema', 'sys')
ORDER BY table_schema, table_name
"#
}
DatabaseType::SQLite => {
r#"
SELECT
name,
NULL as schema
FROM sqlite_master
WHERE type = 'table'
AND name NOT LIKE 'sqlite_%'
ORDER BY name
"#
}
};
let rows = self.execute_query(query).await?;
let mut tables = Vec::new();
for row in rows {
let name: String = match &row {
DatabaseRow::Postgres(r) => r.try_get("name")?,
DatabaseRow::MySQL(r) => r.try_get("name")?,
DatabaseRow::SQLite(r) => r.try_get("name")?,
};
let schema: Option<String> = match &row {
DatabaseRow::Postgres(r) => r.try_get("schema")?,
DatabaseRow::MySQL(r) => r.try_get("schema")?,
DatabaseRow::SQLite(_) => None,
};
// 获取表的列信息
let columns = self.get_table_columns(&name, schema.as_deref()).await?;
// 获取行数
let count_query = if let Some(schema) = &schema {
format!("SELECT COUNT(*) as count FROM \"{}\".\"{}\"", schema, name)
} else {
format!("SELECT COUNT(*) as count FROM \"{}\"", name)
};
let row_count = match self.execute_query_one(&count_query).await {
Ok(count_row) => {
match &count_row {
DatabaseRow::Postgres(r) => r.try_get::<i64, _>("count").ok(),
DatabaseRow::MySQL(r) => r.try_get::<i64, _>("count").ok(),
DatabaseRow::SQLite(r) => r.try_get::<i64, _>("count").ok(),
}
}
Err(_) => None,
};
tables.push(TableInfo {
name,
schema,
columns,
row_count,
});
}
Ok(tables)
}
async fn get_table_columns(&self, table_name: &str, schema_name: Option<&str>) -> Result<Vec<ColumnInfo>> {
let query = match self.database_type {
DatabaseType::PostgreSQL => {
let schema = schema_name.unwrap_or("public");
format!(
r#"
SELECT
c.column_name as name,
c.data_type,
c.is_nullable::boolean as is_nullable,
c.column_default as default_value,
CASE WHEN pk.column_name IS NOT NULL THEN true ELSE false END as is_primary_key,
CASE WHEN fk.column_name IS NOT NULL THEN true ELSE false END as is_foreign_key,
c.character_maximum_length,
c.ordinal_position
FROM information_schema.columns c
LEFT JOIN (
SELECT kcu.column_name
FROM information_schema.table_constraints tc
JOIN information_schema.key_column_usage kcu ON tc.constraint_name = kcu.constraint_name
WHERE tc.table_name = '{}' AND tc.table_schema = '{}' AND tc.constraint_type = 'PRIMARY KEY'
) pk ON c.column_name = pk.column_name
LEFT JOIN (
SELECT kcu.column_name
FROM information_schema.table_constraints tc
JOIN information_schema.key_column_usage kcu ON tc.constraint_name = kcu.constraint_name
WHERE tc.table_name = '{}' AND tc.table_schema = '{}' AND tc.constraint_type = 'FOREIGN KEY'
) fk ON c.column_name = fk.column_name
WHERE c.table_name = '{}' AND c.table_schema = '{}'
ORDER BY c.ordinal_position
"#,
table_name, schema, table_name, schema, table_name, schema
)
}
DatabaseType::MySQL => {
let schema = schema_name.unwrap_or("information_schema");
format!(
r#"
SELECT
c.COLUMN_NAME as name,
c.DATA_TYPE as data_type,
CASE WHEN c.IS_NULLABLE = 'YES' THEN true ELSE false END as is_nullable,
c.COLUMN_DEFAULT as default_value,
CASE WHEN c.COLUMN_KEY = 'PRI' THEN true ELSE false END as is_primary_key,
CASE WHEN c.COLUMN_KEY = 'MUL' THEN true ELSE false END as is_foreign_key,
c.CHARACTER_MAXIMUM_LENGTH as character_maximum_length,
c.ORDINAL_POSITION as ordinal_position
FROM information_schema.COLUMNS c
WHERE c.TABLE_NAME = '{}' AND c.TABLE_SCHEMA = '{}'
ORDER BY c.ORDINAL_POSITION
"#,
table_name, schema
)
}
DatabaseType::SQLite => {
format!("PRAGMA table_info('{}')", table_name)
}
};
let rows = self.execute_query(&query).await?;
let mut columns = Vec::new();
for row in rows {
let column_info = match self.database_type {
DatabaseType::PostgreSQL => {
let row = match &row {
DatabaseRow::Postgres(r) => r,
_ => unreachable!(),
};
ColumnInfo {
name: row.try_get("name")?,
data_type: row.try_get("data_type")?,
is_nullable: row.try_get("is_nullable")?,
default_value: row.try_get("default_value")?,
is_primary_key: row.try_get("is_primary_key")?,
is_foreign_key: row.try_get("is_foreign_key")?,
character_maximum_length: row.try_get("character_maximum_length")?,
ordinal_position: row.try_get("ordinal_position")?,
}
}
DatabaseType::MySQL => {
let row = match &row {
DatabaseRow::MySQL(r) => r,
_ => unreachable!(),
};
ColumnInfo {
name: row.try_get("name")?,
data_type: row.try_get("data_type")?,
is_nullable: row.try_get("is_nullable")?,
default_value: row.try_get("default_value")?,
is_primary_key: row.try_get("is_primary_key")?,
is_foreign_key: row.try_get("is_foreign_key")?,
character_maximum_length: row.try_get("character_maximum_length")?,
ordinal_position: row.try_get("ordinal_position")?,
}
}
DatabaseType::SQLite => {
let row = match &row {
DatabaseRow::SQLite(r) => r,
_ => unreachable!(),
};
ColumnInfo {
name: row.try_get("name")?,
data_type: row.try_get("type")?,
is_nullable: !row.try_get::<bool, _>("notnull")?,
default_value: row.try_get("dflt_value")?,
is_primary_key: row.try_get::<i32, _>("pk")? > 0,
is_foreign_key: false, // SQLite PRAGMA doesn't include FK info
character_maximum_length: None,
ordinal_position: row.try_get("cid")?,
}
}
};
columns.push(column_info);
}
Ok(columns)
}
pub async fn describe_table(&self, table_name: &str, schema_name: Option<&str>) -> Result<TableInfo> {
let columns = self.get_table_columns(table_name, schema_name).await?;
// 获取行数
let count_query = if let Some(schema) = schema_name {
format!("SELECT COUNT(*) as count FROM \"{}\".\"{}\"", schema, table_name)
} else {
format!("SELECT COUNT(*) as count FROM \"{}\"", table_name)
};
let row_count = match self.execute_query_one(&count_query).await {
Ok(count_row) => {
match &count_row {
DatabaseRow::Postgres(r) => r.try_get::<i64, _>("count").ok(),
DatabaseRow::MySQL(r) => r.try_get::<i64, _>("count").ok(),
DatabaseRow::SQLite(r) => r.try_get::<i64, _>("count").ok(),
}
}
Err(_) => None,
};
Ok(TableInfo {
name: table_name.to_string(),
schema: schema_name.map(|s| s.to_string()),
columns,
row_count,
})
}
pub async fn list_schemas(&self) -> Result<Vec<SchemaObject>> {
let query = match self.database_type {
DatabaseType::PostgreSQL => {
r#"
SELECT
schema_name as name,
'SCHEMA' as object_type,
NULL as schema,
NULL as description
FROM information_schema.schemata
WHERE schema_name NOT IN ('information_schema', 'pg_catalog', 'pg_toast')
ORDER BY schema_name
"#
}
DatabaseType::MySQL => {
r#"
SELECT
schema_name as name,
'SCHEMA' as object_type,
NULL as schema,
NULL as description
FROM information_schema.schemata
WHERE schema_name NOT IN ('information_schema', 'mysql', 'performance_schema', 'sys')
ORDER BY schema_name
"#
}
DatabaseType::SQLite => {
// SQLite doesn't have schemas, return empty list
return Ok(Vec::new());
}
};
let rows = self.execute_query(query).await?;
let mut schemas = Vec::new();
for row in rows {
let name: String = match &row {
DatabaseRow::Postgres(r) => r.try_get("name")?,
DatabaseRow::MySQL(r) => r.try_get("name")?,
DatabaseRow::SQLite(_) => unreachable!(),
};
schemas.push(SchemaObject {
name,
object_type: "SCHEMA".to_string(),
schema: None,
description: None,
dependencies: Vec::new(),
});
}
Ok(schemas)
}
pub async fn get_connection_info(&self) -> Result<HashMap<String, serde_json::Value>> {
let mut info = HashMap::new();
match self.database_type {
DatabaseType::PostgreSQL => {
info.insert("database_type".to_string(), serde_json::Value::String("PostgreSQL".to_string()));
// 获取数据库版本
let version_row = self.execute_query_one("SELECT version()").await?;
let version: String = match &version_row {
DatabaseRow::Postgres(r) => r.try_get("version")?,
_ => unreachable!(),
};
info.insert("version".to_string(), serde_json::Value::String(version));
// 获取当前数据库
let db_row = self.execute_query_one("SELECT current_database()").await?;
let current_db: String = match &db_row {
DatabaseRow::Postgres(r) => r.try_get("current_database")?,
_ => unreachable!(),
};
info.insert("current_database".to_string(), serde_json::Value::String(current_db));
}
DatabaseType::MySQL => {
info.insert("database_type".to_string(), serde_json::Value::String("MySQL".to_string()));
// 获取数据库版本
let version_row = self.execute_query_one("SELECT VERSION() as version").await?;
let version: String = match &version_row {
DatabaseRow::MySQL(r) => r.try_get("version")?,
_ => unreachable!(),
};
info.insert("version".to_string(), serde_json::Value::String(version));
// 获取当前数据库
let db_row = self.execute_query_one("SELECT DATABASE() as current_database").await?;
let current_db: Option<String> = match &db_row {
DatabaseRow::MySQL(r) => r.try_get("current_database")?,
_ => unreachable!(),
};
info.insert("current_database".to_string(),
current_db.map(serde_json::Value::String).unwrap_or(serde_json::Value::Null));
}
DatabaseType::SQLite => {
info.insert("database_type".to_string(), serde_json::Value::String("SQLite".to_string()));
// 获取 SQLite 版本
let version_row = self.execute_query_one("SELECT sqlite_version() as version").await?;
let version: String = match &version_row {
DatabaseRow::SQLite(r) => r.try_get("version")?,
_ => unreachable!(),
};
info.insert("version".to_string(), serde_json::Value::String(version));
}
}
Ok(info)
}
}