Skip to main content
Glama

QueryNest

by niuzaishu
query_generation.py34 kB
# -*- coding: utf-8 -*- """查询生成工具 v2 - 支持用户确认机制""" from typing import Dict, List, Any, Optional import structlog from mcp.types import Tool, TextContent from database.connection_manager import ConnectionManager from database.metadata_manager import MetadataManager from scanner.semantic_analyzer import SemanticAnalyzer from utils.parameter_validator import ParameterValidator, MCPParameterHelper, ValidationResult from utils.tool_context import get_context_manager from utils.error_handler import with_error_handling, with_retry, RetryConfig from utils.workflow_manager import get_workflow_manager, WorkflowStage from utils.user_confirmation import UserConfirmationHelper, ConfirmationParser logger = structlog.get_logger(__name__) class QueryGenerationTool: """查询生成工具 v2 - 支持用户确认机制""" def __init__(self, connection_manager: ConnectionManager, metadata_manager: MetadataManager, semantic_analyzer: SemanticAnalyzer): self.connection_manager = connection_manager self.metadata_manager = metadata_manager self.semantic_analyzer = semantic_analyzer self.context_manager = get_context_manager() self.workflow_manager = get_workflow_manager() def get_tool_definition(self) -> Tool: """获取工具定义""" return Tool( name="generate_query", description="智能查询生成工具:生成MongoDB查询语句并要求用户确认后执行", inputSchema={ "type": "object", "properties": { "instance_id": { "type": "string", "description": "MongoDB实例ID(可选,会从工作流上下文自动获取)" }, "database_name": { "type": "string", "description": "数据库名称(可选,会从工作流上下文自动获取)" }, "collection_name": { "type": "string", "description": "集合名称(可选,会从工作流上下文自动获取)" }, "query_description": { "type": "string", "description": "查询需求的自然语言描述" }, "query_type": { "type": "string", "description": "查询类型", "enum": ["auto", "find", "count", "aggregate", "distinct"], "default": "auto" }, "limit": { "type": "integer", "description": "结果限制数量", "default": 10, "minimum": 1, "maximum": 1000 }, "session_id": { "type": "string", "description": "会话标识符,默认为'default'", "default": "default" }, "user_confirmation": { "type": "string", "description": "用户对生成查询的确认选择(A=执行, B=修改, C=查看计划, D=取消)" }, "skip_confirmation": { "type": "boolean", "description": "跳过用户确认,直接生成查询语句(不执行)", "default": False } }, "required": ["query_description"] } ) @with_error_handling("查询生成") async def execute(self, arguments: Dict[str, Any]) -> List[TextContent]: """执行查询生成""" session_id = arguments.get("session_id", "default") query_description = arguments["query_description"] query_type = arguments.get("query_type", "auto") limit = arguments.get("limit", 10) user_confirmation = arguments.get("user_confirmation") skip_confirmation = arguments.get("skip_confirmation", False) # 从工作流上下文获取缺失参数 workflow_data = await self.workflow_manager.get_workflow_data(session_id) instance_id = arguments.get("instance_id") or workflow_data.get("instance_id") database_name = arguments.get("database_name") or workflow_data.get("database_name") collection_name = arguments.get("collection_name") or workflow_data.get("collection_name") # 验证必需参数 if not instance_id: return [TextContent( type="text", text="## ❌ 缺少实例信息\n\n请先选择MongoDB实例。" )] if not database_name: return [TextContent( type="text", text="## ❌ 缺少数据库信息\n\n请先选择数据库。" )] if not collection_name: return [TextContent( type="text", text="## ❌ 缺少集合信息\n\n请先选择集合。" )] # 验证连接 if not self.connection_manager.has_instance(instance_id): return [TextContent( type="text", text=f"## ❌ 实例不存在\n\n实例 `{instance_id}` 不存在。" )] # 生成查询语句 try: query_info = await self._generate_query( instance_id, database_name, collection_name, query_description, query_type, limit, session_id ) except Exception as e: return [TextContent( type="text", text=f"## ❌ 查询生成失败\n\n错误: {str(e)}\n\n请检查查询描述是否清晰,或尝试更简单的查询。" )] # 如果跳过确认,直接返回查询语句 if skip_confirmation: return await self._show_query_only(query_info) # 如果没有用户确认,显示确认提示 if not user_confirmation: return await self._show_confirmation_prompt(query_info) # 处理用户确认 return await self._handle_user_confirmation(user_confirmation, query_info, session_id) async def _generate_query(self, instance_id: str, database_name: str, collection_name: str, query_description: str, query_type: str, limit: int, session_id: str) -> Dict[str, Any]: """生成MongoDB查询语句""" logger.info("生成查询语句", instance_id=instance_id, database_name=database_name, collection_name=collection_name, query_description=query_description, query_type=query_type) # 获取集合结构信息 collection_info = await self._get_collection_info(instance_id, database_name, collection_name) # 使用语义分析器来理解查询意图 semantic_info = await self._analyze_query_semantics( instance_id, database_name, collection_name, query_description ) # 基于结构和语义信息生成查询 mongodb_query = await self._build_mongodb_query( collection_info, semantic_info, query_description, query_type, limit ) # 估算结果数量 estimated_count = await self._estimate_result_count( instance_id, database_name, collection_name, mongodb_query ) return { "instance_id": instance_id, "database_name": database_name, "collection_name": collection_name, "query_description": query_description, "query_type": mongodb_query.get("operation", query_type), "mongodb_query": mongodb_query, "limit": limit, "estimated_result_count": estimated_count, "collection_info": collection_info, "semantic_info": semantic_info } async def _get_collection_info(self, instance_id: str, database_name: str, collection_name: str) -> Dict[str, Any]: """获取集合结构信息""" try: connection = self.connection_manager.get_instance_connection(instance_id) if not connection or not connection.client: raise ValueError(f"实例 {instance_id} 连接不可用") db = connection.client[database_name] collection = db[collection_name] # 获取样本文档来分析结构 sample_docs = [] async for doc in collection.find().limit(5): sample_docs.append(doc) # 分析字段结构 field_info = {} if sample_docs: for doc in sample_docs: if isinstance(doc, dict): for field, value in doc.items(): if field not in field_info: field_info[field] = { "name": field, "types": set(), "sample_values": [] } # 记录字段类型 field_info[field]["types"].add(type(value).__name__) # 记录样本值(避免太长) if len(field_info[field]["sample_values"]) < 3: sample_value = str(value)[:50] if len(str(value)) > 50 else str(value) field_info[field]["sample_values"].append(sample_value) # 转换为列表格式 fields = [] for field_name, info in field_info.items(): fields.append({ "name": field_name, "types": list(info["types"]), "sample_values": info["sample_values"] }) return { "collection_name": collection_name, "document_count": await collection.count_documents({}), "fields": fields, "sample_documents": sample_docs[:2] # 保留2个样本文档 } except Exception as e: logger.error("获取集合信息失败", error=str(e)) raise async def _analyze_query_semantics(self, instance_id: str, database_name: str, collection_name: str, query_description: str) -> Dict[str, Any]: """分析查询的语义意图""" try: # 使用语义分析器分析查询意图 return await self.semantic_analyzer.analyze_query_intent( query_description, instance_id, database_name, collection_name ) except Exception as e: logger.warning("语义分析失败,使用基础分析", error=str(e)) # 基础的关键词分析 return self._basic_query_analysis(query_description) def _basic_query_analysis(self, query_description: str) -> Dict[str, Any]: """基础查询意图分析""" description_lower = query_description.lower() # 检测查询类型 if any(keyword in description_lower for keyword in ["count", "数量", "多少", "统计"]): operation = "count" elif any(keyword in description_lower for keyword in ["distinct", "唯一", "去重", "不同"]): operation = "distinct" elif any(keyword in description_lower for keyword in ["sum", "average", "max", "min", "group", "聚合", "分组", "求和", "平均"]): operation = "aggregate" else: operation = "find" # 提取可能的字段名和条件 potential_fields = [] conditions = [] # 简单的字段提取(基于常见模式) import re # 查找类似 "field = value" 的模式 equals_patterns = re.findall(r'(\w+)\s*[=等于是]\s*["\']?([^"\',,]+)["\']?', description_lower) for field, value in equals_patterns: potential_fields.append(field) conditions.append({"field": field, "operator": "equals", "value": value.strip()}) # 查找类似 "field > value" 的模式 comparison_patterns = re.findall(r'(\w+)\s*([>大于<小于>=<=])\s*(\d+)', description_lower) for field, operator, value in comparison_patterns: potential_fields.append(field) op_map = {">": "gt", "大于": "gt", "<": "lt", "小于": "lt", ">=": "gte", "<=": "lte"} conditions.append({"field": field, "operator": op_map.get(operator, "gt"), "value": int(value)}) return { "operation": operation, "potential_fields": potential_fields, "conditions": conditions, "confidence": 0.6 # 基础分析的置信度较低 } async def _build_mongodb_query(self, collection_info: Dict[str, Any], semantic_info: Dict[str, Any], query_description: str, query_type: str, limit: int) -> Dict[str, Any]: """构建MongoDB查询语句""" operation = semantic_info.get("operation", query_type) if operation == "auto": operation = "find" # 构建查询条件 query_filter = {} # 根据语义信息构建过滤条件 conditions = semantic_info.get("conditions", []) for condition in conditions: field = condition["field"] operator = condition["operator"] value = condition["value"] # 验证字段是否存在 field_names = [f["name"] for f in collection_info.get("fields", [])] if field in field_names: if operator == "equals": query_filter[field] = value elif operator == "gt": query_filter[field] = {"$gt": value} elif operator == "lt": query_filter[field] = {"$lt": value} elif operator == "gte": query_filter[field] = {"$gte": value} elif operator == "lte": query_filter[field] = {"$lte": value} # 如果没有明确条件,尝试智能匹配 if not query_filter: query_filter = await self._smart_field_matching(collection_info, query_description) # 构建完整查询 mongodb_query = { "operation": operation, "filter": query_filter } if operation == "find": mongodb_query["limit"] = limit # 选择要返回的字段(限制返回字段以提高性能) important_fields = self._select_important_fields(collection_info, semantic_info) if important_fields: mongodb_query["projection"] = {field: 1 for field in important_fields} elif operation == "count": # count 查询不需要 limit 和 projection pass elif operation == "distinct": # 为 distinct 查询选择字段 distinct_field = self._select_distinct_field(collection_info, semantic_info) mongodb_query["field"] = distinct_field elif operation == "aggregate": # 构建聚合管道 mongodb_query["pipeline"] = self._build_aggregation_pipeline(collection_info, semantic_info, query_description) return mongodb_query async def _smart_field_matching(self, collection_info: Dict[str, Any], query_description: str) -> Dict[str, Any]: """智能字段匹配""" query_filter = {} description_lower = query_description.lower() # 遍历字段,寻找可能的匹配 for field_info in collection_info.get("fields", []): field_name = field_info["name"] field_name_lower = field_name.lower() # 如果查询描述中包含字段名 if field_name_lower in description_lower: # 尝试提取值 import re # 查找字段名后面的值 pattern = f"{field_name_lower}\\s*[=:是为]\\s*[\"']?([^\"',,\\s]+)[\"']?" match = re.search(pattern, description_lower) if match: value = match.group(1) # 尝试转换类型 if value.isdigit(): query_filter[field_name] = int(value) elif value.replace('.', '').isdigit(): query_filter[field_name] = float(value) else: query_filter[field_name] = value return query_filter def _select_important_fields(self, collection_info: Dict[str, Any], semantic_info: Dict[str, Any]) -> List[str]: """选择重要字段""" all_fields = [f["name"] for f in collection_info.get("fields", [])] # 优先选择语义分析中涉及的字段 important_fields = semantic_info.get("potential_fields", []) # 添加一些常见的重要字段 common_important = ["_id", "id", "name", "title", "status", "created_at", "updated_at"] for field in all_fields: if field.lower() in [f.lower() for f in common_important]: if field not in important_fields: important_fields.append(field) # 限制字段数量,避免返回过多数据 return important_fields[:10] def _select_distinct_field(self, collection_info: Dict[str, Any], semantic_info: Dict[str, Any]) -> str: """选择distinct查询的字段""" potential_fields = semantic_info.get("potential_fields", []) if potential_fields: return potential_fields[0] # 默认选择第一个非_id字段 for field_info in collection_info.get("fields", []): if field_info["name"] != "_id": return field_info["name"] return "_id" def _build_aggregation_pipeline(self, collection_info: Dict[str, Any], semantic_info: Dict[str, Any], query_description: str) -> List[Dict[str, Any]]: """构建聚合管道""" pipeline = [] # 基础的聚合管道 description_lower = query_description.lower() if "group" in description_lower or "分组" in description_lower: # 添加分组阶段 group_stage = {"$group": {"_id": None, "count": {"$sum": 1}}} pipeline.append(group_stage) if "sum" in description_lower or "求和" in description_lower: # 查找数值字段进行求和 numeric_fields = [] for field_info in collection_info.get("fields", []): if any(t in ["int", "float", "Decimal128"] for t in field_info.get("types", [])): numeric_fields.append(field_info["name"]) if numeric_fields: group_stage = { "$group": { "_id": None, "total": {"$sum": f"${numeric_fields[0]}"} } } pipeline.append(group_stage) # 如果没有特殊聚合,返回基础统计 if not pipeline: pipeline = [ {"$group": {"_id": None, "count": {"$sum": 1}}}, {"$project": {"_id": 0, "total_documents": "$count"}} ] return pipeline async def _estimate_result_count(self, instance_id: str, database_name: str, collection_name: str, mongodb_query: Dict[str, Any]) -> int: """估算查询结果数量""" try: connection = self.connection_manager.get_instance_connection(instance_id) if not connection or not connection.client: return -1 db = connection.client[database_name] collection = db[collection_name] # 对于简单查询,直接统计 if mongodb_query.get("operation") == "count": return await collection.count_documents(mongodb_query.get("filter", {})) elif mongodb_query.get("operation") == "find": # 限制统计时间,如果超过1000条就返回估算值 filter_query = mongodb_query.get("filter", {}) if not filter_query: # 无过滤条件,返回总文档数 return await collection.count_documents({}) else: # 有过滤条件,统计匹配数量 return await collection.count_documents(filter_query) else: # 其他类型查询,返回未知 return -1 except Exception as e: logger.warning("估算结果数量失败", error=str(e)) return -1 async def _show_query_only(self, query_info: Dict[str, Any]) -> List[TextContent]: """仅显示生成的查询语句""" text = f"## 🔍 生成的MongoDB查询语句\n\n" text += f"**查询描述**: {query_info['query_description']}\n" text += f"**目标集合**: `{query_info['instance_id']}.{query_info['database_name']}.{query_info['collection_name']}`\n" text += f"**查询类型**: {query_info['query_type']}\n\n" text += "### 📄 MongoDB查询语句\n\n" text += "```javascript\n" # 格式化显示查询语句 mongodb_query = query_info["mongodb_query"] operation = mongodb_query.get("operation", "find") if operation == "find": filter_part = mongodb_query.get("filter", {}) projection_part = mongodb_query.get("projection", {}) limit_part = mongodb_query.get("limit", 10) text += f"db.{query_info['collection_name']}.find(" if filter_part: import json text += json.dumps(filter_part, indent=2, ensure_ascii=False) else: text += "{}" if projection_part: text += ",\n " text += json.dumps(projection_part, indent=2, ensure_ascii=False) text += f").limit({limit_part})" elif operation == "count": filter_part = mongodb_query.get("filter", {}) text += f"db.{query_info['collection_name']}.countDocuments(" if filter_part: import json text += json.dumps(filter_part, indent=2, ensure_ascii=False) else: text += "{}" text += ")" elif operation == "distinct": field = mongodb_query.get("field", "_id") filter_part = mongodb_query.get("filter", {}) text += f'db.{query_info["collection_name"]}.distinct("{field}"' if filter_part: text += ", " import json text += json.dumps(filter_part, indent=2, ensure_ascii=False) text += ")" elif operation == "aggregate": pipeline = mongodb_query.get("pipeline", []) text += f"db.{query_info['collection_name']}.aggregate(" import json text += json.dumps(pipeline, indent=2, ensure_ascii=False) text += ")" text += "\n```\n\n" # 显示预期结果 if query_info.get("estimated_result_count", -1) >= 0: text += f"**预期结果数量**: 约 {query_info['estimated_result_count']} 条\n" text += f"**结果限制**: 最多返回 {query_info.get('limit', 10)} 条\n\n" text += "💡 **提示**: 使用 `generate_query()` 并提供 `user_confirmation` 参数来执行查询" return [TextContent(type="text", text=text)] async def _show_confirmation_prompt(self, query_info: Dict[str, Any]) -> List[TextContent]: """显示确认提示""" return [UserConfirmationHelper.create_query_confirmation_prompt(query_info)] async def _handle_user_confirmation(self, user_confirmation: str, query_info: Dict[str, Any], session_id: str) -> List[TextContent]: """处理用户确认""" choice_upper = user_confirmation.upper() if choice_upper in ['A', 'CONFIRM', 'EXECUTE']: # 确认执行查询 return await self._execute_query(query_info, session_id) elif choice_upper in ['B', 'MODIFY', 'REGENERATE']: # 重新生成查询 return [TextContent( type="text", text="## 🔧 重新生成查询\n\n请使用不同的查询描述重新调用 `generate_query(query_description=\"新的查询描述\")`" )] elif choice_upper in ['C', 'PLAN', 'EXPLAIN']: # 查看执行计划 return await self._show_execution_plan(query_info) elif choice_upper in ['D', 'CANCEL']: # 取消执行 return [TextContent( type="text", text="## ❌ 已取消查询执行" )] else: # 无效选择 return [TextContent( type="text", text=f"## ❌ 无效选择\n\n选择 '{user_confirmation}' 无效。请选择 A(执行)、B(修改)、C(查看计划)或 D(取消)。" )] async def _execute_query(self, query_info: Dict[str, Any], session_id: str) -> List[TextContent]: """执行查询""" logger.info("执行确认的查询", instance_id=query_info["instance_id"], database_name=query_info["database_name"], collection_name=query_info["collection_name"]) try: # 更新工作流状态 update_data = { "instance_id": query_info["instance_id"], "database_name": query_info["database_name"], "collection_name": query_info["collection_name"], "generated_query": query_info["mongodb_query"] } await self.workflow_manager.update_workflow_data(session_id, update_data) # 执行查询 results = await self._run_mongodb_query(query_info) # 格式化结果 return await self._format_query_results(query_info, results) except Exception as e: logger.error("查询执行失败", error=str(e)) return [TextContent( type="text", text=f"## ❌ 查询执行失败\n\n错误: {str(e)}\n\n请检查查询语句或数据库连接。" )] async def _run_mongodb_query(self, query_info: Dict[str, Any]) -> Any: """运行MongoDB查询""" connection = self.connection_manager.get_instance_connection(query_info["instance_id"]) if not connection or not connection.client: raise ValueError("数据库连接不可用") db = connection.client[query_info["database_name"]] collection = db[query_info["collection_name"]] mongodb_query = query_info["mongodb_query"] operation = mongodb_query.get("operation", "find") if operation == "find": filter_query = mongodb_query.get("filter", {}) projection = mongodb_query.get("projection", {}) limit = mongodb_query.get("limit", 10) cursor = collection.find(filter_query, projection).limit(limit) results = [] async for doc in cursor: results.append(doc) return results elif operation == "count": filter_query = mongodb_query.get("filter", {}) return await collection.count_documents(filter_query) elif operation == "distinct": field = mongodb_query.get("field", "_id") filter_query = mongodb_query.get("filter", {}) return await collection.distinct(field, filter_query) elif operation == "aggregate": pipeline = mongodb_query.get("pipeline", []) cursor = collection.aggregate(pipeline) results = [] async for doc in cursor: results.append(doc) return results else: raise ValueError(f"不支持的查询操作: {operation}") async def _format_query_results(self, query_info: Dict[str, Any], results: Any) -> List[TextContent]: """格式化查询结果""" operation = query_info["mongodb_query"].get("operation", "find") text = f"## ✅ 查询执行成功\n\n" text += f"**查询描述**: {query_info['query_description']}\n" text += f"**目标集合**: `{query_info['collection_name']}`\n" text += f"**查询类型**: {operation}\n\n" if operation == "count": text += f"### 📊 统计结果\n\n" text += f"**文档数量**: {results}\n" elif operation == "distinct": text += f"### 📋 唯一值列表\n\n" if isinstance(results, list): for i, value in enumerate(results[:20], 1): # 最多显示20个 text += f"{i}. {value}\n" if len(results) > 20: text += f"... 还有 {len(results) - 20} 个值\n" text += f"\n**总计**: {len(results)} 个唯一值\n" else: text += f"结果: {results}\n" elif operation in ["find", "aggregate"]: text += f"### 📄 查询结果\n\n" if isinstance(results, list): text += f"**返回记录数**: {len(results)}\n\n" for i, doc in enumerate(results[:5], 1): # 最多显示5条记录 text += f"#### 记录 {i}\n" text += "```json\n" import json text += json.dumps(doc, indent=2, ensure_ascii=False, default=str) text += "\n```\n\n" if len(results) > 5: text += f"*... 还有 {len(results) - 5} 条记录*\n\n" else: text += f"结果: {results}\n" # 添加下一步建议 text += "## 🎯 下一步操作\n\n" text += "可以继续以下操作:\n" text += "- `generate_query(query_description=\"新的查询需求\")` - 生成新查询\n" text += "- `workflow_status()` - 查看工作流状态\n" text += "- 分析查询结果,根据需要调整查询条件\n" return [TextContent(type="text", text=text)] async def _show_execution_plan(self, query_info: Dict[str, Any]) -> List[TextContent]: """显示执行计划""" try: connection = self.connection_manager.get_instance_connection(query_info["instance_id"]) if not connection or not connection.client: raise ValueError("数据库连接不可用") db = connection.client[query_info["database_name"]] collection = db[query_info["collection_name"]] mongodb_query = query_info["mongodb_query"] # 获取执行计划 if mongodb_query.get("operation") == "find": filter_query = mongodb_query.get("filter", {}) explain_result = await collection.find(filter_query).explain() else: explain_result = {"message": "只有find查询支持执行计划分析"} text = f"## 📊 查询执行计划\n\n" text += f"**查询类型**: {mongodb_query.get('operation', 'find')}\n" text += f"**集合**: `{query_info['collection_name']}`\n\n" text += "### 📄 执行计划详情\n\n" text += "```json\n" import json text += json.dumps(explain_result, indent=2, ensure_ascii=False, default=str) text += "\n```\n\n" text += "### 📋 确认选项\n\n" text += "查看执行计划后,请选择下一步操作:\n" text += "- `generate_query(..., user_confirmation=\"A\")` - 确认执行查询\n" text += "- `generate_query(..., user_confirmation=\"B\")` - 修改查询\n" text += "- `generate_query(..., user_confirmation=\"D\")` - 取消查询\n" return [TextContent(type="text", text=text)] except Exception as e: return [TextContent( type="text", text=f"## ❌ 获取执行计划失败\n\n错误: {str(e)}" )]

MCP directory API

We provide all the information about MCP servers via our MCP API.

curl -X GET 'https://glama.ai/api/mcp/v1/servers/niuzaishu/QueryNest'

If you have feedback or need assistance with the MCP directory API, please join our Discord server