import asyncio
import json
import logging
import os
import tempfile
from pathlib import Path
from typing import Any, Optional, List, Dict
from enum import Enum
import httpx
from mcp.server import Server
from mcp.server.stdio import stdio_server
from mcp.types import Tool, TextContent, ImageContent, ErrorContent
from pydantic import BaseModel, Field
from pydantic_settings import BaseSettings
from .tool_detection import ToolDetector, ToolConfig, ExecutionMode, ToolInfo
logger = logging.getLogger(__name__)
class Evo2ExecutionMode(str, Enum):
LOCAL = "local" # Direct Python execution with GPU
SBATCH = "sbatch" # Submit to SLURM cluster
SINGULARITY = "singularity" # Use Singularity container
DOCKER = "docker" # Use Docker container
API = "api" # Use Nvidia NIM API
class ServerSettings(BaseSettings):
max_sequence_length: int = Field(default=1_000_000, description="Maximum DNA sequence length")
temp_dir: Optional[str] = Field(default=None, description="Temporary directory for processing")
timeout: int = Field(default=600, description="Command timeout in seconds")
# Execution mode settings
execution_mode: Evo2ExecutionMode = Field(default=Evo2ExecutionMode.API, description="Execution mode")
# GPU settings for local mode
cuda_device: int = Field(default=0, description="CUDA device index")
model_size: str = Field(default="7b", description="Model size: 7b or 40b")
model_path: Optional[str] = Field(default=None, description="Path to local model checkpoint")
# SBATCH settings
sbatch_partition: str = Field(default="gpu", description="SLURM partition")
sbatch_time: str = Field(default="01:00:00", description="SLURM time limit")
sbatch_memory: str = Field(default="64G", description="SLURM memory")
sbatch_gpus: int = Field(default=1, description="Number of GPUs")
sbatch_gpu_type: str = Field(default="h100", description="GPU type constraint")
sbatch_account: Optional[str] = Field(default=None, description="SLURM account")
# Container settings
singularity_image_path: Optional[str] = Field(default=None, description="Path to Singularity image")
docker_image: str = Field(default="nvcr.io/nvidia/bionemo/evo2:latest", description="Docker image")
# API settings
nim_api_key: Optional[str] = Field(default=None, description="Nvidia NIM API key")
nim_api_url: str = Field(default="https://integrate.api.nvidia.com/v1", description="NIM API URL")
class Config:
env_prefix = "BIO_MCP_EVO2_"
class Evo2Server:
def __init__(self, settings: Optional[ServerSettings] = None):
self.settings = settings or ServerSettings()
self.server = Server("bio-mcp-evo2")
self._setup_handlers()
self._model = None
self._tokenizer = None
async def _detect_gpu(self) -> bool:
"""Detect if GPU is available."""
try:
import torch
return torch.cuda.is_available()
except ImportError:
return False
async def _validate_sequence(self, sequence: str) -> Optional[str]:
"""Validate DNA sequence."""
if len(sequence) > self.settings.max_sequence_length:
return f"Sequence too long. Maximum length: {self.settings.max_sequence_length}"
valid_bases = set("ACGTN")
invalid_bases = set(sequence.upper()) - valid_bases
if invalid_bases:
return f"Invalid DNA bases found: {invalid_bases}"
return None
def _setup_handlers(self):
@self.server.list_tools()
async def list_tools() -> list[Tool]:
return [
Tool(
name="evo2_generate",
description="Generate DNA sequences using evo2 model",
inputSchema={
"type": "object",
"properties": {
"prompt": {
"type": "string",
"description": "DNA sequence prompt (e.g., 'ATCG...')"
},
"n_tokens": {
"type": "integer",
"description": "Number of tokens to generate",
"default": 100,
"minimum": 1,
"maximum": 10000
},
"temperature": {
"type": "number",
"description": "Sampling temperature",
"default": 1.0,
"minimum": 0.1,
"maximum": 2.0
},
"top_k": {
"type": "integer",
"description": "Top-k sampling parameter",
"default": 4,
"minimum": 1
}
},
"required": ["prompt"]
}
),
Tool(
name="evo2_score",
description="Score DNA sequences with evo2 model",
inputSchema={
"type": "object",
"properties": {
"sequence": {
"type": "string",
"description": "DNA sequence to score"
},
"return_logits": {
"type": "boolean",
"description": "Return raw logits instead of perplexity",
"default": False
}
},
"required": ["sequence"]
}
),
Tool(
name="evo2_embed",
description="Extract embeddings from DNA sequences",
inputSchema={
"type": "object",
"properties": {
"sequence": {
"type": "string",
"description": "DNA sequence to embed"
},
"layer": {
"type": "string",
"description": "Layer to extract embeddings from",
"default": "blocks.28.mlp.l3"
}
},
"required": ["sequence"]
}
),
Tool(
name="evo2_variant_effect",
description="Predict the effect of DNA variants",
inputSchema={
"type": "object",
"properties": {
"reference_sequence": {
"type": "string",
"description": "Reference DNA sequence"
},
"variant_sequence": {
"type": "string",
"description": "Variant DNA sequence"
},
"context_window": {
"type": "integer",
"description": "Context window size around variant",
"default": 1000,
"minimum": 100,
"maximum": 10000
}
},
"required": ["reference_sequence", "variant_sequence"]
}
),
]
@self.server.call_tool()
async def call_tool(name: str, arguments: Any) -> list[TextContent | ImageContent | ErrorContent]:
if name == "evo2_generate":
return await self._generate(arguments)
elif name == "evo2_score":
return await self._score(arguments)
elif name == "evo2_embed":
return await self._embed(arguments)
elif name == "evo2_variant_effect":
return await self._variant_effect(arguments)
else:
return [ErrorContent(text=f"Unknown tool: {name}")]
async def _execute_local(self, function: str, arguments: dict) -> dict:
"""Execute evo2 locally with GPU."""
try:
# Lazy import to avoid loading if not using local mode
import torch
from evo2 import Evo2
# Initialize model if not already loaded
if self._model is None:
logger.info(f"Loading evo2 {self.settings.model_size} model...")
os.environ["CUDA_VISIBLE_DEVICES"] = str(self.settings.cuda_device)
if self.settings.model_path:
self._model = Evo2.from_checkpoint(self.settings.model_path)
else:
self._model = Evo2(f'evo2_{self.settings.model_size}')
# Execute the requested function
if function == "generate":
output = self._model.generate(
prompt_seqs=[arguments["prompt"]],
n_tokens=arguments.get("n_tokens", 100),
temperature=arguments.get("temperature", 1.0),
top_k=arguments.get("top_k", 4)
)
return {"generated_sequence": output["generated_text"][0]}
elif function == "score":
sequence = arguments["sequence"]
input_ids = torch.tensor(
self._model.tokenizer.tokenize(sequence),
dtype=torch.int
).unsqueeze(0).to(f'cuda:{self.settings.cuda_device}')
outputs, _ = self._model(input_ids)
logits = outputs[0]
if arguments.get("return_logits", False):
return {"logits": logits.tolist()}
else:
# Calculate perplexity
loss = torch.nn.functional.cross_entropy(
logits[:-1].reshape(-1, logits.size(-1)),
input_ids[0, 1:].reshape(-1)
)
perplexity = torch.exp(loss).item()
return {"perplexity": perplexity}
elif function == "embed":
sequence = arguments["sequence"]
input_ids = torch.tensor(
self._model.tokenizer.tokenize(sequence),
dtype=torch.int
).unsqueeze(0).to(f'cuda:{self.settings.cuda_device}')
layer_name = arguments.get("layer", "blocks.28.mlp.l3")
outputs, embeddings = self._model(
input_ids,
return_embeddings=True,
layer_names=[layer_name]
)
return {
"embeddings": embeddings[layer_name].tolist(),
"shape": list(embeddings[layer_name].shape)
}
else:
raise ValueError(f"Unknown function: {function}")
except Exception as e:
raise RuntimeError(f"Local execution failed: {str(e)}")
async def _execute_sbatch(self, function: str, arguments: dict) -> dict:
"""Submit job to SLURM cluster."""
with tempfile.TemporaryDirectory(dir=self.settings.temp_dir) as tmpdir:
# Create job script
job_script = Path(tmpdir) / "evo2_job.sh"
input_file = Path(tmpdir) / "input.json"
output_file = Path(tmpdir) / "output.json"
# Write input arguments
input_file.write_text(json.dumps({
"function": function,
"arguments": arguments
}))
# Build SLURM script
account_line = f"#SBATCH --account={self.settings.sbatch_account}" if self.settings.sbatch_account else ""
script_content = f"""#!/bin/bash
#SBATCH --job-name=evo2-mcp
#SBATCH --partition={self.settings.sbatch_partition}
#SBATCH --time={self.settings.sbatch_time}
#SBATCH --mem={self.settings.sbatch_memory}
#SBATCH --gpus={self.settings.sbatch_gpus}
#SBATCH --constraint={self.settings.sbatch_gpu_type}
#SBATCH --output={tmpdir}/slurm-%j.out
#SBATCH --error={tmpdir}/slurm-%j.err
{account_line}
module load python/3.11
module load cuda/12.0
# Activate environment or load evo2 module
source /path/to/evo2/env/bin/activate || module load evo2
python -c "
import json
import torch
from evo2 import Evo2
# Load input
with open('{input_file}', 'r') as f:
data = json.load(f)
function = data['function']
arguments = data['arguments']
# Initialize model
model = Evo2('evo2_{self.settings.model_size}')
# Execute function
result = None
if function == 'generate':
output = model.generate(
prompt_seqs=[arguments['prompt']],
n_tokens=arguments.get('n_tokens', 100),
temperature=arguments.get('temperature', 1.0),
top_k=arguments.get('top_k', 4)
)
result = {{'generated_sequence': output['generated_text'][0]}}
elif function == 'score':
sequence = arguments['sequence']
input_ids = torch.tensor(
model.tokenizer.tokenize(sequence),
dtype=torch.int
).unsqueeze(0).cuda()
outputs, _ = model(input_ids)
logits = outputs[0]
if arguments.get('return_logits', False):
result = {{'logits': logits.tolist()}}
else:
loss = torch.nn.functional.cross_entropy(
logits[:-1].reshape(-1, logits.size(-1)),
input_ids[0, 1:].reshape(-1)
)
perplexity = torch.exp(loss).item()
result = {{'perplexity': perplexity}}
# Save output
with open('{output_file}', 'w') as f:
json.dump(result, f)
"
"""
job_script.write_text(script_content)
# Submit job
process = await asyncio.create_subprocess_exec(
"sbatch", "--wait", str(job_script),
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE
)
stdout, stderr = await process.communicate()
if process.returncode != 0:
raise RuntimeError(f"SBATCH submission failed: {stderr.decode()}")
# Read output
if output_file.exists():
return json.loads(output_file.read_text())
else:
raise RuntimeError("Job completed but no output file found")
async def run(self):
async with stdio_server() as (read_stream, write_stream):
await self.server.run(read_stream, write_stream)
async def _execute_container(self, function: str, arguments: dict, use_docker: bool = False) -> dict:
"""Execute using container (Docker or Singularity)."""
with tempfile.TemporaryDirectory(dir=self.settings.temp_dir) as tmpdir:
input_file = Path(tmpdir) / "input.json"
output_file = Path(tmpdir) / "output.json"
# Write input
input_file.write_text(json.dumps({
"function": function,
"arguments": arguments
}))
# Create Python script to run inside container
script_file = Path(tmpdir) / "run_evo2.py"
script_content = f"""import json
import torch
from evo2 import Evo2
# Load input
with open('{input_file}', 'r') as f:
data = json.load(f)
function = data['function']
arguments = data['arguments']
# Initialize model
model = Evo2('evo2_{self.settings.model_size}')
# Execute function
result = None
if function == 'generate':
output = model.generate(
prompt_seqs=[arguments['prompt']],
n_tokens=arguments.get('n_tokens', 100),
temperature=arguments.get('temperature', 1.0),
top_k=arguments.get('top_k', 4)
)
result = {{'generated_sequence': output['generated_text'][0]}}
elif function == 'score':
sequence = arguments['sequence']
input_ids = torch.tensor(
model.tokenizer.tokenize(sequence),
dtype=torch.int
).unsqueeze(0).cuda()
outputs, _ = model(input_ids)
logits = outputs[0]
if arguments.get('return_logits', False):
result = {{'logits': logits.tolist()}}
else:
loss = torch.nn.functional.cross_entropy(
logits[:-1].reshape(-1, logits.size(-1)),
input_ids[0, 1:].reshape(-1)
)
perplexity = torch.exp(loss).item()
result = {{'perplexity': perplexity}}
# Save output
with open('{output_file}', 'w') as f:
json.dump(result, f)
"""
script_file.write_text(script_content)
# Build container command
if use_docker:
cmd = [
"docker", "run", "--rm",
"--gpus", "all",
"-v", f"{tmpdir}:{tmpdir}",
self.settings.docker_image,
"python", str(script_file)
]
else:
if not self.settings.singularity_image_path:
raise RuntimeError("Singularity image path not configured")
cmd = [
"singularity", "run",
"--nv", # Enable GPU support
"--bind", f"{tmpdir}:{tmpdir}",
self.settings.singularity_image_path,
"python", str(script_file)
]
process = await asyncio.create_subprocess_exec(
*cmd,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE
)
try:
stdout, stderr = await asyncio.wait_for(
process.communicate(),
timeout=self.settings.timeout
)
except asyncio.TimeoutError:
process.kill()
raise RuntimeError(f"Container execution timed out after {self.settings.timeout} seconds")
if process.returncode != 0:
raise RuntimeError(f"Container execution failed: {stderr.decode()}")
# Read output
if output_file.exists():
return json.loads(output_file.read_text())
else:
raise RuntimeError("Container completed but no output file found")
async def _execute_api(self, function: str, arguments: dict) -> dict:
"""Execute using Nvidia NIM API."""
if not self.settings.nim_api_key:
raise RuntimeError("NIM API key not configured. Set BIO_MCP_EVO2_NIM_API_KEY environment variable.")
async with httpx.AsyncClient() as client:
headers = {
"Authorization": f"Bearer {self.settings.nim_api_key}",
"Content-Type": "application/json"
}
# Map functions to API endpoints
if function == "generate":
response = await client.post(
f"{self.settings.nim_api_url}/completions",
headers=headers,
json={
"model": f"evo2-{self.settings.model_size}",
"prompt": arguments["prompt"],
"max_tokens": arguments.get("n_tokens", 100),
"temperature": arguments.get("temperature", 1.0),
"top_k": arguments.get("top_k", 4)
},
timeout=self.settings.timeout
)
if response.status_code != 200:
raise RuntimeError(f"API request failed: {response.text}")
data = response.json()
return {"generated_sequence": data["choices"][0]["text"]}
elif function == "score":
# Use embeddings endpoint for scoring
response = await client.post(
f"{self.settings.nim_api_url}/embeddings",
headers=headers,
json={
"model": f"evo2-{self.settings.model_size}",
"input": arguments["sequence"],
"encoding_format": "float"
},
timeout=self.settings.timeout
)
if response.status_code != 200:
raise RuntimeError(f"API request failed: {response.text}")
data = response.json()
# API returns embeddings, calculate perplexity from them
return {"perplexity": data.get("perplexity", 0.0)}
else:
raise ValueError(f"Function {function} not supported via API")
async def _generate(self, arguments: dict) -> list[TextContent | ErrorContent]:
"""Generate DNA sequences."""
try:
# Validate prompt
prompt = arguments.get("prompt", "")
error = await self._validate_sequence(prompt)
if error:
return [ErrorContent(text=error)]
# Execute based on mode
result = None
if self.settings.execution_mode == Evo2ExecutionMode.LOCAL:
result = await self._execute_local("generate", arguments)
elif self.settings.execution_mode == Evo2ExecutionMode.SBATCH:
result = await self._execute_sbatch("generate", arguments)
elif self.settings.execution_mode == Evo2ExecutionMode.SINGULARITY:
result = await self._execute_container("generate", arguments, use_docker=False)
elif self.settings.execution_mode == Evo2ExecutionMode.DOCKER:
result = await self._execute_container("generate", arguments, use_docker=True)
elif self.settings.execution_mode == Evo2ExecutionMode.API:
result = await self._execute_api("generate", arguments)
return [TextContent(text=f"Generated sequence:\n{result['generated_sequence']}")]
except Exception as e:
logger.error(f"Error in generate: {e}", exc_info=True)
return [ErrorContent(text=f"Generation failed: {str(e)}")]
async def _score(self, arguments: dict) -> list[TextContent | ErrorContent]:
"""Score DNA sequences."""
try:
# Validate sequence
sequence = arguments.get("sequence", "")
error = await self._validate_sequence(sequence)
if error:
return [ErrorContent(text=error)]
# Execute based on mode
result = None
if self.settings.execution_mode == Evo2ExecutionMode.LOCAL:
result = await self._execute_local("score", arguments)
elif self.settings.execution_mode == Evo2ExecutionMode.SBATCH:
result = await self._execute_sbatch("score", arguments)
elif self.settings.execution_mode == Evo2ExecutionMode.SINGULARITY:
result = await self._execute_container("score", arguments, use_docker=False)
elif self.settings.execution_mode == Evo2ExecutionMode.DOCKER:
result = await self._execute_container("score", arguments, use_docker=True)
elif self.settings.execution_mode == Evo2ExecutionMode.API:
result = await self._execute_api("score", arguments)
if "perplexity" in result:
return [TextContent(text=f"Sequence perplexity: {result['perplexity']:.4f}")]
else:
return [TextContent(text=f"Logits shape: {result.get('shape', 'N/A')}")]
except Exception as e:
logger.error(f"Error in score: {e}", exc_info=True)
return [ErrorContent(text=f"Scoring failed: {str(e)}")]
async def _embed(self, arguments: dict) -> list[TextContent | ErrorContent]:
"""Extract embeddings from sequences."""
try:
# Validate sequence
sequence = arguments.get("sequence", "")
error = await self._validate_sequence(sequence)
if error:
return [ErrorContent(text=error)]
# Note: Embedding might not be available in all modes
if self.settings.execution_mode == Evo2ExecutionMode.API:
return [ErrorContent(text="Embedding extraction not available via API yet")]
# Execute based on mode
result = None
if self.settings.execution_mode == Evo2ExecutionMode.LOCAL:
result = await self._execute_local("embed", arguments)
elif self.settings.execution_mode == Evo2ExecutionMode.SBATCH:
result = await self._execute_sbatch("embed", arguments)
elif self.settings.execution_mode == Evo2ExecutionMode.SINGULARITY:
result = await self._execute_container("embed", arguments, use_docker=False)
elif self.settings.execution_mode == Evo2ExecutionMode.DOCKER:
result = await self._execute_container("embed", arguments, use_docker=True)
shape = result.get("shape", [])
return [TextContent(text=f"Embeddings extracted. Shape: {shape}")]
except Exception as e:
logger.error(f"Error in embed: {e}", exc_info=True)
return [ErrorContent(text=f"Embedding failed: {str(e)}")]
async def _variant_effect(self, arguments: dict) -> list[TextContent | ErrorContent]:
"""Predict variant effects."""
try:
# Validate sequences
ref_seq = arguments.get("reference_sequence", "")
var_seq = arguments.get("variant_sequence", "")
for seq, name in [(ref_seq, "reference"), (var_seq, "variant")]:
error = await self._validate_sequence(seq)
if error:
return [ErrorContent(text=f"{name} sequence: {error}")]
# Score both sequences
ref_args = {"sequence": ref_seq, "return_logits": False}
var_args = {"sequence": var_seq, "return_logits": False}
# Execute based on mode
if self.settings.execution_mode == Evo2ExecutionMode.LOCAL:
ref_result = await self._execute_local("score", ref_args)
var_result = await self._execute_local("score", var_args)
elif self.settings.execution_mode == Evo2ExecutionMode.API:
ref_result = await self._execute_api("score", ref_args)
var_result = await self._execute_api("score", var_args)
else:
# For other modes, batch the requests
ref_result = await self._execute_sbatch("score", ref_args)
var_result = await self._execute_sbatch("score", var_args)
ref_perplexity = ref_result.get("perplexity", 0)
var_perplexity = var_result.get("perplexity", 0)
# Calculate effect
effect_score = var_perplexity - ref_perplexity
return [TextContent(text=f"""Variant Effect Analysis:
Reference perplexity: {ref_perplexity:.4f}
Variant perplexity: {var_perplexity:.4f}
Effect score: {effect_score:.4f}
Interpretation: {'Deleterious' if effect_score > 0 else 'Neutral/Beneficial'}""")]
except Exception as e:
logger.error(f"Error in variant_effect: {e}", exc_info=True)
return [ErrorContent(text=f"Variant analysis failed: {str(e)}")]
async def main():
logging.basicConfig(level=logging.INFO)
server = Evo2Server()
await server.run()
if __name__ == "__main__":
asyncio.run(main())