"""
表单收集器模块
智能表单信息收集和处理
"""
import json
import uuid
from typing import Dict, Any, List, Optional
from datetime import datetime
from loguru import logger
from config import FormTemplate, FORM_TEMPLATES
from validators import FormValidator
from llm_service import LLMService
from storage import StorageManager, SessionData
class FormCollector:
"""表单收集器"""
def __init__(self, template: FormTemplate, llm_service: LLMService, storage_manager: StorageManager):
self.template = template
self.llm_service = llm_service
self.storage_manager = storage_manager
self.form_fields = {field.name: field for field in template.fields}
# 获取必填字段列表
self.required_fields = [field.name for field in template.fields if field.required]
logger.info(f"表单收集器初始化完成 - 模板: {template.name}")
async def collect_info(self, session_id: str, user_input: str, use_stream: bool = True) -> Dict[str, Any]:
"""
收集用户输入的信息
Args:
session_id: 会话ID
user_input: 用户输入
use_stream: 是否使用流式输出
Returns:
收集结果
"""
try:
logger.info(f"开始收集信息 - 会话: {session_id}")
# 获取会话数据
session_data = await self.storage_manager.get_session(session_id)
if not session_data:
return {
"success": False,
"error": "会话不存在",
"message": "未找到指定的会话,请重新开始。"
}
# 提取字段信息
extracted_fields = await self.llm_service.extract_fields(user_input, self.required_fields)
logger.info(f"提取到字段: {list(extracted_fields.keys())}")
# 合并到现有数据
current_fields = session_data.collected_fields.copy()
current_fields.update(extracted_fields)
# 验证数据
validation_results = FormValidator.validate_form_data(self.form_fields, current_fields)
validation_summary = FormValidator.get_validation_summary(validation_results)
# 检查是否完成
missing_fields = [field for field in self.required_fields if field not in current_fields]
is_complete = len(missing_fields) == 0 and validation_summary["is_all_valid"]
# 更新会话数据
await self.storage_manager.update_session(session_id, current_fields, is_complete)
# 生成响应消息
if use_stream:
# 流式输出
response_parts = []
async for chunk in self.llm_service.generate_response_stream(
current_fields, missing_fields, self.required_fields, is_complete
):
response_parts.append(chunk)
response_message = "".join(response_parts)
else:
response_message = await self.llm_service.generate_response(
current_fields, missing_fields, self.required_fields, is_complete
)
# 构建返回结果
result = {
"success": True,
"session_id": session_id,
"extracted_fields": extracted_fields,
"collected_fields": current_fields,
"missing_fields": missing_fields,
"is_complete": is_complete,
"progress_percentage": round(len(current_fields) / len(self.required_fields) * 100, 1),
"validation_errors": validation_summary["errors"],
"message": response_message,
"status_table": self.llm_service._build_status_table(current_fields, self.required_fields),
"timestamp": datetime.utcnow().isoformat()
}
logger.info(f"信息收集完成 - 进度: {result['progress_percentage']}%")
return result
except Exception as e:
logger.error(f"收集信息失败: {e}")
return {
"success": False,
"error": str(e),
"message": "处理您的信息时出现错误,请重试。"
}
async def get_status(self, session_id: str) -> Dict[str, Any]:
"""
获取收集状态
Args:
session_id: 会话ID
Returns:
状态信息
"""
try:
# 获取会话数据
session_data = await self.storage_manager.get_session(session_id)
if not session_data:
return {
"success": False,
"error": "会话不存在",
"message": "未找到指定的会话。"
}
collected_fields = session_data.collected_fields
missing_fields = [field for field in self.required_fields if field not in collected_fields]
progress = len(collected_fields) / len(self.required_fields) * 100 if self.required_fields else 100
# 构建状态表格
status_table = self.llm_service._build_status_table(collected_fields, self.required_fields)
return {
"success": True,
"session_id": session_id,
"template_name": session_data.template_name,
"collected_fields": collected_fields,
"missing_fields": missing_fields,
"progress_percentage": round(progress, 1),
"is_complete": session_data.is_complete,
"status_table": status_table,
"created_at": session_data.created_at.isoformat(),
"updated_at": session_data.updated_at.isoformat(),
"expires_at": session_data.expires_at.isoformat()
}
except Exception as e:
logger.error(f"获取状态失败: {e}")
return {
"success": False,
"error": str(e),
"message": "获取状态时出现错误。"
}
async def validate_data(self, session_id: str) -> Dict[str, Any]:
"""
验证表单数据
Args:
session_id: 会话ID
Returns:
验证结果
"""
try:
# 获取会话数据
session_data = await self.storage_manager.get_session(session_id)
if not session_data:
return {
"success": False,
"error": "会话不存在",
"message": "未找到指定的会话。"
}
# 验证数据
validation_results = FormValidator.validate_form_data(self.form_fields, session_data.collected_fields)
missing_fields = FormValidator.get_missing_required_fields(self.form_fields, session_data.collected_fields)
validation_summary = FormValidator.get_validation_summary(validation_results)
return {
"success": True,
"session_id": session_id,
"is_valid": validation_summary["is_all_valid"],
"errors": validation_summary["errors"],
"missing_fields": missing_fields,
"validation_summary": validation_summary,
"message": "表单验证通过" if validation_summary["is_all_valid"] else "表单验证失败"
}
except Exception as e:
logger.error(f"验证数据失败: {e}")
return {
"success": False,
"error": str(e),
"message": "验证数据时出现错误。"
}
async def submit_form(self, session_id: str) -> Dict[str, Any]:
"""
提交表单
Args:
session_id: 会话ID
Returns:
提交结果
"""
try:
# 获取会话数据
session_data = await self.storage_manager.get_session(session_id)
if not session_data:
return {
"success": False,
"error": "会话不存在",
"message": "未找到指定的会话。"
}
# 最终验证
validation_result = self.validator.validate_form(session_data.collected_fields)
if not validation_result.is_valid:
return {
"success": False,
"error": "数据验证失败",
"validation_errors": validation_result.errors,
"missing_fields": validation_result.missing_fields,
"message": "表单数据不完整或有误,无法提交。"
}
# 提交表单
submission_id = await self.storage_manager.submit_form(
session_id,
session_data.collected_fields,
session_data.template_name
)
# 生成确认消息
confirmation_message = await self.llm_service.generate_confirmation_message(
session_data.collected_fields
)
return {
"success": True,
"submission_id": submission_id,
"session_id": session_id,
"template_name": session_data.template_name,
"submitted_data": session_data.collected_fields,
"submitted_at": datetime.utcnow().isoformat(),
"message": confirmation_message
}
except Exception as e:
logger.error(f"提交表单失败: {e}")
return {
"success": False,
"error": str(e),
"message": "提交表单时出现错误,请稍后重试。"
}
def get_field_info(self, field_name: str) -> Optional[Dict[str, Any]]:
"""
获取字段信息
Args:
field_name: 字段名称
Returns:
字段信息
"""
for field in self.template.fields:
if field.name == field_name:
return {
"name": field.name,
"label": field.label,
"type": field.field_type.value,
"required": field.required,
"min_length": field.min_length,
"max_length": field.max_length,
"pattern": field.pattern,
"placeholder": field.placeholder,
"help_text": field.help_text
}
return None
def get_template_info(self) -> Dict[str, Any]:
"""
获取模板信息
Returns:
模板信息
"""
return {
"name": self.template.name,
"title": self.template.title,
"description": self.template.description,
"fields": [
{
"name": field.name,
"label": field.label,
"type": field.field_type.value,
"required": field.required,
"help_text": field.help_text
}
for field in self.template.fields
],
"required_fields": self.required_fields
}
# 全局表单收集器实例
_form_collector = None
def get_form_collector(template: FormTemplate, llm_service: LLMService, storage_manager: StorageManager) -> FormCollector:
"""获取表单收集器实例"""
global _form_collector
if _form_collector is None:
_form_collector = FormCollector(template, llm_service, storage_manager)
return _form_collector