Skip to main content
Glama
mm-repos

Azure AI Search MCP Server

by mm-repos
chain.py16.3 kB
"""LangGraph chain implementation for Azure AI Search integration.""" import asyncio import logging import os from typing import Any, Dict, List, Optional, TypedDict from langchain_core.output_parsers import StrOutputParser from langchain_core.runnables import RunnableLambda from langchain_google_genai import ChatGoogleGenerativeAI from langgraph.graph import END, StateGraph from langsmith import traceable # Assuming these are in the same directory from .azure_search import AzureSearchClient from .config import config from .prompt_manager import PromptManager logger = logging.getLogger(__name__) class SearchState(TypedDict): """State for the search chain.""" query: str search_type: str # "text", "hybrid" top_k: int output_format: str # "summary", "analysis", "structured" documents: List[Dict[str, Any]] context: str # This will hold the FINAL output of any chain raw_context: str # Holds the unprocessed context from documents metadata: Dict[str, Any] error: Optional[str] class AzureSearchChain: """LangGraph chain for Azure AI Search operations.""" def __init__(self): """Initialize the search chain.""" # Configure LangSmith tracing if enabled if config.langsmith.tracing_enabled and config.langsmith.api_key: os.environ["LANGCHAIN_TRACING_V2"] = "true" os.environ["LANGCHAIN_ENDPOINT"] = config.langsmith.endpoint os.environ["LANGCHAIN_API_KEY"] = config.langsmith.api_key os.environ["LANGCHAIN_PROJECT"] = config.langsmith.project self.search_client = AzureSearchClient() # Initialize prompt manager self.prompt_manager = PromptManager() # Initialize Google Gemini LLM if config.gemini.api_key: self.llm = ChatGoogleGenerativeAI( model=config.gemini.model_name, google_api_key=config.gemini.api_key, temperature=config.gemini.temperature, ) else: logger.warning("Gemini API key not found. LLM will not be available.") self.llm = None # --- PROMPT TEMPLATES FROM JSON CONFIGURATION --- # Load prompt templates from PromptManager self.structured_formatter_prompt = ( self.prompt_manager.get_prompt_template_for_format("structured") ) self.context_summarizer_prompt = ( self.prompt_manager.get_prompt_template_for_format("summary") ) self.relevance_analyzer_prompt = ( self.prompt_manager.get_prompt_template_for_format("analysis") ) # --- Build LCEL Chains from Prompts --- if self.llm: self.structured_formatter_chain = ( RunnableLambda(self._log_prompt_input).with_config( {"run_name": "ChatPromptTemplate_StructuredFormatter_Input"} ) | self.structured_formatter_prompt.with_config( {"run_name": "ChatPromptTemplate_StructuredFormatter"} ) | RunnableLambda(self._log_full_prompt).with_config( {"run_name": "FULL_PROMPT_StructuredFormatter"} ) | self.llm.with_config({"run_name": "Gemini_StructuredFormatter"}) | StrOutputParser().with_config( {"run_name": "StructuredFormatter_Output"} ) ) self.context_summarizer_chain = ( RunnableLambda(self._log_prompt_input).with_config( {"run_name": "ChatPromptTemplate_ContextSummarizer_Input"} ) | self.context_summarizer_prompt.with_config( {"run_name": "ChatPromptTemplate_ContextSummarizer"} ) | RunnableLambda(self._log_full_prompt).with_config( {"run_name": "FULL_PROMPT_ContextSummarizer"} ) | self.llm.with_config({"run_name": "Gemini_ContextSummarizer"}) | StrOutputParser().with_config( {"run_name": "ContextSummarizer_Output"} ) ) self.relevance_analyzer_chain = ( RunnableLambda(self._log_prompt_input).with_config( {"run_name": "ChatPromptTemplate_RelevanceAnalyzer_Input"} ) | self.relevance_analyzer_prompt.with_config( {"run_name": "ChatPromptTemplate_RelevanceAnalyzer"} ) | RunnableLambda(self._log_full_prompt).with_config( {"run_name": "FULL_PROMPT_RelevanceAnalyzer"} ) | self.llm.with_config({"run_name": "Gemini_RelevanceAnalyzer"}) | StrOutputParser().with_config( {"run_name": "RelevanceAnalyzer_Output"} ) ) else: self.structured_formatter_chain = None self.context_summarizer_chain = None self.relevance_analyzer_chain = None # Build the graph after all components are initialized self.graph = self._build_graph() # --- Logging Helper Methods --- def _log_prompt_input(self, input_data: Dict[str, Any]) -> Dict[str, Any]: """Helper method to log prompt input for better LangSmith visibility.""" return { **input_data, "_langsmith_metadata": {"input_type": "ChatPromptTemplate_Variables"}, } def _log_full_prompt(self, messages) -> Any: """Helper method to log the complete formatted prompt.""" return messages # --- Graph Definition --- def _build_graph(self): """Build the LangGraph state graph with routing.""" workflow = StateGraph(SearchState) workflow.add_node("validate_input", self._validate_input) workflow.add_node("search_documents", self._search_documents) workflow.add_node("prepare_context", self._prepare_context) workflow.add_node("summarize_results", self._summarize_results) workflow.add_node("analyze_results", self._analyze_results) workflow.add_node("format_structured", self._format_results_structured) workflow.add_node("handle_error", self._handle_error) workflow.set_entry_point("validate_input") workflow.add_conditional_edges( "validate_input", self._should_continue_after_validation, {"continue": "search_documents", "error": "handle_error"}, ) workflow.add_conditional_edges( "search_documents", self._should_continue_after_search, {"continue": "prepare_context", "error": "handle_error"}, ) workflow.add_conditional_edges( "prepare_context", self._route_to_processor, { "summary": "summarize_results", "analysis": "analyze_results", "structured": "format_structured", }, ) workflow.add_edge("summarize_results", END) workflow.add_edge("analyze_results", END) workflow.add_edge("format_structured", END) workflow.add_edge("handle_error", END) return workflow.compile() # --- Graph Node Methods --- @traceable(name="validate_input") async def _validate_input(self, state: SearchState) -> SearchState: """Validate input parameters.""" if not state["query"].strip(): state["error"] = "Query cannot be empty" elif state["top_k"] <= 0: state["error"] = "top_k must be greater than 0" return state @traceable(name="search_documents") async def _search_documents(self, state: SearchState) -> SearchState: """Perform document search.""" import time start_time = time.time() try: # Add timeout wrapper at node level too (45 seconds) documents = await asyncio.wait_for( self.search_client.search_documents( query=state["query"], top_k=state["top_k"] ), timeout=45.0, ) state["documents"] = self.search_client.documents_to_dict(documents) except asyncio.TimeoutError: elapsed = time.time() - start_time error_msg = f"Node search timed out after {elapsed:.2f} seconds" logger.error(error_msg) state["error"] = error_msg state["documents"] = [] except Exception as e: elapsed = time.time() - start_time import traceback traceback.print_exc(file=sys.stderr) logger.error(f"Search failed: {e}") state["error"] = f"Search error: {e}" state["documents"] = [] return state @traceable(name="prepare_context") async def _prepare_context(self, state: SearchState) -> SearchState: """Prepares the raw context from documents for LLM processing.""" if not state["documents"]: state["raw_context"] = "No documents found for the given query." return state context_parts = [] for i, doc in enumerate(state["documents"], 1): content = doc.get("content", "") title = doc.get("title", f"Document {i}") doc_section = f"## {title}\n\n{content}" context_parts.append(doc_section) state["raw_context"] = "\n\n---\n\n".join(context_parts) return state @traceable(name="summarize_results") async def _summarize_results(self, state: SearchState) -> SearchState: """Node that invokes the summarization LCEL chain.""" if not self.context_summarizer_chain: state["error"] = "Summarizer chain is not available (LLM not configured)." return state try: summary = await self.context_summarizer_chain.ainvoke( {"query": state["query"], "documents": state["raw_context"]} ) state["context"] = summary except Exception as e: print(f"!!! NODE ERROR IN SUMMARIZATION: {e} !!!", file=sys.stderr) import traceback traceback.print_exc(file=sys.stderr) state["error"] = f"Error during summarization: {e}" return state @traceable(name="analyze_results") async def _analyze_results(self, state: SearchState) -> SearchState: """Node that invokes the relevance analysis LCEL chain.""" if not self.relevance_analyzer_chain: state["error"] = "Analyzer chain is not available (LLM not configured)." return state try: analysis = await self.relevance_analyzer_chain.ainvoke( {"query": state["query"], "documents": state["raw_context"]} ) state["context"] = analysis except Exception as e: print(f"!!! NODE ERROR IN ANALYSIS: {e} !!!", file=sys.stderr) import traceback traceback.print_exc(file=sys.stderr) state["error"] = f"Error during analysis: {e}" return state @traceable(name="format_results_structured") async def _format_results_structured(self, state: SearchState) -> SearchState: """Node that invokes the document formatting LCEL chain.""" if not self.structured_formatter_chain: state["context"] = ( f"Found {len(state['documents'])} documents.\n\n{state['raw_context']}" ) return state try: formatted_docs = await self.structured_formatter_chain.ainvoke( { "query": state["query"], "documents": state["raw_context"], "num_results": len(state["documents"]), } ) state["context"] = formatted_docs except Exception as e: state["error"] = f"Error during structured formatting: {e}" return state async def _handle_error(self, state: SearchState) -> SearchState: """Handle errors in the chain.""" error_msg = state.get("error", "Unknown error occurred") logger.error(f"Handling error: {error_msg}") state["context"] = f"Error: {error_msg}" return state # --- Graph Conditional Edges / Routers --- def _should_continue_after_validation(self, state: SearchState) -> str: return "error" if state.get("error") else "continue" def _should_continue_after_search(self, state: SearchState) -> str: return "error" if state.get("error") else "continue" def _route_to_processor(self, state: SearchState) -> str: """Router to decide which processing node to use.""" return state["output_format"] # --- Main Public Methods --- @traceable(name="azure_search_chain") async def run( self, query: str, search_type: str = "text", top_k: int = 5, output_format: str = "analysis", # Defaults to analysis ) -> Dict[str, Any]: """Run the search chain with a specified output format.""" initial_state = SearchState( query=query, search_type=search_type, top_k=top_k, output_format=output_format, documents=[], context="", raw_context="", metadata={}, error=None, ) try: # Add timeout to graph execution to prevent hanging result = await asyncio.wait_for( self.graph.ainvoke(initial_state), timeout=25.0 ) except asyncio.TimeoutError: print( "!!! CRITICAL TIMEOUT IN CHAIN GRAPH EXECUTION after 25s !!!", file=sys.stderr ) raise Exception("Chain graph execution timed out after 25 seconds") except Exception as e: print( f"!!! CRITICAL ERROR IN CHAIN GRAPH EXECUTION: {e} !!!", file=sys.stderr ) import traceback traceback.print_exc(file=sys.stderr) raise result["metadata"] = { "total_results": len(result.get("documents", [])), "search_type": result.get("search_type"), "query": result.get("query"), "output_format": result.get("output_format"), "used_llm": bool(self.llm), } final_result = { "context": result["context"], "metadata": result["metadata"], "documents": result["documents"], "success": not bool(result.get("error")), } return final_result async def get_document_context_tool(self, document_ids: str) -> str: """Tool function to get context from specific documents by ID.""" try: ids = [id.strip() for id in document_ids.split(",") if id.strip()] if not ids: return "Error: No valid document IDs provided" documents = await self.search_client.get_document_context(ids) if not documents: return "No documents found for the provided IDs" context_parts = [] for doc in documents: if hasattr(doc, "page_content") and hasattr(doc, "metadata"): title = doc.metadata.get("title", "Untitled") content = doc.page_content or "No content available" context_parts.append(f"**{title}**\n{content}") return "\n\n---\n\n".join(context_parts) except Exception as e: return f"Error retrieving document context: {e}" def print_graph_diagram(self): """Print a simple representation of the graph structure.""" try: graph_repr = self.graph.get_graph() nodes = list(graph_repr.nodes.keys()) print("LangGraph Chain Structure:") print(f"Nodes: {', '.join([n for n in nodes if not n.startswith('__')])}") except Exception: print("Graph structure visualization not available") async def close(self): """Clean up resources.""" await self.search_client.close()

Latest Blog Posts

MCP directory API

We provide all the information about MCP servers via our MCP API.

curl -X GET 'https://glama.ai/api/mcp/v1/servers/mm-repos/langgraph-claude-azure-mcp'

If you have feedback or need assistance with the MCP directory API, please join our Discord server