nlp_processor.py•16.4 kB
"""
MySQL MCP的自然语言处理模块
将自然语言查询转换为SQL
"""
import re
import jieba
import jieba.posseg as pseg
from typing import List, Dict, Any, Optional
from database_models import DatabaseConnection, TableMetadata, ColumnMetadata
class NLPProcessor:
"""处理自然语言查询并将其转换为SQL"""
def __init__(self, metadata_db_manager):
self.metadata_db_manager = metadata_db_manager
# 初始化中文词汇词典
jieba.load_userdict(["用户", "产品", "订单", "信息", "记录", "统计", "数量"])
def process_query(self, natural_query: str, database_id: int) -> str:
"""
将自然语言查询转换为SQL
Args:
natural_query: 自然语言查询字符串
database_id: 目标数据库ID
Returns:
生成的SQL查询语句
"""
# 连接到元数据数据库以获取模式信息
if not self.metadata_db_manager.connection or not self.metadata_db_manager.connection.is_connected():
self.metadata_db_manager.connect()
databases = self.metadata_db_manager.get_all_databases()
target_db = next((db for db in databases if db.id == database_id), None)
if not target_db:
raise ValueError(f"未找到ID为 {database_id} 的数据库")
# 获取数据库模式信息
schema_info = self._get_database_schema(database_id)
# 检查是否是中文查询
if self._contains_chinese(natural_query):
return self._process_chinese_query(natural_query, database_id, schema_info)
else:
# 处理英文查询
return self._process_english_query(natural_query, database_id, schema_info)
def _contains_chinese(self, text: str) -> bool:
"""检查文本是否包含中文字符"""
for ch in text:
if '\u4e00' <= ch <= '\u9fff':
return True
return False
def _process_chinese_query(self, natural_query: str, database_id: int, schema_info: Dict[str, Any] = None) -> str:
"""
处理中文自然语言查询
Args:
natural_query: 中文自然语言查询
database_id: 目标数据库ID
schema_info: 数据库模式信息
Returns:
生成的SQL查询语句
"""
# 使用jieba进行分词和词性标注
words = pseg.cut(natural_query)
word_list = [(word, flag) for word, flag in words]
# 定义中英文映射
table_mapping = {
'用户': 'users',
'产品': 'products',
'订单': 'orders',
'用户信息': 'users',
'产品信息': 'products',
'订单信息': 'orders'
}
# 查找查询类型
if any(word in ['查询', '查找'] for word, _ in word_list) or "show" in natural_query.lower():
# 处理查询类请求
table_name = None
# 首先尝试使用真实的表名信息
if schema_info:
for table in schema_info.get('table_names', []):
if table in natural_query:
table_name = table
break
# 如果在真实表名中找不到,使用预定义的映射
if not table_name:
for chinese_word, english_table in table_mapping.items():
if chinese_word in natural_query:
table_name = english_table
break
# 如果在中文映射中找不到,尝试从查询中直接提取英文表名
if not table_name:
# 查找常见的英文表名
common_tables = ['users', 'products', 'orders', 'user', 'product', 'order']
for table in common_tables:
if table in natural_query.lower():
table_name = table
# 如果是单数形式,转换为复数
if table_name in ['user', 'product', 'order']:
table_name += 's'
break
if table_name:
# 检查是否有特定条件
if "张三" in natural_query:
# 尝试使用真实的列信息
name_columns = ['name', 'username', 'user_name', 'full_name']
if schema_info and table_name in schema_info['tables']:
columns = [col['column_name'] for col in schema_info['tables'][table_name]['columns']]
# 查找名称相关的列
found_columns = [col for col in columns if any(name_col in col.lower() for name_col in name_columns)]
if found_columns:
return f"SELECT * FROM `{table_name}` WHERE `{found_columns[0]}` LIKE '%张三%'"
return f"SELECT * FROM `{table_name}` WHERE name LIKE '%张三%' OR username LIKE '%张三%' OR user_name LIKE '%张三%'"
elif "李四" in natural_query:
return f"SELECT * FROM `{table_name}` WHERE name LIKE '%李四%' OR username LIKE '%李四%' OR user_name LIKE '%李四%'"
else:
return f"SELECT * FROM `{table_name}`"
elif "列出" in natural_query or "显示" in natural_query or "list" in natural_query.lower():
# 处理列表类请求
table_name = None
# 首先尝试使用真实的表名信息
if schema_info:
for table in schema_info.get('table_names', []):
if table in natural_query:
table_name = table
break
# 如果在真实表名中找不到,使用预定义的映射
if not table_name:
for chinese_word, english_table in table_mapping.items():
if chinese_word in natural_query:
table_name = english_table
break
# 如果在中文映射中找不到,尝试从查询中直接提取英文表名
if not table_name:
common_tables = ['users', 'products', 'orders', 'user', 'product', 'order']
for table in common_tables:
if table in natural_query.lower():
table_name = table
# 如果是单数形式,转换为复数
if table_name in ['user', 'product', 'order']:
table_name += 's'
break
if table_name:
return f"SELECT * FROM `{table_name}`"
elif "统计" in natural_query or "数量" in natural_query or "count" in natural_query.lower():
# 处理统计类请求
table_name = None
# 首先尝试使用真实的表名信息
if schema_info:
for table in schema_info.get('table_names', []):
if table in natural_query:
table_name = table
break
# 如果在真实表名中找不到,使用预定义的映射
if not table_name:
for chinese_word, english_table in table_mapping.items():
if chinese_word in natural_query:
table_name = english_table
break
# 如果在中文映射中找不到,尝试从查询中直接提取英文表名
if not table_name:
common_tables = ['users', 'products', 'orders', 'user', 'product', 'order']
for table in common_tables:
if table in natural_query.lower():
table_name = table
# 如果是单数形式,转换为复数
if table_name in ['user', 'product', 'order']:
table_name += 's'
break
if table_name:
return f"SELECT COUNT(*) AS count FROM `{table_name}`"
# 如果无法处理,抛出异常
raise ValueError(f"无法处理的中文查询: {natural_query}。支持的格式包括:'查询用户张三的信息'、'列出所有用户信息'、'统计用户数量',以及英文格式如 'show me all users'")
def _process_english_query(self, natural_query: str, database_id: int, schema_info: Dict[str, Any] = None) -> str:
"""
处理英文自然语言查询
Args:
natural_query: 英文自然语言查询
database_id: 目标数据库ID
schema_info: 数据库模式信息
Returns:
生成的SQL查询语句
"""
# 模式: "show me all records from [table]"
match = re.search(r"show me all records from (\w+)", natural_query.lower())
if match:
table_name = match.group(1)
# 验证表是否存在
if schema_info and table_name not in schema_info.get('table_names', []):
# 尝试寻找最接近的表名
for real_table in schema_info.get('table_names', []):
if table_name in real_table or real_table in table_name:
table_name = real_table
break
return f"SELECT * FROM `{table_name}`"
# 模式: "show all records from [table]"
match = re.search(r"show all records from (\w+)", natural_query.lower())
if match:
table_name = match.group(1)
# 验证表是否存在
if schema_info and table_name not in schema_info.get('table_names', []):
# 尝试寻找最接近的表名
for real_table in schema_info.get('table_names', []):
if table_name in real_table or real_table in table_name:
table_name = real_table
break
return f"SELECT * FROM `{table_name}`"
# 模式: "show all [table]"
match = re.search(r"show all (\w+)", natural_query.lower())
if match:
table_name = match.group(1)
# 验证表是否存在
if schema_info and table_name not in schema_info.get('table_names', []):
# 尝试寻找最接近的表名
for real_table in schema_info.get('table_names', []):
if table_name in real_table or real_table in table_name:
table_name = real_table
break
# 如果是单数形式,转换为复数
if table_name in ['user', 'product', 'order']:
table_name += 's'
return f"SELECT * FROM `{table_name}`"
# 模式: "count records in [table]"
match = re.search(r"count records in (\w+)", natural_query.lower())
if match:
table_name = match.group(1)
# 验证表是否存在
if schema_info and table_name not in schema_info.get('table_names', []):
# 尝试寻找最接近的表名
for real_table in schema_info.get('table_names', []):
if table_name in real_table or real_table in table_name:
table_name = real_table
break
return f"SELECT COUNT(*) AS count FROM `{table_name}`"
# 模式: "find [columns] from [table] where [condition]"
match = re.search(r"find (.+) from (\w+)(?: where (.*))?", natural_query.lower())
if match:
columns = match.group(1)
table_name = match.group(2)
condition = match.group(3)
# 验证表是否存在
if schema_info and table_name not in schema_info.get('table_names', []):
# 尝试寻找最接近的表名
for real_table in schema_info.get('table_names', []):
if table_name in real_table or real_table in table_name:
table_name = real_table
break
# 处理 "all" 关键字
if columns.strip() == "all":
columns = "*"
sql = f"SELECT {columns} FROM `{table_name}`"
if condition:
sql += f" WHERE {condition}"
return sql
# 模式: "select [columns] from [table] where [condition]"
match = re.search(r"select (.+) from (\w+)(?: where (.*))?", natural_query.lower())
if match:
columns = match.group(1)
table_name = match.group(2)
condition = match.group(3)
# 验证表是否存在
if schema_info and table_name not in schema_info.get('table_names', []):
# 尝试寻找最接近的表名
for real_table in schema_info.get('table_names', []):
if table_name in real_table or real_table in table_name:
table_name = real_table
break
sql = f"SELECT {columns} FROM `{table_name}`"
if condition:
sql += f" WHERE {condition}"
return sql
# 默认回退 - 如果看起来像SQL则直接返回
if natural_query.lower().startswith(('select', 'insert', 'update', 'delete')):
return natural_query
# 如果无法处理,抛出异常
raise ValueError(f"无法将查询转换为SQL: {natural_query}。"
f"支持的格式包括: 'show me all records from [table]', "
f"'show all records from [table]', 'count records in [table]', "
f"'find [columns] from [table]', 'select [columns] from [table]'")
def _get_database_schema(self, database_id: int) -> Dict[str, Any]:
"""
从元数据中检索数据库模式信息
Args:
database_id: 要检索模式的数据库ID
Returns:
包含模式信息的字典
"""
if not self.metadata_db_manager.connection or not self.metadata_db_manager.connection.is_connected():
self.metadata_db_manager.connect()
cursor = self.metadata_db_manager.connection.cursor(dictionary=True)
try:
# 获取该数据库的所有表信息
cursor.execute("""
SELECT id, table_name, description
FROM table_metadata
WHERE database_id = %s
""", (database_id,))
tables = cursor.fetchall()
schema_info = {
'tables': {},
'table_names': []
}
# 记录表名列表
for table in tables:
schema_info['table_names'].append(table['table_name'])
# 获取每个表的列信息
for table in tables:
table_id = table['id']
table_name = table['table_name']
cursor.execute("""
SELECT column_name, data_type, is_nullable, column_key, column_comment
FROM column_metadata
WHERE table_id = %s
ORDER BY id
""", (table_id,))
columns = cursor.fetchall()
schema_info['tables'][table_name] = {
'id': table_id,
'description': table['description'],
'columns': columns
}
return schema_info
finally:
cursor.close()