"""Unit tests for AnalyzeTrendCalc."""
import pytest
from unittest.mock import AsyncMock, MagicMock
from postgres_mcp.revenue import AnalyzeTrendCalc
class MockRow:
"""Mock row object mimicking SqlDriver response."""
def __init__(self, cells: dict):
self.cells = cells
@pytest.fixture
def mock_sql_driver():
"""Create a mock SQL driver."""
return MagicMock()
class TestAnalyzeTrendCalc:
"""Tests for AnalyzeTrendCalc class."""
@pytest.mark.asyncio
async def test_analyze_no_current_data(self, mock_sql_driver):
"""Test analyze when no data exists for current cycle."""
mock_sql_driver.execute_query = AsyncMock(return_value=[])
calc = AnalyzeTrendCalc(mock_sql_driver)
result = await calc.analyze(current_cycle="01-01-2026")
assert result["summary"]["total_anomalies"] == 0
assert result["alert_required"] is False
assert "No data found" in result.get("message", "")
@pytest.mark.asyncio
async def test_analyze_detects_increase(self, mock_sql_driver):
"""Test that significant increases are detected."""
# Current cycle data
current_rows = [
MockRow({"email": "user1@test.com", "service_type": "cloud_server", "current_total": "150000"}),
]
# Previous cycle average data
previous_rows = [
MockRow({"email": "user1@test.com", "service_type": "cloud_server", "avg_total": "100000"}),
]
# Mock returns different results for different calls
mock_sql_driver.execute_query = AsyncMock(side_effect=[current_rows, previous_rows])
calc = AnalyzeTrendCalc(mock_sql_driver)
result = await calc.analyze(current_cycle="01-01-2026", threshold_percent=20.0)
assert result["summary"]["significant_increases"] == 1
assert len(result["increases"]) == 1
assert result["increases"][0]["email"] == "user1@test.com"
assert result["increases"][0]["change_percent"] == 50.0
@pytest.mark.asyncio
async def test_analyze_detects_decrease_churn_risk(self, mock_sql_driver):
"""Test that significant decreases (churn risks) are detected."""
# Current cycle data - significant decrease
current_rows = [
MockRow({"email": "user1@test.com", "service_type": "cloud_server", "current_total": "50000"}),
]
# Previous cycle average data
previous_rows = [
MockRow({"email": "user1@test.com", "service_type": "cloud_server", "avg_total": "100000"}),
]
mock_sql_driver.execute_query = AsyncMock(side_effect=[current_rows, previous_rows])
calc = AnalyzeTrendCalc(mock_sql_driver)
result = await calc.analyze(current_cycle="01-01-2026", threshold_percent=20.0)
assert result["summary"]["churn_risks"] == 1
assert result["alert_required"] is True
assert len(result["decreases"]) == 1
assert result["decreases"][0]["change_percent"] == -50.0
assert result["decreases"][0]["trend"] == "decrease"
@pytest.mark.asyncio
async def test_analyze_ignores_below_threshold(self, mock_sql_driver):
"""Test that changes below threshold are not flagged."""
# Current cycle data - small change
current_rows = [
MockRow({"email": "user1@test.com", "service_type": "cloud_server", "current_total": "105000"}),
]
# Previous cycle average data
previous_rows = [
MockRow({"email": "user1@test.com", "service_type": "cloud_server", "avg_total": "100000"}),
]
mock_sql_driver.execute_query = AsyncMock(side_effect=[current_rows, previous_rows])
calc = AnalyzeTrendCalc(mock_sql_driver)
result = await calc.analyze(current_cycle="01-01-2026", threshold_percent=20.0)
# 5% change should not trigger anomaly when threshold is 20%
assert result["summary"]["total_anomalies"] == 0
assert len(result["increases"]) == 0
assert len(result["decreases"]) == 0
@pytest.mark.asyncio
async def test_analyze_with_service_filter(self, mock_sql_driver):
"""Test analyze with service_type filter."""
current_rows = [
MockRow({"email": "user1@test.com", "service_type": "dbaas", "current_total": "200000"}),
]
previous_rows = [
MockRow({"email": "user1@test.com", "service_type": "dbaas", "avg_total": "100000"}),
]
mock_sql_driver.execute_query = AsyncMock(side_effect=[current_rows, previous_rows])
calc = AnalyzeTrendCalc(mock_sql_driver)
result = await calc.analyze(current_cycle="01-01-2026", service_type="dbaas")
assert result["filters_applied"]["service_type"] == "dbaas"
assert result["summary"]["significant_increases"] == 1
@pytest.mark.asyncio
async def test_analyze_new_customer_ignored(self, mock_sql_driver):
"""Test that new customers (no previous data) are ignored."""
# Current cycle has data
current_rows = [
MockRow({"email": "newuser@test.com", "service_type": "cloud_server", "current_total": "100000"}),
]
# No previous data for this customer
previous_rows = []
mock_sql_driver.execute_query = AsyncMock(side_effect=[current_rows, previous_rows])
calc = AnalyzeTrendCalc(mock_sql_driver)
result = await calc.analyze(current_cycle="01-01-2026")
# New customer should not be flagged as anomaly
assert result["summary"]["total_anomalies"] == 0
@pytest.mark.asyncio
async def test_analyze_error_handling(self, mock_sql_driver):
"""Test error handling when analysis fails."""
mock_sql_driver.execute_query = AsyncMock(side_effect=Exception("Database error"))
calc = AnalyzeTrendCalc(mock_sql_driver)
result = await calc.analyze(current_cycle="01-01-2026")
assert "error" in result
assert result["alert_required"] is False
@pytest.mark.asyncio
async def test_analyze_custom_comparison_cycles(self, mock_sql_driver):
"""Test analyze with custom number of comparison cycles."""
current_rows = [
MockRow({"email": "user1@test.com", "service_type": "cloud_server", "current_total": "200000"}),
]
previous_rows = [
MockRow({"email": "user1@test.com", "service_type": "cloud_server", "avg_total": "100000"}),
]
mock_sql_driver.execute_query = AsyncMock(side_effect=[current_rows, previous_rows])
calc = AnalyzeTrendCalc(mock_sql_driver)
result = await calc.analyze(current_cycle="01-01-2026", comparison_cycles=6)
assert result["summary"]["comparison_cycles"] == 6
@pytest.mark.asyncio
async def test_analyze_limits_results(self, mock_sql_driver):
"""Test that results are limited to top 10."""
# Create 15 customers with increases
current_rows = [
MockRow({"email": f"user{i}@test.com", "service_type": "cloud_server", "current_total": "200000"})
for i in range(15)
]
previous_rows = [
MockRow({"email": f"user{i}@test.com", "service_type": "cloud_server", "avg_total": "100000"})
for i in range(15)
]
mock_sql_driver.execute_query = AsyncMock(side_effect=[current_rows, previous_rows])
calc = AnalyzeTrendCalc(mock_sql_driver)
result = await calc.analyze(current_cycle="01-01-2026")
# Should only return top 10
assert len(result["increases"]) <= 10