"""Unit tests for QueryRevenueCalc."""
import pytest
from unittest.mock import AsyncMock, MagicMock
from postgres_mcp.revenue import QueryRevenueCalc
class MockRow:
"""Mock row object mimicking SqlDriver response."""
def __init__(self, cells: dict):
self.cells = cells
@pytest.fixture
def mock_sql_driver():
"""Create a mock SQL driver."""
return MagicMock()
class TestQueryRevenueCalc:
"""Tests for QueryRevenueCalc class."""
@pytest.mark.asyncio
async def test_query_no_filters(self, mock_sql_driver):
"""Test query with no filters returns all data."""
mock_rows = [
MockRow({"email": "user1@test.com", "billing_cycle": "01-12-2025",
"total_revenue": "100000", "paid_revenue": "90000", "open_revenue": "10000", "record_count": 1}),
MockRow({"email": "user2@test.com", "billing_cycle": "01-12-2025",
"total_revenue": "50000", "paid_revenue": "50000", "open_revenue": "0", "record_count": 1}),
]
mock_sql_driver.execute_query = AsyncMock(return_value=mock_rows)
calc = QueryRevenueCalc(mock_sql_driver)
result = await calc.query()
assert result["total_records"] == 2
assert len(result["data"]) == 2
assert result["summary"]["total_revenue"] == 150000
assert result["summary"]["total_paid"] == 140000
@pytest.mark.asyncio
async def test_query_with_billing_cycle(self, mock_sql_driver):
"""Test query with billing_cycle filter."""
mock_rows = [
MockRow({"email": "user1@test.com", "billing_cycle": "01-12-2025",
"total_revenue": "100000", "paid_revenue": "100000", "open_revenue": "0", "record_count": 1}),
]
mock_sql_driver.execute_query = AsyncMock(return_value=mock_rows)
calc = QueryRevenueCalc(mock_sql_driver)
result = await calc.query(billing_cycle="01-12-2025")
# Verify filter is applied
assert result["filters_applied"]["billing_cycle"] == "01-12-2025"
assert result["total_records"] == 1
@pytest.mark.asyncio
async def test_query_with_service_type_filter(self, mock_sql_driver):
"""Test query with service_type filter."""
mock_rows = [
MockRow({"email": "user1@test.com", "billing_cycle": "01-12-2025",
"total_revenue": "80000", "paid_revenue": "80000", "open_revenue": "0", "record_count": 1}),
]
mock_sql_driver.execute_query = AsyncMock(return_value=mock_rows)
calc = QueryRevenueCalc(mock_sql_driver)
result = await calc.query(service_type="cloud_server")
assert result["filters_applied"]["service_type"] == "cloud_server"
assert result["total_records"] == 1
@pytest.mark.asyncio
async def test_query_with_billing_plan_filter(self, mock_sql_driver):
"""Test query with billing_plan filter uses correct table."""
mock_rows = [
MockRow({"email": "user1@test.com", "billing_cycle": "01-12-2025",
"total_revenue": "50000", "paid_revenue": "50000", "open_revenue": "0", "record_count": 1}),
]
mock_sql_driver.execute_query = AsyncMock(return_value=mock_rows)
calc = QueryRevenueCalc(mock_sql_driver)
result = await calc.query(billing_plan="on_demand")
assert result["filters_applied"]["billing_plan"] == "on_demand"
# Verify query was called
mock_sql_driver.execute_query.assert_called_once()
@pytest.mark.asyncio
async def test_query_empty_results(self, mock_sql_driver):
"""Test query with no matching results."""
mock_sql_driver.execute_query = AsyncMock(return_value=[])
calc = QueryRevenueCalc(mock_sql_driver)
result = await calc.query(email="nonexistent@test.com")
assert result["total_records"] == 0
assert result["data"] == []
assert result["summary"]["total_revenue"] == 0
@pytest.mark.asyncio
async def test_query_aggregation_service(self, mock_sql_driver):
"""Test query with service aggregation."""
mock_rows = [
MockRow({"service_type": "cloud_server",
"total_revenue": "500000", "paid_revenue": "450000", "open_revenue": "50000", "record_count": 10}),
MockRow({"service_type": "dbaas",
"total_revenue": "200000", "paid_revenue": "200000", "open_revenue": "0", "record_count": 5}),
]
mock_sql_driver.execute_query = AsyncMock(return_value=mock_rows)
calc = QueryRevenueCalc(mock_sql_driver)
result = await calc.query(aggregation="service")
assert result["aggregation"] == "service"
assert result["total_records"] == 2
@pytest.mark.asyncio
async def test_query_error_handling(self, mock_sql_driver):
"""Test error handling when query fails."""
mock_sql_driver.execute_query = AsyncMock(side_effect=Exception("Database error"))
calc = QueryRevenueCalc(mock_sql_driver)
result = await calc.query()
assert "error" in result
assert result["total_records"] == 0
def test_service_types_constant(self):
"""Test that SERVICE_TYPES constant is properly defined."""
assert "cloud_server" in QueryRevenueCalc.SERVICE_TYPES
assert "dbaas" in QueryRevenueCalc.SERVICE_TYPES
assert "cdn" in QueryRevenueCalc.SERVICE_TYPES
assert len(QueryRevenueCalc.SERVICE_TYPES) == 19
def test_billing_plans_constant(self):
"""Test that BILLING_PLANS constant is properly defined."""
assert "on_demand" in QueryRevenueCalc.BILLING_PLANS
assert "subscription" in QueryRevenueCalc.BILLING_PLANS
assert len(QueryRevenueCalc.BILLING_PLANS) == 2