fhir_graphrag_mcp_server.py•45.9 kB
#!/usr/bin/env python3
"""
FHIR + GraphRAG MCP Server
Model Context Protocol (MCP) server exposing FHIR repository search,
GraphRAG knowledge graph queries, and medical data visualization tools
for AI-powered medical chat applications.
Tools provided:
- search_fhir_documents: Full-text search of FHIR DocumentReference resources
- search_knowledge_graph: Entity-based search through medical knowledge graph
- hybrid_search: Combined FHIR + GraphRAG search with RRF fusion
- get_entity_relationships: Traverse knowledge graph from seed entities
- get_document_details: Retrieve full FHIR document with decoded clinical notes
- get_entity_statistics: Knowledge graph statistics and entity distribution
- plot_entity_distribution: Generate entity type distribution chart data
- plot_patient_timeline: Generate patient encounter timeline chart data
- plot_symptom_frequency: Generate symptom frequency bar chart data
- plot_entity_network: Generate entity relationship network graph data
"""
import sys
import os
# Add project root to path to find src (MUST be first in path)
parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
if parent_dir not in sys.path:
sys.path.insert(0, parent_dir)
import json
import time # For execution time tracking
from typing import Any, Dict, List, Optional
from datetime import datetime
# Import MCP modules
from mcp.server import Server
from mcp.types import Resource, Tool,TextContent, ImageContent, EmbeddedResource
from mcp.server.stdio import stdio_server
# Import async tools
import asyncio
# Import boto3 for AWS Bedrock
try:
import boto3
BOTO3_AVAILABLE = True
except ImportError:
BOTO3_AVAILABLE = False
print("Warning: boto3 not available, AWS Bedrock will not be used", file=sys.stderr)
# Add src to path for imports
# This line is now redundant due to the 'parent_dir' block above, but keeping it for now as it wasn't explicitly removed.
# sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
try:
from src.embeddings.nvclip_embeddings import NVCLIPEmbeddings
except ImportError:
print("Warning: Could not import NVCLIPEmbeddings. Image search will be limited.", file=sys.stderr)
NVCLIPEmbeddings = None
# MCP SDK
# This block is now replaced by the explicit MCP imports above.
# try:
# from mcp.server import Server
# from mcp.types import Tool, TextContent, ImageContent
# import mcp.server.stdio
# except ImportError:
# print("Error: MCP SDK not installed. Install with: pip install mcp", file=sys.stderr)
# sys.exit(1)
# Import database connection module
from src.db.connection import get_connection
# Import search modules for scoring and caching
from src.search.scoring import get_score_color, get_confidence_level
from src.search.cache import get_cached_embedding, cache_info
# Initialize MCP server
server = Server("fhir-graphrag-server")
# Initialize NV-CLIP embedder (lazy load or global)
embedder = None
def get_embedder():
"""Get or initialize NV-CLIP embedder."""
global embedder
if embedder is None and NVCLIPEmbeddings:
try:
embedder = NVCLIPEmbeddings()
except Exception as e:
print(f"Warning: Failed to initialize NV-CLIP: {e}", file=sys.stderr)
return embedder
@server.list_tools()
async def list_tools() -> List[Tool]:
"""List all available tools."""
return [
Tool(
name="search_fhir_documents",
description="Search FHIR DocumentReference resources by text content. "
"Decodes hex-encoded clinical notes and performs full-text search. "
"Returns matching documents with clinical note previews.",
inputSchema={
"type": "object",
"properties": {
"query": {
"type": "string",
"description": "Search query terms (e.g., 'chest pain', 'fever vomiting')"
},
"limit": {
"type": "integer",
"description": "Maximum number of results to return",
"default": 10
}
},
"required": ["query"]
}
),
Tool(
name="search_knowledge_graph",
description="Search medical knowledge graph for entities matching query terms. "
"Returns entities with types (SYMPTOM, CONDITION, MEDICATION, etc.) "
"and confidence scores, plus documents containing those entities.",
inputSchema={
"type": "object",
"properties": {
"query": {
"type": "string",
"description": "Medical entity search terms (e.g., 'respiratory', 'abdominal pain')"
},
"limit": {
"type": "integer",
"description": "Maximum number of entities to return",
"default": 5
}
},
"required": ["query"]
}
),
Tool(
name="hybrid_search",
description="Combined FHIR + GraphRAG search using Reciprocal Rank Fusion. "
"Searches both FHIR document text and knowledge graph entities, "
"then fuses results for optimal ranking. Best for complex medical queries.",
inputSchema={
"type": "object",
"properties": {
"query": {
"type": "string",
"description": "Medical search query (e.g., 'chest pain and difficulty breathing')"
},
"top_k": {
"type": "integer",
"description": "Number of top results to return",
"default": 5
}
},
"required": ["query"]
}
),
Tool(
name="get_entity_relationships",
description="Traverse knowledge graph relationships from seed entities. "
"Performs multi-hop graph traversal to find related entities and "
"discovers semantic connections between medical concepts.",
inputSchema={
"type": "object",
"properties": {
"entity_text": {
"type": "string",
"description": "Starting entity text (e.g., 'fever', 'respiratory')"
},
"max_depth": {
"type": "integer",
"description": "Maximum traversal depth (1-3 hops)",
"default": 2
}
},
"required": ["entity_text"]
}
),
Tool(
name="get_document_details",
description="Retrieve full details of a FHIR DocumentReference resource. "
"Decodes hex-encoded clinical notes and returns complete document "
"with metadata and full text content.",
inputSchema={
"type": "object",
"properties": {
"fhir_id": {
"type": "string",
"description": "FHIR resource ID (e.g., '1474', '2079')"
}
},
"required": ["fhir_id"]
}
),
Tool(
name="get_entity_statistics",
description="Get knowledge graph statistics including entity counts by type, "
"relationship counts, and high-confidence entities. Useful for "
"understanding the scope and quality of the knowledge graph.",
inputSchema={
"type": "object",
"properties": {}
}
),
Tool(
name="plot_entity_distribution",
description="Generate interactive pie/bar chart data showing distribution of entity types "
"(SYMPTOM, CONDITION, MEDICATION, etc.). Returns Plotly-compatible chart data.",
inputSchema={
"type": "object",
"properties": {
"chart_type": {
"type": "string",
"description": "Chart type: 'pie' or 'bar'",
"default": "pie"
}
}
}
),
Tool(
name="plot_patient_timeline",
description="Generate timeline visualization showing patient encounters over time. "
"Returns Plotly-compatible timeline chart data with dates and event counts.",
inputSchema={
"type": "object",
"properties": {
"granularity": {
"type": "string",
"description": "Time granularity: 'day', 'week', 'month'",
"default": "month"
}
}
}
),
Tool(
name="plot_symptom_frequency",
description="Generate bar chart showing most frequent symptoms in the knowledge graph. "
"Returns top N symptoms with their occurrence counts.",
inputSchema={
"type": "object",
"properties": {
"top_n": {
"type": "integer",
"description": "Number of top symptoms to show",
"default": 10
}
}
}
),
Tool(
name="plot_entity_network",
description="Generate network graph data showing entity relationships for visualization. "
"Returns nodes and edges in format compatible with network visualization libraries.",
inputSchema={
"type": "object",
"properties": {
"entity_type": {
"type": "string",
"description": "Filter by entity type (SYMPTOM, CONDITION, etc.) or 'all'",
"default": "all"
},
"max_nodes": {
"type": "integer",
"description": "Maximum number of nodes to include",
"default": 50
}
}
}
),
Tool(
name="visualize_graphrag_results",
description="Visualize GraphRAG search results showing the query, entities found, and their relationships. "
"Performs a knowledge graph search and returns visualization data showing how entities connect to the query.",
inputSchema={
"type": "object",
"properties": {
"query": {
"type": "string",
"description": "Medical search query to visualize (e.g., 'diabetes symptoms', 'chest pain treatment')"
},
"limit": {
"type": "integer",
"description": "Maximum number of entities to include in visualization",
"default": 10
}
},
"required": ["query"]
}
),
Tool(
name="search_medical_images",
description="Search for medical images (X-rays) using text query or metadata. "
"Uses NV-CLIP for semantic search (e.g., 'pneumonia', 'lateral view') "
"to find relevant chest X-rays from MIMIC-CXR dataset.",
inputSchema={
"type": "object",
"properties": {
"query": {
"type": "string",
"description": "Search query (e.g., 'chest X-ray of pneumonia', 'enlarged heart')"
},
"limit": {
"type": "integer",
"description": "Maximum number of images to return",
"default": 5
}
},
"required": ["query"]
}
)
]
@server.call_tool()
async def call_tool(name: str, arguments: Dict[str, Any]) -> List[TextContent]:
"""Handle tool invocations."""
try:
conn = get_connection()
cursor = conn.cursor()
if name == "search_fhir_documents":
query = arguments["query"]
limit = arguments.get("limit", 10)
# Query FHIR documents
sql = """
SELECT TOP ?
FHIRResourceId,
ResourceType,
ResourceString
FROM SQLUser.FHIRDocuments
WHERE ResourceType = 'DocumentReference'
"""
cursor.execute(sql, (limit * 3,))
search_terms = query.lower().split()
results = []
for fhir_id, resource_type, resource_string in cursor.fetchall():
try:
resource_json = json.loads(resource_string)
# Decode clinical note
clinical_note = None
if 'content' in resource_json:
try:
encoded_data = resource_json['content'][0]['attachment']['data']
clinical_note = bytes.fromhex(encoded_data).decode('utf-8')
except:
pass
# Filter by search terms
if clinical_note and search_terms:
clinical_note_lower = clinical_note.lower()
if any(term in clinical_note_lower for term in search_terms):
results.append({
'fhir_id': fhir_id,
'preview': clinical_note[:300],
'relevance': 'matches: ' + ', '.join([t for t in search_terms if t in clinical_note_lower])
})
if len(results) >= limit:
break
except:
pass
cursor.close()
conn.close()
return [TextContent(
type="text",
text=json.dumps({
"query": query,
"results_count": len(results),
"documents": results
}, indent=2)
)]
elif name == "search_knowledge_graph":
query = arguments["query"]
limit = arguments.get("limit", 5)
keywords = [w.strip() for w in query.lower().split() if len(w.strip()) > 2]
all_entities = []
seen_ids = set()
for keyword in keywords:
sql = """
SELECT EntityID, EntityText, EntityType, Confidence
FROM SQLUser.Entities
WHERE LOWER(EntityText) LIKE ?
ORDER BY Confidence DESC
LIMIT ?
"""
cursor.execute(sql, (f'%{keyword}%', 10))
for entity_id, text, entity_type, confidence in cursor.fetchall():
if entity_id not in seen_ids:
seen_ids.add(entity_id)
all_entities.append({
'id': entity_id,
'text': text,
'type': entity_type,
'confidence': float(confidence) if confidence else 0.0,
'matched_keyword': keyword
})
all_entities.sort(key=lambda x: x['confidence'], reverse=True)
top_entities = all_entities[:limit]
# Get documents for these entities
if top_entities:
entity_ids = [e['id'] for e in top_entities]
placeholders = ','.join(['?'] * len(entity_ids))
doc_sql = f"""
SELECT DISTINCT e.ResourceID, f.FHIRResourceId,
COUNT(DISTINCT e.EntityID) as EntityCount
FROM SQLUser.Entities e
JOIN SQLUser.FHIRDocuments f ON e.ResourceID = f.FHIRResourceId
WHERE e.EntityID IN ({placeholders})
GROUP BY e.ResourceID, f.FHIRResourceId
ORDER BY EntityCount DESC
"""
cursor.execute(doc_sql, entity_ids)
documents = []
for resource_id, fhir_id, entity_count in cursor.fetchall():
documents.append({
'fhir_id': fhir_id,
'entity_count': entity_count
})
else:
documents = []
cursor.close()
conn.close()
return [TextContent(
type="text",
text=json.dumps({
"query": query,
"entities_found": len(top_entities),
"entities": top_entities,
"documents_found": len(documents),
"documents": documents
}, indent=2)
)]
elif name == "hybrid_search":
query = arguments["query"]
top_k = arguments.get("top_k", 5)
# Execute FHIR search (text matching)
fhir_sql = """
SELECT TOP ?
FHIRResourceId,
ResourceType,
ResourceString
FROM SQLUser.FHIRDocuments
WHERE ResourceType = 'DocumentReference'
"""
cursor.execute(fhir_sql, (top_k * 3,))
search_terms = query.lower().split()
fhir_results = []
for fhir_id, resource_type, resource_string in cursor.fetchall():
try:
resource_json = json.loads(resource_string)
clinical_note = None
if 'content' in resource_json:
try:
encoded_data = resource_json['content'][0]['attachment']['data']
clinical_note = bytes.fromhex(encoded_data).decode('utf-8')
except:
pass
if clinical_note and search_terms:
clinical_note_lower = clinical_note.lower()
if any(term in clinical_note_lower for term in search_terms):
fhir_results.append({'fhir_id': fhir_id})
if len(fhir_results) >= top_k * 2:
break
except:
pass
# Execute GraphRAG search (entity matching)
keywords = [w.strip() for w in query.lower().split() if len(w.strip()) > 2]
all_entity_ids = set()
for keyword in keywords[:3]: # Limit to 3 keywords
entity_sql = """
SELECT EntityID FROM SQLUser.Entities
WHERE LOWER(EntityText) LIKE ?
LIMIT 5
"""
cursor.execute(entity_sql, (f'%{keyword}%',))
for row in cursor.fetchall():
all_entity_ids.add(row[0])
graphrag_results = []
if all_entity_ids:
entity_list = list(all_entity_ids)
placeholders = ','.join(['?'] * len(entity_list))
doc_sql = f"""
SELECT DISTINCT f.FHIRResourceId
FROM SQLUser.Entities e
JOIN SQLUser.FHIRDocuments f ON e.ResourceID = f.FHIRResourceId
WHERE e.EntityID IN ({placeholders})
"""
cursor.execute(doc_sql, entity_list)
graphrag_results = [{'fhir_id': row[0]} for row in cursor.fetchall()]
# RRF fusion
k = 60
score_map = {}
for rank, doc in enumerate(fhir_results, 1):
fhir_id = doc['fhir_id']
if fhir_id not in score_map:
score_map[fhir_id] = {'fhir_rank': None, 'graph_rank': None}
score_map[fhir_id]['fhir_rank'] = rank
for rank, doc in enumerate(graphrag_results, 1):
fhir_id = doc['fhir_id']
if fhir_id not in score_map:
score_map[fhir_id] = {'fhir_rank': None, 'graph_rank': None}
score_map[fhir_id]['graph_rank'] = rank
fused = []
for fhir_id, data in score_map.items():
rrf_score = 0.0
if data['fhir_rank']:
rrf_score += 1.0 / (k + data['fhir_rank'])
if data['graph_rank']:
rrf_score += 1.0 / (k + data['graph_rank'])
sources = []
if data['fhir_rank']:
sources.append('fhir')
if data['graph_rank']:
sources.append('graphrag')
fused.append({
'fhir_id': fhir_id,
'rrf_score': rrf_score,
'sources': sources
})
fused.sort(key=lambda x: x['rrf_score'], reverse=True)
cursor.close()
conn.close()
return [TextContent(
type="text",
text=json.dumps({
"query": query,
"fhir_results": len(fhir_results),
"graphrag_results": len(graphrag_results),
"fused_results": len(fused),
"top_documents": fused[:top_k]
}, indent=2)
)]
elif name == "get_entity_relationships":
entity_text = arguments["entity_text"]
max_depth = arguments.get("max_depth", 2)
# Find entity
sql = """
SELECT EntityID, EntityText, EntityType, Confidence
FROM SQLUser.Entities
WHERE LOWER(EntityText) LIKE ?
ORDER BY Confidence DESC
LIMIT 1
"""
cursor.execute(sql, (f'%{entity_text.lower()}%',))
result = cursor.fetchone()
if not result:
cursor.close()
conn.close()
return [TextContent(
type="text",
text=json.dumps({"error": f"Entity '{entity_text}' not found"})
)]
entity_id, text, entity_type, confidence = result
# Traverse relationships
visited = {entity_id}
current_level = {entity_id}
graph = {'entities': [{'id': entity_id, 'text': text, 'type': entity_type, 'depth': 0}],
'relationships': []}
for depth in range(max_depth):
if not current_level:
break
next_level = set()
entity_list = list(current_level)
placeholders = ','.join(['?'] * len(entity_list))
rel_sql = f"""
SELECT DISTINCT
r.SourceEntityID, r.TargetEntityID, r.RelationshipType,
e1.EntityText as SourceText, e1.EntityType as SourceType,
e2.EntityText as TargetText, e2.EntityType as TargetType
FROM SQLUser.EntityRelationships r
JOIN SQLUser.Entities e1 ON r.SourceEntityID = e1.EntityID
JOIN SQLUser.Entities e2 ON r.TargetEntityID = e2.EntityID
WHERE r.SourceEntityID IN ({placeholders})
OR r.TargetEntityID IN ({placeholders})
"""
cursor.execute(rel_sql, entity_list + entity_list)
for row in cursor.fetchall():
source_id, target_id, rel_type, source_text, source_type, target_text, target_type = row
graph['relationships'].append({
'source': source_text,
'target': target_text,
'type': rel_type
})
for eid, text, etype in [(source_id, source_text, source_type),
(target_id, target_text, target_type)]:
if eid not in visited:
visited.add(eid)
next_level.add(eid)
graph['entities'].append({
'id': eid,
'text': text,
'type': etype,
'depth': depth + 1
})
current_level = next_level
cursor.close()
conn.close()
return [TextContent(
type="text",
text=json.dumps({
"seed_entity": entity_text,
"entities_found": len(graph['entities']),
"relationships_found": len(graph['relationships']),
"graph": graph
}, indent=2)
)]
elif name == "get_document_details":
fhir_id = arguments["fhir_id"]
sql = """
SELECT FHIRResourceId, ResourceType, ResourceString
FROM SQLUser.FHIRDocuments
WHERE FHIRResourceId = ?
"""
cursor.execute(sql, (fhir_id,))
result = cursor.fetchone()
if not result:
cursor.close()
conn.close()
return [TextContent(
type="text",
text=json.dumps({"error": f"Document '{fhir_id}' not found"})
)]
_, resource_type, resource_string = result
resource_json = json.loads(resource_string)
# Decode clinical note
clinical_note = None
if 'content' in resource_json:
try:
encoded_data = resource_json['content'][0]['attachment']['data']
clinical_note = bytes.fromhex(encoded_data).decode('utf-8')
except:
pass
cursor.close()
conn.close()
return [TextContent(
type="text",
text=json.dumps({
"fhir_id": fhir_id,
"resource_type": resource_type,
"clinical_note": clinical_note,
"metadata": {
"id": resource_json.get('id'),
"status": resource_json.get('status'),
"type": resource_json.get('type', {}).get('coding', [{}])[0].get('display') if resource_json.get('type') else None
}
}, indent=2)
)]
elif name == "get_entity_statistics":
# Entity count by type
stats_sql = """
SELECT EntityType, COUNT(*) as EntityCount
FROM SQLUser.Entities
GROUP BY EntityType
ORDER BY EntityCount DESC
"""
cursor.execute(stats_sql)
entity_stats = []
for entity_type, count in cursor.fetchall():
entity_stats.append({'type': entity_type, 'count': count})
# Total counts
cursor.execute("SELECT COUNT(*) FROM SQLUser.Entities")
total_entities = cursor.fetchone()[0]
cursor.execute("SELECT COUNT(*) FROM SQLUser.EntityRelationships")
total_relationships = cursor.fetchone()[0]
# High confidence entities
cursor.execute("""
SELECT TOP 10 EntityText, EntityType, Confidence
FROM SQLUser.Entities
ORDER BY Confidence DESC
""")
high_confidence = []
for text, entity_type, confidence in cursor.fetchall():
high_confidence.append({
'text': text,
'type': entity_type,
'confidence': float(confidence) if confidence else 0.0
})
cursor.close()
conn.close()
return [TextContent(
type="text",
text=json.dumps({
"total_entities": total_entities,
"total_relationships": total_relationships,
"entity_distribution": entity_stats,
"high_confidence_entities": high_confidence
}, indent=2)
)]
elif name == "plot_entity_distribution":
chart_type = arguments.get("chart_type", "pie")
# Get entity type distribution
cursor.execute("""
SELECT EntityType, COUNT(*) as EntityCount
FROM SQLUser.Entities
GROUP BY EntityType
ORDER BY EntityCount DESC
""")
labels = []
values = []
for entity_type, count in cursor.fetchall():
labels.append(entity_type)
values.append(count)
cursor.close()
conn.close()
# Return Plotly-compatible data
chart_data = {
"chart_type": chart_type,
"data": {
"labels": labels,
"values": values
},
"layout": {
"title": "Medical Entity Type Distribution",
"xaxis_title": "Entity Type" if chart_type == "bar" else None,
"yaxis_title": "Count" if chart_type == "bar" else None
}
}
return [TextContent(
type="text",
text=json.dumps(chart_data, indent=2)
)]
elif name == "plot_patient_timeline":
granularity = arguments.get("granularity", "month")
# Get temporal entities (dates)
cursor.execute("""
SELECT EntityText, COUNT(*) as Frequency
FROM SQLUser.Entities
WHERE EntityType = 'TEMPORAL'
GROUP BY EntityText
ORDER BY EntityText
""")
dates = []
counts = []
for date_text, count in cursor.fetchall():
# Try to parse date (handle various formats)
try:
if len(date_text) == 10 and '-' in date_text: # YYYY-MM-DD
dates.append(date_text)
counts.append(count)
except:
pass
cursor.close()
conn.close()
chart_data = {
"chart_type": "timeline",
"data": {
"dates": dates,
"counts": counts
},
"layout": {
"title": "Patient Encounters Timeline",
"xaxis_title": "Date",
"yaxis_title": "Event Count"
}
}
return [TextContent(
type="text",
text=json.dumps(chart_data, indent=2)
)]
elif name == "plot_symptom_frequency":
top_n = arguments.get("top_n", 10)
# Get symptom frequencies
cursor.execute(f"""
SELECT TOP {top_n}
e.EntityText,
COUNT(DISTINCT er.RelationshipID) as Frequency
FROM SQLUser.Entities e
LEFT JOIN SQLUser.EntityRelationships er
ON e.EntityID = er.SourceEntityID OR e.EntityID = er.TargetEntityID
WHERE e.EntityType = 'SYMPTOM'
GROUP BY e.EntityText
ORDER BY Frequency DESC
""")
symptoms = []
frequencies = []
for text, freq in cursor.fetchall():
symptoms.append(text)
frequencies.append(freq)
cursor.close()
conn.close()
# Return Plotly-compatible data
chart_data = {
"chart_type": "bar",
"data": {
"symptoms": symptoms,
"frequencies": frequencies
},
"layout": {
"title": f"Top {top_n} Most Frequent Symptoms",
"xaxis_title": "Symptom",
"yaxis_title": "Frequency"
}
}
return [TextContent(
type="text",
text=json.dumps(chart_data, indent=2)
)]
elif name == "plot_entity_network":
entity_type = arguments.get("entity_type", "all")
max_nodes = arguments.get("max_nodes", 50)
# Get top entities and their relationships
cursor.execute(f"""
SELECT TOP {max_nodes} EntityID, EntityText, EntityType
FROM SQLUser.Entities
{'WHERE EntityType = ?' if entity_type != 'all' else ''}
ORDER BY Confidence DESC
""", (entity_type,) if entity_type != 'all' else ())
nodes = []
node_ids = set()
for eid, text, etype in cursor.fetchall():
nodes.append({
"id": eid,
"name": text,
"type": etype,
# Random layout for now, client can use force-directed
"x": 0, "y": 0
})
node_ids.add(eid)
# Get relationships between these nodes
if node_ids:
placeholders = ','.join(['?'] * len(node_ids))
cursor.execute(f"""
SELECT SourceEntityID, TargetEntityID, RelationshipType
FROM SQLUser.EntityRelationships
WHERE SourceEntityID IN ({placeholders})
AND TargetEntityID IN ({placeholders})
""", list(node_ids) + list(node_ids))
edges = []
for source, target, rel_type in cursor.fetchall():
# Find indices
source_idx = next((i for i, n in enumerate(nodes) if n["id"] == source), -1)
target_idx = next((i for i, n in enumerate(nodes) if n["id"] == target), -1)
if source_idx != -1 and target_idx != -1:
edges.append({
"source": source_idx,
"target": target_idx,
"type": rel_type
})
else:
edges = []
cursor.close()
conn.close()
return [TextContent(
type="text",
text=json.dumps({
"chart_type": "network",
"data": {
"nodes": nodes,
"edges": edges
}
}, indent=2)
)]
elif name == "visualize_graphrag_results":
query = arguments["query"]
limit = arguments.get("limit", 10)
# 1. Find entities matching query
keywords = [w.strip() for w in query.lower().split() if len(w.strip()) > 2]
matched_entities = []
for keyword in keywords:
cursor.execute("""
SELECT EntityID, EntityText, EntityType
FROM SQLUser.Entities
WHERE LOWER(EntityText) LIKE ?
ORDER BY Confidence DESC
LIMIT 5
""", (f'%{keyword}%',))
for row in cursor.fetchall():
matched_entities.append(row)
# 2. Find related entities (1 hop)
nodes = []
edges = []
node_ids = set()
# Add query node
nodes.append({
"id": -1,
"name": query,
"type": "QUERY",
"color": "red",
"x": 0, "y": 0
})
# Add matched entities
for eid, text, etype in matched_entities[:limit]:
if eid not in node_ids:
nodes.append({
"id": eid,
"name": text,
"type": etype,
"color": "blue",
"x": 0, "y": 0
})
node_ids.add(eid)
# Edge from query to entity
edges.append({
"source": 0, # Query is index 0
"target": len(nodes) - 1,
"type": "MATCHES"
})
# Add relationships between found entities
if node_ids:
placeholders = ','.join(['?'] * len(node_ids))
cursor.execute(f"""
SELECT SourceEntityID, TargetEntityID, RelationshipType
FROM SQLUser.EntityRelationships
WHERE SourceEntityID IN ({placeholders})
AND TargetEntityID IN ({placeholders})
""", list(node_ids) + list(node_ids))
for source, target, rel_type in cursor.fetchall():
source_idx = next((i for i, n in enumerate(nodes) if n["id"] == source), -1)
target_idx = next((i for i, n in enumerate(nodes) if n["id"] == target), -1)
if source_idx != -1 and target_idx != -1:
edges.append({
"source": source_idx,
"target": target_idx,
"type": rel_type
})
cursor.close()
conn.close()
return [TextContent(
type="text",
text=json.dumps({
"chart_type": "network",
"query": query,
"entities_found": len(matched_entities),
"data": {
"nodes": nodes,
"edges": edges
}
}, indent=2)
)]
elif name == "search_medical_images":
query = arguments["query"]
limit = arguments.get("limit", 5)
min_score = arguments.get("min_score", 0.0) # New: score threshold filter
emb = get_embedder()
results = []
search_mode = "keyword" # Default to keyword search
cache_hit = False
fallback_reason = None
start_time = time.time()
if emb:
try:
# Semantic search with NV-CLIP (using cached embeddings)
query_vector_tuple = get_cached_embedding(query)
query_vector = list(query_vector_tuple) # Convert tuple back to list
vector_str = ','.join(map(str, query_vector))
# Check if this was a cache hit
info = cache_info()
cache_hit = info.hits > 0
search_mode = "semantic"
# Updated SQL: Include similarity score in results (T017)
sql = f"""
SELECT TOP ?
ImageID, StudyID, SubjectID, ViewPosition, ImagePath,
VECTOR_COSINE(Vector, TO_VECTOR(?, double)) AS Similarity
FROM VectorSearch.MIMICCXRImages
ORDER BY Similarity DESC
"""
cursor.execute(sql, (limit, vector_str))
except Exception as e:
# Enhanced error handling (T018)
print(f"Error in vector search: {e}", file=sys.stderr)
fallback_reason = f"NV-CLIP search failed: {str(e)}"
search_mode = "keyword"
# Fallback to keyword search
emb = None
else:
# No embedder available
fallback_reason = "NV-CLIP embedder not available"
if not emb:
# Keyword search fallback
keywords = query.lower().split()
conditions = []
params = []
for kw in keywords:
conditions.append("(LOWER(ViewPosition) LIKE ? OR LOWER(ImageID) LIKE ?)")
params.extend([f'%{kw}%', f'%{kw}%'])
where_clause = " OR ".join(conditions) if conditions else "1=1"
sql = f"""
SELECT TOP ?
ImageID, StudyID, SubjectID, ViewPosition, ImagePath
FROM VectorSearch.MIMICCXRImages
WHERE {where_clause}
"""
cursor.execute(sql, [limit] + params)
# Process results with scoring metadata
scores = []
for row in cursor.fetchall():
if search_mode == "semantic":
# Semantic search returns 6 columns (including similarity)
image_id, study_id, subject_id, view_pos, image_path, similarity = row
similarity_score = float(similarity) if similarity is not None else 0.0
# Filter by min_score if specified
if similarity_score < min_score:
continue
scores.append(similarity_score)
results.append({
"image_id": image_id,
"study_id": study_id,
"subject_id": subject_id,
"view_position": view_pos,
"image_path": image_path,
"similarity_score": similarity_score,
"score_color": get_score_color(similarity_score),
"confidence_level": get_confidence_level(similarity_score),
"description": f"Chest X-ray ({view_pos}) for patient {subject_id}",
"embedding_model": "nvidia/nvclip"
})
else:
# Keyword search returns 5 columns (no similarity)
image_id, study_id, subject_id, view_pos, image_path = row
results.append({
"image_id": image_id,
"study_id": study_id,
"subject_id": subject_id,
"view_position": view_pos,
"image_path": image_path,
"description": f"Chest X-ray ({view_pos}) for patient {subject_id}"
})
cursor.close()
conn.close()
# Calculate execution time and statistics
execution_time_ms = int((time.time() - start_time) * 1000)
# Build response with enhanced metadata
response = {
"query": query,
"results_count": len(results),
"total_results": len(results), # TODO: Get actual total from COUNT query
"search_mode": search_mode,
"execution_time_ms": execution_time_ms,
"images": results
}
# Add semantic search statistics
if search_mode == "semantic" and scores:
response["cache_hit"] = cache_hit
response["avg_score"] = sum(scores) / len(scores) if scores else None
response["max_score"] = max(scores) if scores else None
response["min_score"] = min(scores) if scores else None
# Add fallback reason if applicable
if fallback_reason:
response["fallback_reason"] = fallback_reason
return [TextContent(
type="text",
text=json.dumps(response, indent=2)
)]
else:
raise ValueError(f"Unknown tool: {name}")
except Exception as e:
import traceback
return [TextContent(
type="text",
text=json.dumps({
"error": str(e),
"traceback": traceback.format_exc()
})
)]
async def main():
"""Run the MCP server."""
async with mcp.server.stdio.stdio_server() as (read_stream, write_stream):
await server.run(
read_stream,
write_stream,
server.create_initialization_options()
)
if __name__ == "__main__":
import asyncio
asyncio.run(main())