import pytest
from pathlib import Path
from unittest.mock import AsyncMock, patch, MagicMock
import tempfile
import json
from src.server import Evo2Server, ServerSettings, Evo2ExecutionMode
@pytest.fixture
def server():
settings = ServerSettings(
execution_mode=Evo2ExecutionMode.API,
nim_api_key="test-api-key",
temp_dir=tempfile.gettempdir()
)
return Evo2Server(settings)
@pytest.mark.asyncio
async def test_list_tools(server):
tools = await server.server.list_tools()
assert len(tools) == 4
tool_names = [tool.name for tool in tools]
assert "evo2_generate" in tool_names
assert "evo2_score" in tool_names
assert "evo2_embed" in tool_names
assert "evo2_variant_effect" in tool_names
@pytest.mark.asyncio
async def test_validate_sequence(server):
# Valid sequence
error = await server._validate_sequence("ATCGATCG")
assert error is None
# Invalid bases
error = await server._validate_sequence("ATCGXYZ")
assert "Invalid DNA bases" in error
# Too long sequence
server.settings.max_sequence_length = 10
error = await server._validate_sequence("ATCGATCGATCGATCG")
assert "Sequence too long" in error
@pytest.mark.asyncio
async def test_generate_api_mode(server):
# Mock httpx client
with patch("httpx.AsyncClient") as mock_client_class:
mock_client = MagicMock()
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.json.return_value = {
"choices": [{"text": "ATCGATCGATCG"}]
}
mock_client.post.return_value = mock_response
mock_client_class.return_value.__aenter__.return_value = mock_client
result = await server._generate({
"prompt": "ATCG",
"n_tokens": 10,
"temperature": 1.0
})
assert len(result) == 1
assert "Generated sequence:" in result[0].text
assert "ATCGATCGATCG" in result[0].text
@pytest.mark.asyncio
async def test_score_api_mode(server):
# Mock httpx client
with patch("httpx.AsyncClient") as mock_client_class:
mock_client = MagicMock()
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.json.return_value = {
"perplexity": 2.345
}
mock_client.post.return_value = mock_response
mock_client_class.return_value.__aenter__.return_value = mock_client
result = await server._score({
"sequence": "ATCGATCG",
"return_logits": False
})
assert len(result) == 1
assert "Sequence perplexity: 2.3450" in result[0].text
@pytest.mark.asyncio
async def test_variant_effect(server):
# Mock httpx client for two API calls
with patch("httpx.AsyncClient") as mock_client_class:
mock_client = MagicMock()
mock_response = MagicMock()
mock_response.status_code = 200
# Return different perplexities for ref and var
mock_response.json.side_effect = [
{"perplexity": 2.0}, # Reference
{"perplexity": 3.0} # Variant
]
mock_client.post.return_value = mock_response
mock_client_class.return_value.__aenter__.return_value = mock_client
result = await server._variant_effect({
"reference_sequence": "ATCGATCG",
"variant_sequence": "ATCAATCG"
})
assert len(result) == 1
assert "Reference perplexity: 2.0000" in result[0].text
assert "Variant perplexity: 3.0000" in result[0].text
assert "Effect score: 1.0000" in result[0].text
assert "Deleterious" in result[0].text
@pytest.mark.asyncio
async def test_sbatch_execution(server):
server.settings.execution_mode = Evo2ExecutionMode.SBATCH
with patch("asyncio.create_subprocess_exec") as mock_exec:
mock_process = AsyncMock()
mock_process.returncode = 0
mock_process.communicate.return_value = (b"Job submitted", b"")
mock_exec.return_value = mock_process
# Mock output file creation
with patch("pathlib.Path.exists") as mock_exists:
mock_exists.return_value = True
with patch("pathlib.Path.read_text") as mock_read:
mock_read.return_value = json.dumps({
"generated_sequence": "ATCGATCG"
})
result = await server._execute_sbatch("generate", {
"prompt": "ATCG",
"n_tokens": 10
})
assert result["generated_sequence"] == "ATCGATCG"