Skip to main content
Glama

Adaptive Graph of Thoughts MCP Server

neo4j_utils.pyβ€’18.6 kB
import asyncio import atexit import threading from typing import Any, Optional import re from dataclasses import dataclass from typing import Any, Optional from loguru import logger from neo4j import ( Driver, GraphDatabase, ManagedTransaction, Record, Result, unit_of_work, ) from neo4j.exceptions import Neo4jError, ServiceUnavailable from adaptive_graph_of_thoughts.config import runtime_settings # --- Simple Configuration --- class Neo4jSettings: def __init__(self): """ Initialize Neo4j connection settings from the application's runtime configuration. """ self.uri: str = runtime_settings.neo4j.uri self.user: str = runtime_settings.neo4j.user self.password: str = runtime_settings.neo4j.password self.database: str = runtime_settings.neo4j.database # --- Global Configuration --- class GlobalSettings: def __init__(self): self.neo4j = Neo4jSettings() _neo4j_settings: Optional[GlobalSettings] = None _driver: Optional[Driver] = None # Backwards compatibility for tests class Neo4jDriverManager: """Thread-safe manager for the Neo4j driver.""" def __init__(self) -> None: self._driver: Optional[Driver] = None self._lock = threading.Lock() atexit.register(self.cleanup) def _create_driver(self) -> Driver: settings = get_neo4j_settings().neo4j return create_neo4j_driver(settings) def get_driver(self) -> Driver: with self._lock: if self._driver is None or self._driver.closed: self._driver = self._create_driver() if self._driver: global _driver _driver = self._driver return self._driver def cleanup(self) -> None: if self._driver and not self._driver.closed: logger.info("Closing Neo4j driver.") self._driver.close() self._driver = None global _driver _driver = None driver_manager = Neo4jDriverManager() # Allowed labels for node creation to mitigate injection attacks ALLOWED_LABELS = {"User", "Document", "Hypothesis", "Evidence"} def sanitize_cypher_input(value: str) -> str: """Remove potentially dangerous characters from a Cypher identifier.""" # Allow alphanumeric, underscore, hyphen, and dot (for namespaces) return re.sub(r"[^\w.-]", "", value) def mask_uri(uri: str) -> str: """Mask sensitive parts of a URI for logging.""" import re return re.sub(r"://([^:]+):([^@]+)@", r"://\1:***@", uri) def mask_username(username: str) -> str: """Mask a username for logging.""" if len(username) <= 2: return "***" return username[:2] + "*" * (len(username) - 2) def mask_credentials(uri: str, username: str) -> tuple[str, str]: """Mask both URI credentials and username.""" masked_uri = mask_uri(uri) masked_user = mask_username(username) return masked_uri, masked_user @dataclass class Neo4jConnection: """Lightweight Neo4j connection wrapper used in unit tests.""" uri: str user: str password: str database: str = "neo4j" def __post_init__(self) -> None: self._driver: Driver | None = None def connect(self) -> None: if self._driver is None: self._driver = GraphDatabase.driver( self.uri, auth=(self.user, self.password) ) def close(self) -> None: if self._driver: self._driver.close() self._driver = None def create_neo4j_driver(settings: Neo4jSettings) -> Driver: """Create and verify a new Neo4j driver using the provided settings.""" if not settings.uri or not settings.user or not settings.password: raise ServiceUnavailable("Neo4j connection details are incomplete in settings.") logger.info(f"Initializing Neo4j driver for URI: {mask_uri(settings.uri)}") driver = GraphDatabase.driver( settings.uri, auth=(settings.user, settings.password), max_connection_lifetime=3600, # 1 hour max_connection_pool_size=50, connection_acquisition_timeout=60, connection_timeout=30, max_retry_time=30, initial_retry_delay=1.0, retry_delay_multiplier=2.0, retry_delay_jitter_factor=0.2, ) driver.verify_connectivity() logger.info("Neo4j driver initialized and connectivity verified.") return driver def get_neo4j_settings() -> GlobalSettings: """Returns the Neo4j settings, initializing them if necessary.""" global _neo4j_settings if _neo4j_settings is None: logger.info("Initializing Neo4j settings.") _neo4j_settings = GlobalSettings() masked_uri, masked_user = mask_credentials( _neo4j_settings.neo4j.uri, _neo4j_settings.neo4j.user ) logger.debug( "Neo4j Settings loaded: URI='%s', User='%s', Default DB='%s'", masked_uri, masked_user, _neo4j_settings.neo4j.database, ) return _neo4j_settings # --- Driver Management --- def get_neo4j_driver() -> Driver: """Return a singleton Neo4j driver instance.""" try: return driver_manager.get_driver() except ServiceUnavailable: raise except Exception as e: logger.error(f"Unexpected error obtaining Neo4j driver: {e}") raise def close_neo4j_driver() -> None: """Close the Neo4j driver if it is open.""" driver_manager.cleanup() # --- Query Execution --- async def execute_query( query: str, parameters: Optional[dict[str, Any]] = None, database: Optional[str] = None, tx_type: str = "read", # 'read' or 'write' *, driver: Optional[Driver] = None, ) -> list[Record]: """ Executes a Cypher query asynchronously against the Neo4j database. Runs the specified query as either a read or write transaction, using the provided parameters and database name. If no database is specified, falls back to the configured default or "neo4j". Returns the list of records resulting from the query. Args: query: The Cypher query string to execute. parameters: Optional dictionary of parameters to pass to the query. database: Optional database name; uses the configured default or "neo4j" if not provided. tx_type: Transaction type, either "read" or "write". Defaults to "read". driver: Optional Neo4j ``Driver`` instance. If not provided, the module's global driver will be used. Returns: List of records returned by the query. Raises: ServiceUnavailable: If the Neo4j service is unavailable. Neo4jError: If an error occurs during query execution. ValueError: If an invalid transaction type is specified. """ driver = driver or get_neo4j_driver() if not driver: # Should not happen if get_neo4j_driver raises on failure logger.error("Neo4j driver not available. Cannot execute query.") raise ServiceUnavailable("Neo4j driver not initialized or connection failed.") settings = get_neo4j_settings() db_name = database if database else settings.neo4j.database records: list[Record] = [] def _execute_sync_query() -> list[Record]: with driver.session(database=db_name) as session: logger.debug( f"Executing query on database '{db_name}' with type '{tx_type}': {query[:100]}..." ) @unit_of_work(timeout=30) # Example timeout, adjust as needed def _transaction_work(tx: ManagedTransaction) -> list[Record]: result: Result = tx.run(query, parameters) return list(result) if tx_type == "read": sync_records = session.execute_read(_transaction_work) elif tx_type == "write": sync_records = session.execute_write(_transaction_work) else: logger.error( f"Invalid transaction type: {tx_type}. Must be 'read' or 'write'." ) raise ValueError( f"Invalid transaction type: {tx_type}. Must be 'read' or 'write'." ) logger.info( f"Query executed successfully on database '{db_name}'. Fetched {len(sync_records)} records." ) return sync_records try: # Run blocking I/O in a thread to avoid blocking the event loop records = await asyncio.to_thread(_execute_sync_query) except Neo4jError as e: logger.error(f"Neo4j error executing Cypher query on database '{db_name}': {e}") logger.error(f"Query: {query}, Parameters: {parameters}") raise # Re-raise the specific Neo4jError except ServiceUnavailable: logger.error( f"Neo4j service became unavailable while attempting to execute query on '{db_name}'." ) raise except Exception as e: logger.error( f"Unexpected error executing Cypher query on database '{db_name}': {e}" ) logger.error(f"Query: {query}, Parameters: {parameters}") raise # Re-raise any other unexpected exception return records async def create_node(label: str, properties: dict[str, Any]) -> list[Record]: # Validate label to prevent injection if not label.replace("_", "").replace("-", "").isalnum(): raise ValueError( f"Invalid label: {label}. Labels must be alphanumeric with underscores/hyphens only." ) query = ( f"CREATE (n:{label} {{" + ", ".join(f"{k}: ${k}" for k in properties) + "}) RETURN n" ) return await execute_query(query, properties, tx_type="write") async def update_node(node_id: str, updates: dict[str, Any]) -> list[Record]: # Validate property names to prevent injection for key in updates: if not key.replace("_", "").replace("-", "").isalnum(): raise ValueError( f"Invalid property name: {key}. Property names must be alphanumeric with underscores/hyphens only." ) try: node_id_int = int(node_id) except ValueError as exc: raise ValueError( f"Invalid node_id: {node_id}. Must be a valid integer." ) from exc query = "MATCH (n) WHERE id(n) = $id SET n += $props RETURN n" params = {"id": node_id_int, "props": updates} return await execute_query(query, params, tx_type="write") async def delete_node(node_id: str) -> list[Record]: try: node_id_int = int(node_id) except ValueError: raise ValueError(f"Invalid node_id: {node_id}. Must be a valid integer.") query = "MATCH (n) WHERE id(n) = $id DETACH DELETE n RETURN count(n)" return await execute_query(query, {"id": node_id_int}, tx_type="write") async def find_nodes(label: str, filters: dict[str, Any]) -> list[Record]: # Validate label to prevent injection if not label.replace("_", "").replace("-", "").isalnum(): raise ValueError( f"Invalid label: {label}. Labels must be alphanumeric with underscores/hyphens only." ) # Validate property names to prevent injection for key in filters: if not key.replace("_", "").replace("-", "").isalnum(): raise ValueError( f"Invalid property name: {key}. Property names must be alphanumeric with underscores/hyphens only." ) where = " AND ".join(f"n.{k} = ${k}" for k in filters) query = f"MATCH (n:{label}) WHERE {where} RETURN n" return await execute_query(query, filters) async def create_relationship( from_id: str, to_id: str, rel_type: str, properties: dict[str, Any] ) -> list[Record]: # Validate relationship type to prevent injection if not rel_type.replace("_", "").replace("-", "").isalnum(): raise ValueError( f"Invalid relationship type: {rel_type}. " "Must be alphanumeric with underscores/hyphens only." ) clean_rel_type = sanitize_cypher_input(rel_type) # Validate property names to prevent injection for key in properties: if not key.replace("_", "").replace("-", "").isalnum(): raise ValueError( f"Invalid property name: {key}. " "Property names must be alphanumeric with underscores/hyphens only." ) # Validate node IDs try: from_id_int = int(from_id) to_id_int = int(to_id) except ValueError: raise ValueError( "Invalid node IDs. Both from_id and to_id must be valid integers." ) query = ( f"MATCH (a),(b) WHERE id(a)=$from AND id(b)=$to " f"CREATE (a)-[r:{rel_type}]->(b) SET r = $props RETURN r" ) params = {"from": from_id_int, "to": to_id_int, "props": properties} return await execute_query(query, params, tx_type="write") async def get_database_info() -> list[Record]: query = "CALL db.info()" return await execute_query(query) async def validate_connection() -> bool: try: await get_database_info() return True except Exception: return False async def bulk_create_nodes(label: str, nodes: list[dict[str, Any]]) -> list[Record]: # Validate label to prevent injection if not label.replace("_", "").replace("-", "").isalnum(): raise ValueError( f"Invalid label: {label}. Labels must be alphanumeric with underscores/hyphens only." ) # Use UNWIND for efficient bulk creation query = f"UNWIND $nodes AS nodeData CREATE (n:{label}) SET n = nodeData RETURN n" return await execute_query(query, {"nodes": nodes}, tx_type="write") async def bulk_create_nodes_optimized( label: str, nodes: list[dict[str, Any]], batch_size: int = 1000 ) -> list[Record]: """Create nodes in batches with basic optimizations.""" clean_label = sanitize_cypher_input(label) if clean_label not in ALLOWED_LABELS: raise ValueError(f"Label '{label}' not in allowed list") if not nodes: return [] if batch_size <= 0 or batch_size > 10000: raise ValueError("Batch size must be between 1 and 10000") all_results: list[Record] = [] for i in range(0, len(nodes), batch_size): batch = nodes[i : i + batch_size] query = f""" UNWIND $nodes AS nodeData MERGE (n:{clean_label} {{id: nodeData.id}}) SET n += nodeData RETURN n """ batch_results = await execute_query( query, {"nodes": batch}, tx_type="write", ) all_results.extend(batch_results) if i + batch_size < len(nodes): await asyncio.sleep(0.01) return all_results async def ensure_indexes() -> None: """Create indexes and constraints required for performance.""" indexes = [ "CREATE INDEX IF NOT EXISTS FOR (n:Hypothesis) ON (n.id)", "CREATE INDEX IF NOT EXISTS FOR (n:Evidence) ON (n.id)", "CREATE INDEX IF NOT EXISTS FOR (n:Document) ON (n.timestamp)", "CREATE CONSTRAINT IF NOT EXISTS FOR (n:Session) REQUIRE n.session_id IS UNIQUE", ] for index_query in indexes: try: await execute_query(index_query, tx_type="write") logger.info(f"Index created/verified: {index_query}") except Exception as exc: logger.warning(f"Could not create index with query '{index_query}': {exc}") async def execute_cypher_file(path: str) -> list[Record]: try: with open(path, "r", encoding="utf-8") as f: query = f.read() except FileNotFoundError as exc: raise FileNotFoundError(f"Cypher file not found: {path}") from exc except OSError as exc: raise OSError(f"Error reading Cypher file {path}: {exc}") from exc if not query.strip(): raise ValueError(f"Cypher file {path} is empty or contains only whitespace") return await execute_query(query, tx_type="write") # Example of how to use (optional, for testing or demonstration) if __name__ == "__main__": logger.add("neo4j_utils.log", rotation="500 MB") # For local testing # Ensure NEO4J_PASSWORD is set as an environment variable if different from default # For example: export NEO4J_PASSWORD="your_actual_password" async def main(): try: # Example Read Query logger.info("Attempting to execute a sample READ query...") read_query = "MATCH (n) RETURN count(n) AS node_count" await execute_query(read_query, tx_type="read") # ... rest of example code with await ... # ... exception handling ... finally: close_neo4j_driver() logger.info("Neo4j utils example finished.") if __name__ == "__main__": asyncio.run(main()) """ try: # Example Read Query logger.info("Attempting to execute a sample READ query...") read_query = "MATCH (n) RETURN count(n) AS node_count" read_results = await execute_query(read_query, tx_type="read") if read_results: logger.info(f"Read query results: Found {read_results[0]['node_count']} nodes in database.") else: logger.info("Read query returned no results or failed.") # Example Write Query (use with caution on your database) # logger.info("Attempting to execute a sample WRITE query...") # write_query = "CREATE (a:Greeting {message: $msg})" # write_params = {"msg": "Hello from neo4j_utils"} # await execute_query(write_query, parameters=write_params, tx_type="write") # logger.info("Write query executed (if no errors).") # logger.info("Attempting to read the written data...") # verify_query = "MATCH (g:Greeting) WHERE g.message = $msg RETURN g.message AS message" # verify_results = await execute_query(verify_query, parameters={"msg": "Hello from neo4j_utils"}, tx_type="read") # if verify_results: # logger.info(f"Verification query results: Found message '{verify_results[0]['message']}'") # else: # logger.warning("Verification query did not find the written data or failed.") except ServiceUnavailable: logger.error("Could not connect to Neo4j. Ensure Neo4j is running and accessible.") except Exception as e: logger.error(f"An error occurred during the example usage: {e}") finally: close_neo4j_driver() logger.info("Neo4j utils example finished.") """

MCP directory API

We provide all the information about MCP servers via our MCP API.

curl -X GET 'https://glama.ai/api/mcp/v1/servers/SaptaDey/Adaptive-Graph-of-Thoughts-MCP-server'

If you have feedback or need assistance with the MCP directory API, please join our Discord server