database.py•5.58 kB
import os
import mysql.connector
from typing import Optional, Dict, Any, List
from dotenv import load_dotenv
# 加载环境变量
load_dotenv()
class DatabaseManager:
"""数据库管理类,负责MySQL数据库的连接和查询"""
def __init__(self):
self.host = os.getenv('DB_HOST', 'localhost')
self.port = int(os.getenv('DB_PORT', 3306))
self.user = os.getenv('DB_USER')
self.password = os.getenv('DB_PASSWORD')
self.database = os.getenv('DB_NAME')
self.connection = None
def connect(self) -> bool:
"""建立数据库连接"""
try:
self.connection = mysql.connector.connect(
host=self.host,
port=self.port,
user=self.user,
password=self.password,
database=self.database,
charset='utf8mb4'
)
print("✅ 成功连接到数据库: {}".format(self.database))
return True
except Exception as e:
print("❌ 数据库连接失败: {}".format(str(e)))
return False
def execute_query(self, query: str) -> Optional[List[Dict[str, Any]]]:
"""执行SQL查询并返回字典列表"""
try:
if not self.connection:
print("❌ 数据库未连接")
return None
cursor = self.connection.cursor(dictionary=True)
cursor.execute(query)
# 获取查询结果
results = cursor.fetchall()
cursor.close()
if results:
# 转换数据类型
converted_results = [self._convert_row_types(row) for row in results]
print("✅ 查询成功,返回 {} 行数据".format(len(converted_results)))
return converted_results
else:
print("✅ 查询成功,但没有返回数据")
return []
except Exception as e:
print("❌ 查询执行失败: {}".format(str(e)))
return None
def _convert_row_types(self, row: Dict[str, Any]) -> Dict[str, Any]:
"""转换行数据中的特殊类型为JSON可序列化的类型"""
converted = {}
for key, value in row.items():
if value is None:
converted[key] = None
elif isinstance(value, (int, float, str, bool)):
converted[key] = value
elif hasattr(value, 'isoformat'): # datetime objects
converted[key] = value.isoformat()
elif isinstance(value, bytes):
converted[key] = value.decode('utf-8', errors='ignore')
else:
converted[key] = str(value)
return converted
def get_table_info(self, table_name: str) -> Dict[str, Any]:
"""获取表结构信息"""
try:
# 获取表结构
structure_query = f"DESCRIBE {table_name}"
structure_data = self.execute_query(structure_query)
# 获取表数据样本
sample_query = f"SELECT * FROM {table_name} LIMIT 5"
sample_data = self.execute_query(sample_query)
# 获取表统计信息
count_query = f"SELECT COUNT(*) as total_rows FROM {table_name}"
count_data = self.execute_query(count_query)
total_rows = count_data[0]['total_rows'] if count_data and len(count_data) > 0 else 0
return {
'structure': structure_data,
'sample_data': sample_data,
'total_rows': total_rows
}
except Exception as e:
print("❌ 获取表信息失败: {}".format(str(e)))
return {}
def get_all_tables(self) -> List[str]:
"""获取数据库中所有表名"""
try:
query = "SHOW TABLES"
results = self.execute_query(query)
if results:
# 提取表名(通常是结果中的第一个字段)
table_names = [list(row.values())[0] for row in results]
return table_names
return []
except Exception as e:
print("❌ 获取表列表失败: {}".format(str(e)))
return []
def close(self):
"""关闭数据库连接"""
if self.connection:
self.connection.close()
print("✅ 数据库连接已关闭")
def __enter__(self):
self.connect()
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.close()
# 示例用法
if __name__ == "__main__":
# 使用上下文管理器
with DatabaseManager() as db:
# 获取所有表
tables = db.get_all_tables()
print(f"数据库中的表: {tables}")
# 如果有表,获取第一个表的信息
if tables:
table_info = db.get_table_info(tables[0])
print(f"\n表 {tables[0]} 的信息:")
print(f"总行数: {table_info.get('total_rows', 0)}")
if 'structure' in table_info:
print("\n表结构:")
for row in table_info['structure']:
print(row)
if 'sample_data' in table_info:
print("\n样本数据:")
for row in table_info['sample_data']:
print(row)