Skip to main content
Glama
yooztech

yooztech_mcp_mysql

Official
by yooztech
app.py16.3 kB
#!/usr/bin/env python3 from __future__ import annotations import os from typing import Any, Dict, List, Optional, Tuple import mysql.connector as mc from mcp.server.fastmcp import FastMCP server = FastMCP("yooztech_mcp_mysql") class MySQLGuard: """封装连接、白名单校验与库推断,支持只读查询。""" def __init__(self) -> None: self.host = os.getenv("DB_HOST", "127.0.0.1") self.port = int(os.getenv("DB_PORT", "3306")) self.user = os.getenv("DB_USER", "root") self.password = os.getenv("DB_PASS", "") # 不再支持/读取 DB_NAME,全靠运行时推断或工具入参 self.inferred_db: Optional[str] = None self._schema_cache: Dict[str, List[str]] = {} self._conn = mc.connect( host=self.host, port=self.port, user=self.user, password=self.password, database=None, autocommit=False, ) def _non_system_databases(self) -> List[str]: """列出非系统库,用于自动解析可访问库。 注意:这里基于 schemata 过滤系统库名;并不精准检查权限,但在只读账号下一般可用。 """ with self._conn.cursor() as cur: cur.execute( """ SELECT SCHEMA_NAME FROM information_schema.SCHEMATA WHERE SCHEMA_NAME NOT IN ('mysql','information_schema','performance_schema','sys') ORDER BY SCHEMA_NAME """ ) return [r[0] for r in cur.fetchall()] def _resolve_database(self, db: Optional[str]) -> str: if db: return db if self.inferred_db: return self.inferred_db # 自动推断一次(基于当前工作目录) guessed, _ = self._infer_database_internal(os.getcwd()) if guessed: self.inferred_db = guessed return guessed # 回退:若仅有一个非系统库,直接使用 candidates = self._non_system_databases() if len(candidates) == 1: self.inferred_db = candidates[0] return candidates[0] raise ValueError("存在多个可访问库且无法从项目中推断,请在参数中指定 db 或先调用 infer_database 工具") # --- 推断逻辑 --- def _extract_db_hints_from_text(self, text: str) -> List[str]: """从文本中提取可能的数据库名(简单启发式)。""" import re hints: List[str] = [] # 常见 .env 键 for key in [ "MYSQL_DATABASE", "DB_NAME", "DATABASE_NAME", "MYSQL_DB", ]: m = re.search(rf"{key}\s*=\s*([A-Za-z0-9_\-]+)", text) if m: hints.append(m.group(1)) # JDBC / URL 形式 .../(dbname)? for pat in [ r"jdbc:mysql://[^/\s]+/([A-Za-z0-9_\-]+)", r"mysql:\/\/[^/\s]+/([A-Za-z0-9_\-]+)", ]: for m in re.finditer(pat, text, re.IGNORECASE): hints.append(m.group(1)) return list(dict.fromkeys(hints)) # 去重且保序 def _infer_database_internal(self, project_root: Optional[str]) -> Tuple[Optional[str], Dict[str, Any]]: """从项目目录推断数据库名,返回 (db, 证据)。""" if not project_root: project_root = os.getcwd() candidates = self._non_system_databases() evidence: Dict[str, Any] = {"candidates": candidates, "matches": []} if not candidates: return None, evidence # 仅扫描有限文件集合,防止开销过大 prefer_names = [ ".env", ".env.local", "env.example", "config.yml", "application.yml", "application.yaml", "config.json", "settings.py", "database.yml", "package.json", "pyproject.toml", ] scanned = 0 max_files = 200 size_limit = 256 * 1024 found_hints: List[str] = [] # 优先扫描常见文件名 for name in prefer_names: path = os.path.join(project_root, name) if os.path.isfile(path): try: with open(path, "r", encoding="utf-8", errors="ignore") as f: text = f.read(size_limit) found_hints.extend(self._extract_db_hints_from_text(text)) scanned += 1 except Exception: pass # 继续浅层扫描部分文件 if scanned < max_files: for root, _dirs, files in os.walk(project_root): # 仅扫描前两级目录 depth = root[len(project_root) :].count(os.sep) if depth > 2: continue for fn in files: if scanned >= max_files: break # 仅看文本向的后缀 if not any( fn.lower().endswith(ext) for ext in (".env", ".yml", ".yaml", ".json", ".py", ".ts", ".js", ".toml", ".ini", ".properties") ): continue path = os.path.join(root, fn) try: with open(path, "r", encoding="utf-8", errors="ignore") as f: text = f.read(size_limit) hints = self._extract_db_hints_from_text(text) if hints: evidence["matches"].append({"file": path, "hints": hints}) found_hints.extend(hints) scanned += 1 except Exception: continue found_hints = list(dict.fromkeys(found_hints)) # 与可访问库求交集 intersection = [h for h in found_hints if h in candidates] if len(intersection) == 1: return intersection[0], {**evidence, "selected": intersection[0], "hints": found_hints} if not intersection and len(candidates) == 1: # 只有一个库可访问,直接使用 return candidates[0], {**evidence, "selected": candidates[0], "hints": found_hints} # 无法唯一确定 return None, {**evidence, "hints": found_hints} def _ensure_table_cached(self, db: str, table: str) -> None: key = f"{db}.{table}" if key in self._schema_cache: return with self._conn.cursor() as cur: cur.execute( """ SELECT COLUMN_NAME FROM information_schema.COLUMNS WHERE TABLE_SCHEMA = %s AND TABLE_NAME = %s ORDER BY ORDINAL_POSITION """, (db, table), ) cols = [r[0] for r in cur.fetchall()] if not cols: raise ValueError(f"表不存在或无列: {table}") self._schema_cache[key] = cols def list_tables(self, db: Optional[str] = None) -> List[str]: dbname = self._resolve_database(db) with self._conn.cursor() as cur: cur.execute( """ SELECT TABLE_NAME FROM information_schema.TABLES WHERE TABLE_SCHEMA = %s ORDER BY TABLE_NAME """, (dbname,), ) return [r[0] for r in cur.fetchall()] def get_table_schema(self, table: str, db: Optional[str] = None) -> Dict[str, Any]: """返回指定表的结构信息:列定义、主键、索引与表注释。 返回示例: { "db": "mydb", "table": "users", "comment": "table comment", "columns": [ { "name": "id", "data_type": "int", "column_type": "int(11)", "nullable": false, "default": null, "key": "PRI", "extra": "auto_increment", "comment": "primary key", "ordinal_position": 1 }, ... ], "primary_key": ["id"], "indexes": [ {"name": "idx_email", "columns": ["email"], "unique": false, "index_type": "BTREE"} ] } """ dbname = self._resolve_database(db) # 列定义 with self._conn.cursor() as cur: cur.execute( """ SELECT COLUMN_NAME, DATA_TYPE, COLUMN_TYPE, IS_NULLABLE, COLUMN_DEFAULT, COLUMN_KEY, EXTRA, COLUMN_COMMENT, ORDINAL_POSITION FROM information_schema.COLUMNS WHERE TABLE_SCHEMA = %s AND TABLE_NAME = %s ORDER BY ORDINAL_POSITION """, (dbname, table), ) col_rows = cur.fetchall() if not col_rows: raise ValueError(f"表不存在或无列: {table}") columns: List[Dict[str, Any]] = [] primary_key_cols: List[str] = [] for ( column_name, data_type, column_type, is_nullable, column_default, column_key, extra, column_comment, ordinal_position, ) in col_rows: nullable_flag = (str(is_nullable).upper() == "YES") columns.append( { "name": column_name, "data_type": data_type, "column_type": column_type, "nullable": nullable_flag, "default": column_default, "key": column_key, "extra": extra, "comment": column_comment, "ordinal_position": int(ordinal_position), } ) if str(column_key).upper() == "PRI": primary_key_cols.append(column_name) # 表注释 table_comment: Optional[str] = None with self._conn.cursor() as cur: cur.execute( """ SELECT TABLE_COMMENT FROM information_schema.TABLES WHERE TABLE_SCHEMA = %s AND TABLE_NAME = %s """, (dbname, table), ) row = cur.fetchone() if row: table_comment = row[0] # 索引信息(含 PRIMARY) with self._conn.cursor() as cur: cur.execute( """ SELECT INDEX_NAME, NON_UNIQUE, INDEX_TYPE, SEQ_IN_INDEX, COLUMN_NAME FROM information_schema.STATISTICS WHERE TABLE_SCHEMA = %s AND TABLE_NAME = %s ORDER BY INDEX_NAME, SEQ_IN_INDEX """, (dbname, table), ) idx_rows = cur.fetchall() indexes_map: Dict[str, Dict[str, Any]] = {} for index_name, non_unique, index_type, seq_in_index, col_name in idx_rows: if str(index_name).upper() == "PRIMARY": # 以 STATISTICS 为准,覆盖 primary_key_cols 的顺序 if col_name not in primary_key_cols: primary_key_cols.append(col_name) continue if index_name not in indexes_map: indexes_map[index_name] = { "name": index_name, "columns": [], "unique": (int(non_unique) == 0), "index_type": index_type, } indexes_map[index_name]["columns"].append(col_name) return { "db": dbname, "table": table, "comment": table_comment, "columns": columns, "primary_key": primary_key_cols, "indexes": list(indexes_map.values()), } def select_rows( self, table: str, db: Optional[str] = None, columns: Optional[List[str]] = None, where: Optional[Dict[str, Any]] = None, order_by: Optional[List[str]] = None, limit: int = 100, ) -> List[Dict[str, Any]]: if limit <= 0 or limit > 1000: raise ValueError("limit 必须在 1..1000 之间") dbname = self._resolve_database(db) self._ensure_table_cached(dbname, table) allowed_cols = self._schema_cache[f"{dbname}.{table}"] # 列白名单 if columns: for c in columns: if c not in allowed_cols: raise ValueError(f"非法列: {c}") else: columns = allowed_cols # 组装 WHERE,键必须是合法列,值参数化 where_clauses: List[str] = [] params: List[Any] = [] if where: for k, v in where.items(): if k not in allowed_cols: raise ValueError(f"非法条件列: {k}") where_clauses.append(f"`{k}` = %s") params.append(v) # 组装 ORDER BY,列必须合法 order_clause = "" if order_by: for ob in order_by: col = ob.lstrip("-+") if col not in allowed_cols: raise ValueError(f"非法排序列: {col}") parts = [ (ob.lstrip("-+"), "DESC" if ob.startswith("-") else "ASC") for ob in order_by ] order_clause = " ORDER BY " + ", ".join(f"`{c}` {d}" for c, d in parts) where_clause = (" WHERE " + " AND ".join(where_clauses)) if where_clauses else "" col_sql = ", ".join(f"`{c}`" for c in columns) sql = ( f"SELECT {col_sql} FROM `{dbname}`.`{table}`" f"{where_clause}{order_clause} LIMIT {int(limit)}" ) with self._conn.cursor(dictionary=True) as cur: cur.execute("SET SESSION sql_safe_updates=1") cur.execute(sql, tuple(params)) rows = cur.fetchall() return rows guard = MySQLGuard() @server.tool() async def list_databases() -> List[str]: """列出当前账号可访问的非系统库。""" return guard._non_system_databases() @server.tool() async def infer_database(project_root: Optional[str] = None, include_evidence: bool = False) -> Dict[str, Any]: """从项目内容推断数据库。默认仅返回 { db };当 include_evidence=true 时返回经过脱敏的证据统计(不含文件路径/内容)。""" db, ev = guard._infer_database_internal(project_root) if db: guard.inferred_db = db result: Dict[str, Any] = {"db": db} if include_evidence: result["evidence"] = { "candidates_count": len(ev.get("candidates", [])), "hint_count": len(ev.get("hints", [])) if isinstance(ev.get("hints"), list) else 0, "selected": ev.get("selected") is not None, "method": ev.get("selected") and "selected" in ev or None, } return result @server.tool() async def list_tables(db: Optional[str] = None) -> List[str]: """列出数据库中的所有表。db 省略时,使用已推断的库或自动推断。""" return guard.list_tables(db) @server.tool() async def get_table_schema(table: str, db: Optional[str] = None) -> Dict[str, Any]: """获取表结构:列定义、主键、索引。db 省略时自动推断。""" return guard.get_table_schema(table, db) @server.tool() async def select_rows( table: str, db: Optional[str] = None, columns: Optional[List[str]] = None, where: Optional[Dict[str, Any]] = None, order_by: Optional[List[str]] = None, limit: int = 100, ) -> List[Dict[str, Any]]: """从指定表安全查询。order_by 项可用前缀 '-' 表示 DESC。db 省略时自动推断。""" return guard.select_rows(table, db, columns, where, order_by, limit) def main() -> None: server.run() if __name__ == "__main__": main()

Latest Blog Posts

MCP directory API

We provide all the information about MCP servers via our MCP API.

curl -X GET 'https://glama.ai/api/mcp/v1/servers/yooztech/mcp_mysql'

If you have feedback or need assistance with the MCP directory API, please join our Discord server