Skip to main content
Glama
graph_maintenance.py23.7 kB
"""GraphRAG 维护工具:graph_status, extract_graph_missing, rebuild_communities, clear_graph""" from typing import Any import asyncio from fastmcp import FastMCP from paperlib_mcp.db import get_db, query_all, query_one from paperlib_mcp.settings import get_settings from paperlib_mcp.models_graph import ( MCPErrorModel, GraphStatusOut, ExtractGraphMissingOut, ClearGraphOut, ) # ============================================================ # 工具注册 # ============================================================ def register_graph_maintenance_tools(mcp: FastMCP) -> None: """注册 GraphRAG 维护工具""" @mcp.tool() def graph_status(doc_id: str | None = None) -> dict[str, Any]: """查看 GraphRAG 覆盖状态 统计每个文档(或全局)的 entities、mentions、claims 覆盖率。 Args: doc_id: 文档 ID(可选,若无则返回全局统计) Returns: 覆盖率统计信息 """ try: if doc_id: # 单文档统计 chunks_count = query_one( "SELECT COUNT(*) as count FROM chunks WHERE doc_id = %s", (doc_id,) ) mentions_count = query_one( "SELECT COUNT(*) as count FROM mentions WHERE doc_id = %s", (doc_id,) ) claims_count = query_one( "SELECT COUNT(*) as count FROM claims WHERE doc_id = %s", (doc_id,) ) # 获取该文档关联的实体数 entities_count = query_one( """ SELECT COUNT(DISTINCT entity_id) as count FROM mentions WHERE doc_id = %s """, (doc_id,) ) # 获取该文档的关系数 relations_count = query_one( """ SELECT COUNT(*) as count FROM relations r JOIN entities p ON p.entity_id = r.subj_entity_id WHERE p.type = 'Paper' AND p.canonical_key = %s """, (doc_id,) ) coverage = { "doc_id": doc_id, "chunks": chunks_count["count"] if chunks_count else 0, "mentions": mentions_count["count"] if mentions_count else 0, "claims": claims_count["count"] if claims_count else 0, "entities_referenced": entities_count["count"] if entities_count else 0, "relations": relations_count["count"] if relations_count else 0, "is_extracted": (mentions_count["count"] if mentions_count else 0) > 0, } else: # 全局统计 total_docs = query_one("SELECT COUNT(*) as count FROM documents") total_chunks = query_one("SELECT COUNT(*) as count FROM chunks") total_entities = query_one("SELECT COUNT(*) as count FROM entities") total_mentions = query_one("SELECT COUNT(*) as count FROM mentions") total_relations = query_one("SELECT COUNT(*) as count FROM relations") total_claims = query_one("SELECT COUNT(*) as count FROM claims") total_communities = query_one("SELECT COUNT(*) as count FROM communities") # 已抽取的文档数(有 mentions 的文档) extracted_docs = query_one( "SELECT COUNT(DISTINCT doc_id) as count FROM mentions" ) # 有社区摘要的社区数 summarized_communities = query_one( "SELECT COUNT(*) as count FROM community_summaries" ) # 实体类型分布 entity_types = query_all( """ SELECT type, COUNT(*) as count FROM entities GROUP BY type ORDER BY count DESC """ ) coverage = { "total_documents": total_docs["count"] if total_docs else 0, "extracted_documents": extracted_docs["count"] if extracted_docs else 0, "extraction_coverage": ( round(extracted_docs["count"] / total_docs["count"] * 100, 1) if total_docs and total_docs["count"] > 0 else 0 ), "total_chunks": total_chunks["count"] if total_chunks else 0, "total_entities": total_entities["count"] if total_entities else 0, "total_mentions": total_mentions["count"] if total_mentions else 0, "total_relations": total_relations["count"] if total_relations else 0, "total_claims": total_claims["count"] if total_claims else 0, "total_communities": total_communities["count"] if total_communities else 0, "summarized_communities": summarized_communities["count"] if summarized_communities else 0, "entity_type_distribution": {r["type"]: r["count"] for r in entity_types}, } return GraphStatusOut(coverage=coverage).model_dump() except Exception as e: return GraphStatusOut( coverage={}, error=MCPErrorModel(code="DB_CONN_ERROR", message=str(e)), ).model_dump() @mcp.tool() async def extract_graph_missing( limit_docs: int = 50, llm_model: str | None = None, min_confidence: float = 0.8, concurrency: int = 30, doc_concurrency: int = 15, max_chunks: int = 60, ) -> dict[str, Any]: """批量补跑未抽取的文档 找出没有 mentions 的文档,并对它们执行 extract_graph_v1。 Args: limit_docs: 最大处理文档数,默认 50 llm_model: LLM 模型,默认使用环境变量 LLM_MODEL 配置 min_confidence: 最小置信度阈值 Returns: 处理的文档数和文档 ID 列表 """ try: settings = get_settings() actual_llm_model = llm_model or settings.llm_model # 找出未抽取的文档 missing_docs = query_all( """ SELECT d.doc_id FROM documents d LEFT JOIN mentions m ON m.doc_id = d.doc_id GROUP BY d.doc_id HAVING COUNT(m.mention_id) = 0 ORDER BY d.created_at DESC LIMIT %s """, (limit_docs,) ) if not missing_docs: return ExtractGraphMissingOut( processed_docs=0, doc_ids=[], ).model_dump() # 导入抽取工具 from paperlib_mcp.tools.graph_extract import ( extract_graph_v1_run, HIGH_VALUE_KEYWORDS_DEFAULT, ) processed_doc_ids = [] # 使用 Semaphore 限制文档级并发 sem = asyncio.Semaphore(doc_concurrency) async def process_single_doc(doc_row): async with sem: doc_id = doc_row["doc_id"] # 获取该文档的高价值 chunks fts_query = " OR ".join(f"'{kw}'" for kw in HIGH_VALUE_KEYWORDS_DEFAULT) chunks = query_all( """ SELECT chunk_id, doc_id, page_start, page_end, text FROM chunks WHERE doc_id = %s AND tsv @@ websearch_to_tsquery('english', %s) ORDER BY ts_rank(tsv, websearch_to_tsquery('english', %s)) DESC LIMIT %s """, (doc_id, fts_query, fts_query, max_chunks) ) if not chunks: # 没有高价值 chunks,使用所有 chunks chunks = query_all( """ SELECT chunk_id, doc_id, page_start, page_end, text FROM chunks WHERE doc_id = %s ORDER BY chunk_index LIMIT %s """, (doc_id, max_chunks) ) if not chunks: return None # 直接调用优化后的 extract_graph_v1_run (支持并行) # from paperlib_mcp.tools.graph_extract import extract_graph_v1_run # Already imported above # 提取 chunk_ids chunk_ids = [c["chunk_id"] for c in chunks] await extract_graph_v1_run( doc_id=doc_id, chunk_ids=chunk_ids, mode="custom", # 使用传入的 chunk_ids max_chunks=len(chunks), llm_model=actual_llm_model, min_confidence=min_confidence, concurrency=concurrency, ) return doc_id # 并发执行所有文档 tasks = [process_single_doc(doc) for doc in missing_docs] results = await asyncio.gather(*tasks) # 收集能够成功处理的 doc_id processed_doc_ids = [r for r in results if r is not None] return ExtractGraphMissingOut( processed_docs=len(processed_doc_ids), doc_ids=processed_doc_ids, ).model_dump() except Exception as e: return ExtractGraphMissingOut( processed_docs=0, doc_ids=[], error=MCPErrorModel(code="LLM_ERROR", message=str(e)), ).model_dump() @mcp.tool() def rebuild_communities( level: str = "macro", min_df: int = 3, resolution: float = 1.0, ) -> dict[str, Any]: """重建社区 清除指定层级的旧社区并重新构建。 Args: level: 社区层级,"macro" 或 "micro" min_df: 节点最小文档频率,默认 3 resolution: Leiden 分辨率参数,默认 1.0 Returns: 新社区列表 """ try: # 直接调用 build_communities_v1 with rebuild=True from paperlib_mcp.tools.graph_community import register_graph_community_tools # 由于我们需要直接调用逻辑,这里重新实现简化版本 # 或者导入并使用内部逻辑 try: import igraph as ig import leidenalg except ImportError: return { "communities": [], "error": { "code": "DEPENDENCY_ERROR", "message": "igraph and leidenalg are required" } } from paperlib_mcp.tools.graph_community import ( COMMUNITY_ENTITY_TYPES, get_edge_weight, ) from paperlib_mcp.models_graph import CommunityBrief from collections import defaultdict import json with get_db() as conn: # 清除旧结果 with conn.cursor() as cur: cur.execute( """ DELETE FROM community_members WHERE comm_id IN (SELECT comm_id FROM communities WHERE level = %s) """, (level,) ) cur.execute( """ DELETE FROM community_summaries WHERE comm_id IN (SELECT comm_id FROM communities WHERE level = %s) """, (level,) ) cur.execute( "DELETE FROM communities WHERE level = %s", (level,) ) # 获取 Paper->Entity 关系 relations = query_all( """ SELECT p.entity_id AS paper_eid, p.canonical_key AS doc_id, x.entity_id AS node_eid, x.type AS node_type, x.canonical_name FROM relations r JOIN entities p ON p.entity_id = r.subj_entity_id AND p.type = 'Paper' JOIN entities x ON x.entity_id = r.obj_entity_id WHERE r.predicate IN ( 'PAPER_HAS_TOPIC', 'PAPER_USES_MEASURE', 'PAPER_IDENTIFIES_WITH', 'PAPER_IN_SETTING', 'PAPER_USES_DATA' ) AND x.type = ANY(%s) """, (COMMUNITY_ENTITY_TYPES,) ) if not relations: return {"communities": [], "error": {"code": "NOT_FOUND", "message": "No relations found"}} # 构建图并聚类(同 build_communities_v1) paper_to_nodes: dict[str, set[int]] = defaultdict(set) node_info: dict[int, dict] = {} for r in relations: paper_to_nodes[r["doc_id"]].add(r["node_eid"]) if r["node_eid"] not in node_info: node_info[r["node_eid"]] = { "entity_id": r["node_eid"], "type": r["node_type"], "canonical_name": r["canonical_name"], } node_df: dict[int, int] = defaultdict(int) for doc_id, nodes in paper_to_nodes.items(): for node_id in nodes: node_df[node_id] += 1 valid_nodes = {nid for nid, df in node_df.items() if df >= min_df} if not valid_nodes: return {"communities": [], "error": {"code": "NOT_FOUND", "message": f"No nodes with df >= {min_df}"}} # 构建边 edge_counts: dict[tuple[int, int], float] = defaultdict(float) for doc_id, nodes in paper_to_nodes.items(): valid_doc_nodes = [n for n in nodes if n in valid_nodes] for i, n1 in enumerate(valid_doc_nodes): for n2 in valid_doc_nodes[i+1:]: key = (min(n1, n2), max(n1, n2)) weight = get_edge_weight(node_info[n1]["type"], node_info[n2]["type"]) edge_counts[key] += weight if not edge_counts: return {"communities": [], "error": {"code": "NOT_FOUND", "message": "No edges found"}} # 构建 igraph node_list = sorted(valid_nodes) node_to_idx = {nid: idx for idx, nid in enumerate(node_list)} edges = [] weights = [] for (n1, n2), w in edge_counts.items(): edges.append((node_to_idx[n1], node_to_idx[n2])) weights.append(w) g = ig.Graph(n=len(node_list), edges=edges, directed=False) g.es["weight"] = weights # Leiden 聚类 partition = leidenalg.find_partition( g, leidenalg.RBConfigurationVertexPartition, weights="weight", resolution_parameter=resolution ) # 写入数据库 communities_result = [] community_members_map: dict[int, list[tuple[int, float]]] = defaultdict(list) for node_idx, comm_idx in enumerate(partition.membership): node_id = node_list[node_idx] weight = float(node_df[node_id]) community_members_map[comm_idx].append((node_id, weight)) with conn.cursor() as cur: for comm_idx, members in community_members_map.items(): if len(members) < 2: continue cur.execute( """ INSERT INTO communities(level, method, params) VALUES (%s, 'leiden', %s::jsonb) RETURNING comm_id """, (level, json.dumps({"resolution": resolution, "min_df": min_df})) ) result = cur.fetchone() comm_id = result["comm_id"] for node_id, weight in members: cur.execute( """ INSERT INTO community_members(comm_id, entity_id, role, weight) VALUES (%s, %s, 'member', %s) """, (comm_id, node_id, weight) ) sorted_members = sorted(members, key=lambda x: x[1], reverse=True) top_entities = [] for node_id, weight in sorted_members[:20]: info = node_info.get(node_id, {}) top_entities.append({ "entity_id": node_id, "type": info.get("type", ""), "canonical_name": info.get("canonical_name", ""), "weight": weight, }) communities_result.append({ "comm_id": comm_id, "size": len(members), "top_entities": top_entities, }) communities_result.sort(key=lambda x: x["size"], reverse=True) return {"communities": communities_result} except Exception as e: return {"communities": [], "error": {"code": "DB_CONN_ERROR", "message": str(e)}} @mcp.tool() def clear_graph(doc_id: str | None = None, clear_all: bool = False) -> dict[str, Any]: """清理 GraphRAG 数据 清理指定文档或全部的 GraphRAG 数据。 Args: doc_id: 文档 ID(清理单个文档) clear_all: 是否清理全部(危险操作) Returns: 清理结果,包含删除的记录数 """ try: if not doc_id and not clear_all: return ClearGraphOut( ok=False, error=MCPErrorModel(code="VALIDATION_ERROR", message="Must provide doc_id or set clear_all=True"), ).model_dump() deleted_counts = {} with get_db() as conn: with conn.cursor() as cur: if clear_all: # 按顺序清理(先清理依赖表) cur.execute("DELETE FROM community_summaries") deleted_counts["community_summaries"] = cur.rowcount cur.execute("DELETE FROM community_members") deleted_counts["community_members"] = cur.rowcount cur.execute("DELETE FROM communities") deleted_counts["communities"] = cur.rowcount cur.execute("DELETE FROM claims") deleted_counts["claims"] = cur.rowcount cur.execute("DELETE FROM relations") deleted_counts["relations"] = cur.rowcount cur.execute("DELETE FROM mentions") deleted_counts["mentions"] = cur.rowcount cur.execute("DELETE FROM entity_aliases") deleted_counts["entity_aliases"] = cur.rowcount cur.execute("DELETE FROM entities") deleted_counts["entities"] = cur.rowcount cur.execute("DELETE FROM entity_merge_log") deleted_counts["entity_merge_log"] = cur.rowcount else: # 清理单个文档 # 删除 claims cur.execute("DELETE FROM claims WHERE doc_id = %s", (doc_id,)) deleted_counts["claims"] = cur.rowcount # 删除 mentions cur.execute("DELETE FROM mentions WHERE doc_id = %s", (doc_id,)) deleted_counts["mentions"] = cur.rowcount # 删除 relations(通过 evidence) cur.execute( "DELETE FROM relations WHERE evidence->>'doc_id' = %s", (doc_id,) ) deleted_counts["relations"] = cur.rowcount # 可选:删除 Paper entity # cur.execute( # "DELETE FROM entities WHERE type = 'Paper' AND canonical_key = %s", # (doc_id,) # ) # deleted_counts["paper_entity"] = cur.rowcount return ClearGraphOut( ok=True, deleted_counts=deleted_counts, ).model_dump() except Exception as e: return ClearGraphOut( ok=False, error=MCPErrorModel(code="DB_CONN_ERROR", message=str(e)), ).model_dump()

Implementation Reference

Latest Blog Posts

MCP directory API

We provide all the information about MCP servers via our MCP API.

curl -X GET 'https://glama.ai/api/mcp/v1/servers/h-lu/paperlib-mcp'

If you have feedback or need assistance with the MCP directory API, please join our Discord server