import os
import json
import asyncio
from typing import Any, Dict, List
from neo4j import GraphDatabase
from neo4j_graphrag.experimental.components.kg_writer import Neo4jWriter
from neo4j_graphrag.experimental.components.types import Neo4jGraph, Neo4jNode, Neo4jRelationship
from langchain_openai import ChatOpenAI # Swap: from langchain_huggingface import HuggingFaceHub; llm = HuggingFaceHub(repo_id="microsoft/BioGPT")
from langchain.output_parsers import PydanticOutputParser
from langchain.prompts import PromptTemplate
from mcp.server.fastmcp import FastMCP
import fitz # PyMuPDF for PDFs
import logging
from pydantic import BaseModel
# Setup logging (important for MCP servers)
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Env vars (set in Claude config or .env)
URI = os.getenv("NEO4J_URI", "bolt://localhost:7687")
USER = os.getenv("NEO4J_USERNAME", "neo4j")
PASS = os.getenv("NEO4J_PASSWORD", "password")
# LLM setup (swap for custom: e.g., from langchain_community.llms import Ollama; llm = Ollama(model="biogpt"))
llm = ChatOpenAI(model="gpt-4o", api_key=os.getenv("OPENAI_API_KEY"))
# Pydantic models for structured extraction
class CardiologyEntity(BaseModel):
name: str
label: str # e.g., "Condition", "Procedure", "Anatomy"
properties: Dict[str, Any] = {}
class CardiologyRelationship(BaseModel):
source: str
target: str
type: str # e.g., "CAUSES", "TREATS", "PART_OF"
properties: Dict[str, Any] = {}
class ExtractionResult(BaseModel):
entities: List[CardiologyEntity]
relationships: List[CardiologyRelationship]
# Parsers for structured output
parser = PydanticOutputParser(pydantic_object=ExtractionResult)
# Extraction prompt (cardiology-tuned)
extract_prompt = PromptTemplate(
template="""Extract cardiology entities and relationships from the following text.
Focus on:
- Medical conditions (e.g., atrial fibrillation, heart failure)
- Anatomical structures (e.g., left ventricle, aortic valve)
- Procedures (e.g., echocardiogram, catheterization)
- Physiological processes (e.g., cardiac cycle, systole)
- Medications (e.g., beta-blockers, ACE inhibitors)
For relationships, identify:
- CAUSES (condition A causes condition B)
- TREATS (medication/procedure treats condition)
- PART_OF (anatomical relationships)
- AFFECTS (how conditions affect structures/processes)
- FOLLOWS (temporal/sequential relationships)
Text: {text}
{format_instructions}""",
input_variables=["text"],
partial_variables={"format_instructions": parser.get_format_instructions()}
)
extract_chain = extract_prompt | llm | parser
# Initialize MCP server
mcp = FastMCP(name="CardiologyKG")
@mcp.tool()
async def ingest_document(file_path_or_text: str) -> str:
"""Ingest a cardiology document (PDF path or raw text) and extract draft entities/relationships."""
try:
logger.info(f"Starting document ingestion: {file_path_or_text[:100]}...")
# Determine if input is file path or text
if file_path_or_text.endswith('.pdf') and os.path.exists(file_path_or_text):
logger.info("Processing PDF file")
doc = fitz.open(file_path_or_text)
text = "".join(page.get_text() for page in doc)
doc.close()
else:
text = file_path_or_text
# Extract with LLM
logger.info("Extracting entities and relationships...")
extraction_result = await extract_chain.ainvoke({"text": text})
# Convert to JSON for review
draft_json = extraction_result.json(indent=2)
logger.info(f"Extraction complete. Found {len(extraction_result.entities)} entities and {len(extraction_result.relationships)} relationships")
return f"""Document ingestion complete!
**Extraction Summary:**
- Entities found: {len(extraction_result.entities)}
- Relationships found: {len(extraction_result.relationships)}
**Draft JSON (review and edit as needed):**
```json
{draft_json}
```
To add this to your knowledge graph, review the extraction above and then use the 'add_to_graph' tool with your finalized JSON."""
except Exception as e:
logger.error(f"Ingestion error: {str(e)}")
return f"Error during document ingestion: {str(e)}"
@mcp.tool()
async def add_to_graph(extraction_json: str) -> str:
"""Add finalized entities/relationships JSON to the Neo4j knowledge graph."""
try:
logger.info("Parsing extraction JSON...")
# Parse the JSON input
try:
extraction_data = json.loads(extraction_json)
extraction_result = ExtractionResult(**extraction_data)
except json.JSONDecodeError as e:
return f"Invalid JSON format: {str(e)}"
except Exception as e:
return f"Error parsing extraction data: {str(e)}"
# Connect to Neo4j and add data
logger.info("Connecting to Neo4j...")
with GraphDatabase.driver(URI, auth=(USER, PASS)) as driver:
# Create nodes
for entity in extraction_result.entities:
cypher = """
MERGE (n:{label} {{name: $name}})
SET n += $properties
""".format(label=entity.label)
driver.execute_query(
cypher,
name=entity.name,
properties=entity.properties
)
# Create relationships
for rel in extraction_result.relationships:
cypher = """
MATCH (source {{name: $source_name}})
MATCH (target {{name: $target_name}})
MERGE (source)-[r:{rel_type}]->(target)
SET r += $properties
""".format(rel_type=rel.type)
driver.execute_query(
cypher,
source_name=rel.source,
target_name=rel.target,
properties=rel.properties
)
logger.info("Knowledge graph updated successfully")
return f"""✅ Knowledge graph updated successfully!
Added:
- {len(extraction_result.entities)} entities
- {len(extraction_result.relationships)} relationships
Your cardiology knowledge graph now contains this new information and is ready for querying."""
except Exception as e:
logger.error(f"Graph add error: {str(e)}")
return f"Error adding to knowledge graph: {str(e)}"
@mcp.tool()
async def query_graph(natural_language_query: str) -> str:
"""Query the cardiology knowledge graph with natural language."""
try:
logger.info(f"Processing query: {natural_language_query}")
# Generate Cypher from natural language query
cypher_prompt = PromptTemplate(
template="""Convert this natural language query about cardiology into a Cypher query for Neo4j.
The graph contains:
- Nodes with labels like: Condition, Anatomy, Procedure, Medication, Process
- Relationships like: CAUSES, TREATS, PART_OF, AFFECTS, FOLLOWS
Query: {query}
Return only the Cypher query, no explanations:""",
input_variables=["query"]
)
cypher_chain = cypher_prompt | llm
cypher_response = await cypher_chain.ainvoke({"query": natural_language_query})
cypher = cypher_response.content.strip()
# Clean up the Cypher query (remove markdown formatting if present)
if cypher.startswith("```"):
cypher = cypher.split("\n")[1:-1]
cypher = "\n".join(cypher)
logger.info(f"Generated Cypher: {cypher}")
# Execute the query
with GraphDatabase.driver(URI, auth=(USER, PASS)) as driver:
result = driver.execute_query(cypher)
records = [record.data() for record in result.records]
logger.info(f"Query returned {len(records)} results")
return f"""📊 **Query Results**
**Your question:** {natural_language_query}
**Generated Cypher:**
```cypher
{cypher}
```
**Results ({len(records)} found):**
{json.dumps(records, indent=2) if records else "No results found."}"""
except Exception as e:
logger.error(f"Query error: {str(e)}")
return f"Error querying knowledge graph: {str(e)}"
@mcp.tool()
async def get_graph_stats() -> str:
"""Get statistics about the current cardiology knowledge graph."""
try:
with GraphDatabase.driver(URI, auth=(USER, PASS)) as driver:
# Get node counts by label
node_stats = driver.execute_query("""
MATCH (n)
RETURN labels(n)[0] as label, count(n) as count
ORDER BY count DESC
""")
# Get relationship counts by type
rel_stats = driver.execute_query("""
MATCH ()-[r]->()
RETURN type(r) as relationship_type, count(r) as count
ORDER BY count DESC
""")
# Total counts
total_nodes = driver.execute_query("MATCH (n) RETURN count(n) as total").records[0]["total"]
total_rels = driver.execute_query("MATCH ()-[r]->() RETURN count(r) as total").records[0]["total"]
node_breakdown = "\n".join([f"- {record['label']}: {record['count']}" for record in node_stats.records])
rel_breakdown = "\n".join([f"- {record['relationship_type']}: {record['count']}" for record in rel_stats.records])
return f"""📈 **Cardiology Knowledge Graph Statistics**
**Total Nodes:** {total_nodes}
**Total Relationships:** {total_rels}
**Node Types:**
{node_breakdown}
**Relationship Types:**
{rel_breakdown}"""
except Exception as e:
logger.error(f"Stats error: {str(e)}")
return f"Error getting graph statistics: {str(e)}"
@mcp.tool()
async def clear_graph() -> str:
"""Clear all data from the knowledge graph (use with caution!)."""
try:
with GraphDatabase.driver(URI, auth=(USER, PASS)) as driver:
driver.execute_query("MATCH (n) DETACH DELETE n")
return "⚠️ Knowledge graph cleared successfully. All nodes and relationships have been deleted."
except Exception as e:
logger.error(f"Clear error: {str(e)}")
return f"Error clearing knowledge graph: {str(e)}"
if __name__ == "__main__":
logger.info("Starting Cardiology Knowledge Graph MCP Server...")
mcp.run(transport='stdio') # Or 'http' for web access