"""
Unit tests for the AMP client.
"""
import pytest
from unittest.mock import AsyncMock, patch, MagicMock
from prometheus_mcp.amp_client import AMPClient
class TestAMPClient:
"""Tests for AMPClient."""
def test_init_with_env_vars(self, mock_boto3_session):
"""Test client initialization with environment variables."""
client = AMPClient()
assert client.workspace_id == "ws-test-workspace-id"
assert client.region == "us-east-1"
assert "aps-workspaces.us-east-1.amazonaws.com" in client.base_url
def test_init_with_params(self, mock_boto3_session):
"""Test client initialization with explicit parameters."""
client = AMPClient(workspace_id="ws-custom", region="us-west-2")
assert client.workspace_id == "ws-custom"
assert client.region == "us-west-2"
assert "aps-workspaces.us-west-2.amazonaws.com" in client.base_url
def test_init_missing_workspace_id(self, mock_boto3_session):
"""Test that missing workspace ID raises an error."""
import os
with patch.dict(os.environ, {"PROMETHEUS_WORKSPACE_ID": ""}, clear=False):
# Remove the env var
env_copy = os.environ.copy()
if "PROMETHEUS_WORKSPACE_ID" in env_copy:
del env_copy["PROMETHEUS_WORKSPACE_ID"]
with patch.dict(os.environ, env_copy, clear=True):
with pytest.raises(ValueError, match="workspace_id must be provided"):
AMPClient(workspace_id=None)
def test_sign_request(self, mock_boto3_session):
"""Test that requests are signed with SigV4."""
client = AMPClient()
headers = client._sign_request(
method="GET",
url="https://aps-workspaces.us-east-1.amazonaws.com/workspaces/ws-test/api/v1/labels",
headers={"Content-Type": "application/json"},
)
# SigV4 adds Authorization header
assert "Authorization" in headers
assert "AWS4-HMAC-SHA256" in headers["Authorization"]
@pytest.mark.asyncio
async def test_query_instant(self, mock_boto3_session, sample_instant_query_response):
"""Test instant query execution."""
client = AMPClient()
with patch.object(client, "_request", new_callable=AsyncMock) as mock_request:
mock_request.return_value = sample_instant_query_response
result = await client.query("up")
mock_request.assert_called_once_with(
"POST",
"/api/v1/query",
data={"query": "up"}
)
assert result["status"] == "success"
assert len(result["data"]["result"]) == 2
@pytest.mark.asyncio
async def test_query_instant_with_time(self, mock_boto3_session, sample_instant_query_response):
"""Test instant query with specific time."""
client = AMPClient()
with patch.object(client, "_request", new_callable=AsyncMock) as mock_request:
mock_request.return_value = sample_instant_query_response
await client.query("up", time="2024-01-15T10:00:00Z")
mock_request.assert_called_once_with(
"POST",
"/api/v1/query",
data={"query": "up", "time": "2024-01-15T10:00:00Z"}
)
@pytest.mark.asyncio
async def test_query_range(self, mock_boto3_session, sample_range_query_response):
"""Test range query execution."""
client = AMPClient()
with patch.object(client, "_request", new_callable=AsyncMock) as mock_request:
mock_request.return_value = sample_range_query_response
result = await client.query_range(
promql="up",
start="2024-01-15T00:00:00Z",
end="2024-01-15T01:00:00Z",
step="1m"
)
mock_request.assert_called_once_with(
"POST",
"/api/v1/query_range",
data={
"query": "up",
"start": "2024-01-15T00:00:00Z",
"end": "2024-01-15T01:00:00Z",
"step": "1m"
}
)
assert result["status"] == "success"
@pytest.mark.asyncio
async def test_labels(self, mock_boto3_session, sample_labels_response):
"""Test labels retrieval."""
client = AMPClient()
with patch.object(client, "_request", new_callable=AsyncMock) as mock_request:
mock_request.return_value = sample_labels_response
result = await client.labels()
mock_request.assert_called_once_with("GET", "/api/v1/labels", params=None)
assert "__name__" in result["data"]
@pytest.mark.asyncio
async def test_labels_with_match(self, mock_boto3_session, sample_labels_response):
"""Test labels with match filter."""
client = AMPClient()
with patch.object(client, "_request", new_callable=AsyncMock) as mock_request:
mock_request.return_value = sample_labels_response
await client.labels(match=["up", "http_requests_total"])
mock_request.assert_called_once_with(
"GET",
"/api/v1/labels",
params={"match[]": ["up", "http_requests_total"]}
)
@pytest.mark.asyncio
async def test_label_values(self, mock_boto3_session, sample_label_values_response):
"""Test label values retrieval."""
client = AMPClient()
with patch.object(client, "_request", new_callable=AsyncMock) as mock_request:
mock_request.return_value = sample_label_values_response
result = await client.label_values("job")
mock_request.assert_called_once_with(
"GET",
"/api/v1/label/job/values",
params=None
)
assert "prometheus" in result["data"]
@pytest.mark.asyncio
async def test_metadata(self, mock_boto3_session, sample_metadata_response):
"""Test metadata retrieval."""
client = AMPClient()
with patch.object(client, "_request", new_callable=AsyncMock) as mock_request:
mock_request.return_value = sample_metadata_response
result = await client.metadata()
mock_request.assert_called_once_with("GET", "/api/v1/metadata", params=None)
assert "up" in result["data"]
@pytest.mark.asyncio
async def test_context_manager(self, mock_boto3_session):
"""Test async context manager."""
async with AMPClient() as client:
assert client is not None
assert client.workspace_id == "ws-test-workspace-id"