NebulaGraph MCP Server

by PsiACE
Verified
  • src
  • nebulagraph_mcp_server
import argparse import os from contextlib import asynccontextmanager from dataclasses import dataclass from typing import AsyncIterator from dotenv import load_dotenv from mcp.server.fastmcp import FastMCP from nebula3.Config import Config from nebula3.gclient.net import ConnectionPool load_dotenv() @dataclass class NebulaContext: pool: ConnectionPool # Create a global connection pool config = Config() config.max_connection_pool_size = 10 global_pool = ConnectionPool() def get_connection_pool() -> ConnectionPool: """Get the global connection pool""" return global_pool @asynccontextmanager async def nebula_lifespan(server: FastMCP) -> AsyncIterator[NebulaContext]: """This is a context manager for NebulaGraph connection.""" try: if os.environ["NEBULA_VERSION"] != "v3": raise ValueError("NebulaGraph version must be v3") # Initialize the connection global_pool.init( [ ( os.getenv("NEBULA_HOST", "127.0.0.1"), int(os.getenv("NEBULA_PORT", "9669")), ) ], config, ) yield NebulaContext(pool=global_pool) finally: # Clean up the connection global_pool.close() # Create MCP server mcp = FastMCP("NebulaGraph MCP Server", lifespan=nebula_lifespan) @mcp.resource("schema://space/{space}") def get_space_schema_resource(space: str) -> str: """Get the schema information of the specified space Args: space: The space to get the schema for Returns: The schema information of the specified space """ pool = get_connection_pool() session = pool.get_session( os.getenv("NEBULA_USER", "root"), os.getenv("NEBULA_PASSWORD", "nebula") ) try: session.execute(f"USE {space}") # Get tags tags = session.execute("SHOW TAGS").column_values("Name") # Get edges edges = session.execute("SHOW EDGES").column_values("Name") schema = f"Space: {space}\n\nTags:\n" for tag in tags: tag_result = session.execute(f"DESCRIBE TAG {tag}") schema += f"\n{tag}:\n" # Iterate through all rows for i in range(tag_result.row_size()): field = tag_result.row_values(i) schema += f" - {field[0]}: {field[1]}\n" schema += "\nEdges:\n" for edge in edges: edge_result = session.execute(f"DESCRIBE EDGE {edge}") schema += f"\n{edge}:\n" # Iterate through all rows for i in range(edge_result.row_size()): field = edge_result.row_values(i) schema += f" - {field[0]}: {field[1]}\n" return schema finally: session.release() @mcp.resource("path://space/{space}/from/{src}/to/{dst}/depth/{depth}/limit/{limit}") def get_path_resource(space: str, src: str, dst: str, depth: int, limit: int) -> str: """Get the path between two vertices Args: space: The space to use src: The source vertex ID dst: The destination vertex ID depth: The maximum path depth limit: The maximum number of paths to return Returns: The path between the source and destination vertices """ pool = get_connection_pool() session = pool.get_session( os.getenv("NEBULA_USER", "root"), os.getenv("NEBULA_PASSWORD", "nebula") ) try: session.execute(f"USE {space}") query = f"""FIND ALL PATH WITH PROP FROM "{src}" TO "{dst}" OVER * BIDIRECT UPTO {depth} STEPS YIELD PATH AS paths | LIMIT {limit}""" result = session.execute(query) if result.is_succeeded(): # Format the path results if result.row_size() > 0: output = f"Find paths from {src} to {dst}: \n\n" # Iterate through all paths for i in range(result.row_size()): path = result.row_values(i)[ 0 ] # The path should be in the first column output += f"Path {i + 1}:\n{path}\n\n" return output return f"No paths found from {src} to {dst}" else: return f"Query failed: {result.error_msg()}" finally: session.release() @mcp.tool() def list_spaces() -> str: """List all available spaces Returns: The available spaces """ pool = get_connection_pool() session = pool.get_session( os.getenv("NEBULA_USER", "root"), os.getenv("NEBULA_PASSWORD", "nebula") ) try: result = session.execute("SHOW SPACES") if result.is_succeeded(): spaces = result.column_values("Name") return "Available spaces:\n" + "\n".join(f"- {space}" for space in spaces) return f"Failed to list spaces: {result.error_msg()}" finally: session.release() @mcp.tool() def get_space_schema(space: str) -> str: """Get the schema information of the specified space Args: space: The space to get the schema for Returns: The schema information of the specified space """ return get_space_schema_resource(space) @mcp.tool() def execute_query(query: str, space: str) -> str: """Execute a query Args: query: The query to execute space: The space to use Returns: The results of the query """ pool = get_connection_pool() session = pool.get_session( os.getenv("NEBULA_USER", "root"), os.getenv("NEBULA_PASSWORD", "nebula") ) try: session.execute(f"USE {space}") result = session.execute(query) if result.is_succeeded(): # Format the query results if result.row_size() > 0: columns = result.keys() output = "Results:\n" output += " | ".join(columns) + "\n" output += "-" * (len(" | ".join(columns))) + "\n" # Iterate through all rows for i in range(result.row_size()): row = result.row_values(i) output += " | ".join(str(val) for val in row) + "\n" return output return "Query executed successfully (no results)" else: return f"Query failed: {result.error_msg()}" finally: session.release() @mcp.tool() def find_path(src: str, dst: str, space: str, depth: int = 3, limit: int = 10) -> str: """Find paths between two vertices Args: src: The source vertex ID dst: The destination vertex ID space: The space to use depth: The maximum path depth limit: The maximum number of paths to return Returns: The path results """ return get_path_resource(space, src, dst, depth, limit) @mcp.resource("neighbors://space/{space}/vertex/{vertex}/depth/{depth}") def get_neighbors_resource(space: str, vertex: str, depth: int) -> str: """Get the neighbors of the specified vertex Args: space: The space to use vertex: The vertex ID to query depth: The depth of the query Returns: The neighbors of the specified vertex """ pool = get_connection_pool() session = pool.get_session( os.getenv("NEBULA_USER", "root"), os.getenv("NEBULA_PASSWORD", "nebula") ) try: session.execute(f"USE {space}") query = f""" MATCH (u)-[e*1..{depth}]-(v) WHERE id(u) == "{vertex}" RETURN DISTINCT v, e """ result = session.execute(query) if result.is_succeeded(): if result.row_size() > 0: output = f"Vertex {vertex} neighbors (depth {depth}):\n\n" for i in range(result.row_size()): row = result.row_values(i) neighbor_vertex = row[0] edges = row[1] output += ( f"Neighbor Vertex:\n{neighbor_vertex}\nEdges:\n{edges}\n\n" ) return output return f"No neighbors found for vertex {vertex}" else: return f"Query failed: {result.error_msg()}" finally: session.release() @mcp.tool() def find_neighbors(vertex: str, space: str, depth: int = 1) -> str: """Find the neighbors of the specified vertex Args: vertex: The vertex ID to query space: The space to use depth: The depth of the query, default is 1 Returns: The neighbors of the specified vertex """ return get_neighbors_resource(space, vertex, depth) def main(): parser = argparse.ArgumentParser(description="NebulaGraph MCP server") parser.add_argument( "--transport", type=str, choices=["stdio", "sse"], default="stdio", help="Transport method (stdio or sse)", ) args = parser.parse_args() if args.transport == "sse": mcp.run("sse") else: mcp.run("stdio") if __name__ == "__main__": main()