Skip to main content
Glama
sql_parser.py15.1 kB
import sqlparse import re import logging from typing import List, Set, Tuple, Optional, Dict from ..config import SQLConfig logger = logging.getLogger(__name__) class SQLParser: """ SQL解析器 - 使用sqlparse库提供更精确的SQL解析功能 """ @staticmethod def parse_query(sql_query: str) -> Dict: """ 解析SQL查询,返回解析结果 Args: sql_query: SQL查询语句 Returns: Dict: 包含解析结果的字典 """ if not sql_query or not sql_query.strip(): return { 'operation_type': '', 'tables': [], 'has_where': False, 'has_limit': False, 'is_valid': False, 'normalized_query': '', 'category': 'UNKNOWN', 'multi_statement': False, 'statement_count': 0 } try: # 标准化和格式化SQL formatted_sql = SQLParser._format_sql(sql_query) # 解析SQL语句 - 可能有多个语句 parsed = sqlparse.parse(formatted_sql) # 检查是否有多个语句 is_multi_statement = len(parsed) > 1 statement_count = len(parsed) if not parsed: return { 'operation_type': '', 'tables': [], 'has_where': False, 'has_limit': False, 'is_valid': False, 'normalized_query': formatted_sql, 'category': 'UNKNOWN', 'multi_statement': False, 'statement_count': 0 } # 默认分析第一个语句,但记录多语句信息 stmt = parsed[0] # 获取操作类型 operation_type = SQLParser._get_operation_type(stmt) # 确定操作类别 category = SQLParser._get_operation_category(operation_type) # 提取表名 - 汇总所有语句中的表名 tables = set() has_where = False has_limit = False for statement in parsed: # 将各语句涉及的表合并 tables.update(SQLParser._extract_tables(statement)) # 检查任一语句是否有WHERE子句 if SQLParser._has_where_clause(statement): has_where = True # 检查任一语句是否有LIMIT子句 if SQLParser._has_limit_clause(statement): has_limit = True # 对于多语句,获取最高风险的操作类型 if is_multi_statement and len(parsed) > 1: operations = [] categories = [] for statement in parsed: op = SQLParser._get_operation_type(statement) operations.append(op) categories.append(SQLParser._get_operation_category(op)) # 风险优先级: DDL > DML > METADATA if 'DDL' in categories: category = 'DDL' # 在DDL操作中找出优先级最高的 # DROP/TRUNCATE > ALTER > CREATE if 'DROP' in operations or 'TRUNCATE' in operations: operation_type = 'DROP' if 'DROP' in operations else 'TRUNCATE' elif 'ALTER' in operations: operation_type = 'ALTER' elif 'CREATE' in operations: operation_type = 'CREATE' elif 'DML' in categories: category = 'DML' # 在DML操作中找出优先级最高的 # DELETE > UPDATE > INSERT > SELECT if 'DELETE' in operations: operation_type = 'DELETE' elif 'UPDATE' in operations: operation_type = 'UPDATE' elif 'INSERT' in operations: operation_type = 'INSERT' elif 'SELECT' in operations: operation_type = 'SELECT' return { 'operation_type': operation_type, 'tables': list(tables), 'has_where': has_where, 'has_limit': has_limit, 'is_valid': True, 'normalized_query': formatted_sql, 'category': category, 'multi_statement': is_multi_statement, 'statement_count': statement_count } except Exception as e: logger.error(f"SQL解析错误: {str(e)}") # 回退到简单的字符串解析 result = SQLParser._fallback_parse(sql_query) # 添加多语句检测,简单检测分号 result['multi_statement'] = ';' in sql_query.strip() result['statement_count'] = sql_query.count(';') + 1 if sql_query.strip() else 0 return result @staticmethod def _format_sql(sql_query: str) -> str: """标准化SQL查询格式""" # 去除多余空白和注释 return sqlparse.format( sql_query, strip_comments=True, reindent=True, keyword_case='upper' ) @staticmethod def _get_operation_type(stmt: sqlparse.sql.Statement) -> str: """获取SQL操作类型""" # 获取第一个token if stmt.tokens and stmt.tokens[0].ttype is sqlparse.tokens.DML: return stmt.tokens[0].value.upper() elif stmt.tokens and stmt.tokens[0].ttype is sqlparse.tokens.DDL: return stmt.tokens[0].value.upper() elif stmt.tokens and stmt.tokens[0].ttype is sqlparse.tokens.Keyword: return stmt.tokens[0].value.upper() # 如果无法确定,返回空字符串 return "" @staticmethod def _get_operation_category(operation_type: str) -> str: """确定操作类别(DDL、DML或元数据)""" if operation_type in SQLConfig.DDL_OPERATIONS: return 'DDL' elif operation_type in SQLConfig.DML_OPERATIONS: return 'DML' elif operation_type in SQLConfig.METADATA_OPERATIONS: return 'METADATA' else: return 'UNKNOWN' @staticmethod def _extract_tables(stmt: sqlparse.sql.Statement) -> List[str]: """从SQL语句中提取所有表名""" tables = [] # 根据操作类型处理表名提取 operation_type = SQLParser._get_operation_type(stmt) # 递归函数用于深入处理复杂的SQL结构 def extract_from_token_list(token_list): local_tables = [] in_from_clause = False in_join_clause = False for token in token_list.tokens: # 检测FROM子句 if token.ttype is sqlparse.tokens.Keyword and token.value.upper() == 'FROM': in_from_clause = True continue # 检测JOIN子句 if token.ttype is sqlparse.tokens.Keyword and 'JOIN' in token.value.upper(): in_join_clause = True continue # 在FROM或JOIN子句后提取表名 if in_from_clause or in_join_clause: if isinstance(token, sqlparse.sql.Identifier): # 直接引用的表名 if token.get_real_name(): local_tables.append(token.get_real_name()) elif isinstance(token, sqlparse.sql.IdentifierList): # 多个表,如FROM table1, table2 for identifier in token.get_identifiers(): if identifier.get_real_name(): local_tables.append(identifier.get_real_name()) elif isinstance(token, sqlparse.sql.Function): # 处理子查询中的函数,可能包含表 local_tables.extend(extract_from_token_list(token)) elif isinstance(token, sqlparse.sql.Parenthesis): # 可能是子查询 if token.tokens and isinstance(token.tokens[1], sqlparse.sql.Statement): # 是子查询,递归解析 local_tables.extend(SQLParser._extract_tables(token.tokens[1])) else: # 其他括号结构,递归处理 local_tables.extend(extract_from_token_list(token)) # 重置标志以避免收集其他部分的标识符 if token.ttype in (sqlparse.tokens.Keyword, sqlparse.tokens.Punctuation): in_from_clause = False in_join_clause = False # 递归处理其他TokenList if isinstance(token, sqlparse.sql.TokenList) and not isinstance(token, sqlparse.sql.Identifier): local_tables.extend(extract_from_token_list(token)) return local_tables # 特殊处理DML语句 if operation_type == 'UPDATE': # UPDATE语句通常在第一个标识符中包含表名 for i, token in enumerate(stmt.tokens): if token.ttype is sqlparse.tokens.DML and token.value.upper() == 'UPDATE': if i+1 < len(stmt.tokens): if isinstance(stmt.tokens[i+1], sqlparse.sql.Identifier): tables.append(stmt.tokens[i+1].get_real_name()) elif isinstance(stmt.tokens[i+1], sqlparse.sql.IdentifierList): # 多表更新 for identifier in stmt.tokens[i+1].get_identifiers(): if identifier.get_real_name(): tables.append(identifier.get_real_name()) break elif operation_type == 'INSERT': # INSERT语句 into_found = False for i, token in enumerate(stmt.tokens): if token.ttype is sqlparse.tokens.Keyword and token.value.upper() == 'INTO': into_found = True elif into_found and isinstance(token, sqlparse.sql.Identifier): tables.append(token.get_real_name()) break elif into_found and isinstance(token, sqlparse.sql.Function): # 处理INSERT INTO table(...) if token.get_name(): tables.append(token.get_name()) break elif operation_type == 'DELETE': # DELETE FROM table from_found = False for i, token in enumerate(stmt.tokens): if token.ttype is sqlparse.tokens.Keyword and token.value.upper() == 'FROM': from_found = True elif from_found and isinstance(token, sqlparse.sql.Identifier): tables.append(token.get_real_name()) break elif from_found and isinstance(token, sqlparse.sql.IdentifierList): for identifier in token.get_identifiers(): if identifier.get_real_name(): tables.append(identifier.get_real_name()) break elif operation_type in {'CREATE', 'ALTER', 'DROP', 'TRUNCATE'}: # DDL语句 table_found = False for i, token in enumerate(stmt.tokens): if token.ttype is sqlparse.tokens.Keyword and token.value.upper() == 'TABLE': table_found = True elif table_found and isinstance(token, sqlparse.sql.Identifier): tables.append(token.get_real_name()) break else: # 对于其他语句,通过递归处理提取表名 tables.extend(extract_from_token_list(stmt)) # 移除可能的重复项 return list(set([table for table in tables if table])) @staticmethod def _has_where_clause(stmt: sqlparse.sql.Statement) -> bool: """检查SQL语句是否包含WHERE子句""" for token in stmt.tokens: if isinstance(token, sqlparse.sql.Where): return True return False @staticmethod def _has_limit_clause(stmt: sqlparse.sql.Statement) -> bool: """检查SQL语句是否包含LIMIT子句""" # LIMIT通常作为一个关键字出现 for token in stmt.tokens: if token.ttype is sqlparse.tokens.Keyword and token.value.upper() == 'LIMIT': return True # 处理更复杂的语句结构 elif isinstance(token, sqlparse.sql.TokenList): for subtoken in token.tokens: if subtoken.ttype is sqlparse.tokens.Keyword and subtoken.value.upper() == 'LIMIT': return True return False @staticmethod def _fallback_parse(sql_query: str) -> Dict: """当高级解析失败时,回退到基本字符串解析""" sql_upper = sql_query.strip().upper() parts = sql_upper.split() operation_type = parts[0] if parts else "" # 确定操作类别 category = 'UNKNOWN' if operation_type in SQLConfig.DDL_OPERATIONS: category = 'DDL' elif operation_type in SQLConfig.DML_OPERATIONS: category = 'DML' elif operation_type in SQLConfig.METADATA_OPERATIONS: category = 'METADATA' # 基本的表名提取 tables = [] for i, word in enumerate(parts): if word in {'FROM', 'JOIN', 'UPDATE', 'INTO', 'TABLE'}: if i + 1 < len(parts): table = parts[i + 1].strip('`;') if table not in {'SELECT', 'WHERE', 'SET'}: tables.append(table) # 简单检查WHERE子句 has_where = 'WHERE' in sql_upper # 简单检查LIMIT子句 has_limit = 'LIMIT' in sql_upper return { 'operation_type': operation_type, 'tables': list(set(tables)), 'has_where': has_where, 'has_limit': has_limit, 'is_valid': bool(operation_type), 'normalized_query': sql_query, 'category': category }

Latest Blog Posts

MCP directory API

We provide all the information about MCP servers via our MCP API.

curl -X GET 'https://glama.ai/api/mcp/v1/servers/mangooer/mysql-mcp-server-sse'

If you have feedback or need assistance with the MCP directory API, please join our Discord server