"""Unit tests for source collection helpers.
Covers success paths, 404 handling per-group, and error propagation.
"""
# pyright: reportPrivateUsage=false
import json
from unittest.mock import AsyncMock, MagicMock
import httpx
import pytest
from cribl_control_plane.errors import CriblControlPlaneError, ResponseValidationError
from cribl_control_plane.models.productscore import ProductsCore
from cribl_control_plane.models.security import Security
from fastmcp import Context
from pydantic import BaseModel, ValidationError
from snc_cribl_mcp.operations.sources import (
_merge_source_results,
collect_product_sources,
)
# Test helper models for validation error tests
class _DummyOutputModel(BaseModel):
"""Model requiring an 'output' field, used to generate ValidationErrors in tests."""
output: str
class _DummyConnectionModel(BaseModel):
"""Model requiring an 'output' field, used to generate ValidationErrors in tests."""
output: str
@pytest.fixture
def mock_ctx() -> Context:
"""Provide a Context-like AsyncMock for tool logging."""
ctx = MagicMock(spec=Context)
ctx.info = AsyncMock()
ctx.warning = AsyncMock()
return ctx
@pytest.mark.asyncio
async def test_collect_product_sources_success(mock_ctx: Context) -> None:
"""It should list sources for each group via the product-scoped client and aggregate results."""
# Mock groups list
mock_client = MagicMock()
groups_response = MagicMock(items=[MagicMock(), MagicMock()])
groups_response.items[0].model_dump.return_value = {"id": "g1"}
groups_response.items[1].model_dump.return_value = {"id": "g2"}
mock_client.groups.list_async = AsyncMock(return_value=groups_response)
# Mock top-level sources client and sdk_configuration for base URL
mock_client.sdk_configuration = MagicMock(server_url="https://example/api/v1")
stream_sources = MagicMock()
mock_client.sources = stream_sources
# First group returns 2 items, second returns 1
resp_g1 = MagicMock(items=[MagicMock(), MagicMock()], count=2)
resp_g1.items[0].model_dump.return_value = {"name": "s1"}
resp_g1.items[1].model_dump.return_value = {"name": "s2"}
resp_g2 = MagicMock(items=[MagicMock()], count=1)
resp_g2.items[0].model_dump.return_value = {"name": "s3"}
async def list_async_side_effect(*_args: object, **kwargs: object) -> MagicMock:
# Ensure correct kwargs are used
assert "server_url" in kwargs
assert kwargs["server_url"].endswith("/m/g1") or kwargs["server_url"].endswith("/m/g2") # type: ignore[index]
assert "timeout_ms" in kwargs
if kwargs["server_url"].endswith("/m/g1"): # type: ignore[index]
return resp_g1
return resp_g2
stream_sources.list_async = AsyncMock(side_effect=list_async_side_effect)
result = await collect_product_sources(
mock_client,
product=ProductsCore.STREAM,
timeout_ms=10000,
ctx=mock_ctx,
)
assert result["status"] == "ok"
assert result["total_count"] == 3
assert len(result["groups"]) == 2
assert result["groups"][0]["group_id"] == "g1"
assert result["groups"][0]["count"] == 2
assert result["groups"][0]["reported_count"] == 2
@pytest.mark.asyncio
async def test_collect_product_sources_404_per_group(mock_ctx: Context) -> None:
"""404 on a group's sources should be treated as empty for that group."""
mock_client = MagicMock()
groups_response = MagicMock(items=[MagicMock()])
groups_response.items[0].model_dump.return_value = {"id": "g404"}
mock_client.groups.list_async = AsyncMock(return_value=groups_response)
mock_client.sdk_configuration = MagicMock(server_url="https://example/api/v1")
stream_sources = MagicMock()
mock_client.sources = stream_sources
api_error_404 = CriblControlPlaneError(message="Not found", body=None, raw_response=MagicMock(status_code=404))
stream_sources.list_async = AsyncMock(side_effect=api_error_404)
result = await collect_product_sources(
mock_client,
product=ProductsCore.STREAM,
timeout_ms=10000,
ctx=mock_ctx,
)
assert result["status"] == "ok"
assert result["total_count"] == 0
assert result["groups"][0]["group_id"] == "g404"
assert result["groups"][0]["count"] == 0
# warning should have been awaited at least once
assert getattr(mock_ctx.warning, "await_count", 0) >= 1
@pytest.mark.asyncio
async def test_collect_product_sources_api_error_non_404(mock_ctx: Context) -> None:
"""Non-404 API errors should be raised as RuntimeError."""
mock_client = MagicMock()
groups_response = MagicMock(items=[MagicMock()])
groups_response.items[0].model_dump.return_value = {"id": "g1"}
mock_client.groups.list_async = AsyncMock(return_value=groups_response)
mock_client.sdk_configuration = MagicMock(server_url="https://example/api/v1")
edge_sources = MagicMock()
mock_client.sources = edge_sources
api_error_500 = CriblControlPlaneError(message="Boom", body=None, raw_response=MagicMock(status_code=500))
edge_sources.list_async = AsyncMock(side_effect=api_error_500)
with pytest.raises(RuntimeError, match="Cribl API error while listing sources"):
await collect_product_sources(
mock_client,
product=ProductsCore.EDGE,
timeout_ms=10000,
ctx=mock_ctx,
)
@pytest.mark.asyncio
async def test_collect_product_sources_network_error(mock_ctx: Context) -> None:
"""Network failures should be raised as RuntimeError."""
mock_client = MagicMock()
groups_response = MagicMock(items=[MagicMock()])
groups_response.items[0].model_dump.return_value = {"id": "g1"}
mock_client.groups.list_async = AsyncMock(return_value=groups_response)
mock_client.sdk_configuration = MagicMock(server_url="https://example/api/v1")
stream_sources = MagicMock()
mock_client.sources = stream_sources
stream_sources.list_async = AsyncMock(side_effect=httpx.ConnectError("fail"))
with pytest.raises(RuntimeError, match="Network error while listing sources"):
await collect_product_sources(
mock_client,
product=ProductsCore.STREAM,
timeout_ms=10000,
ctx=mock_ctx,
)
@pytest.mark.asyncio
async def test_collect_product_sources_unavailable_product_returns_unavailable(mock_ctx: Context) -> None:
"""If listing groups returns 404, the function should return an 'unavailable' status."""
mock_client = MagicMock()
api_error_404 = CriblControlPlaneError(message="Not found", body=None, raw_response=MagicMock(status_code=404))
mock_client.groups.list_async = AsyncMock(side_effect=api_error_404)
result = await collect_product_sources(
mock_client,
product=ProductsCore.STREAM,
timeout_ms=10000,
ctx=mock_ctx,
)
assert result["status"] == "unavailable"
assert result["total_count"] == 0
assert result["groups"] == []
@pytest.mark.asyncio
async def test_collect_product_sources_network_error_on_groups(mock_ctx: Context) -> None:
"""Network error while listing groups should raise RuntimeError."""
mock_client = MagicMock()
mock_client.groups.list_async = AsyncMock(side_effect=httpx.ConnectError("Network failure"))
with pytest.raises(RuntimeError, match="Network error while listing stream groups"):
await collect_product_sources(
mock_client,
product=ProductsCore.STREAM,
timeout_ms=10000,
ctx=mock_ctx,
)
@pytest.mark.asyncio
async def test_collect_product_sources_api_error_non_404_on_groups(mock_ctx: Context) -> None:
"""Non-404 API error while listing groups should raise RuntimeError."""
mock_client = MagicMock()
api_error_500 = CriblControlPlaneError(message="Server error", body=None, raw_response=MagicMock(status_code=500))
mock_client.groups.list_async = AsyncMock(side_effect=api_error_500)
with pytest.raises(RuntimeError, match="Cribl API error while listing stream groups for sources"):
await collect_product_sources(
mock_client,
product=ProductsCore.STREAM,
timeout_ms=10000,
ctx=mock_ctx,
)
@pytest.mark.asyncio
async def test_collect_product_sources_skips_groups_without_id(mock_ctx: Context) -> None:
"""Groups without id or groupId should be skipped."""
mock_client = MagicMock()
groups_response = MagicMock(items=[MagicMock(), MagicMock()])
# First group has no id, second has id
groups_response.items[0].model_dump.return_value = {"name": "no_id_group"}
groups_response.items[1].model_dump.return_value = {"id": "g1"}
mock_client.groups.list_async = AsyncMock(return_value=groups_response)
mock_client.sdk_configuration = MagicMock(server_url="https://example/api/v1")
stream_sources = MagicMock()
mock_client.sources = stream_sources
resp_g1 = MagicMock(items=[MagicMock()], count=1)
resp_g1.items[0].model_dump.return_value = {"name": "s1"}
stream_sources.list_async = AsyncMock(return_value=resp_g1)
result = await collect_product_sources(
mock_client,
product=ProductsCore.STREAM,
timeout_ms=10000,
ctx=mock_ctx,
)
# Only one group should be processed (the one with id)
assert len(result["groups"]) == 1
assert result["groups"][0]["group_id"] == "g1"
assert result["total_count"] == 1
@pytest.mark.asyncio
async def test_collect_product_sources_validation_error_returns_structured_error(mock_ctx: Context) -> None:
"""SDK validation errors should return a structured error response, not raise."""
mock_client = MagicMock()
groups_response = MagicMock(items=[MagicMock()])
groups_response.items[0].model_dump.return_value = {"id": "default"}
mock_client.groups.list_async = AsyncMock(return_value=groups_response)
mock_client.sdk_configuration = MagicMock(server_url="https://example/api/v1")
stream_sources = MagicMock()
mock_client.sources = stream_sources
# Add error method to context mock
mock_ctx.error = AsyncMock()
# Create a mock HTTP response
mock_response = MagicMock(spec=httpx.Response)
mock_response.status_code = 200
mock_response.headers = httpx.Headers({})
mock_response.text = "{}"
# Create a real Pydantic ValidationError with a realistic structure
# This simulates the SDK failing to parse a source with a missing field
raw_body = json.dumps(
{
"items": [
{"id": "source1", "type": "syslog"},
{"id": "source2", "type": "tcpjson", "tcpjson": {"connections": [{}]}},
],
"count": 2,
}
)
# Create a Pydantic ValidationError with realistic location info
pydantic_error: ValidationError
try:
_DummyOutputModel.model_validate({})
except ValidationError as ve:
pydantic_error = ve
else:
pytest.fail("Expected ValidationError was not raised")
# Wrap it in ResponseValidationError
mock_response.text = raw_body
validation_exc = ResponseValidationError(
"Response validation failed",
mock_response,
pydantic_error,
raw_body,
)
stream_sources.list_async = AsyncMock(side_effect=validation_exc)
result = await collect_product_sources(
mock_client,
product=ProductsCore.STREAM,
timeout_ms=10000,
ctx=mock_ctx,
)
# Should return error response, not raise
assert result["status"] == "validation_error"
assert "product" in result
assert result["product"] == "stream"
assert result["group_id"] == "default"
assert "errors" in result
assert "message" in result
assert "SDK could not validate" in result["message"]
@pytest.mark.asyncio
async def test_collect_product_sources_validation_error_extracts_object_id(mock_ctx: Context) -> None:
"""SDK validation errors should return structured error with helpful info."""
mock_client = MagicMock()
groups_response = MagicMock(items=[MagicMock()])
groups_response.items[0].model_dump.return_value = {"id": "default"}
mock_client.groups.list_async = AsyncMock(return_value=groups_response)
mock_client.sdk_configuration = MagicMock(server_url="https://example/api/v1")
stream_sources = MagicMock()
mock_client.sources = stream_sources
mock_ctx.error = AsyncMock()
# Create a mock HTTP response
mock_response = MagicMock(spec=httpx.Response)
mock_response.status_code = 200
mock_response.headers = httpx.Headers({})
# Body with identifiable source
raw_body = json.dumps(
{
"items": [
{"id": "in_tcp_json", "type": "tcpjson", "connections": [{}]},
],
"count": 1,
}
)
mock_response.text = raw_body
# Create a simple Pydantic ValidationError by catching the exception
pydantic_error: ValidationError
try:
_DummyConnectionModel.model_validate({})
except ValidationError as ve:
pydantic_error = ve
else:
pytest.fail("Expected ValidationError was not raised")
validation_exc = ResponseValidationError(
"Response validation failed",
mock_response,
pydantic_error,
raw_body,
)
stream_sources.list_async = AsyncMock(side_effect=validation_exc)
result = await collect_product_sources(
mock_client,
product=ProductsCore.STREAM,
timeout_ms=10000,
ctx=mock_ctx,
)
# Verify the error response structure
assert result["status"] == "validation_error"
# Simple validation errors may not have full location info, but errors list should exist
assert "errors" in result
# The message should be informative even without location details
assert "validation_error" in result["status"]
# Should include resolution guidance
assert "resolution" in result
# =============================================================================
# Tests for merge edge cases (with security parameter to trigger collector fetch)
# =============================================================================
@pytest.mark.asyncio
async def test_collect_product_sources_with_collectors_merged(mock_ctx: Context) -> None:
"""When security is provided, both regular sources and collectors should be merged."""
mock_client = MagicMock()
groups_response = MagicMock(items=[MagicMock()])
groups_response.items[0].model_dump.return_value = {"id": "g1"}
mock_client.groups.list_async = AsyncMock(return_value=groups_response)
# Mock SDK sources
mock_client.sdk_configuration = MagicMock(server_url="https://example/api/v1")
resp_g1 = MagicMock(items=[MagicMock()], count=1)
resp_g1.items[0].model_dump.return_value = {"name": "regular_source"}
mock_client.sources = MagicMock(list_async=AsyncMock(return_value=resp_g1))
# Mock HTTP client for collectors - return a collector job
mock_http_client = AsyncMock()
resp_jobs = MagicMock()
resp_jobs.status_code = 200
resp_jobs.json.return_value = {
"items": [
{"id": "s3_collector", "type": "collection", "collector": {"type": "s3", "bucket": "test"}},
]
}
mock_http_client.get = AsyncMock(return_value=resp_jobs)
mock_client.sdk_configuration.async_client = mock_http_client
security = Security(bearer_auth="test-token")
result = await collect_product_sources(
mock_client,
product=ProductsCore.STREAM,
timeout_ms=10000,
ctx=mock_ctx,
security=security,
)
assert result["status"] == "ok"
# Should have merged items: 1 regular + 1 collector = 2
assert result["total_count"] == 2
assert len(result["groups"]) == 1
assert result["groups"][0]["count"] == 2
@pytest.mark.asyncio
async def test_collect_product_sources_collector_failure_returns_regular_only(mock_ctx: Context) -> None:
"""When collectors fail but regular sources succeed, return regular sources only."""
mock_client = MagicMock()
groups_response = MagicMock(items=[MagicMock()])
groups_response.items[0].model_dump.return_value = {"id": "g1"}
mock_client.groups.list_async = AsyncMock(return_value=groups_response)
# Mock SDK sources - success
mock_client.sdk_configuration = MagicMock(server_url="https://example/api/v1")
resp_g1 = MagicMock(items=[MagicMock()], count=1)
resp_g1.items[0].model_dump.return_value = {"name": "regular_source"}
mock_client.sources = MagicMock(list_async=AsyncMock(return_value=resp_g1))
# Mock HTTP client for collectors - return 500 error
mock_http_client = AsyncMock()
resp_error = MagicMock()
resp_error.status_code = 500
resp_error.text = "Internal Server Error"
mock_http_client.get = AsyncMock(return_value=resp_error)
mock_client.sdk_configuration.async_client = mock_http_client
security = Security(bearer_auth="test-token")
result = await collect_product_sources(
mock_client,
product=ProductsCore.STREAM,
timeout_ms=10000,
ctx=mock_ctx,
security=security,
)
# Should still succeed with just regular sources
assert result["status"] == "ok"
assert result["total_count"] == 1
@pytest.mark.asyncio
async def test_collect_product_sources_collector_json_error_returns_regular_only(mock_ctx: Context) -> None:
"""When collector JSON parsing fails, return regular sources only with warning."""
mock_client = MagicMock()
groups_response = MagicMock(items=[MagicMock()])
groups_response.items[0].model_dump.return_value = {"id": "g1"}
mock_client.groups.list_async = AsyncMock(return_value=groups_response)
# Mock SDK sources - success
mock_client.sdk_configuration = MagicMock(server_url="https://example/api/v1")
resp_g1 = MagicMock(items=[MagicMock()], count=1)
resp_g1.items[0].model_dump.return_value = {"name": "regular_source"}
mock_client.sources = MagicMock(list_async=AsyncMock(return_value=resp_g1))
# Mock HTTP client for collectors - return invalid JSON
mock_http_client = AsyncMock()
resp_invalid = MagicMock()
resp_invalid.status_code = 200
resp_invalid.json.side_effect = ValueError("Invalid JSON")
resp_invalid.raise_for_status = MagicMock() # Does nothing
mock_http_client.get = AsyncMock(return_value=resp_invalid)
mock_client.sdk_configuration.async_client = mock_http_client
security = Security(bearer_auth="test-token")
result = await collect_product_sources(
mock_client,
product=ProductsCore.STREAM,
timeout_ms=10000,
ctx=mock_ctx,
security=security,
)
# Should still succeed with just regular sources
assert result["status"] == "ok"
assert result["total_count"] == 1
# Should have logged a warning about the collector failure
assert getattr(mock_ctx.warning, "await_count", 0) >= 1
@pytest.mark.asyncio
async def test_collect_product_sources_regular_failure_returns_collectors(mock_ctx: Context) -> None:
"""When regular sources fail but collectors succeed, return collector sources."""
mock_client = MagicMock()
groups_response = MagicMock(items=[MagicMock()])
groups_response.items[0].model_dump.return_value = {"id": "g1"}
mock_client.groups.list_async = AsyncMock(return_value=groups_response)
# Mock SDK sources - failure (non-404)
mock_client.sdk_configuration = MagicMock(server_url="https://example/api/v1")
api_error_500 = CriblControlPlaneError(message="Boom", body=None, raw_response=MagicMock(status_code=500))
mock_client.sources = MagicMock(list_async=AsyncMock(side_effect=api_error_500))
# Mock HTTP client for collectors - success
mock_http_client = AsyncMock()
resp_jobs = MagicMock()
resp_jobs.status_code = 200
resp_jobs.json.return_value = {
"items": [
{"id": "s3_collector", "type": "collection", "collector": {"type": "s3", "bucket": "test"}},
]
}
mock_http_client.get = AsyncMock(return_value=resp_jobs)
mock_client.sdk_configuration.async_client = mock_http_client
security = Security(bearer_auth="test-token")
# The regular sources will raise RuntimeError, but with security we still try collectors
# Actually, since regular sources raise, the whole function should raise
# Let me check the actual implementation flow...
# Looking at the code: collect_items_via_sdk raises RuntimeError for non-404
# So collectors won't even be attempted if regular sources fail hard
with pytest.raises(RuntimeError, match="Cribl API error while listing sources"):
await collect_product_sources(
mock_client,
product=ProductsCore.STREAM,
timeout_ms=10000,
ctx=mock_ctx,
security=security,
)
@pytest.mark.asyncio
async def test_collect_product_sources_collector_only_groups(mock_ctx: Context) -> None:
"""Groups that only exist in collectors should be included in merged result."""
mock_client = MagicMock()
# Return two groups
groups_response = MagicMock(items=[MagicMock(), MagicMock()])
groups_response.items[0].model_dump.return_value = {"id": "g1"}
groups_response.items[1].model_dump.return_value = {"id": "g2"}
mock_client.groups.list_async = AsyncMock(return_value=groups_response)
mock_client.sdk_configuration = MagicMock(server_url="https://example/api/v1")
# Mock SDK sources - g1 has sources, g2 is empty
resp_g1 = MagicMock(items=[MagicMock()], count=1)
resp_g1.items[0].model_dump.return_value = {"name": "regular_source"}
resp_g2 = MagicMock(items=[], count=0)
async def sources_list_async(*_args: object, **kwargs: object) -> MagicMock:
if str(kwargs.get("server_url", "")).endswith("/m/g1"):
return resp_g1
return resp_g2
mock_client.sources = MagicMock(list_async=AsyncMock(side_effect=sources_list_async))
# Mock HTTP client for collectors - g1 empty, g2 has collector
mock_http_client = AsyncMock()
async def http_get(url: str, **_kwargs: object) -> MagicMock:
resp = MagicMock()
resp.status_code = 200
if "/m/g1/" in url:
resp.json.return_value = {"items": []}
else:
resp.json.return_value = {
"items": [
{"id": "collector_g2", "type": "collection", "collector": {"type": "rest"}},
]
}
return resp
mock_http_client.get = AsyncMock(side_effect=http_get)
mock_client.sdk_configuration.async_client = mock_http_client
security = Security(bearer_auth="test-token")
result = await collect_product_sources(
mock_client,
product=ProductsCore.STREAM,
timeout_ms=10000,
ctx=mock_ctx,
security=security,
)
assert result["status"] == "ok"
# g1: 1 regular, g2: 1 collector = 2 total
assert result["total_count"] == 2
assert len(result["groups"]) == 2
# Find each group
g1_result = next(g for g in result["groups"] if g["group_id"] == "g1")
g2_result = next(g for g in result["groups"] if g["group_id"] == "g2")
assert g1_result["count"] == 1
assert g2_result["count"] == 1
# =============================================================================
# Direct unit tests for _merge_source_results helper
# =============================================================================
class TestMergeSourceResults:
"""Unit tests for the _merge_source_results helper function."""
def test_merge_both_ok(self) -> None:
"""When both results are ok, items should be merged per group."""
regular = {
"status": "ok",
"total_count": 1,
"groups": [{"group_id": "g1", "count": 1, "items": [{"name": "src1"}]}],
}
collector = {
"status": "ok",
"total_count": 1,
"groups": [{"group_id": "g1", "count": 1, "items": [{"name": "coll1"}]}],
}
result = _merge_source_results(regular, collector)
assert result["status"] == "ok"
assert result["groups"][0]["count"] == 2
assert len(result["groups"][0]["items"]) == 2
def test_merge_regular_failed_collector_ok(self) -> None:
"""When regular fails but collector succeeds, return collector result."""
regular = {"status": "error", "error": "boom"}
collector = {
"status": "ok",
"total_count": 1,
"groups": [{"group_id": "g1", "count": 1, "items": [{"name": "coll1"}]}],
}
result = _merge_source_results(regular, collector)
assert result["status"] == "ok"
assert result["total_count"] == 1
def test_merge_regular_failed_collector_failed(self) -> None:
"""When both fail, return regular result (primary)."""
regular = {"status": "error", "error": "regular boom"}
collector = {"status": "error", "error": "collector boom"}
result = _merge_source_results(regular, collector)
assert result["status"] == "error"
assert result["error"] == "regular boom"
def test_merge_collector_failed_regular_ok(self) -> None:
"""When collector fails but regular succeeds, return regular result."""
regular = {
"status": "ok",
"total_count": 1,
"groups": [{"group_id": "g1", "count": 1, "items": [{"name": "src1"}]}],
}
collector = {"status": "error", "error": "collector boom"}
result = _merge_source_results(regular, collector)
assert result["status"] == "ok"
assert result["total_count"] == 1
def test_merge_collector_only_group(self) -> None:
"""Groups that exist only in collector should be added to result."""
regular = {
"status": "ok",
"total_count": 1,
"groups": [{"group_id": "g1", "count": 1, "items": [{"name": "src1"}]}],
}
collector = {
"status": "ok",
"total_count": 1,
"groups": [{"group_id": "g2", "count": 1, "items": [{"name": "coll1"}]}],
}
result = _merge_source_results(regular, collector)
assert result["status"] == "ok"
assert len(result["groups"]) == 2
group_ids = {g["group_id"] for g in result["groups"]}
assert group_ids == {"g1", "g2"}
def test_merge_skips_collector_groups_without_id(self) -> None:
"""Collector groups without group_id should be skipped."""
regular = {
"status": "ok",
"total_count": 1,
"groups": [{"group_id": "g1", "count": 1, "items": [{"name": "src1"}]}],
}
collector = {
"status": "ok",
"total_count": 1,
"groups": [
{"group_id": None, "count": 1, "items": [{"name": "bad"}]},
{"count": 1, "items": [{"name": "also_bad"}]}, # missing group_id key
],
}
result = _merge_source_results(regular, collector)
assert result["status"] == "ok"
assert len(result["groups"]) == 1
assert result["groups"][0]["group_id"] == "g1"