Skip to main content
Glama
EOSC-Data-Commons

EOSC Data Commons Search

Official
main.py19.7 kB
"""HTTP API to deploy the EOSC Data Commons search agent.""" import json import uuid from collections.abc import AsyncGenerator from datetime import datetime, timezone from ag_ui.core import ( RunFinishedEvent, RunStartedEvent, TextMessageChunkEvent, TextMessageEndEvent, TextMessageStartEvent, ToolCallArgsEvent, ToolCallEndEvent, ToolCallResultEvent, ToolCallStartEvent, ) from langchain.chat_models import BaseChatModel from langchain.messages import AnyMessage, HumanMessage from langchain_mcp_adapters.client import MultiServerMCPClient from mcp.types import TextContent from starlette.middleware.cors import CORSMiddleware from starlette.requests import Request from starlette.responses import FileResponse, StreamingResponse from starlette.staticfiles import StaticFiles from data_commons_search.config import settings from data_commons_search.logging import BLUE, BOLD, RESET, YELLOW from data_commons_search.mcp_server import mcp from data_commons_search.models import ( AgentInput, LangChainRerankingOutputMsg, LangChainResponseMetadata, OpenSearchResults, RankedSearchResponse, RerankingOutput, TokenUsageMetadata, ) from data_commons_search.prompts import RERANK_PROMPT, SUMMARIZE_PROMPT, TOOL_CALL_PROMPT from data_commons_search.utils import ( file_logger, get_langchain_msgs, get_system_prompt, load_chat_model, logger, sse_event, ) # Get the MCP server Starlette app, and mount our routes to it app = mcp.streamable_http_app() if settings.cors_enabled: app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) mcp_client = MultiServerMCPClient( { "data-commons-search": { "url": f"{settings.server_url}/mcp", "transport": "streamable_http", } } ) logger.info(f"""💬 {BOLD}{BLUE}Search UI{RESET} started on {BOLD}{YELLOW}{settings.server_url}{RESET} ⚡️ Streamable HTTP MCP server started on {BOLD}{settings.server_url}/mcp{RESET} 🔎 Using OpenSearch service on {BOLD}{settings.opensearch_url}{RESET}""") async def chat_handler(request: Request) -> StreamingResponse: """Chat with the assistant main endpoint.""" auth_header = request.headers.get("Authorization", "") if settings.chat_api_key and (not auth_header or not auth_header.startswith("Bearer ")): raise ValueError("Missing or invalid Authorization header") if settings.chat_api_key and auth_header.split(" ")[1] != settings.chat_api_key: raise ValueError("Invalid API key") return StreamingResponse( stream_chat_response(AgentInput.model_validate(await request.json())), media_type="text/event-stream", headers={ "Cache-Control": "no-cache", "X-Accel-Buffering": "no", }, ) def get_timestamp() -> int: """Get the current UTC timestamp in seconds.""" return int(datetime.now(timezone.utc).timestamp()) async def stream_chat_response(request: AgentInput) -> AsyncGenerator[str, None]: """Stream the chat response with tool calls, reranking, and results.""" msg_id = str(uuid.uuid4()) token_usage = TokenUsageMetadata() yield sse_event(RunStartedEvent(thread_id=request.thread_id, run_id=request.run_id, timestamp=get_timestamp())) yield sse_event(TextMessageStartEvent(message_id=msg_id, role="assistant", timestamp=get_timestamp())) # Get tools from the MCP client tools = await mcp_client.get_tools() # Get model with tools for the initial query llm = load_chat_model(request.model) llm_with_tools = llm.bind_tools(tools) # Step 1: Call LLM to get tool calls msgs = get_langchain_msgs(request.messages) tc_llm_resp = llm_with_tools.invoke([get_system_prompt(TOOL_CALL_PROMPT), *msgs]) token_usage += LangChainResponseMetadata.model_validate(tc_llm_resp.response_metadata).token_usage if tc_llm_resp.content and isinstance(tc_llm_resp.content, str): # If tc_llm_resp has text send it as a TextMessage content alongside tool calls yield sse_event( TextMessageChunkEvent( delta=tc_llm_resp.content, timestamp=get_timestamp(), ) ) # Step 2: Execute each tool and collect search results and textual outputs search_results = OpenSearchResults(total_found=0, hits=[]) tool_text_outputs: list[str] = [] async with mcp_client.session("data-commons-search") as session: for tool_call in tc_llm_resp.tool_calls: tool_call_id = tool_call["name"] yield sse_event( ToolCallStartEvent( tool_call_id=tool_call_id, tool_call_name=tool_call["name"], parent_message_id=msg_id, timestamp=get_timestamp(), ) ) yield sse_event( ToolCallArgsEvent( tool_call_id=tool_call_id, delta=json.dumps(tool_call["args"]), timestamp=get_timestamp() ) ) tc_exec_res = await session.call_tool(tool_call["name"], tool_call["args"]) if tc_exec_res.structuredContent: # Handle structured content, try to parse as `OpenSearchResults` try: tool_results = OpenSearchResults(**tc_exec_res.structuredContent) search_results.hits.extend(tool_results.hits) search_results.total_found += tool_results.total_found finally: tool_results_str = json.dumps(tc_exec_res.structuredContent) yield sse_event( ToolCallResultEvent( message_id=msg_id, tool_call_id=tool_call_id, content=tool_results_str, role="tool", timestamp=get_timestamp(), ) ) elif tc_exec_res.content: # Handle if text content is sent back for resp_content in tc_exec_res.content: if isinstance(resp_content, TextContent): # Stream the raw tool text back to the UI, and record it for fallback summarization yield sse_event( ToolCallResultEvent( message_id=msg_id, tool_call_id=tool_call_id, content=resp_content.text, role="tool", timestamp=get_timestamp(), ) ) try: if resp_content.text: tool_text_outputs.append(resp_content.text) except Exception as exc: logger.exception("Failed to record tool text output: %s", exc) yield sse_event(ToolCallEndEvent(tool_call_id=tool_call_id, timestamp=get_timestamp())) # Handle if there were tool calls output, but no search results: ask the LLM to summarize tools outputs if tc_llm_resp.tool_calls and search_results.total_found == 0 and tool_text_outputs: summary_msgs: list[AnyMessage] = [ get_system_prompt(SUMMARIZE_PROMPT), *msgs, HumanMessage( content=( "The following tool outputs were produced when handling the user's query:\n\n" + "\n\n---\n\n".join(tool_text_outputs) + "\n\nPlease provide a concise summary for the user explaining what the tools returned and any recommendation or next steps." ) ), ] try: fallback_tool_id = "search_summary" yield sse_event( ToolCallStartEvent( tool_call_id=fallback_tool_id, tool_call_name=fallback_tool_id, parent_message_id=msg_id ) ) summary_resp = llm.invoke(summary_msgs) token_usage += LangChainResponseMetadata.model_validate(summary_resp.response_metadata).token_usage # Send the summary back as a ToolCallResult-like event so the UI can display it # NOTE: use TextMessageChunkEvent? yield sse_event( ToolCallResultEvent( message_id=msg_id, tool_call_id=fallback_tool_id, content=str(summary_resp.content), role="tool", timestamp=get_timestamp(), ) ) yield sse_event(ToolCallEndEvent(tool_call_id=fallback_tool_id)) return except Exception as e: logger.error(f"Fallback summarization failed: {e}") # Step 3: If no results found or no tool calls, handle early exit if not tc_llm_resp.tool_calls or search_results.total_found == 0: yield sse_event(TextMessageEndEvent(message_id=msg_id, timestamp=get_timestamp())) yield sse_event(RunFinishedEvent(thread_id=request.thread_id, run_id=request.run_id, timestamp=get_timestamp())) return # print(json.dumps(search_results.model_dump(), indent=2)) # Step 4: Rerank search results using LLM with structured output rerank_tc_id = "rerank_results" yield sse_event( ToolCallStartEvent( tool_call_id=rerank_tc_id, tool_call_name="rerank_results", parent_message_id=msg_id, timestamp=get_timestamp(), ) ) final_response = await rerank_search_results( llm, msgs, search_results, token_usage, ) yield sse_event( ToolCallResultEvent( message_id=msg_id, tool_call_id=rerank_tc_id, content=final_response.model_dump_json(by_alias=True), role="tool", timestamp=get_timestamp(), ) ) yield sse_event(ToolCallEndEvent(tool_call_id=rerank_tc_id, timestamp=get_timestamp())) yield sse_event(TextMessageEndEvent(message_id=msg_id, timestamp=get_timestamp())) yield sse_event(RunFinishedEvent(thread_id=request.thread_id, run_id=request.run_id, timestamp=get_timestamp())) file_logger.info( json.dumps( { "timestamp": datetime.now(timezone.utc).isoformat(), "token_usage": token_usage.model_dump(), "input": request.model_dump(), "response": final_response.model_dump(), } ) ) logger.info(f'/chat "{request.messages[-1].content}" | {token_usage.model_dump()}') app.router.add_route("/chat", chat_handler, methods=["POST"]) async def rerank_search_results( llm: BaseChatModel, chat_messages: list[AnyMessage], search_results: OpenSearchResults, token_usage: TokenUsageMetadata, ) -> RankedSearchResponse: """Rerank search results using LLM with structured output. Args: model: The LLM model to use for reranking chat_messages: Original chat messages for context search_results: Search results to rerank Returns: RankedSearchResponse with reranked hits and summary """ # Format the context for the LLM last_msg = chat_messages[-1] if chat_messages else None last_msg_content = last_msg.content if last_msg and isinstance(last_msg.content, str) else "" formatted_context = f"Found {search_results.total_found} datasets relevant to the query '{last_msg_content}':\n\n" for i, hit in enumerate(search_results.hits[: settings.reranking_results_count]): formatted_context += f"{i + 1}. **{hit.id}**\n" formatted_context += f" {' | '.join([title.title for title in hit.source.titles])}\n" if hit.source.dates: formatted_context += ( f" Dates: {' | '.join([f'{date.date_type}: {date.date}' for date in hit.source.dates])}\n" ) if hit.source.creators: formatted_context += f" Authors: {', '.join([creator.creator_name for creator in hit.source.creators])}\n" if hit.source.subjects: formatted_context += f" Keywords: {', '.join([subj.subject for subj in hit.source.subjects])}\n" formatted_context += f" Description: {hit.description}\n\n" rerank_msgs: list[AnyMessage] = [ get_system_prompt(RERANK_PROMPT), *chat_messages, HumanMessage(content=formatted_context), ] try: # Call LLM with structured output for reranking llm_structured_rerank = llm.with_structured_output(RerankingOutput, method="function_calling", include_raw=True) rerank_resp = LangChainRerankingOutputMsg.model_validate(llm_structured_rerank.invoke(rerank_msgs)) token_usage += LangChainResponseMetadata.model_validate(rerank_resp.raw.response_metadata).token_usage # Only keep the hits that were sent for reranking reranked_hits = search_results.hits[: settings.reranking_results_count] # Add scores to the reranked datasets score_lookup = {hit.url: hit.score for hit in rerank_resp.parsed.hits} # print(f"Rerank response: {score_lookup}") for hit in reranked_hits: hit.score = score_lookup.get(hit.id, 0.0) # Sort hits by score in descending order reranked_hits.sort(key=lambda h: h.score or 0.0, reverse=True) # await get_relevant_tools(reranked_hits) return RankedSearchResponse(summary=rerank_resp.parsed.summary, hits=reranked_hits) except Exception as e: logger.error(f"Reranking failed: {e}") # Fallback: return results as-is without reranking return RankedSearchResponse( summary=f"Found {search_results.total_found} relevant datasets.", hits=search_results.hits, ) # Serve website built using vite app.mount( "/assets", StaticFiles(directory="src/data_commons_search/webapp/assets"), name="static", ) async def ui_handler(request: Request) -> FileResponse: """Serve the chat UI HTML file directly.""" return FileResponse("src/data_commons_search/webapp/index.html") # Serve index.html for root and any other unmatched GET paths, so a SPA can handle routing app.router.add_route("/", ui_handler, methods=["GET"]) app.router.add_route("/{path:path}", ui_handler, methods=["GET"]) # In OpenSearch and Filemetrix: https://doi.org/10.17026/DANS-2B8-ZGY2 # Data to Monitor Soil Aggregate Breakdown # Data on fair evaluation # NOTE: commented out for now as this is done directly from the frontend when a user show interest for a dataset (e.g. clicks on it) # # https://confluence.egi.eu/display/EOSCDATACOMMONS/API+Definitions+and+Implementation+Guidelines # # https://dev.matchmaker.eosc-data-commons.eu/search?q=search for data about Cognitive load in cyclists while navigating in traffic&model=einfracz%2Fqwen3-coder # # curl -X POST http://localhost:8001/chat -H "Content-Type: application/json" -H "Authorization: SECRET_KEY" -d '{"messages": [{"role": "user", "content": "Datasets about representation of dogs in medieval time"}], "model": "einfracz/qwen3-coder", "stream": true}' # # curl -X POST http://localhost:8001/chat -H "Content-Type: application/json" -H "Authorization: SECRET_KEY" -d '{"messages": [{"role": "user", "content": "search for data about Harelbeke Evolis"}], "model": "einfracz/qwen3-coder", "stream": true}' # # curl -X POST http://localhost:8001/chat -H "Content-Type: application/json" -H "Authorization: SECRET_KEY" -d '{"messages": [{"role": "user", "content": "search for data about Cognitive load in cyclists while navigating in traffic"}], "model": "einfracz/qwen3-coder", "stream": true}' # async def get_relevant_tools(search_hits: list[SearchHit]) -> None: # """Fetch file extensions and relevant tools from the FileMetrix API in parallel for each hit's DOI, # and update hits in-place. # Args: # search_results: The OpenSearch results to enhance with file extensions and relevant tools. # """ # async def fetch_extensions(client: httpx.AsyncClient, doi: str) -> FileMetrixExtensionsResponse | None: # """Fetch extensions for a single DOI.""" # try: # encoded = quote(doi, safe="") # resp = await client.get( # f"{settings.filemetrix_api}/extensions/{encoded}", # headers={"accept": "application/json"}, # ) # if resp.status_code == 200: # return FileMetrixExtensionsResponse.model_validate(resp.json()) # logger.warning(f"FileMetrix returned {resp.status_code} for DOI {doi}") # except Exception as e: # logger.warning(f"FileMetrix fetch error for {doi}: {e}") # return None # async def fetch_tools_for_extension(client: httpx.AsyncClient, extension: str) -> list[dict[str, str]] | None: # """Fetch relevant tools for a file extension from the tool registry.""" # try: # resp = await client.get( # f"{settings.tool_registry_api}/input/{extension}", # headers={"accept": "application/json"}, # ) # if resp.status_code == 200: # return resp.json() # logger.warning(f"Tool registry returned {resp.status_code} for extension {extension}") # except Exception as e: # logger.warning(f"Tool registry fetch error for {extension}: {e}") # return None # # Extract DOI from hit and create fetch task # async def process_hit(client: httpx.AsyncClient, hit: SearchHit) -> None: # """Extract DOI from hit and fetch/apply extensions and relevant tools.""" # doi = None # try: # if hit.id.startswith("http"): # parsed = urlparse(hit.id) # if "doi.org" in parsed.netloc: # doi = unquote(parsed.path.lstrip("/")) # else: # doi = hit.id # except Exception: # return # if not doi: # return # # Fetch file extensions # fm = await fetch_extensions(client, doi) # if fm: # hit.file_extensions = fm.extensions # logger.info(f"📁 https://doi.org/{doi} -> extensions: {fm.extensions}") # # Fetch relevant tools for each extension # all_tools = [] # for ext in fm.extensions: # tools_data = await fetch_tools_for_extension(client, ext) # if tools_data: # try: # for tool_dict in tools_data: # tool = ToolRegistryTool.model_validate(tool_dict) # all_tools.append(tool) # logger.info(f"🔧 {ext} -> tool: {tool.tool_label}") # except Exception as e: # logger.warning(f"Error parsing tool data for {ext}: {e}") # # Remove duplicates by tool_uri while preserving order # seen = set() # unique_tools = [] # for tool in all_tools: # if tool.tool_uri not in seen: # seen.add(tool.tool_uri) # unique_tools.append(tool) # hit.relevant_tools = unique_tools # async with httpx.AsyncClient(timeout=10.0) as client: # await asyncio.gather(*(process_hit(client, hit) for hit in search_hits))

Latest Blog Posts

MCP directory API

We provide all the information about MCP servers via our MCP API.

curl -X GET 'https://glama.ai/api/mcp/v1/servers/EOSC-Data-Commons/data-commons-search'

If you have feedback or need assistance with the MCP directory API, please join our Discord server