test_escalation_policies_unit.py•7.14 kB
from unittest.mock import MagicMock
import pytest
from pagerduty_mcp_server import escalation_policies, utils
from pagerduty_mcp_server.parsers.escalation_policy_parser import (
    parse_escalation_policy,
)
@pytest.mark.unit
@pytest.mark.escalation_policies
def test_list_escalation_policies(mock_get_api_client, mock_escalation_policies):
    """Test that escalation policies are fetched correctly."""
    mock_get_api_client.list_all.return_value = mock_escalation_policies
    policy_list = escalation_policies.list_escalation_policies()
    mock_get_api_client.list_all.assert_called_once_with(
        escalation_policies.ESCALATION_POLICIES_URL, params={}
    )
    assert policy_list == utils.api_response_handler(
        results=[
            parse_escalation_policy(result=policy)
            for policy in mock_escalation_policies
        ],
        resource_name="escalation_policies",
    )
@pytest.mark.unit
@pytest.mark.escalation_policies
def test_list_escalation_policies_with_query(
    mock_get_api_client, mock_escalation_policies
):
    """Test that escalation policies can be filtered by query parameter."""
    query = "test"
    mock_get_api_client.list_all.return_value = mock_escalation_policies
    policy_list = escalation_policies.list_escalation_policies(query=query)
    mock_get_api_client.list_all.assert_called_once_with(
        escalation_policies.ESCALATION_POLICIES_URL, params={"query": query}
    )
    assert policy_list == utils.api_response_handler(
        results=[
            parse_escalation_policy(result=policy)
            for policy in mock_escalation_policies
        ],
        resource_name="escalation_policies",
    )
@pytest.mark.unit
@pytest.mark.escalation_policies
def test_list_escalation_policies_api_error(mock_get_api_client):
    """Test that list_escalation_policies handles API errors correctly."""
    mock_get_api_client.list_all.side_effect = RuntimeError("API Error")
    with pytest.raises(RuntimeError) as exc_info:
        escalation_policies.list_escalation_policies()
    assert str(exc_info.value) == "API Error"
@pytest.mark.unit
@pytest.mark.escalation_policies
def test_list_escalation_policies_empty_response(mock_get_api_client):
    """Test that list_escalation_policies handles empty response correctly."""
    mock_get_api_client.list_all.return_value = []
    policy_list = escalation_policies.list_escalation_policies()
    assert policy_list == utils.api_response_handler(
        results=[], resource_name="escalation_policies"
    )
@pytest.mark.unit
@pytest.mark.escalation_policies
def test_list_escalation_policies_preserves_full_error_message(mock_get_api_client):
    """Test that list_escalation_policies preserves the full error message from the API response."""
    mock_response = MagicMock()
    mock_response.text = '{"error":{"message":"Invalid Input Provided","code":2001,"errors":["Invalid team ID format"]}}'
    mock_error = RuntimeError("API Error")
    mock_error.response = mock_response
    mock_get_api_client.list_all.side_effect = mock_error
    with pytest.raises(RuntimeError) as exc_info:
        escalation_policies.list_escalation_policies()
    assert str(exc_info.value) == "API Error"
@pytest.mark.unit
@pytest.mark.escalation_policies
def test_fetch_escalation_policy_ids(
    mock_get_api_client, mock_escalation_policies, mock_user
):
    """Test that escalation policy IDs are fetched correctly."""
    mock_get_api_client.list_all.return_value = mock_escalation_policies
    policy_ids = escalation_policies.fetch_escalation_policy_ids(
        user_id=mock_user["id"]
    )
    mock_get_api_client.list_all.assert_called_once_with(
        escalation_policies.ESCALATION_POLICIES_URL,
        params={"user_ids[]": [mock_user["id"]]},
    )
    assert set(policy_ids) == set([policy["id"] for policy in mock_escalation_policies])
@pytest.mark.unit
@pytest.mark.escalation_policies
def test_fetch_escalation_policy_ids_api_error(mock_get_api_client, mock_user):
    """Test that fetch_escalation_policy_ids handles API errors correctly."""
    mock_get_api_client.list_all.side_effect = RuntimeError("API Error")
    with pytest.raises(RuntimeError) as exc_info:
        escalation_policies.fetch_escalation_policy_ids(user_id=mock_user["id"])
    assert str(exc_info.value) == "API Error"
@pytest.mark.unit
@pytest.mark.escalation_policies
def test_show_escalation_policy(mock_get_api_client, mock_escalation_policies):
    """Test that a single escalation policy is fetched correctly."""
    policy_id = mock_escalation_policies[0]["id"]
    mock_get_api_client.jget.return_value = {
        "escalation_policy": mock_escalation_policies[0]
    }
    policy = escalation_policies.show_escalation_policy(policy_id=policy_id)
    mock_get_api_client.jget.assert_called_once_with(
        f"{escalation_policies.ESCALATION_POLICIES_URL}/{policy_id}"
    )
    assert policy == utils.api_response_handler(
        results=parse_escalation_policy(result=mock_escalation_policies[0]),
        resource_name="escalation_policy",
    )
@pytest.mark.unit
@pytest.mark.escalation_policies
def test_show_escalation_policy_invalid_id(mock_get_api_client):
    """Test that show_escalation_policy raises ValueError for invalid policy ID."""
    with pytest.raises(ValueError) as exc_info:
        escalation_policies.show_escalation_policy(policy_id="")
    assert str(exc_info.value) == "policy_id cannot be empty"
@pytest.mark.unit
@pytest.mark.escalation_policies
def test_show_escalation_policy_api_error(mock_get_api_client):
    """Test that show_escalation_policy handles API errors correctly."""
    policy_id = "123"
    mock_get_api_client.jget.side_effect = RuntimeError("API Error")
    with pytest.raises(RuntimeError) as exc_info:
        escalation_policies.show_escalation_policy(policy_id=policy_id)
    assert str(exc_info.value) == "API Error"
@pytest.mark.unit
@pytest.mark.escalation_policies
def test_show_escalation_policy_invalid_response(mock_get_api_client):
    """Test that show_escalation_policy handles invalid API response correctly."""
    policy_id = "123"
    mock_get_api_client.jget.return_value = {}  # Missing 'escalation_policy' key
    with pytest.raises(RuntimeError) as exc_info:
        escalation_policies.show_escalation_policy(policy_id=policy_id)
    assert (
        str(exc_info.value)
        == "Failed to fetch escalation policy 123: Response missing 'escalation_policy' field"
    )
@pytest.mark.unit
@pytest.mark.parsers
@pytest.mark.escalation_policies
def test_parse_escalation_policy(
    mock_escalation_policies, mock_escalation_policies_parsed
):
    """Test that parse_escalation_policy correctly parses raw escalation policy data."""
    parsed_policy = parse_escalation_policy(result=mock_escalation_policies[0])
    assert parsed_policy == mock_escalation_policies_parsed[0]
@pytest.mark.unit
@pytest.mark.parsers
@pytest.mark.escalation_policies
def test_parse_escalation_policy_none():
    """Test that parse_escalation_policy handles None input correctly."""
    assert parse_escalation_policy(result=None) == {}