"""
大模型服务模块
支持DeepSeek模型和流式输出
"""
import json
import asyncio
from typing import Dict, Any, Optional, List, AsyncGenerator
from openai import AsyncOpenAI
import httpx
from loguru import logger
from config import get_config, SYSTEM_PROMPTS
class LLMService:
"""大模型服务类,支持DeepSeek和流式输出"""
def __init__(self):
self.config = get_config()
self.llm_config = self.config["llm"]
# 初始化DeepSeek客户端
self.client = AsyncOpenAI(
api_key=self.llm_config.api_key,
base_url=self.llm_config.base_url
)
logger.info(f"LLM服务初始化完成 - 模型: {self.llm_config.model}")
async def extract_fields(self, user_input: str, template_fields: List[str]) -> Dict[str, Any]:
"""
从用户输入中提取字段信息
Args:
user_input: 用户输入文本
template_fields: 模板字段列表
Returns:
提取的字段信息字典
"""
try:
system_prompt = SYSTEM_PROMPTS["extraction"]
user_prompt = f"""
用户输入:{user_input}
请从上述输入中提取以下字段的信息:
{', '.join(template_fields)}
请返回JSON格式的结果,只包含能够确定提取的字段。
如果某个字段无法确定,请不要包含在结果中。
示例格式:
{{
"applicant_name": "张三",
"contact_phone": "13812345678",
"contact_address": "北京市朝阳区建国路1号"
}}
"""
response = await self.client.chat.completions.create(
model=self.llm_config.model,
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt}
],
temperature=0.1, # 降低温度以提高准确性
max_tokens=1000,
stream=False # 信息提取不使用流式
)
content = response.choices[0].message.content.strip()
logger.debug(f"LLM提取响应: {content}")
# 尝试解析JSON
try:
# 提取JSON部分
if "```json" in content:
json_start = content.find("```json") + 7
json_end = content.find("```", json_start)
json_content = content[json_start:json_end].strip()
elif "{" in content and "}" in content:
json_start = content.find("{")
json_end = content.rfind("}") + 1
json_content = content[json_start:json_end]
else:
json_content = content
extracted_fields = json.loads(json_content)
logger.info(f"成功提取字段: {list(extracted_fields.keys())}")
return extracted_fields
except json.JSONDecodeError as e:
logger.warning(f"JSON解析失败: {e}, 原始内容: {content}")
return {}
except Exception as e:
logger.error(f"字段提取失败: {e}")
return {}
async def generate_response(self,
collected_fields: Dict[str, Any],
missing_fields: List[str],
template_fields: List[str],
is_complete: bool = False) -> str:
"""
生成响应消息
Args:
collected_fields: 已收集的字段
missing_fields: 缺失的字段
template_fields: 所有模板字段
is_complete: 是否收集完成
Returns:
生成的响应消息
"""
try:
system_prompt = SYSTEM_PROMPTS["response_generation"]
# 构建状态表格
status_table = self._build_status_table(collected_fields, template_fields)
if is_complete:
user_prompt = f"""
当前收集状态:
{status_table}
所有必填信息已收集完成!请生成确认消息,询问用户是否确认提交表单。
"""
else:
user_prompt = f"""
当前收集状态:
{status_table}
已收集字段:{list(collected_fields.keys())}
缺失字段:{missing_fields}
请生成友好的响应消息:
1. 确认已收到的信息
2. 引导用户提供缺失的信息
3. 使用Markdown表格展示当前状态
4. 保持专业友好的语调
"""
response = await self.client.chat.completions.create(
model=self.llm_config.model,
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt}
],
temperature=self.llm_config.temperature,
max_tokens=self.llm_config.max_tokens,
stream=False
)
return response.choices[0].message.content.strip()
except Exception as e:
logger.error(f"响应生成失败: {e}")
return "抱歉,生成响应时出现错误,请稍后重试。"
async def generate_response_stream(self,
collected_fields: Dict[str, Any],
missing_fields: List[str],
template_fields: List[str],
is_complete: bool = False) -> AsyncGenerator[str, None]:
"""
生成流式响应消息
Args:
collected_fields: 已收集的字段
missing_fields: 缺失的字段
template_fields: 所有模板字段
is_complete: 是否收集完成
Yields:
响应消息片段
"""
try:
system_prompt = SYSTEM_PROMPTS["response_generation"]
# 构建状态表格
status_table = self._build_status_table(collected_fields, template_fields)
if is_complete:
user_prompt = f"""
当前收集状态:
{status_table}
所有必填信息已收集完成!请生成确认消息,询问用户是否确认提交表单。
"""
else:
user_prompt = f"""
当前收集状态:
{status_table}
已收集字段:{list(collected_fields.keys())}
缺失字段:{missing_fields}
请生成友好的响应消息:
1. 确认已收到的信息
2. 引导用户提供缺失的信息
3. 使用Markdown表格展示当前状态
4. 保持专业友好的语调
"""
stream = await self.client.chat.completions.create(
model=self.llm_config.model,
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt}
],
temperature=self.llm_config.temperature,
max_tokens=self.llm_config.max_tokens,
stream=True
)
async for chunk in stream:
if chunk.choices[0].delta.content:
yield chunk.choices[0].delta.content
except Exception as e:
logger.error(f"流式响应生成失败: {e}")
yield "抱歉,生成响应时出现错误,请稍后重试。"
async def generate_welcome_message(self, template_name: str) -> str:
"""
生成欢迎消息
Args:
template_name: 表单模板名称
Returns:
欢迎消息
"""
return SYSTEM_PROMPTS["welcome"]
async def generate_welcome_message_stream(self, template_name: str) -> AsyncGenerator[str, None]:
"""
生成流式欢迎消息
Args:
template_name: 表单模板名称
Yields:
欢迎消息片段
"""
welcome_msg = SYSTEM_PROMPTS["welcome"]
# 模拟流式输出
words = welcome_msg.split()
for i, word in enumerate(words):
if i == 0:
yield word
else:
yield " " + word
await asyncio.sleep(0.01) # 小延迟模拟流式效果
async def generate_confirmation_message(self, collected_fields: Dict[str, Any]) -> str:
"""
生成确认消息
Args:
collected_fields: 已收集的字段
Returns:
确认消息
"""
try:
system_prompt = """
你是一个专业的表单确认助手。请根据用户提供的信息生成最终确认消息。
要求:
1. 使用Markdown表格清晰展示所有收集到的信息
2. 询问用户是否确认提交
3. 保持专业友好的语调
4. 提醒用户确认信息的准确性
"""
user_prompt = f"""
请为以下收集到的信息生成确认消息:
{json.dumps(collected_fields, ensure_ascii=False, indent=2)}
请生成专业的确认消息,包含完整的信息展示表格。
"""
response = await self.client.chat.completions.create(
model=self.llm_config.model,
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt}
],
temperature=0.3,
max_tokens=1500,
stream=False
)
return response.choices[0].message.content.strip()
except Exception as e:
logger.error(f"确认消息生成失败: {e}")
return "请确认以上信息是否正确,如无误请回复'确认提交'。"
def _build_status_table(self, collected_fields: Dict[str, Any], template_fields: List[str]) -> str:
"""
构建状态表格
Args:
collected_fields: 已收集的字段
template_fields: 所有模板字段
Returns:
Markdown格式的状态表格
"""
field_labels = {
"applicant_name": "申请人姓名",
"contact_phone": "联系电话",
"contact_address": "联系地址",
"incident_time": "事件发生时间",
"incident_location": "事件发生地点",
"incident_description": "事件详情描述",
"involved_parties_count": "涉及人数"
}
table_rows = ["| 字段 | 状态 | 值 |", "|------|------|------|"]
for field in template_fields:
label = field_labels.get(field, field)
if field in collected_fields:
value = str(collected_fields[field])
if len(value) > 50:
value = value[:47] + "..."
status = "✅ 已收集"
table_rows.append(f"| {label} | {status} | {value} |")
else:
table_rows.append(f"| {label} | ❌ 待收集 | - |")
return "\n".join(table_rows)
async def health_check(self) -> bool:
"""
健康检查
Returns:
服务是否正常
"""
try:
response = await self.client.chat.completions.create(
model=self.llm_config.model,
messages=[{"role": "user", "content": "Hello"}],
max_tokens=10,
stream=False
)
return True
except Exception as e:
logger.error(f"LLM服务健康检查失败: {e}")
return False
# 全局LLM服务实例
_llm_service = None
async def get_llm_service() -> LLMService:
"""获取LLM服务实例"""
global _llm_service
if _llm_service is None:
_llm_service = LLMService()
return _llm_service