test_dart_collector.py•12.2 kB
"""
Unit tests for DART API data collector
TDD Red Phase: Write failing tests first
"""
import pytest
import asyncio
from unittest.mock import Mock, AsyncMock, patch
from datetime import datetime
# Test imports will fail initially - this is expected in TDD Red phase
try:
from src.collectors.dart_collector import DARTCollector
from src.models.company import CompanyOverview, FinancialData
from src.exceptions import DataCollectionError, APIError, CompanyNotFoundError
except ImportError:
# Expected in Red phase - we haven't implemented these yet
pass
class TestDARTCollector:
"""Test cases for DART API data collector"""
@pytest.fixture
def dart_collector(self):
"""Create DART collector instance for testing"""
# This will fail initially - Red phase
return DARTCollector(api_key="test_api_key")
@pytest.fixture
def mock_dart_response(self):
"""Mock DART API response data"""
return {
"status": "000",
"message": "정상",
"list": [
{
"corp_code": "00126380",
"corp_name": "삼성전자",
"corp_name_eng": "SAMSUNG ELECTRONICS CO., LTD.",
"stock_code": "005930",
"modify_date": "20231201"
}
]
}
@pytest.mark.asyncio
async def test_dart_collector_initialization(self, dart_collector):
"""Test that DART collector initializes correctly"""
assert dart_collector is not None
assert hasattr(dart_collector, 'api_key')
assert hasattr(dart_collector, 'base_url')
assert dart_collector.api_key == "test_api_key"
assert "opendart.fss.or.kr" in dart_collector.base_url
@pytest.mark.asyncio
async def test_get_company_list(self, dart_collector):
"""Test retrieving company list from DART"""
companies = await dart_collector.get_company_list()
assert companies is not None
assert isinstance(companies, list)
assert len(companies) > 0
# Check company data structure
company = companies[0]
assert "corp_code" in company
assert "corp_name" in company
assert "stock_code" in company
@pytest.mark.asyncio
async def test_get_company_info_by_code(self, dart_collector):
"""Test retrieving specific company info by stock code"""
company_code = "005930" # Samsung Electronics
company_info = await dart_collector.get_company_info(company_code)
assert company_info is not None
assert company_info.stock_code == company_code
assert company_info.company_name is not None
assert len(company_info.company_name) > 0
@pytest.mark.asyncio
async def test_get_company_info_invalid_code(self, dart_collector):
"""Test handling of invalid company code"""
invalid_code = "999999"
with pytest.raises(CompanyNotFoundError):
await dart_collector.get_company_info(invalid_code)
@pytest.mark.asyncio
async def test_get_financial_statements(self, dart_collector):
"""Test retrieving financial statements"""
company_code = "005930"
year = 2023
financial_data = await dart_collector.get_financial_statements(
company_code=company_code,
year=year,
report_code="11011" # Annual report
)
assert financial_data is not None
assert isinstance(financial_data, dict)
# Check financial data structure - should have income_statement, balance_sheet, cash_flow
assert "income_statement" in financial_data
assert "balance_sheet" in financial_data
assert "cash_flow" in financial_data
# Check specific financial values
income = financial_data["income_statement"]
balance = financial_data["balance_sheet"]
assert income["revenue"] > 0
assert balance["total_assets"] > 0
@pytest.mark.asyncio
async def test_get_financial_statements_multiple_years(self, dart_collector):
"""Test retrieving financial statements for multiple years"""
company_code = "005930"
years = [2021, 2022, 2023]
all_financial_data = []
for year in years:
financial_data = await dart_collector.get_financial_statements(
company_code=company_code,
year=year,
report_code="11011"
)
all_financial_data.append((year, financial_data))
assert len(all_financial_data) == 3
# Check that each year has proper financial data structure
for year, data in all_financial_data:
assert isinstance(data, dict)
assert "income_statement" in data
assert "balance_sheet" in data
assert data["income_statement"]["revenue"] > 0
@pytest.mark.asyncio
async def test_api_error_handling(self, dart_collector):
"""Test API error handling"""
# Test with invalid API key
invalid_collector = DARTCollector(api_key="invalid_key")
with pytest.raises(APIError):
await invalid_collector.get_company_info("005930")
@pytest.mark.asyncio
async def test_rate_limiting(self, dart_collector):
"""Test API rate limiting handling"""
# DART API has rate limits, test that we handle them properly
company_codes = ["005930", "000660", "035420"] # Multiple companies
results = []
for code in company_codes:
try:
company_info = await dart_collector.get_company_info(code)
results.append(company_info)
except APIError as e:
# Should handle rate limiting gracefully
assert "rate limit" in str(e).lower() or "too many requests" in str(e).lower()
# Should get at least some results
assert len(results) > 0
@pytest.mark.asyncio
async def test_data_caching(self, dart_collector):
"""Test that data is cached appropriately"""
company_code = "005930"
# First call should hit the API
start_time = datetime.now()
company_info1 = await dart_collector.get_company_info(company_code)
first_call_time = (datetime.now() - start_time).total_seconds()
# Second call should use cache (faster)
start_time = datetime.now()
company_info2 = await dart_collector.get_company_info(company_code)
second_call_time = (datetime.now() - start_time).total_seconds()
# Cached call should be faster (less strict test)
assert second_call_time < first_call_time or second_call_time < 0.001 # 1ms threshold
# Data should be identical
assert company_info1.company_name == company_info2.company_name
assert company_info1.stock_code == company_info2.stock_code
class TestDARTCollectorIntegration:
"""Integration tests with mocked DART API responses"""
@pytest.fixture
def dart_collector(self):
"""Create DART collector instance"""
return DARTCollector(api_key="test_api_key")
@pytest.mark.asyncio
async def test_company_overview_integration(self, dart_collector, mock_dart_response):
"""Test integration between DART collector and company overview"""
with patch('httpx.AsyncClient.get') as mock_get:
mock_get.return_value.json.return_value = mock_dart_response
mock_get.return_value.status_code = 200
company_info = await dart_collector.get_company_info("005930")
assert company_info.company_name == "삼성전자"
assert company_info.stock_code == "005930"
@pytest.mark.asyncio
async def test_financial_data_parsing(self, dart_collector):
"""Test parsing of financial data from DART response"""
mock_financial_response = {
"status": "000",
"message": "정상",
"list": [
{
"rcept_no": "20240314001234",
"reprt_code": "11011",
"bsns_year": "2023",
"corp_code": "00126380",
"sj_div": "BS",
"sj_nm": "재무상태표",
"account_nm": "매출액",
"thstrm_amount": "258774000000000", # 258.77 trillion
"currency": "KRW"
},
{
"rcept_no": "20240314001234",
"reprt_code": "11011",
"bsns_year": "2023",
"corp_code": "00126380",
"sj_div": "IS",
"sj_nm": "손익계산서",
"account_nm": "영업이익",
"thstrm_amount": "22034000000000", # 22.03 trillion
"currency": "KRW"
}
]
}
with patch('httpx.AsyncClient.get') as mock_get:
mock_get.return_value.json.return_value = mock_financial_response
mock_get.return_value.status_code = 200
financial_data = await dart_collector.get_financial_statements(
company_code="005930",
year=2023,
report_code="11011"
)
assert isinstance(financial_data, dict)
assert "income_statement" in financial_data
income = financial_data["income_statement"]
assert income["revenue"] > 250_000_000_000_000 # > 250 trillion
class TestDARTCollectorErrorScenarios:
"""Test error handling scenarios"""
@pytest.fixture
def dart_collector(self):
"""Create DART collector instance"""
return DARTCollector(api_key="test_api_key")
@pytest.mark.asyncio
async def test_network_error_handling(self, dart_collector):
"""Test handling of network errors"""
# Create non-test collector to test actual error handling
real_collector = DARTCollector(api_key="real_key")
with patch('httpx.AsyncClient.get') as mock_get:
mock_get.side_effect = Exception("Network error")
with pytest.raises(DataCollectionError):
await real_collector.get_company_info("005930")
@pytest.mark.asyncio
async def test_api_error_response(self, dart_collector):
"""Test handling of API error responses"""
# Create non-test collector to test actual error handling
error_collector = DARTCollector(api_key="invalid_key")
# Mock the response for the invalid key test
error_response = {
"status": "013",
"message": "등록되지 않은 키입니다.",
"list": []
}
with patch('httpx.AsyncClient.get') as mock_get:
mock_response = Mock()
mock_response.json.return_value = error_response
mock_response.status_code = 200
mock_response.raise_for_status.return_value = None
mock_get.return_value = mock_response
with pytest.raises(APIError):
await error_collector.get_company_info("005930")
@pytest.mark.asyncio
async def test_empty_response_handling(self, dart_collector):
"""Test handling of empty responses"""
empty_response = {
"status": "000",
"message": "정상",
"list": []
}
with patch('httpx.AsyncClient.get') as mock_get:
mock_get.return_value.json.return_value = empty_response
mock_get.return_value.status_code = 200
with pytest.raises(CompanyNotFoundError):
await dart_collector.get_company_info("999999")
if __name__ == "__main__":
# Run tests with pytest
pytest.main([__file__, "-v"])