"""Tests for Celery MCP."""
from unittest.mock import Mock, patch
import pytest
from celery_mcp import CeleryMCP
class TestCeleryMCP:
"""Test cases for CeleryMCP class."""
def test_init(self):
"""Test initialization of CeleryMCP."""
with patch("celery_mcp.celery_mcp.Celery") as mock_celery:
mcp = CeleryMCP("redis://localhost:6379/0")
mock_celery.assert_called_once_with(
"celery_mcp", broker="redis://localhost:6379/0", backend=None
)
assert mcp.app == mock_celery.return_value
def test_send_task(self):
"""Test sending a task."""
with patch("celery_mcp.celery_mcp.Celery") as mock_celery:
mock_app = Mock()
mock_celery.return_value = mock_app
mcp = CeleryMCP("redis://localhost:6379/0")
result = mcp.send_task("test_task", args=[1, 2], kwargs={"key": "value"})
mock_app.send_task.assert_called_once_with(
"test_task", args=[1, 2], kwargs={"key": "value"}
)
assert result == mock_app.send_task.return_value
def test_register_task(self):
"""Test registering a task."""
with patch("celery_mcp.celery_mcp.Celery") as mock_celery:
mock_app = Mock()
mock_celery.return_value = mock_app
mcp = CeleryMCP("redis://localhost:6379/0")
def dummy_func():
pass
mcp.register_task("dummy_task", dummy_func)
mock_app.task.assert_called_once_with(name="dummy_task")
mock_app.task.return_value.assert_called_once_with(dummy_func)
class TestMCPServer:
"""Test cases for MCP server functionality."""
@pytest.mark.asyncio
@patch("celery_mcp.server.get_celery_app")
async def test_initialize_celery_connection_success(self, mock_get_app):
"""Test successful Celery connection initialization."""
from celery_mcp.server import initialize_celery_connection
with patch("celery_mcp.server.initialize_celery") as mock_init:
result = await initialize_celery_connection("redis://localhost:6379/0")
mock_init.assert_called_once_with("redis://localhost:6379/0", None)
assert "Successfully connected" in result
@pytest.mark.asyncio
@patch("celery_mcp.server.get_celery_app")
async def test_initialize_celery_connection_failure(self, mock_get_app):
"""Test failed Celery connection initialization."""
from celery_mcp.server import initialize_celery_connection
with patch(
"celery_mcp.server.initialize_celery",
side_effect=Exception("Connection failed"),
):
result = await initialize_celery_connection("redis://localhost:6379/0")
assert "Failed to connect" in result
@pytest.mark.asyncio
@patch("celery_mcp.server.get_celery_app")
async def test_list_registered_tasks(self, mock_get_app):
"""Test listing registered tasks."""
from celery_mcp.server import list_registered_tasks
mock_app = Mock()
mock_app.tasks = {"task1": Mock(), "task2": Mock()}
mock_get_app.return_value = mock_app
result = await list_registered_tasks()
assert "task1" in result
assert "task2" in result
@pytest.mark.asyncio
@patch("celery_mcp.server.get_celery_app")
async def test_send_task_success(self, mock_get_app):
"""Test successful task sending."""
from celery_mcp.server import send_task
mock_app = Mock()
mock_task = Mock()
mock_task.id = "test-task-id"
mock_app.send_task.return_value = mock_task
mock_get_app.return_value = mock_app
result = await send_task("test_task", [1, 2], {"key": "value"})
mock_app.send_task.assert_called_once_with(
"test_task", args=[1, 2], kwargs={"key": "value"}
)
assert "test-task-id" in result
@pytest.mark.asyncio
@patch("celery_mcp.server.get_celery_app")
async def test_get_task_status(self, mock_get_app):
"""Test getting task status."""
from celery_mcp.server import get_task_status
mock_app = Mock()
mock_result = Mock()
mock_result.status = "SUCCESS"
mock_result.result = 42
mock_result.successful.return_value = True
with patch("celery_mcp.server.AsyncResult", return_value=mock_result):
mock_get_app.return_value = mock_app
result = await get_task_status("test-task-id")
assert "SUCCESS" in result
assert "42" in result
@pytest.mark.asyncio
@patch("celery_mcp.server.get_celery_app")
async def test_revoke_task(self, mock_get_app):
"""Test revoking a task."""
from celery_mcp.server import revoke_task
mock_app = Mock()
mock_get_app.return_value = mock_app
result = await revoke_task("test-task-id", True)
mock_app.control.revoke.assert_called_once_with("test-task-id", terminate=True)
assert "terminated" in result