Skip to main content
Glama
graph_state.py9.23 kB
import uuid from typing import Any, Dict, List, Optional, Set, TypeVar import networkx as nx from loguru import logger from pydantic import BaseModel, Field, field_validator from .common import TimestampedModel from .graph_elements import Edge, Hyperedge, Node # Import your domain models TNode = TypeVar("TNode", bound=Node) TEdge = TypeVar("TEdge", bound=Edge) THyperedge = TypeVar("THyperedge", bound=Hyperedge) class GraphStatistics(BaseModel): node_count: int = 0 edge_count: int = 0 hyperedge_count: int = 0 layer_count: int = 0 # Add more specific stats as needed, e.g., counts by node_type class ASRGoTGraph(TimestampedModel): """ Core graph data structure for ASR-GoT, using NetworkX internally. Implements aspects of the mathematical formalism from P1.11. Manages domain model instances (Node, Edge, Hyperedge). """ id: str = Field(default_factory=lambda: f"graph-{uuid.uuid4()}") # We store domain model instances in dictionaries for quick access by ID nodes: Dict[str, Node] = Field(default_factory=dict) edges: Dict[str, Edge] = Field(default_factory=dict) # Edge ID to Edge object hyperedges: Dict[str, Hyperedge] = Field(default_factory=dict) # Layer structure (P1.23) layers: Dict[str, Set[str]] = Field( default_factory=dict ) # Layer name to set of node_ids # Internal NetworkX graph for topology and algorithms # It will store node IDs and edge IDs (or tuples for edges) # Node attributes in nx_graph can point back to Node objects or store lightweight data nx_graph: nx.MultiDiGraph = Field(default_factory=nx.MultiDiGraph, exclude=True) # Metadata about the graph itself, e.g., current stage, overall query graph_metadata: Dict[str, Any] = Field(default_factory=dict) @field_validator("nx_graph", mode="before") @classmethod def init_nx_graph(cls, v): return v or nx.MultiDiGraph() def add_node(self, node: Node) -> None: if node.id in self.nodes: logger.warning(f"Node with ID {node.id} already exists. Overwriting.") self.nodes[node.id] = node # Add to NetworkX graph. Store essential data for quick access, or just the ID. # For simplicity, we can store the full Pydantic model dict, but be mindful of memory # if graphs are huge. Or, just store key attributes. self.nx_graph.add_node( node.id, type=node.type.value, label=node.label, confidence=node.confidence.model_dump(), ) # Add pydantic model directly logger.debug(f"Added node '{node.label}' (ID: {node.id}) to graph.") if node.metadata.layer_id: self.assign_node_to_layer(node.id, node.metadata.layer_id) self.touch() def get_node(self, node_id: str) -> Optional[Node]: return self.nodes.get(node_id) def remove_node(self, node_id: str) -> Optional[Node]: node = self.nodes.pop(node_id, None) if node: if self.nx_graph.has_node(node_id): self.nx_graph.remove_node(node_id) # Also remove from any layers for layer_name, node_ids_in_layer in list(self.layers.items()): if node_id in node_ids_in_layer: node_ids_in_layer.remove(node_id) if not node_ids_in_layer: # Remove layer if empty del self.layers[layer_name] # TODO: Also handle removal of incident edges & hyperedges logger.info( f"Removed node ID: {node_id}. Further cleanup of edges/hyperedges might be needed." ) self.touch() return node def add_edge(self, edge: Edge) -> None: if edge.id in self.edges: logger.warning(f"Edge with ID {edge.id} already exists. Overwriting.") if edge.source_id not in self.nodes or edge.target_id not in self.nodes: raise ValueError( f"Cannot add edge '{edge.id}': Source or target node does not exist." ) self.edges[edge.id] = edge # Use edge.id as the key in MultiDiGraph for potentially multiple edges between nodes self.nx_graph.add_edge( edge.source_id, edge.target_id, key=edge.id, # Important for MultiDiGraph to distinguish multiple edges type=edge.type.value, id=edge.id, confidence=edge.confidence, ) logger.debug( f"Added edge '{edge.type.value}' (ID: {edge.id}) from {edge.source_id} to {edge.target_id}." ) self.touch() def get_edge(self, edge_id: str) -> Optional[Edge]: return self.edges.get(edge_id) def remove_edge(self, edge_id: str) -> Optional[Edge]: edge = self.edges.pop(edge_id, None) if edge: if self.nx_graph.has_edge(edge.source_id, edge.target_id, key=edge_id): self.nx_graph.remove_edge(edge.source_id, edge.target_id, key=edge_id) logger.info(f"Removed edge ID: {edge_id}.") self.touch() return edge def add_hyperedge(self, hyperedge: Hyperedge) -> None: if hyperedge.id in self.hyperedges: logger.warning( f"Hyperedge with ID {hyperedge.id} already exists. Overwriting." ) for node_id in hyperedge.node_ids: if node_id not in self.nodes: raise ValueError( f"Cannot add hyperedge '{hyperedge.id}': Node {node_id} does not exist." ) self.hyperedges[hyperedge.id] = hyperedge # Representation in nx_graph for hyperedges can be tricky. # Common approaches: # 1. Star graph: Create a central "hyperedge node" and connect all involved nodes to it. # 2. Clique: Add edges between all pairs of nodes in the hyperedge (can make graph dense). # For now, we just store it separately. Visualization/algorithms might need to handle it. logger.debug( f"Added hyperedge (ID: {hyperedge.id}) connecting nodes: {hyperedge.node_ids}." ) self.touch() def get_hyperedge(self, hyperedge_id: str) -> Optional[Hyperedge]: return self.hyperedges.get(hyperedge_id) def remove_hyperedge(self, hyperedge_id: str) -> Optional[Hyperedge]: hyperedge = self.hyperedges.pop(hyperedge_id, None) if hyperedge: logger.info(f"Removed hyperedge ID: {hyperedge_id}.") self.touch() return hyperedge def assign_node_to_layer(self, node_id: str, layer_id: str): if node_id not in self.nodes: raise ValueError(f"Node {node_id} not found.") if layer_id not in self.layers: self.layers[layer_id] = set() logger.info(f"Created new layer: {layer_id}") self.layers[layer_id].add(node_id) # Update node's metadata as well node = self.get_node(node_id) if node: node.metadata.layer_id = layer_id node.touch() self.touch() def get_statistics(self) -> GraphStatistics: return GraphStatistics( node_count=len(self.nodes), edge_count=len(self.edges), hyperedge_count=len(self.hyperedges), layer_count=len(self.layers), ) def get_neighbors(self, node_id: str) -> List[str]: if node_id not in self.nx_graph: return [] return list(self.nx_graph.neighbors(node_id)) def get_predecessors(self, node_id: str) -> List[str]: if node_id not in self.nx_graph: return [] return list(self.nx_graph.predecessors(node_id)) def get_successors(self, node_id: str) -> List[str]: if node_id not in self.nx_graph: return [] return list(self.nx_graph.successors(node_id)) def to_serializable_dict(self) -> Dict[str, Any]: """Converts graph to a dict suitable for API responses (like GraphStateSchema)""" return { "id": self.id, "nodes": [ node.model_dump(exclude_none=True) for node in self.nodes.values() ], "edges": [ edge.model_dump(exclude_none=True) for edge in self.edges.values() ], "hyperedges": [ h.model_dump(exclude_none=True) for h in self.hyperedges.values() ], "layers": {name: list(node_ids) for name, node_ids in self.layers.items()}, "graph_metadata": self.graph_metadata, "statistics": self.get_statistics().model_dump(), "created_at": self.created_at.isoformat(), "updated_at": self.updated_at.isoformat(), } # In Pydantic v2, model_rebuild is used if you dynamically add fields to models, # but for managing the graph structure, manual updates to nx_graph are fine. # We need to ensure nx_graph is not directly part of serialization if it becomes too complex. # The `exclude=True` for nx_graph handles this for default Pydantic serialization. class Config: arbitrary_types_allowed = True # For NetworkX graph object

Latest Blog Posts

MCP directory API

We provide all the information about MCP servers via our MCP API.

curl -X GET 'https://glama.ai/api/mcp/v1/servers/SaptaDey/NexusMind'

If you have feedback or need assistance with the MCP directory API, please join our Discord server