AGE-MCP-Server
by rioriost
Verified
- homebrew-age-mcp-server
- src
- age_mcp_server
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import json
import logging
import re
import sys
from typing import Any
from mcp.server import NotificationOptions, Server
from mcp.server.models import InitializationOptions
import mcp.server.stdio
import mcp.types as types
from psycopg import Connection
from psycopg.rows import dict_row
from agefreighter.cypherparser import CypherParser
logging.basicConfig(level=logging.INFO)
log = logging.getLogger(__name__)
class CypherQueryFormatter:
"""Utility class for formatting Cypher queries for Apache AGE."""
@staticmethod
def format_query(graph_name: str, cypher_query: str, allow_write: bool) -> str:
"""
Format the provided Cypher query for Apache AGE.
Raises:
ValueError: If the query is unsafe or incorrectly formatted.
"""
if not allow_write:
if not CypherQueryFormatter.is_safe_cypher_query(cypher_query):
raise ValueError("Unsafe query")
# Append LIMIT 50 if no limit is specified.
if "limit" not in cypher_query.lower():
cypher_query += " LIMIT 50"
# Claude misunderstands the Cypher definition
if "cast" in cypher_query.lower():
raise ValueError("'CAST' is not a reserved keyword in Cypher")
returns = CypherQueryFormatter.get_return_values(cypher_query)
log.debug(f"Return values: {returns}")
# Check for parameterized query usage.
if re.findall(r"\$(\w+)", cypher_query):
raise ValueError("Parameterized query")
if returns:
ag_types = ", ".join([f"{r} agtype" for r in returns])
return f"SELECT * FROM cypher('{graph_name}', $$ {cypher_query} $$) AS ({ag_types});"
else:
raise ValueError("No return values specified")
@staticmethod
def is_safe_cypher_query(cypher_query: str) -> bool:
"""
Ensure the Cypher query does not contain dangerous commands.
Returns:
bool: True if safe, False otherwise.
"""
tokens = cypher_query.split()
unsafe_keywords = ["add", "create", "delete", "merge", "remove", "set"]
return all(token.lower() not in unsafe_keywords for token in tokens)
@staticmethod
def get_return_values(cypher_query: str) -> list:
parser = CypherParser()
try:
result = parser.parse(cypher_query)
except Exception as e:
log.error(f"Failed to parse Cypher query: {e}")
return []
for op, opr, *_ in result:
log.debug(f"Returning values from query: {opr}")
if op == "RETURN" or op == "RETURN_DISTINCT":
results = []
for v in opr:
if isinstance(v, str):
results.append(v.split(".")[0])
elif isinstance(v, tuple):
match v[0]:
case "alias":
results.append(v[-1])
case "property":
results.append(v[-1])
case "func_call":
results.append(v[1])
case "":
pass
return list(set(results))
return []
class PostgreSQLAGE:
def __init__(self, pg_con_str: str, allow_write: bool, log_level: int):
"""Initialize connection to the PostgreSQL database"""
log.setLevel(log_level)
log.debug(f"Initializing database connection to {pg_con_str}")
self.pg_con_str = pg_con_str
self.allow_write = allow_write
self.con: Connection
try:
self.con = Connection.connect(
self.pg_con_str
+ " options='-c search_path=ag_catalog,\"$user\",public'"
)
except Exception as e:
log.error(f"Failed to connect to PostgreSQL database: {e}")
sys.exit(1)
def _execute_query(
self, graph_name: str, query: str, params: dict[str, Any] | None = None
) -> list[dict[str, Any]]:
"""Execute a Cypher query and return results as a list of dictionaries"""
log.debug(f"Executing query: {query}")
try:
cur = self.con.cursor(row_factory=dict_row)
cypher_query = CypherQueryFormatter.format_query(
graph_name=graph_name,
cypher_query=query,
allow_write=self.allow_write,
)
log.debug(f"Formatted query: {cypher_query}")
cur.execute(cypher_query, params)
results = cur.fetchall()
cur.execute("COMMIT")
count = len(results)
if CypherQueryFormatter.is_safe_cypher_query(query):
log.debug(f"Read query returned {count} rows")
return results
else:
log.debug(f"Write query affected {count}")
return [count]
except Exception as e:
log.error(f"Database error executing query: {e}\n{query}")
self.con.rollback() # Roll back to clear the error state
raise
def _execute_sql(self, query: str) -> list[dict[str, Any]]:
"""Execute a standard query and return results as a list of dictionaries"""
log.debug(f"Executing query: {query}")
try:
cur = self.con.cursor(row_factory=dict_row)
cur.execute(query)
results = cur.fetchall()
cur.execute("COMMIT")
return results
except Exception as e:
log.error(f"Database error executing query: {e}\n{query}")
self.con.rollback() # Roll back to clear the error state
raise
async def main(pg_con_str: str, allow_write: bool, log_level: int) -> None:
log.setLevel(log_level)
log.info(f"Connecting to PostgreSQL with connection string: {pg_con_str}")
db = PostgreSQLAGE(
pg_con_str=pg_con_str,
allow_write=allow_write,
log_level=log_level,
)
server = Server("age-manager")
# Register handlers
log.debug("Registering handlers")
@server.list_tools()
async def handle_list_tools() -> list[types.Tool]:
"""List available tools"""
return [
types.Tool(
name="read-age-cypher",
description="Execute a Cypher query on the AGE",
inputSchema={
"type": "object",
"properties": {
"query": {
"type": "string",
"description": "Cypher read query to execute",
},
"graph_name": {
"type": "string",
"description": "Name of the graph to operate",
},
},
"required": ["query", "graph_name"],
},
),
types.Tool(
name="write-age-cypher",
description="Execute a write Cypher query on the AGE",
inputSchema={
"type": "object",
"properties": {
"query": {
"type": "string",
"description": "Cypher write query to execute, including 'RETURN' statement",
},
"graph_name": {
"type": "string",
"description": "Name of the graph to operate",
},
},
"required": ["query", "graph_name"],
},
),
types.Tool(
name="create-age-graph",
description="Create a new graph in the AGE",
inputSchema={
"type": "object",
"properties": {
"graph_name": {
"type": "string",
"description": "Name of the graph to create",
},
},
"required": ["graph_name"],
},
),
types.Tool(
name="drop-age-graph",
description="Drop a graph in the AGE",
inputSchema={
"type": "object",
"properties": {
"graph_name": {
"type": "string",
"description": "Name of the graph to drop",
},
},
"required": ["graph_name"],
},
),
types.Tool(
name="list-age-graphs",
description="List all graphs in the AGE",
inputSchema={
"type": "object",
},
),
types.Tool(
name="get-age-schema",
description="List all node types, their attributes and their relationships TO other node-types in the AGE",
inputSchema={
"type": "object",
"properties": {
"graph_name": {
"type": "string",
"description": "Name of the graph to create",
},
},
"required": ["graph_name"],
},
),
]
@server.call_tool()
async def handle_call_tool(
name: str, arguments: dict[str, Any] | None
) -> list[types.TextContent | types.ImageContent | types.EmbeddedResource]:
"""Handle tool execution requests"""
try:
if name == "get-age-schema":
node_results = db._execute_query(
graph_name=arguments["graph_name"],
query="""
MATCH (n)
UNWIND labels(n) AS label
RETURN DISTINCT label, collect(DISTINCT keys(n)) AS properties
""",
)
log.debug(f"Node results: {node_results}")
edge_results = db._execute_query(
graph_name=arguments["graph_name"],
query="""
MATCH (a)-[r]->(b)
RETURN DISTINCT type(r) AS rel_type, collect(DISTINCT labels(a)) AS from_labels, collect(DISTINCT labels(b)) AS to_labels
""",
)
log.debug(f"Edge results: {edge_results}")
nodes_dict = {}
for node in node_results:
label = node["label"].strip('"')
props = json.loads(node["properties"])
properties = (
props[0]
if props and isinstance(props, list) and len(props) > 0
else []
)
nodes_dict[label] = {
"label": label,
"properties": properties,
"relationships": {},
}
edges = []
for edge in edge_results:
rel_type = edge["rel_type"].strip('"')
from_labels = json.loads(edge["from_labels"])
to_labels = json.loads(edge["to_labels"])
from_labels = (
from_labels[0]
if from_labels and isinstance(from_labels, list)
else []
)
to_labels = (
to_labels[0]
if to_labels and isinstance(to_labels, list)
else []
)
edges.append(
{
"rel_type": rel_type,
"from_labels": from_labels,
"to_labels": to_labels,
}
)
for from_label in from_labels:
if from_label in nodes_dict and to_labels:
nodes_dict[from_label]["relationships"][rel_type] = (
to_labels[0]
)
for to_label in to_labels:
if to_label in nodes_dict and from_labels:
nodes_dict[to_label]["relationships"][rel_type] = (
from_labels[0]
)
nodes = list(nodes_dict.values())
return [
types.TextContent(
type="text", text=str({"nodes": nodes, "edges": edges})
)
]
elif name == "create-age-graph":
if not allow_write:
raise PermissionError("Not allowed to create graph")
query = "SELECT create_graph('{}')".format(arguments["graph_name"])
log.info(f"Creating graph with name {arguments['graph_name']}")
results = db._execute_sql(query=query)
return [types.TextContent(type="text", text=str(results))]
elif name == "drop-age-graph":
if not allow_write:
raise PermissionError("Not allowed to drop graph")
query = "SELECT drop_graph('{}', True)".format(arguments["graph_name"])
log.info(f"Dropping graph with name {arguments['graph_name']}")
results = db._execute_sql(query=query)
return [types.TextContent(type="text", text=str(results))]
elif name == "list-age-graphs":
query = "SELECT name FROM ag_graph"
log.info("Listing graphs")
results = db._execute_sql(query=query)
return [types.TextContent(type="text", text=str(results))]
elif name == "read-age-cypher":
if not CypherQueryFormatter.is_safe_cypher_query(arguments["query"]):
raise ValueError("Only MATCH queries are allowed for read-query")
results = db._execute_query(
graph_name=arguments["graph_name"], query=arguments["query"]
)
return [types.TextContent(type="text", text=str(results))]
elif name == "write-age-cypher":
if CypherQueryFormatter.is_safe_cypher_query(arguments["query"]):
raise ValueError("Only write queries are allowed for write-query")
results = db._execute_query(
graph_name=arguments["graph_name"], query=arguments["query"]
)
return [types.TextContent(type="text", text=str(results))]
else:
raise ValueError(f"Unknown tool: {name}")
except Exception as e:
return [types.TextContent(type="text", text=f"Error: {str(e)}")]
async with mcp.server.stdio.stdio_server() as (read_stream, write_stream):
log.info("Server running with stdio transport")
await server.run(
read_stream,
write_stream,
InitializationOptions(
server_name="age",
server_version="0.2.2",
capabilities=server.get_capabilities(
notification_options=NotificationOptions(),
experimental_capabilities={},
),
),
)