from mcp.server.models import InitializationOptions
import mcp.types as types
from mcp.server import NotificationOptions, Server
import mcp.server.stdio
import os
import scanpy as sc
from .tool.scvi import scvi_tools, run_scvi_func
from .logging_config import setup_logger
from . import __version__
logger = setup_logger(log_file=os.environ.get("SCVI_MCP_LOG_FILE", None))
class ModelState:
"""
Manages AnnData objects and scvi-tools models.
"""
def __init__(self):
data_path = os.environ.get("SCVI_MCP_DATA", None)
self.adata_dic = {}
self.active = None
# Model storage
self.scvi_model = None
self.scanvi_model = None
self.totalvi_model = None
self.peakvi_model = None
# Results storage
self.de_results = None
self.da_results = None
if data_path:
adata0 = sc.read_h5ad(data_path)
self.adata_dic["adata0"] = adata0
self.active = "adata0"
logger.info(f"Loading data from {data_path}")
# Initialize state
state = ModelState()
# Initialize MCP server
server = Server("scvi-mcp")
@server.list_tools()
async def list_tools() -> list[types.Tool]:
"""
List all available scvi-tools MCP tools.
"""
return list(scvi_tools.values())
@server.call_tool()
async def call_tool(name: str, arguments):
"""
Execute a scvi-tools MCP tool.
Parameters:
name: Tool name
arguments: Tool arguments
Returns:
Tool execution result
"""
try:
logger.info(f"Running {name} with {arguments}")
if name in scvi_tools.keys():
res = run_scvi_func(state, name, arguments)
else:
raise ValueError(f"Unknown tool: {name}")
output = str(res) if res is not None else "Operation completed successfully"
return [
types.TextContent(
type="text",
text=str({"output": output})
)
]
except Exception as error:
logger.error(f"Error in {name}: {error}")
return [
types.TextContent(
type="text",
text=str({"error": str(error)})
)
]
# Run server with stdio transport
async def run_stdio():
"""
Run server with stdio transport.
"""
async with mcp.server.stdio.stdio_server() as (read_stream, write_stream):
await server.run(
read_stream,
write_stream,
InitializationOptions(
server_name="scvi-mcp",
server_version=__version__,
capabilities=server.get_capabilities(
notification_options=NotificationOptions(),
experimental_capabilities={},
),
),
)
# Create application using SSE transport
def create_sse_app(port=8000):
"""
Create application using SSE transport.
Parameters:
port: Server port number
Returns:
Starlette application instance
"""
from starlette.applications import Starlette
from starlette.routing import Route, Mount
from starlette.requests import Request
from mcp.server.sse import SseServerTransport
# Create SSE transport object
sse = SseServerTransport("/messages/")
# Define SSE handler function
async def handle_sse(request):
async with sse.connect_sse(
request.scope, request.receive, request._send
) as streams:
await server.run(
streams[0], streams[1],
InitializationOptions(
server_name="scvi-mcp",
server_version=__version__,
capabilities=server.get_capabilities(
notification_options=NotificationOptions(),
experimental_capabilities={},
),
)
)
# Create Starlette application
starlette_app = Starlette(
routes=[
Route("/sse", endpoint=handle_sse),
Mount("/messages/", app=sse.handle_post_message),
]
)
return starlette_app