"""Tests for SQLCoach routing and analysis packet rendering."""
from __future__ import annotations
import os
import unittest
from typing import Any, Dict
from unittest import mock
from app.coach_local import SQLCoach
from app.model_manager import LocalModelManager
class _FakeModelManager(LocalModelManager):
def __init__(self) -> None:
super().__init__(model_path="/dev/null", min_size_mb=0, expected_sha256=None)
def generate(self, *args, **kwargs): # pragma: no cover - should be bypassed in tests
raise AssertionError("Model should not be invoked in these tests")
def _make_packet() -> Dict[str, Any]:
return {
"meta": {
"run_id": "run_01",
"generated_at": "2024-01-01T00:00:00Z",
"limits": {"last_n_runs": 10, "top_k_skills": 10, "bucket_seconds": 5},
},
"run_summary": {
"columns": ["run_id", "total_hits", "total_damage", "duration_seconds", "dps", "crit_rate_pct"],
"rows": [["run_01", 18, 5500, 55.0, 100.0, 25.0]],
},
"runs_last_n": [
{"run_id": "run_01", "total_hits": 18, "total_damage": 5500, "duration_seconds": 55.0, "dps": 100.0, "crit_rate_pct": 25.0},
{"run_id": "run_02", "total_hits": 12, "total_damage": 4300, "duration_seconds": 60.0, "dps": 71.7, "crit_rate_pct": 20.0},
],
"top_skills": {
"columns": [
"skill_name",
"total_damage",
"total_hits",
"avg_damage",
"crit_rate_pct",
"damage_share_pct",
],
"rows": [
["Shadow Burst", 3200, 10, 320.0, 30.0, 58.18],
["Curse Explosion", 2300, 8, 287.5, 20.0, 41.82],
],
},
"skill_efficiency": {
"columns": [
"skill_name",
"total_hits",
"total_damage",
"avg_damage_per_hit",
"crit_hits",
"crit_rate_pct",
],
"rows": [["Shadow Burst", 10, 3200, 320.0, 3, 30.0]],
},
"timeline": {
"columns": ["bucket_s", "hits", "crit_hits", "crit_rate_pct", "damage"],
"rows": [[0, 10, 2, 20.0, 1500]],
},
"skill_deltas": {
"columns": [
"skill_name", "last_share_pct", "prior_avg_share_pct", "delta_share_pp",
"last_hits", "prior_avg_hits", "delta_hits",
"last_crit_rate_pct", "prior_avg_crit_rate_pct", "delta_crit_pp"
],
"rows": [],
},
"windows": {
"early_window": {"start_s": 0, "end_s": 60, "damage": 0, "top_skills": []},
"late_window": {"start_s": 0, "end_s": 60, "damage": 0, "top_skills": []},
"top_damage_windows": [],
},
"actions": {
"top_levers": [],
},
"notes": [],
}
class SQLCoachRoutingTests(unittest.TestCase):
def setUp(self) -> None:
self.coach = SQLCoach(_FakeModelManager())
self.payload: Dict[str, Any] = {"summary": {}, "runs": []}
self.schema: Dict[str, Any] = {"columns": []}
def test_meta_question_returns_capability_card_without_tools(self) -> None:
analysis_calls = {"count": 0}
def analysis_callback() -> Dict[str, Any]:
analysis_calls["count"] += 1
return _make_packet()
answer, trace = self.coach.answer(
question="help",
payload=self.payload,
schema=self.schema,
query_callback=lambda sql: {},
analysis_callback=analysis_callback,
)
self.assertIn("capability", answer.lower())
self.assertEqual(analysis_calls["count"], 0)
self.assertEqual(trace, [])
def test_data_question_calls_analysis_packet_once(self) -> None:
analysis_calls = {"count": 0}
def analysis_callback() -> Dict[str, Any]:
analysis_calls["count"] += 1
return _make_packet()
# Use a neutral question that triggers DEFAULT intent (no keywords)
answer, trace = self.coach.answer(
question="Tell me about my performance",
payload=self.payload,
schema=self.schema,
query_callback=lambda sql: (_ for _ in ()).throw(AssertionError("Should not call legacy SQL")),
analysis_callback=analysis_callback,
)
self.assertEqual(analysis_calls["count"], 1)
self.assertIn("Insights:", answer)
self.assertIn("Evidence:", answer)
self.assertIn("Actions:", answer)
self.assertIn("Next questions:", answer)
self.assertTrue(any(call.get("tool") == "get_analysis_packet" for call in trace))
def test_legacy_env_routes_sql_path(self) -> None:
calls = iter(["SQL: SELECT 1", "ANSWER: done"])
self.coach._call_model = lambda messages: next(calls)
query_called = {"count": 0}
def query_callback(sql: str) -> Dict[str, Any]:
query_called["count"] += 1
return {"columns": ["skill_name"], "rows": [["SkillA"]]}
with mock.patch.dict(os.environ, {"DPSCOACH_USE_LEGACY_SQL": "1"}, clear=False):
answer, trace = self.coach.answer(
question="Show totals",
payload=self.payload,
schema=self.schema,
query_callback=query_callback,
analysis_callback=lambda: (_ for _ in ()).throw(AssertionError("Analysis path should be skipped")),
)
self.assertEqual(query_called["count"], 1)
self.assertEqual(answer, "done")
self.assertTrue(any(call.get("tool_name") == "query_dps" for call in trace))
if __name__ == "__main__":
unittest.main()