use anyhow::{anyhow, Result};
use serde::{Deserialize, Serialize};
use sqlx::{Column, Row, TypeInfo, ValueRef};
use std::collections::HashMap;
use std::time::Instant;
use rust_decimal::Decimal;
// 数据库类型枚举
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub enum DatabaseType {
PostgreSQL,
MySQL,
SQLite,
}
// 数据库连接池枚举
#[derive(Debug)]
pub enum DatabasePool {
Postgres(sqlx::PgPool),
MySQL(sqlx::MySqlPool),
SQLite(sqlx::SqlitePool),
}
// 统一的行数据结构
pub enum DatabaseRow {
Postgres(sqlx::postgres::PgRow),
MySQL(sqlx::mysql::MySqlRow),
SQLite(sqlx::sqlite::SqliteRow),
}
impl std::fmt::Debug for DatabaseRow {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
DatabaseRow::Postgres(_) => write!(f, "DatabaseRow::Postgres(...)"),
DatabaseRow::MySQL(_) => write!(f, "DatabaseRow::MySQL(...)"),
DatabaseRow::SQLite(_) => write!(f, "DatabaseRow::SQLite(...)"),
}
}
}
impl DatabaseRow {
pub fn column_count(&self) -> usize {
match self {
DatabaseRow::Postgres(row) => row.len(),
DatabaseRow::MySQL(row) => row.len(),
DatabaseRow::SQLite(row) => row.len(),
}
}
pub fn column_name(&self, index: usize) -> Option<&str> {
match self {
DatabaseRow::Postgres(row) => row.columns().get(index).map(|c| c.name()),
DatabaseRow::MySQL(row) => row.columns().get(index).map(|c| c.name()),
DatabaseRow::SQLite(row) => row.columns().get(index).map(|c| c.name()),
}
}
pub fn try_get<T>(&self, index: &str) -> Result<T, sqlx::Error>
where
T: for<'r> sqlx::Decode<'r, sqlx::Postgres> + sqlx::Type<sqlx::Postgres>,
T: for<'r> sqlx::Decode<'r, sqlx::MySql> + sqlx::Type<sqlx::MySql>,
T: for<'r> sqlx::Decode<'r, sqlx::Sqlite> + sqlx::Type<sqlx::Sqlite>,
{
match self {
DatabaseRow::Postgres(row) => row.try_get(index),
DatabaseRow::MySQL(row) => row.try_get(index),
DatabaseRow::SQLite(row) => row.try_get(index),
}
}
pub fn len(&self) -> usize {
match self {
DatabaseRow::Postgres(row) => row.len(),
DatabaseRow::MySQL(row) => row.len(),
DatabaseRow::SQLite(row) => row.len(),
}
}
}
// 表结构信息
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TableInfo {
pub name: String,
pub schema: Option<String>,
pub columns: Vec<ColumnInfo>,
pub row_count: Option<i64>,
}
// 列结构信息
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ColumnInfo {
pub name: String,
pub data_type: String,
pub is_nullable: bool,
pub default_value: Option<String>,
pub is_primary_key: bool,
pub is_foreign_key: bool,
pub character_maximum_length: Option<i32>,
pub ordinal_position: i32,
}
// 数据库模式对象信息
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SchemaObject {
pub name: String,
pub object_type: String, // TABLE, VIEW, FUNCTION, PROCEDURE, etc.
pub schema: Option<String>,
pub description: Option<String>,
pub dependencies: Vec<String>, // 依赖的其他表或对象
}
pub struct DatabaseManager {
pool: DatabasePool,
database_type: DatabaseType,
}
impl DatabaseManager {
pub async fn new(connection_url: &str) -> Result<Self> {
let database_type = Self::detect_database_type(connection_url)?;
let pool = match database_type {
DatabaseType::PostgreSQL => {
let pg_pool = sqlx::PgPool::connect(connection_url).await?;
DatabasePool::Postgres(pg_pool)
}
DatabaseType::MySQL => {
let mysql_pool = sqlx::MySqlPool::connect(connection_url).await?;
DatabasePool::MySQL(mysql_pool)
}
DatabaseType::SQLite => {
let sqlite_pool = sqlx::SqlitePool::connect(connection_url).await?;
DatabasePool::SQLite(sqlite_pool)
}
};
Ok(Self {
pool,
database_type,
})
}
fn detect_database_type(url: &str) -> Result<DatabaseType> {
if url.starts_with("postgres://") || url.starts_with("postgresql://") {
Ok(DatabaseType::PostgreSQL)
} else if url.starts_with("mysql://") {
Ok(DatabaseType::MySQL)
} else if url.starts_with("sqlite://") || url.starts_with("sqlite:") {
Ok(DatabaseType::SQLite)
} else {
Err(anyhow!("Unsupported database URL: {}", url))
}
}
// 执行查询的辅助方法 - 返回多行
async fn execute_query(&self, query: &str) -> Result<Vec<DatabaseRow>> {
match &self.pool {
DatabasePool::Postgres(pool) => {
let rows = sqlx::query(query).fetch_all(pool).await?;
Ok(rows.into_iter().map(DatabaseRow::Postgres).collect())
}
DatabasePool::MySQL(pool) => {
let rows = sqlx::query(query).fetch_all(pool).await?;
Ok(rows.into_iter().map(DatabaseRow::MySQL).collect())
}
DatabasePool::SQLite(pool) => {
let rows = sqlx::query(query).fetch_all(pool).await?;
Ok(rows.into_iter().map(DatabaseRow::SQLite).collect())
}
}
}
// 执行单行查询的辅助方法
async fn execute_query_one(&self, query: &str) -> Result<DatabaseRow> {
match &self.pool {
DatabasePool::Postgres(pool) => {
let row = sqlx::query(query).fetch_one(pool).await?;
Ok(DatabaseRow::Postgres(row))
}
DatabasePool::MySQL(pool) => {
let row = sqlx::query(query).fetch_one(pool).await?;
Ok(DatabaseRow::MySQL(row))
}
DatabasePool::SQLite(pool) => {
let row = sqlx::query(query).fetch_one(pool).await?;
Ok(DatabaseRow::SQLite(row))
}
}
}
pub async fn get_database_info(&self) -> Result<HashMap<String, serde_json::Value>> {
let mut info = HashMap::new();
match self.database_type {
DatabaseType::PostgreSQL => {
let row = self.execute_query_one("SELECT version()").await?;
let version: String = row.try_get("version")?;
info.insert("database_type".to_string(), serde_json::Value::String("PostgreSQL".to_string()));
info.insert("version".to_string(), serde_json::Value::String(version));
}
DatabaseType::MySQL => {
let row = self.execute_query_one("SELECT VERSION() as version").await?;
let version: String = row.try_get("version")?;
info.insert("database_type".to_string(), serde_json::Value::String("MySQL".to_string()));
info.insert("version".to_string(), serde_json::Value::String(version));
}
DatabaseType::SQLite => {
let row = self.execute_query_one("SELECT sqlite_version() as version").await?;
let version: String = row.try_get("version")?;
info.insert("database_type".to_string(), serde_json::Value::String("SQLite".to_string()));
info.insert("version".to_string(), serde_json::Value::String(version));
}
}
Ok(info)
}
pub async fn execute_sql(&self, query: &str) -> Result<serde_json::Value> {
let start_time = Instant::now();
// 修剪查询字符串
let query = query.trim();
if query.is_empty() {
return Err(anyhow!("Empty query"));
}
// 检查是否是查询语句
let is_select = query.to_lowercase().trim_start().starts_with("select")
|| query.to_lowercase().trim_start().starts_with("with")
|| query.to_lowercase().trim_start().starts_with("show")
|| query.to_lowercase().trim_start().starts_with("describe")
|| query.to_lowercase().trim_start().starts_with("explain");
if is_select {
// 执行查询
let rows = self.execute_query(query).await?;
let execution_time = start_time.elapsed();
// 转换为 JSON
let mut result_data = Vec::new();
for row in &rows {
let mut row_data = HashMap::new();
match row {
DatabaseRow::Postgres(pg_row) => {
for (i, column) in pg_row.columns().iter().enumerate() {
let column_name = column.name();
let value = self.convert_postgres_value_to_json(pg_row, i)?;
row_data.insert(column_name.to_string(), value);
}
}
DatabaseRow::MySQL(mysql_row) => {
for (i, column) in mysql_row.columns().iter().enumerate() {
let column_name = column.name();
let value = self.convert_mysql_value_to_json(mysql_row, i)?;
row_data.insert(column_name.to_string(), value);
}
}
DatabaseRow::SQLite(sqlite_row) => {
for (i, column) in sqlite_row.columns().iter().enumerate() {
let column_name = column.name();
let value = self.convert_sqlite_value_to_json(sqlite_row, i)?;
row_data.insert(column_name.to_string(), value);
}
}
}
result_data.push(serde_json::Value::Object(row_data.into_iter().collect()));
}
Ok(serde_json::json!({
"type": "query",
"rows": result_data,
"row_count": rows.len(),
"execution_time_ms": execution_time.as_millis()
}))
} else {
// 执行非查询语句 (INSERT, UPDATE, DELETE, etc.)
let affected_rows = match &self.pool {
DatabasePool::Postgres(pool) => {
let result = sqlx::query(query).execute(pool).await?;
result.rows_affected()
}
DatabasePool::MySQL(pool) => {
let result = sqlx::query(query).execute(pool).await?;
result.rows_affected()
}
DatabasePool::SQLite(pool) => {
let result = sqlx::query(query).execute(pool).await?;
result.rows_affected()
}
};
let execution_time = start_time.elapsed();
Ok(serde_json::json!({
"type": "execute",
"affected_rows": affected_rows,
"execution_time_ms": execution_time.as_millis()
}))
}
}
// PostgreSQL 值转换
fn convert_postgres_value_to_json(&self, row: &sqlx::postgres::PgRow, index: usize) -> Result<serde_json::Value> {
let column = &row.columns()[index];
let type_info = column.type_info();
if row.try_get_raw(index)?.is_null() {
return Ok(serde_json::Value::Null);
}
// 根据类型转换 - 使用健壮的转换策略
match type_info.name() {
"BOOL" => {
if let Ok(val) = row.try_get::<bool, _>(index) {
Ok(serde_json::Value::Bool(val))
} else {
let val: String = row.try_get(index)?;
Ok(serde_json::Value::String(val))
}
},
"INT2" | "SMALLINT" => {
if let Ok(val) = row.try_get::<i16, _>(index) {
Ok(serde_json::Value::Number(serde_json::Number::from(val)))
} else {
let val: String = row.try_get(index)?;
Ok(serde_json::Value::String(val))
}
},
"INT4" | "INTEGER" => {
if let Ok(val) = row.try_get::<i32, _>(index) {
Ok(serde_json::Value::Number(serde_json::Number::from(val)))
} else {
let val: String = row.try_get(index)?;
Ok(serde_json::Value::String(val))
}
},
"INT8" | "BIGINT" => {
if let Ok(val) = row.try_get::<i64, _>(index) {
Ok(serde_json::Value::Number(serde_json::Number::from(val)))
} else {
let val: String = row.try_get(index)?;
Ok(serde_json::Value::String(val))
}
},
"FLOAT4" | "REAL" => {
if let Ok(val) = row.try_get::<f32, _>(index) {
Ok(serde_json::Value::Number(serde_json::Number::from_f64(val as f64).unwrap_or(serde_json::Number::from(0))))
} else {
let val: String = row.try_get(index)?;
Ok(serde_json::Value::String(val))
}
},
"FLOAT8" | "DOUBLE PRECISION" => {
if let Ok(val) = row.try_get::<f64, _>(index) {
Ok(serde_json::Value::Number(serde_json::Number::from_f64(val).unwrap_or(serde_json::Number::from(0))))
} else {
let val: String = row.try_get(index)?;
Ok(serde_json::Value::String(val))
}
},
"NUMERIC" | "DECIMAL" => {
// PostgreSQL的NUMERIC/DECIMAL类型,先尝试字符串
if let Ok(val) = row.try_get::<String, _>(index) {
// 尝试解析为数字
if let Ok(num_val) = val.parse::<f64>() {
Ok(serde_json::Value::Number(serde_json::Number::from_f64(num_val).unwrap_or(serde_json::Number::from(0))))
} else {
Ok(serde_json::Value::String(val))
}
} else if let Ok(val) = row.try_get::<f64, _>(index) {
Ok(serde_json::Value::Number(serde_json::Number::from_f64(val).unwrap_or(serde_json::Number::from(0))))
} else {
Ok(serde_json::Value::String("ERROR: Unable to convert NUMERIC".to_string()))
}
},
_ => {
// 对于所有其他类型,优先尝试字符串转换
if let Ok(val) = row.try_get::<String, _>(index) {
Ok(serde_json::Value::String(val))
} else if let Ok(val) = row.try_get::<i64, _>(index) {
Ok(serde_json::Value::Number(serde_json::Number::from(val)))
} else if let Ok(val) = row.try_get::<f64, _>(index) {
Ok(serde_json::Value::Number(serde_json::Number::from_f64(val).unwrap_or(serde_json::Number::from(0))))
} else {
Ok(serde_json::Value::String(format!("Unsupported type: {}", type_info.name())))
}
}
}
}
// MySQL 值转换
fn convert_mysql_value_to_json(&self, row: &sqlx::mysql::MySqlRow, index: usize) -> Result<serde_json::Value> {
let column = &row.columns()[index];
let type_info = column.type_info();
if row.try_get_raw(index)?.is_null() {
return Ok(serde_json::Value::Null);
}
// 根据类型转换 - 使用更智能的转换策略
match type_info.name() {
"TINYINT(1)" => {
// 先尝试bool,失败则尝试数字
if let Ok(val) = row.try_get::<bool, _>(index) {
Ok(serde_json::Value::Bool(val))
} else if let Ok(val) = row.try_get::<i8, _>(index) {
Ok(serde_json::Value::Number(serde_json::Number::from(val)))
} else {
let val: String = row.try_get(index)?;
Ok(serde_json::Value::String(val))
}
},
"TINYINT" => {
if let Ok(val) = row.try_get::<i8, _>(index) {
Ok(serde_json::Value::Number(serde_json::Number::from(val)))
} else {
let val: String = row.try_get(index)?;
Ok(serde_json::Value::String(val))
}
},
"SMALLINT" => {
if let Ok(val) = row.try_get::<i16, _>(index) {
Ok(serde_json::Value::Number(serde_json::Number::from(val)))
} else {
let val: String = row.try_get(index)?;
Ok(serde_json::Value::String(val))
}
},
"INT" | "INTEGER" => {
if let Ok(val) = row.try_get::<i32, _>(index) {
Ok(serde_json::Value::Number(serde_json::Number::from(val)))
} else {
let val: String = row.try_get(index)?;
Ok(serde_json::Value::String(val))
}
},
"BIGINT" => {
if let Ok(val) = row.try_get::<i64, _>(index) {
Ok(serde_json::Value::Number(serde_json::Number::from(val)))
} else {
let val: String = row.try_get(index)?;
Ok(serde_json::Value::String(val))
}
},
"FLOAT" => {
if let Ok(val) = row.try_get::<f32, _>(index) {
Ok(serde_json::Value::Number(serde_json::Number::from_f64(val as f64).unwrap_or(serde_json::Number::from(0))))
} else {
let val: String = row.try_get(index)?;
Ok(serde_json::Value::String(val))
}
},
"DOUBLE" => {
if let Ok(val) = row.try_get::<f64, _>(index) {
Ok(serde_json::Value::Number(serde_json::Number::from_f64(val).unwrap_or(serde_json::Number::from(0))))
} else {
let val: String = row.try_get(index)?;
Ok(serde_json::Value::String(val))
}
},
"DECIMAL" | "NUMERIC" => {
// DECIMAL类型先尝试rust_decimal::Decimal
if let Ok(val) = row.try_get::<Decimal, _>(index) {
// 将Decimal转换为f64
if let Ok(float_val) = val.to_string().parse::<f64>() {
Ok(serde_json::Value::Number(serde_json::Number::from_f64(float_val).unwrap_or(serde_json::Number::from(0))))
} else {
Ok(serde_json::Value::String(val.to_string()))
}
} else if let Ok(val) = row.try_get::<String, _>(index) {
// 尝试解析为数字
if let Ok(num_val) = val.parse::<f64>() {
Ok(serde_json::Value::Number(serde_json::Number::from_f64(num_val).unwrap_or(serde_json::Number::from(0))))
} else {
Ok(serde_json::Value::String(val))
}
} else if let Ok(val) = row.try_get::<f64, _>(index) {
Ok(serde_json::Value::Number(serde_json::Number::from_f64(val).unwrap_or(serde_json::Number::from(0))))
} else if let Ok(val) = row.try_get::<i64, _>(index) {
Ok(serde_json::Value::Number(serde_json::Number::from(val)))
} else {
// 最后的备选方案
Ok(serde_json::Value::String("ERROR: Unable to convert DECIMAL".to_string()))
}
},
_ => {
// 对于所有其他类型,先尝试字符串,失败则尝试数字类型
if let Ok(val) = row.try_get::<String, _>(index) {
Ok(serde_json::Value::String(val))
} else if let Ok(val) = row.try_get::<i64, _>(index) {
Ok(serde_json::Value::Number(serde_json::Number::from(val)))
} else if let Ok(val) = row.try_get::<f64, _>(index) {
Ok(serde_json::Value::Number(serde_json::Number::from_f64(val).unwrap_or(serde_json::Number::from(0))))
} else {
Ok(serde_json::Value::String(format!("Unsupported type: {}", type_info.name())))
}
}
}
}
// SQLite 值转换
fn convert_sqlite_value_to_json(&self, row: &sqlx::sqlite::SqliteRow, index: usize) -> Result<serde_json::Value> {
let column = &row.columns()[index];
let type_info = column.type_info();
if row.try_get_raw(index)?.is_null() {
return Ok(serde_json::Value::Null);
}
// SQLite 的类型系统比较简单,但仍需要健壮的转换
match type_info.name() {
"BOOLEAN" => {
if let Ok(val) = row.try_get::<bool, _>(index) {
Ok(serde_json::Value::Bool(val))
} else if let Ok(val) = row.try_get::<i64, _>(index) {
Ok(serde_json::Value::Bool(val != 0))
} else {
let val: String = row.try_get(index)?;
Ok(serde_json::Value::String(val))
}
},
"INTEGER" => {
if let Ok(val) = row.try_get::<i64, _>(index) {
Ok(serde_json::Value::Number(serde_json::Number::from(val)))
} else {
let val: String = row.try_get(index)?;
Ok(serde_json::Value::String(val))
}
},
"REAL" => {
if let Ok(val) = row.try_get::<f64, _>(index) {
Ok(serde_json::Value::Number(serde_json::Number::from_f64(val).unwrap_or(serde_json::Number::from(0))))
} else {
let val: String = row.try_get(index)?;
Ok(serde_json::Value::String(val))
}
},
_ => {
// SQLite的默认策略:先尝试字符串,然后尝试数字
if let Ok(val) = row.try_get::<String, _>(index) {
Ok(serde_json::Value::String(val))
} else if let Ok(val) = row.try_get::<i64, _>(index) {
Ok(serde_json::Value::Number(serde_json::Number::from(val)))
} else if let Ok(val) = row.try_get::<f64, _>(index) {
Ok(serde_json::Value::Number(serde_json::Number::from_f64(val).unwrap_or(serde_json::Number::from(0))))
} else {
Ok(serde_json::Value::String(format!("Unsupported type: {}", type_info.name())))
}
}
}
}
pub async fn list_tables(&self) -> Result<Vec<TableInfo>> {
let query = match self.database_type {
DatabaseType::PostgreSQL => {
r#"
SELECT
t.table_name as name,
t.table_schema as schema
FROM information_schema.tables t
WHERE t.table_type = 'BASE TABLE'
AND t.table_schema NOT IN ('information_schema', 'pg_catalog', 'pg_toast')
ORDER BY t.table_schema, t.table_name
"#
}
DatabaseType::MySQL => {
r#"
SELECT
table_name as name,
table_schema as table_schema
FROM information_schema.tables
WHERE table_type = 'BASE TABLE'
AND table_schema NOT IN ('information_schema', 'mysql', 'performance_schema', 'sys')
ORDER BY table_schema, table_name
"#
}
DatabaseType::SQLite => {
r#"
SELECT
name,
NULL as schema
FROM sqlite_master
WHERE type = 'table'
AND name NOT LIKE 'sqlite_%'
ORDER BY name
"#
}
};
let rows = self.execute_query(query).await?;
let mut tables = Vec::new();
for row in rows {
let name: String = match &row {
DatabaseRow::Postgres(r) => r.try_get("name")?,
DatabaseRow::MySQL(r) => r.try_get("name")?,
DatabaseRow::SQLite(r) => r.try_get("name")?,
};
let schema: Option<String> = match &row {
DatabaseRow::Postgres(r) => r.try_get("schema")?,
DatabaseRow::MySQL(r) => r.try_get("table_schema")?,
DatabaseRow::SQLite(_) => None,
};
// 获取表的列信息
let columns = self.get_table_columns(&name, schema.as_deref()).await?;
// 获取行数
let count_query = if let Some(schema) = &schema {
format!("SELECT COUNT(*) as count FROM \"{}\".\"{}\"", schema, name)
} else {
format!("SELECT COUNT(*) as count FROM \"{}\"", name)
};
let row_count = match self.execute_query_one(&count_query).await {
Ok(count_row) => {
match &count_row {
DatabaseRow::Postgres(r) => r.try_get::<i64, _>("count").ok(),
DatabaseRow::MySQL(r) => r.try_get::<i64, _>("count").ok(),
DatabaseRow::SQLite(r) => r.try_get::<i64, _>("count").ok(),
}
}
Err(_) => None,
};
tables.push(TableInfo {
name,
schema,
columns,
row_count,
});
}
Ok(tables)
}
async fn get_table_columns(&self, table_name: &str, schema_name: Option<&str>) -> Result<Vec<ColumnInfo>> {
let query = match self.database_type {
DatabaseType::PostgreSQL => {
let schema = schema_name.unwrap_or("public");
format!(
r#"
SELECT
c.column_name as name,
c.data_type,
c.is_nullable::boolean as is_nullable,
c.column_default as default_value,
CASE WHEN pk.column_name IS NOT NULL THEN true ELSE false END as is_primary_key,
CASE WHEN fk.column_name IS NOT NULL THEN true ELSE false END as is_foreign_key,
c.character_maximum_length,
c.ordinal_position
FROM information_schema.columns c
LEFT JOIN (
SELECT kcu.column_name
FROM information_schema.table_constraints tc
JOIN information_schema.key_column_usage kcu ON tc.constraint_name = kcu.constraint_name
WHERE tc.table_name = '{}' AND tc.table_schema = '{}' AND tc.constraint_type = 'PRIMARY KEY'
) pk ON c.column_name = pk.column_name
LEFT JOIN (
SELECT kcu.column_name
FROM information_schema.table_constraints tc
JOIN information_schema.key_column_usage kcu ON tc.constraint_name = kcu.constraint_name
WHERE tc.table_name = '{}' AND tc.table_schema = '{}' AND tc.constraint_type = 'FOREIGN KEY'
) fk ON c.column_name = fk.column_name
WHERE c.table_name = '{}' AND c.table_schema = '{}'
ORDER BY c.ordinal_position
"#,
table_name, schema, table_name, schema, table_name, schema
)
}
DatabaseType::MySQL => {
let schema = if let Some(s) = schema_name {
s.to_string()
} else {
// 如果没有指定schema,使用当前数据库
match self.execute_query_one("SELECT DATABASE() as current_db").await {
Ok(db_row) => {
match &db_row {
DatabaseRow::MySQL(r) => {
match r.try_get::<Option<String>, _>("current_db")? {
Some(db) => db,
None => return Err(anyhow!("No database selected")),
}
},
_ => unreachable!(),
}
},
Err(_) => return Err(anyhow!("Failed to get current database")),
}
};
format!(
r#"
SELECT
c.COLUMN_NAME as name,
c.DATA_TYPE as data_type,
CASE WHEN c.IS_NULLABLE = 'YES' THEN true ELSE false END as is_nullable,
c.COLUMN_DEFAULT as default_value,
CASE WHEN c.COLUMN_KEY = 'PRI' THEN true ELSE false END as is_primary_key,
CASE WHEN c.COLUMN_KEY = 'MUL' THEN true ELSE false END as is_foreign_key,
CAST(c.CHARACTER_MAXIMUM_LENGTH AS SIGNED) as character_maximum_length,
CAST(c.ORDINAL_POSITION AS SIGNED) as ordinal_position
FROM information_schema.COLUMNS c
WHERE c.TABLE_NAME = '{}' AND c.TABLE_SCHEMA = '{}'
ORDER BY c.ORDINAL_POSITION
"#,
table_name, schema
)
}
DatabaseType::SQLite => {
format!("PRAGMA table_info('{}')", table_name)
}
};
let rows = self.execute_query(&query).await?;
let mut columns = Vec::new();
for row in rows {
let column_info = match self.database_type {
DatabaseType::PostgreSQL => {
let row = match &row {
DatabaseRow::Postgres(r) => r,
_ => unreachable!(),
};
ColumnInfo {
name: row.try_get("name")?,
data_type: row.try_get("data_type")?,
is_nullable: row.try_get("is_nullable")?,
default_value: row.try_get("default_value")?,
is_primary_key: row.try_get("is_primary_key")?,
is_foreign_key: row.try_get("is_foreign_key")?,
character_maximum_length: row.try_get("character_maximum_length")?,
ordinal_position: row.try_get("ordinal_position")?,
}
}
DatabaseType::MySQL => {
let row = match &row {
DatabaseRow::MySQL(r) => r,
_ => unreachable!(),
};
ColumnInfo {
name: row.try_get("name")?,
data_type: row.try_get("data_type")?,
is_nullable: row.try_get("is_nullable")?,
default_value: row.try_get("default_value")?,
is_primary_key: row.try_get("is_primary_key")?,
is_foreign_key: row.try_get("is_foreign_key")?,
character_maximum_length: row.try_get("character_maximum_length")?,
ordinal_position: row.try_get("ordinal_position")?,
}
}
DatabaseType::SQLite => {
let row = match &row {
DatabaseRow::SQLite(r) => r,
_ => unreachable!(),
};
ColumnInfo {
name: row.try_get("name")?,
data_type: row.try_get("type")?,
is_nullable: !row.try_get::<bool, _>("notnull")?,
default_value: row.try_get("dflt_value")?,
is_primary_key: row.try_get::<i32, _>("pk")? > 0,
is_foreign_key: false, // SQLite PRAGMA doesn't include FK info
character_maximum_length: None,
ordinal_position: row.try_get("cid")?,
}
}
};
columns.push(column_info);
}
Ok(columns)
}
pub async fn describe_table(&self, table_name: &str, schema_name: Option<&str>) -> Result<TableInfo> {
let columns = self.get_table_columns(table_name, schema_name).await?;
// 获取行数
let count_query = if let Some(schema) = schema_name {
format!("SELECT COUNT(*) as count FROM \"{}\".\"{}\"", schema, table_name)
} else {
format!("SELECT COUNT(*) as count FROM \"{}\"", table_name)
};
let row_count = match self.execute_query_one(&count_query).await {
Ok(count_row) => {
match &count_row {
DatabaseRow::Postgres(r) => r.try_get::<i64, _>("count").ok(),
DatabaseRow::MySQL(r) => r.try_get::<i64, _>("count").ok(),
DatabaseRow::SQLite(r) => r.try_get::<i64, _>("count").ok(),
}
}
Err(_) => None,
};
Ok(TableInfo {
name: table_name.to_string(),
schema: schema_name.map(|s| s.to_string()),
columns,
row_count,
})
}
pub async fn list_schemas(&self) -> Result<Vec<SchemaObject>> {
let query = match self.database_type {
DatabaseType::PostgreSQL => {
r#"
SELECT
schema_name as name,
'SCHEMA' as object_type,
NULL as schema,
NULL as description
FROM information_schema.schemata
WHERE schema_name NOT IN ('information_schema', 'pg_catalog', 'pg_toast')
ORDER BY schema_name
"#
}
DatabaseType::MySQL => {
r#"
SELECT
schema_name as name,
'SCHEMA' as object_type,
NULL as schema,
NULL as description
FROM information_schema.schemata
WHERE schema_name NOT IN ('information_schema', 'mysql', 'performance_schema', 'sys')
ORDER BY schema_name
"#
}
DatabaseType::SQLite => {
// SQLite doesn't have schemas, return empty list
return Ok(Vec::new());
}
};
let rows = self.execute_query(query).await?;
let mut schemas = Vec::new();
for row in rows {
let name: String = match &row {
DatabaseRow::Postgres(r) => r.try_get("name")?,
DatabaseRow::MySQL(r) => r.try_get("name")?,
DatabaseRow::SQLite(_) => unreachable!(),
};
schemas.push(SchemaObject {
name,
object_type: "SCHEMA".to_string(),
schema: None,
description: None,
dependencies: Vec::new(),
});
}
Ok(schemas)
}
pub async fn get_connection_info(&self) -> Result<HashMap<String, serde_json::Value>> {
let mut info = HashMap::new();
match self.database_type {
DatabaseType::PostgreSQL => {
info.insert("database_type".to_string(), serde_json::Value::String("PostgreSQL".to_string()));
// 获取数据库版本
let version_row = self.execute_query_one("SELECT version()").await?;
let version: String = match &version_row {
DatabaseRow::Postgres(r) => r.try_get("version")?,
_ => unreachable!(),
};
info.insert("version".to_string(), serde_json::Value::String(version));
// 获取当前数据库
let db_row = self.execute_query_one("SELECT current_database()").await?;
let current_db: String = match &db_row {
DatabaseRow::Postgres(r) => r.try_get("current_database")?,
_ => unreachable!(),
};
info.insert("current_database".to_string(), serde_json::Value::String(current_db));
}
DatabaseType::MySQL => {
info.insert("database_type".to_string(), serde_json::Value::String("MySQL".to_string()));
// 获取数据库版本
let version_row = self.execute_query_one("SELECT VERSION() as version").await?;
let version: String = match &version_row {
DatabaseRow::MySQL(r) => r.try_get("version")?,
_ => unreachable!(),
};
info.insert("version".to_string(), serde_json::Value::String(version));
// 获取当前数据库
let db_row = self.execute_query_one("SELECT DATABASE() as current_database").await?;
let current_db: Option<String> = match &db_row {
DatabaseRow::MySQL(r) => r.try_get("current_database")?,
_ => unreachable!(),
};
info.insert("current_database".to_string(),
current_db.map(serde_json::Value::String).unwrap_or(serde_json::Value::Null));
}
DatabaseType::SQLite => {
info.insert("database_type".to_string(), serde_json::Value::String("SQLite".to_string()));
// 获取 SQLite 版本
let version_row = self.execute_query_one("SELECT sqlite_version() as version").await?;
let version: String = match &version_row {
DatabaseRow::SQLite(r) => r.try_get("version")?,
_ => unreachable!(),
};
info.insert("version".to_string(), serde_json::Value::String(version));
}
}
Ok(info)
}
// 兼容性方法 - 用于向后兼容
pub async fn get_table_structure(&self, table_name: &str) -> Result<TableInfo> {
self.describe_table(table_name, None).await
}
pub async fn execute_readonly_query(&self, query: &str) -> Result<serde_json::Value> {
self.execute_sql(query).await
}
pub async fn execute_write_query(&self, query: &str) -> Result<serde_json::Value> {
self.execute_sql(query).await
}
pub async fn get_table_ddl(&self, table_name: &str) -> Result<String> {
match self.database_type {
DatabaseType::PostgreSQL => {
// PostgreSQL 的 DDL 获取比较复杂,这里提供一个简化版本
let query = format!(
r#"
SELECT 'CREATE TABLE ' || table_name || ' (' || string_agg(
column_name || ' ' || data_type ||
CASE
WHEN character_maximum_length IS NOT NULL
THEN '(' || character_maximum_length || ')'
ELSE ''
END ||
CASE
WHEN is_nullable = 'NO' THEN ' NOT NULL'
ELSE ''
END ||
CASE
WHEN column_default IS NOT NULL
THEN ' DEFAULT ' || column_default
ELSE ''
END,
', '
) || ');' as ddl
FROM information_schema.columns
WHERE table_name = '{}'
GROUP BY table_name
"#,
table_name
);
let row = self.execute_query_one(&query).await?;
match &row {
DatabaseRow::Postgres(r) => Ok(r.try_get("ddl")?),
_ => unreachable!(),
}
}
DatabaseType::MySQL => {
let query = format!("SHOW CREATE TABLE `{}`", table_name);
let row = self.execute_query_one(&query).await?;
match &row {
DatabaseRow::MySQL(r) => {
// SHOW CREATE TABLE 返回两列:Table 和 Create Table
// 我们需要第二列
Ok(r.try_get::<String, usize>(1)?)
},
_ => unreachable!(),
}
}
DatabaseType::SQLite => {
let query = format!(
"SELECT sql FROM sqlite_master WHERE type='table' AND name='{}'",
table_name
);
let row = self.execute_query_one(&query).await?;
match &row {
DatabaseRow::SQLite(r) => Ok(r.try_get("sql")?),
_ => unreachable!(),
}
}
}
}
}