"""
存储管理模块
支持MySQL数据库的会话和表单数据存储
"""
import json
import asyncio
from datetime import datetime, timedelta
from typing import Dict, Any, List, Optional
from dataclasses import dataclass, asdict
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession
from sqlalchemy.orm import sessionmaker
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
from sqlalchemy import String, Text, DateTime, Integer, JSON, select, delete, func
from loguru import logger
from config import get_config
class Base(DeclarativeBase):
pass
class FormSession(Base):
"""表单会话表"""
__tablename__ = "form_sessions"
session_id: Mapped[str] = mapped_column(String(64), primary_key=True)
template_name: Mapped[str] = mapped_column(String(50), nullable=False)
collected_fields: Mapped[Dict[str, Any]] = mapped_column(JSON, default=dict)
is_complete: Mapped[bool] = mapped_column(default=False)
created_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.utcnow)
updated_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
expires_at: Mapped[datetime] = mapped_column(DateTime, nullable=False)
class FormSubmission(Base):
"""表单提交记录表"""
__tablename__ = "form_submissions"
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
session_id: Mapped[str] = mapped_column(String(64), nullable=False)
template_name: Mapped[str] = mapped_column(String(50), nullable=False)
form_data: Mapped[Dict[str, Any]] = mapped_column(JSON, nullable=False)
submitted_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.utcnow)
# 矛盾调解表单特定字段
applicant_name: Mapped[Optional[str]] = mapped_column(String(100))
contact_phone: Mapped[Optional[str]] = mapped_column(String(20))
contact_address: Mapped[Optional[str]] = mapped_column(Text)
incident_time: Mapped[Optional[str]] = mapped_column(String(100))
incident_location: Mapped[Optional[str]] = mapped_column(String(200))
incident_description: Mapped[Optional[str]] = mapped_column(Text)
involved_parties_count: Mapped[Optional[int]] = mapped_column(Integer)
@dataclass
class SessionData:
"""会话数据类"""
session_id: str
template_name: str
collected_fields: Dict[str, Any]
is_complete: bool
created_at: datetime
updated_at: datetime
expires_at: datetime
def to_dict(self) -> Dict[str, Any]:
"""转换为字典"""
return asdict(self)
class StorageManager:
"""存储管理器"""
def __init__(self):
self.config = get_config()
self.db_config = self.config["database"]
# 创建异步引擎
self.engine = create_async_engine(
self.db_config.url,
pool_size=self.db_config.pool_size,
max_overflow=self.db_config.max_overflow,
pool_timeout=self.db_config.pool_timeout,
pool_recycle=self.db_config.pool_recycle,
echo=False # 生产环境关闭SQL日志
)
# 创建会话工厂
self.async_session = sessionmaker(
self.engine,
class_=AsyncSession,
expire_on_commit=False
)
logger.info("MySQL存储管理器初始化完成")
async def initialize(self):
"""初始化数据库表"""
try:
async with self.engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
logger.info("数据库表初始化完成")
except Exception as e:
logger.error(f"数据库初始化失败: {e}")
raise
async def create_session(self, session_id: str, template_name: str, timeout_seconds: int = 3600) -> SessionData:
"""
创建新会话
Args:
session_id: 会话ID
template_name: 表单模板名称
timeout_seconds: 超时时间(秒)
Returns:
SessionData: 会话数据
"""
try:
now = datetime.utcnow()
expires_at = now + timedelta(seconds=timeout_seconds)
async with self.async_session() as session:
# 检查会话是否已存在
result = await session.execute(
select(FormSession).where(FormSession.session_id == session_id)
)
existing_session = result.scalar_one_or_none()
if existing_session:
# 更新现有会话
existing_session.template_name = template_name
existing_session.collected_fields = {}
existing_session.is_complete = False
existing_session.updated_at = now
existing_session.expires_at = expires_at
await session.commit()
return SessionData(
session_id=existing_session.session_id,
template_name=existing_session.template_name,
collected_fields=existing_session.collected_fields,
is_complete=existing_session.is_complete,
created_at=existing_session.created_at,
updated_at=existing_session.updated_at,
expires_at=existing_session.expires_at
)
else:
# 创建新会话
new_session = FormSession(
session_id=session_id,
template_name=template_name,
collected_fields={},
is_complete=False,
created_at=now,
updated_at=now,
expires_at=expires_at
)
session.add(new_session)
await session.commit()
return SessionData(
session_id=new_session.session_id,
template_name=new_session.template_name,
collected_fields=new_session.collected_fields,
is_complete=new_session.is_complete,
created_at=new_session.created_at,
updated_at=new_session.updated_at,
expires_at=new_session.expires_at
)
except Exception as e:
logger.error(f"创建会话失败: {e}")
raise
async def get_session(self, session_id: str) -> Optional[SessionData]:
"""
获取会话数据
Args:
session_id: 会话ID
Returns:
Optional[SessionData]: 会话数据,如果不存在则返回None
"""
try:
async with self.async_session() as session:
result = await session.execute(
select(FormSession).where(FormSession.session_id == session_id)
)
form_session = result.scalar_one_or_none()
if not form_session:
return None
# 检查是否过期
if form_session.expires_at < datetime.utcnow():
await self.delete_session(session_id)
return None
return SessionData(
session_id=form_session.session_id,
template_name=form_session.template_name,
collected_fields=form_session.collected_fields,
is_complete=form_session.is_complete,
created_at=form_session.created_at,
updated_at=form_session.updated_at,
expires_at=form_session.expires_at
)
except Exception as e:
logger.error(f"获取会话失败: {e}")
return None
async def update_session(self, session_id: str, collected_fields: Dict[str, Any], is_complete: bool = False) -> bool:
"""
更新会话数据
Args:
session_id: 会话ID
collected_fields: 收集的字段数据
is_complete: 是否完成
Returns:
bool: 更新是否成功
"""
try:
async with self.async_session() as session:
result = await session.execute(
select(FormSession).where(FormSession.session_id == session_id)
)
form_session = result.scalar_one_or_none()
if not form_session:
return False
form_session.collected_fields = collected_fields
form_session.is_complete = is_complete
form_session.updated_at = datetime.utcnow()
await session.commit()
return True
except Exception as e:
logger.error(f"更新会话失败: {e}")
return False
async def delete_session(self, session_id: str) -> bool:
"""
删除会话
Args:
session_id: 会话ID
Returns:
bool: 删除是否成功
"""
try:
async with self.async_session() as session:
await session.execute(
delete(FormSession).where(FormSession.session_id == session_id)
)
await session.commit()
return True
except Exception as e:
logger.error(f"删除会话失败: {e}")
return False
async def submit_form(self, session_id: str, form_data: Dict[str, Any], template_name: str) -> int:
"""
提交表单数据
Args:
session_id: 会话ID
form_data: 表单数据
template_name: 表单模板名称
Returns:
int: 提交记录ID
"""
try:
async with self.async_session() as session:
# 创建提交记录
submission = FormSubmission(
session_id=session_id,
template_name=template_name,
form_data=form_data,
submitted_at=datetime.utcnow()
)
# 如果是矛盾调解表单,提取特定字段
if template_name == "mediation":
submission.applicant_name = form_data.get("applicant_name")
submission.contact_phone = form_data.get("contact_phone")
submission.contact_address = form_data.get("contact_address")
submission.incident_time = form_data.get("incident_time")
submission.incident_location = form_data.get("incident_location")
submission.incident_description = form_data.get("incident_description")
submission.involved_parties_count = form_data.get("involved_parties_count")
session.add(submission)
await session.commit()
await session.refresh(submission)
# 删除会话
await self.delete_session(session_id)
logger.info(f"表单提交成功,记录ID: {submission.id}")
return submission.id
except Exception as e:
logger.error(f"提交表单失败: {e}")
raise
async def get_submission(self, submission_id: int) -> Optional[Dict[str, Any]]:
"""
获取提交记录
Args:
submission_id: 提交记录ID
Returns:
Optional[Dict[str, Any]]: 提交记录数据
"""
try:
async with self.async_session() as session:
result = await session.execute(
select(FormSubmission).where(FormSubmission.id == submission_id)
)
submission = result.scalar_one_or_none()
if not submission:
return None
return {
"id": submission.id,
"session_id": submission.session_id,
"template_name": submission.template_name,
"form_data": submission.form_data,
"submitted_at": submission.submitted_at.isoformat(),
"applicant_name": submission.applicant_name,
"contact_phone": submission.contact_phone,
"contact_address": submission.contact_address,
"incident_time": submission.incident_time,
"incident_location": submission.incident_location,
"incident_description": submission.incident_description,
"involved_parties_count": submission.involved_parties_count
}
except Exception as e:
logger.error(f"获取提交记录失败: {e}")
return None
async def get_statistics(self) -> Dict[str, Any]:
"""
获取系统统计信息
Returns:
Dict[str, Any]: 统计信息
"""
try:
async with self.async_session() as session:
# 活跃会话数
active_sessions_result = await session.execute(
select(func.count(FormSession.session_id)).where(
FormSession.expires_at > datetime.utcnow()
)
)
active_sessions = active_sessions_result.scalar()
# 总提交数
total_submissions_result = await session.execute(
select(func.count(FormSubmission.id))
)
total_submissions = total_submissions_result.scalar()
# 今日提交数
today = datetime.utcnow().date()
today_submissions_result = await session.execute(
select(func.count(FormSubmission.id)).where(
func.date(FormSubmission.submitted_at) == today
)
)
today_submissions = today_submissions_result.scalar()
# 按模板统计
template_stats_result = await session.execute(
select(
FormSubmission.template_name,
func.count(FormSubmission.id).label('count')
).group_by(FormSubmission.template_name)
)
template_stats = {row.template_name: row.count for row in template_stats_result}
return {
"active_sessions": active_sessions,
"total_submissions": total_submissions,
"today_submissions": today_submissions,
"template_statistics": template_stats,
"timestamp": datetime.utcnow().isoformat()
}
except Exception as e:
logger.error(f"获取统计信息失败: {e}")
return {
"active_sessions": 0,
"total_submissions": 0,
"today_submissions": 0,
"template_statistics": {},
"timestamp": datetime.utcnow().isoformat(),
"error": str(e)
}
async def cleanup_expired_sessions(self) -> int:
"""
清理过期会话
Returns:
int: 清理的会话数量
"""
try:
async with self.async_session() as session:
result = await session.execute(
delete(FormSession).where(FormSession.expires_at < datetime.utcnow())
)
await session.commit()
cleaned_count = result.rowcount
if cleaned_count > 0:
logger.info(f"清理了 {cleaned_count} 个过期会话")
return cleaned_count
except Exception as e:
logger.error(f"清理过期会话失败: {e}")
return 0
async def close(self):
"""关闭数据库连接"""
try:
await self.engine.dispose()
logger.info("数据库连接已关闭")
except Exception as e:
logger.error(f"关闭数据库连接失败: {e}")
# 全局存储管理器实例
_storage_manager = None
async def get_storage_manager() -> StorageManager:
"""获取存储管理器实例"""
global _storage_manager
if _storage_manager is None:
_storage_manager = StorageManager()
await _storage_manager.initialize()
return _storage_manager
async def cleanup_task():
"""定期清理任务"""
while True:
try:
storage = await get_storage_manager()
await storage.cleanup_expired_sessions()
# 每5分钟清理一次
await asyncio.sleep(300)
except Exception as e:
logger.error(f"清理任务执行失败: {e}")
await asyncio.sleep(60) # 出错时等待1分钟再重试