# tests/test_task_manager.py
from unittest.mock import MagicMock
import pytest
from a2a.client import A2ACardResolver, A2AClient
from a2a.types import Task, TaskState
@pytest.mark.asyncio
async def test_send_message_async_success(
task_manager, agent_manager, mocker, mock_agent_card
):
"""태스크 생성 성공 케이스를 테스트합니다."""
# Arrange
agent_url = "http://my.agent/api"
# AgentManager가 에이전트 정보를 가지고 있도록 미리 등록합니다.
mocker.patch.object(A2ACardResolver, "get_agent_card", return_value=mock_agent_card)
await agent_manager.register_agent(agent_url)
# A2AClient의 비동기 메서드를 AsyncMock으로 모킹합니다.
from a2a.types import SendMessageSuccessResponse, TaskState, TaskStatus
mock_task_status = TaskStatus(state=TaskState.working)
mock_task = Task(
id="task-123", contextId="test-context", status=mock_task_status, kind="task"
)
mock_response = SendMessageSuccessResponse(result=mock_task)
mocker.patch.object(A2AClient, "send_message", return_value=mock_response)
# Act
task_result = await task_manager.send_message_async(
agent_url, "Hello, agent!", None
)
# Assert
# task_id should be a valid UUID (generated by gateway)
assert "task_id" in task_result
assert len(task_result["task_id"]) == 36 # UUID length
# Check that task was added to task manager
assert task_result["task_id"] in task_manager.tasks
A2AClient.send_message.assert_called_once()
@pytest.mark.asyncio
async def test_send_message_async_agent_not_found(task_manager):
"""등록되지 않은 에이전트에 대한 태스크 생성 시 에러를 테스트합니다."""
with pytest.raises(ValueError, match="Agent not registered"):
await task_manager.send_message_async(
"http://unknown.agent/api", "message", None
)
@pytest.mark.parametrize(
"status_filter, expected_count",
[("all", 3), ("working", 1), ("completed", 1), ("failed", 1), ("canceled", 0)],
)
def test_get_task_list_with_filter(task_manager, status_filter, expected_count):
"""다양한 상태 필터에 따른 태스크 목록 조회를 테스트합니다."""
# Arrange - 여러 상태의 테스트 태스크를 미리 생성합니다.
from datetime import datetime, timezone
mock_task1 = MagicMock()
mock_task1.status = TaskState.working.value
mock_task1.updated_at = datetime(2023, 1, 1, tzinfo=timezone.utc)
mock_task2 = MagicMock()
mock_task2.status = TaskState.completed.value
mock_task2.updated_at = datetime(2023, 1, 2, tzinfo=timezone.utc)
mock_task3 = MagicMock()
mock_task3.status = TaskState.failed.value
mock_task3.updated_at = datetime(2023, 1, 3, tzinfo=timezone.utc)
task_manager.tasks = {
"task1": mock_task1,
"task2": mock_task2,
"task3": mock_task3,
}
# Act
task_list = task_manager.get_task_list(status_filter)
# Assert
assert len(task_list) == expected_count