"""
Tests for SRD API client.
These tests use mocked HTTP responses for fast, reliable testing.
One slow test is marked for optional real API testing.
"""
import asyncio
import json
import pytest
from pathlib import Path
from unittest.mock import AsyncMock, MagicMock, patch
from dm20_protocol.rulebooks.sources import SRDSource, SRDSourceError
from dm20_protocol.rulebooks.models import RulebookSource
def run_async(coro):
"""Helper to run async code in sync tests."""
return asyncio.get_event_loop().run_until_complete(coro)
# Sample API responses for mocking
MOCK_CLASSES_LIST = {
"count": 2,
"results": [
{"index": "wizard", "name": "Wizard", "url": "/api/classes/wizard"},
{"index": "fighter", "name": "Fighter", "url": "/api/classes/fighter"},
]
}
MOCK_WIZARD_CLASS = {
"index": "wizard",
"name": "Wizard",
"hit_die": 6,
"proficiencies": [{"name": "Daggers"}, {"name": "Darts"}],
"saving_throws": [{"name": "Intelligence"}, {"name": "Wisdom"}],
"starting_equipment": [],
"spellcasting": {
"level": 1,
"spellcasting_ability": {"index": "int", "name": "INT"},
},
"subclasses": [{"index": "evocation", "name": "School of Evocation"}],
}
MOCK_WIZARD_LEVELS = [
{"level": 1, "prof_bonus": 2, "features": [{"name": "Spellcasting"}, {"name": "Arcane Recovery"}]},
{"level": 2, "prof_bonus": 2, "features": [{"name": "Arcane Tradition"}]},
]
MOCK_EVOCATION_SUBCLASS = {
"index": "evocation",
"name": "School of Evocation",
"class": {"index": "wizard"},
"subclass_flavor": "Evocation Savant",
"desc": ["You focus your study on magic that creates powerful elemental effects."],
}
MOCK_RACES_LIST = {
"count": 1,
"results": [{"index": "elf", "name": "Elf", "url": "/api/races/elf"}]
}
MOCK_ELF_RACE = {
"index": "elf",
"name": "Elf",
"speed": 30,
"ability_bonuses": [{"ability_score": {"index": "dex"}, "bonus": 2}],
"size": "Medium",
"languages": [{"name": "Common"}, {"name": "Elvish"}],
"traits": [{"index": "darkvision", "name": "Darkvision"}],
"subraces": [{"index": "high-elf"}],
}
MOCK_SPELLS_LIST = {
"count": 1,
"results": [{"index": "fireball", "name": "Fireball"}]
}
MOCK_FIREBALL_SPELL = {
"index": "fireball",
"name": "Fireball",
"level": 3,
"school": {"name": "Evocation"},
"casting_time": "1 action",
"range": "150 feet",
"duration": "Instantaneous",
"components": ["V", "S", "M"],
"material": "A tiny ball of bat guano and sulfur",
"ritual": False,
"concentration": False,
"desc": ["A bright streak flashes from your pointing finger..."],
"higher_level": ["When you cast this spell using a spell slot of 4th level or higher..."],
"classes": [{"index": "wizard"}, {"index": "sorcerer"}],
"subclasses": [],
"damage": {"damage_type": {"name": "fire"}},
}
MOCK_MONSTERS_LIST = {"count": 0, "results": []}
MOCK_EQUIPMENT_LIST = {"count": 0, "results": []}
MOCK_FEATS_LIST = {"count": 0, "results": []}
MOCK_BACKGROUNDS_LIST = {"count": 0, "results": []}
MOCK_SUBRACES_LIST = {"count": 0, "results": []}
class MockResponse:
"""Mock httpx response."""
def __init__(self, data: dict, status_code: int = 200):
self._data = data
self.status_code = status_code
def json(self):
return self._data
def raise_for_status(self):
if self.status_code >= 400:
raise Exception(f"HTTP {self.status_code}")
class TestSRDSourceInit:
"""Test SRDSource initialization."""
def test_default_init(self, tmp_path):
source = SRDSource(cache_dir=tmp_path / "cache")
assert source.source_id == "srd-2014"
assert source.source_type == RulebookSource.SRD
assert source.version == "2014"
assert not source.is_loaded
def test_custom_version(self, tmp_path):
source = SRDSource(version="2024", cache_dir=tmp_path / "cache")
assert source.source_id == "srd-2024"
assert source.version == "2024"
class TestSRDSourceMocked:
"""Test SRDSource with mocked HTTP responses."""
@pytest.fixture
def mock_responses(self):
"""Create a mapping of endpoints to mock responses."""
return {
"/classes": MOCK_CLASSES_LIST,
"/classes/wizard": MOCK_WIZARD_CLASS,
"/classes/wizard/levels": MOCK_WIZARD_LEVELS,
"/classes/fighter": {"index": "fighter", "name": "Fighter", "hit_die": 10, "proficiencies": [], "saving_throws": [], "starting_equipment": [], "subclasses": []},
"/classes/fighter/levels": [],
"/subclasses/evocation": MOCK_EVOCATION_SUBCLASS,
"/races": MOCK_RACES_LIST,
"/races/elf": MOCK_ELF_RACE,
"/subraces/high-elf": {"index": "high-elf", "name": "High Elf", "race": {"index": "elf"}, "ability_bonuses": []},
"/spells": MOCK_SPELLS_LIST,
"/spells/fireball": MOCK_FIREBALL_SPELL,
"/monsters": MOCK_MONSTERS_LIST,
"/equipment": MOCK_EQUIPMENT_LIST,
"/feats": MOCK_FEATS_LIST,
"/backgrounds": MOCK_BACKGROUNDS_LIST,
}
@pytest.fixture
def source_with_mocks(self, tmp_path, mock_responses):
"""Create SRDSource with mocked HTTP client."""
source = SRDSource(cache_dir=tmp_path / "cache")
async def mock_get(url):
# Remove base URL and version prefix (e.g., /2014 or /2024)
endpoint = url.replace("https://www.dnd5eapi.co/api", "")
# Strip version prefix: /2014/classes -> /classes
if endpoint.startswith("/2014") or endpoint.startswith("/2024"):
endpoint = endpoint[5:]
if endpoint in mock_responses:
return MockResponse(mock_responses[endpoint])
return MockResponse({}, 404)
return source, mock_get
def test_load_classes(self, source_with_mocks):
"""Test loading classes from mocked API."""
source, mock_get = source_with_mocks
async def run_test():
with patch('httpx.AsyncClient') as MockClient:
mock_client = MagicMock()
mock_client.get = AsyncMock(side_effect=mock_get)
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
mock_client.__aexit__ = AsyncMock(return_value=None)
MockClient.return_value = mock_client
await source.load()
assert source.is_loaded
assert len(source._classes) == 2
wizard = source.get_class("wizard")
assert wizard is not None
assert wizard.name == "Wizard"
assert wizard.hit_die == 6
assert wizard.source == "srd-2014"
run_async(run_test())
def test_load_spells(self, source_with_mocks):
"""Test loading spells from mocked API."""
source, mock_get = source_with_mocks
async def run_test():
with patch('httpx.AsyncClient') as MockClient:
mock_client = MagicMock()
mock_client.get = AsyncMock(side_effect=mock_get)
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
mock_client.__aexit__ = AsyncMock(return_value=None)
MockClient.return_value = mock_client
await source.load()
fireball = source.get_spell("fireball")
assert fireball is not None
assert fireball.name == "Fireball"
assert fireball.level == 3
assert "wizard" in fireball.classes
run_async(run_test())
def test_load_races(self, source_with_mocks):
"""Test loading races from mocked API."""
source, mock_get = source_with_mocks
async def run_test():
with patch('httpx.AsyncClient') as MockClient:
mock_client = MagicMock()
mock_client.get = AsyncMock(side_effect=mock_get)
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
mock_client.__aexit__ = AsyncMock(return_value=None)
MockClient.return_value = mock_client
await source.load()
elf = source.get_race("elf")
assert elf is not None
assert elf.name == "Elf"
assert elf.speed == 30
assert len(elf.ability_bonuses) == 1
run_async(run_test())
def test_search(self, source_with_mocks):
"""Test searching loaded content."""
source, mock_get = source_with_mocks
async def run_test():
with patch('httpx.AsyncClient') as MockClient:
mock_client = MagicMock()
mock_client.get = AsyncMock(side_effect=mock_get)
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
mock_client.__aexit__ = AsyncMock(return_value=None)
MockClient.return_value = mock_client
await source.load()
# Search for "wizard"
results = list(source.search("wizard"))
assert len(results) >= 1
names = [r.name for r in results]
assert "Wizard" in names
run_async(run_test())
def test_content_counts(self, source_with_mocks):
"""Test content counts after load."""
source, mock_get = source_with_mocks
async def run_test():
with patch('httpx.AsyncClient') as MockClient:
mock_client = MagicMock()
mock_client.get = AsyncMock(side_effect=mock_get)
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
mock_client.__aexit__ = AsyncMock(return_value=None)
MockClient.return_value = mock_client
await source.load()
counts = source.content_counts()
assert counts.classes == 2
assert counts.races == 1
assert counts.spells == 1
run_async(run_test())
class TestSRDSourceCaching:
"""Test caching behavior."""
def test_cache_created(self, tmp_path):
"""Test that cache files are created."""
cache_dir = tmp_path / "cache"
source = SRDSource(cache_dir=cache_dir)
# Manually create a cache file
cache_dir.mkdir(parents=True)
cache_file = cache_dir / "classes.json"
cache_file.write_text(json.dumps(MOCK_CLASSES_LIST))
# Load should use cache
async def run_test():
# Since we can't fully mock, just verify cache structure
assert cache_dir.exists()
assert cache_file.exists()
run_async(run_test())
def test_cache_used_on_second_load(self, tmp_path):
"""Test that cache is used instead of making HTTP requests."""
cache_dir = tmp_path / "cache"
cache_dir.mkdir(parents=True)
# Pre-populate cache
(cache_dir / "classes.json").write_text(json.dumps({"count": 0, "results": []}))
(cache_dir / "races.json").write_text(json.dumps({"count": 0, "results": []}))
(cache_dir / "spells.json").write_text(json.dumps({"count": 0, "results": []}))
(cache_dir / "monsters.json").write_text(json.dumps({"count": 0, "results": []}))
(cache_dir / "equipment.json").write_text(json.dumps({"count": 0, "results": []}))
(cache_dir / "feats.json").write_text(json.dumps({"count": 0, "results": []}))
(cache_dir / "backgrounds.json").write_text(json.dumps({"count": 0, "results": []}))
source = SRDSource(cache_dir=cache_dir)
async def run_test():
# This should use cache and not make HTTP requests
with patch('httpx.AsyncClient') as MockClient:
mock_client = MagicMock()
# Return empty results from cache
async def mock_get(url):
return MockResponse({"count": 0, "results": []})
mock_client.get = AsyncMock(side_effect=mock_get)
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
mock_client.__aexit__ = AsyncMock(return_value=None)
MockClient.return_value = mock_client
await source.load()
assert source.is_loaded
# Should have loaded from cache (0 items)
assert source.content_counts().classes == 0
run_async(run_test())
class TestSRDSourceRealAPI:
"""
Tests against the real 5e-srd-api.
These tests are marked slow and can be skipped with: pytest -m "not slow"
"""
@pytest.mark.slow
def test_real_api_fetch_wizard(self, tmp_path):
"""Test fetching wizard class from real API."""
source = SRDSource(cache_dir=tmp_path / "cache")
async def run_test():
await source.load()
wizard = source.get_class("wizard")
assert wizard is not None
assert wizard.name == "Wizard"
assert wizard.hit_die == 6
assert "evocation" in wizard.subclasses
fireball = source.get_spell("fireball")
assert fireball is not None
assert fireball.level == 3
run_async(run_test())