"""Tests for news collectors."""
import pytest
from unittest.mock import Mock, AsyncMock, patch
from datetime import datetime, timezone
import aiohttp
import hashlib
from src.collectors.base_collector import BaseCollector, CollectorError
from src.utils.deduplicator import NewsDeduplicator
from src.utils.preprocessor import NewsPreprocessor
class TestBaseCollector:
"""Test cases for BaseCollector abstract class."""
class ConcreteCollector(BaseCollector):
"""Concrete implementation for testing."""
async def collect(self, **kwargs):
"""Test implementation of collect method."""
return [{"title": "Test News", "url": "https://example.com/news/1"}]
async def parse(self, raw_data):
"""Test implementation of parse method."""
return {
"title": raw_data.get("title", ""),
"content": raw_data.get("content", ""),
"url": raw_data.get("url", ""),
"published_at": datetime.now(timezone.utc),
"source": self.source_name
}
@pytest.fixture
def collector(self):
"""Create a concrete collector for testing."""
return self.ConcreteCollector(source_name="test")
def test_base_collector_initialization(self, collector):
"""Test BaseCollector initialization."""
assert collector.source_name == "test"
assert hasattr(collector, 'session')
assert hasattr(collector, 'rate_limiter')
assert hasattr(collector, 'preprocessor')
assert hasattr(collector, 'deduplicator')
def test_base_collector_abstract_methods(self):
"""Test that BaseCollector cannot be instantiated directly."""
with pytest.raises(TypeError):
BaseCollector(source_name="test")
@pytest.mark.asyncio
async def test_collect_and_parse_flow(self, collector):
"""Test the collect and parse flow."""
# Test collect method
raw_data = await collector.collect()
assert len(raw_data) == 1
assert raw_data[0]["title"] == "Test News"
# Test parse method
parsed = await collector.parse(raw_data[0])
assert parsed["title"] == "Test News"
assert parsed["source"] == "test"
assert isinstance(parsed["published_at"], datetime)
@pytest.mark.asyncio
async def test_validate_method(self, collector):
"""Test the validate method."""
# Valid data
valid_data = {
"title": "Valid News",
"url": "https://example.com/valid",
"content": "Valid content",
"published_at": datetime.now(timezone.utc),
"source": "test"
}
result = await collector.validate(valid_data)
assert result is True
# Invalid data - missing title
invalid_data = {
"url": "https://example.com/invalid",
"content": "Content without title"
}
result = await collector.validate(invalid_data)
assert result is False
@pytest.mark.asyncio
async def test_generate_hash(self, collector):
"""Test hash generation for news items."""
data = {
"title": "Test News Title",
"url": "https://example.com/news",
"content": "Test content"
}
hash1 = collector.generate_hash(data)
hash2 = collector.generate_hash(data)
# Same data should generate same hash
assert hash1 == hash2
assert len(hash1) == 64 # SHA-256 hex length
# Different data should generate different hash
data2 = data.copy()
data2["title"] = "Different Title"
hash3 = collector.generate_hash(data2)
assert hash1 != hash3
@pytest.mark.asyncio
async def test_collect_with_filters(self, collector):
"""Test collect method with various filters."""
# Test with keyword filter
result = await collector.collect(keyword="technology")
assert isinstance(result, list)
# Test with limit
result = await collector.collect(limit=5)
assert isinstance(result, list)
@pytest.mark.asyncio
async def test_rate_limiting(self, collector):
"""Test that rate limiting is applied."""
# Mock the rate limiter
mock_rate_limiter = AsyncMock()
collector.rate_limiter = mock_rate_limiter
# Test through collect_and_process which uses rate limiter
await collector.collect_and_process()
# Rate limiter should have been used
mock_rate_limiter.__aenter__.assert_called_once()
@pytest.mark.asyncio
async def test_error_handling(self, collector):
"""Test error handling in collector methods."""
# Mock collect to raise an exception
async def failing_collect(**kwargs):
raise aiohttp.ClientError("Network error")
collector.collect = failing_collect
with pytest.raises(CollectorError):
await collector.collect_and_process()
@pytest.mark.asyncio
async def test_collect_and_process_full_pipeline(self, collector):
"""Test the full collect and process pipeline."""
# Mock external dependencies
collector.deduplicator = Mock()
collector.deduplicator.is_duplicate = AsyncMock(return_value=False)
collector.preprocessor = Mock()
collector.preprocessor.clean_html = Mock(return_value="Cleaned content")
collector.preprocessor.normalize_text = Mock(return_value="Normalized text")
# Run the full pipeline
results = await collector.collect_and_process(limit=1)
# Verify the pipeline was executed
assert isinstance(results, list)
collector.deduplicator.is_duplicate.assert_called()
def test_url_normalization(self, collector):
"""Test URL normalization utility."""
test_urls = [
("HTTP://EXAMPLE.COM/NEWS", "https://example.com/news"),
("example.com/news?utm_source=test", "https://example.com/news"),
("https://example.com/news/../other", "https://example.com/other")
]
for input_url, expected in test_urls:
normalized = collector.normalize_url(input_url)
assert normalized == expected
@pytest.mark.asyncio
async def test_session_management(self, collector):
"""Test HTTP session management."""
# Session should be None initially
assert collector.session is None
# Initialize should create session
await collector.initialize()
assert collector.session is not None
# Test session cleanup
await collector.close()
assert collector.session is None
@pytest.mark.asyncio
async def test_retry_mechanism(self, collector):
"""Test retry mechanism for failed requests."""
# Mock a method that fails twice then succeeds
call_count = 0
async def flaky_method():
nonlocal call_count
call_count += 1
if call_count < 3:
raise aiohttp.ClientError("Temporary failure")
return {"success": True}
# Test retry logic (would be implemented in base class)
with patch.object(collector, 'collect', flaky_method):
result = await collector.collect_with_retry(max_retries=3)
assert result == {"success": True}
assert call_count == 3
class TestNewsDeduplicator:
"""Test cases for NewsDeduplicator."""
@pytest.fixture
def deduplicator(self):
"""Create deduplicator instance."""
return NewsDeduplicator(threshold=0.85)
@pytest.mark.asyncio
async def test_duplicate_detection_identical(self, deduplicator):
"""Test detection of identical news items."""
news1 = {
"title": "Breaking News: Market Update",
"content": "The market closed higher today with significant gains.",
"url": "https://example.com/news/1"
}
news2 = news1.copy()
news2["url"] = "https://different.com/news/2" # Different URL, same content
is_duplicate = await deduplicator.is_duplicate(news1, news2)
assert is_duplicate is True
@pytest.mark.asyncio
async def test_duplicate_detection_similar(self, deduplicator):
"""Test detection of similar but not identical news."""
news1 = {
"title": "Market Update: Stocks Rise",
"content": "The stock market experienced gains today.",
"url": "https://example.com/news/1"
}
news2 = {
"title": "Stocks Rise in Today's Market",
"content": "Today's stock market saw significant gains.",
"url": "https://example.com/news/2"
}
is_duplicate = await deduplicator.is_duplicate(news1, news2)
# This might be True or False depending on threshold
assert isinstance(is_duplicate, bool)
@pytest.mark.asyncio
async def test_duplicate_detection_different(self, deduplicator):
"""Test that completely different news are not marked as duplicates."""
news1 = {
"title": "Stock Market Update",
"content": "Market closed higher today.",
"url": "https://example.com/news/1"
}
news2 = {
"title": "Weather Forecast",
"content": "Tomorrow will be sunny and warm.",
"url": "https://example.com/news/2"
}
is_duplicate = await deduplicator.is_duplicate(news1, news2)
assert is_duplicate is False
def test_simhash_calculation(self, deduplicator):
"""Test SimHash calculation."""
text1 = "This is a test article about technology news"
text2 = "This is a test article about tech news"
text3 = "Weather forecast for tomorrow"
hash1 = deduplicator.calculate_simhash(text1)
hash2 = deduplicator.calculate_simhash(text2)
hash3 = deduplicator.calculate_simhash(text3)
# Similar texts should have similar hashes
similarity_12 = deduplicator.calculate_similarity(hash1, hash2)
similarity_13 = deduplicator.calculate_similarity(hash1, hash3)
assert similarity_12 > similarity_13
assert 0 <= similarity_12 <= 1
assert 0 <= similarity_13 <= 1
def test_threshold_configuration(self):
"""Test different threshold configurations."""
strict_deduplicator = NewsDeduplicator(threshold=0.95)
lenient_deduplicator = NewsDeduplicator(threshold=0.7)
assert strict_deduplicator.threshold == 0.95
assert lenient_deduplicator.threshold == 0.7
class TestNewsPreprocessor:
"""Test cases for NewsPreprocessor."""
@pytest.fixture
def preprocessor(self):
"""Create preprocessor instance."""
return NewsPreprocessor()
def test_clean_html(self, preprocessor):
"""Test HTML cleaning."""
html_content = """
<html>
<body>
<h1>News Title</h1>
<p>This is <strong>important</strong> news.</p>
<script>alert('malicious');</script>
<div class="ad">Advertisement</div>
</body>
</html>
"""
cleaned = preprocessor.clean_html(html_content)
assert "News Title" in cleaned
assert "important" in cleaned
assert "<script>" not in cleaned
assert "<div" not in cleaned
assert "alert" not in cleaned
def test_normalize_text(self, preprocessor):
"""Test text normalization."""
messy_text = " This has\t\textra whitespace\n\nand characters! "
normalized = preprocessor.normalize_text(messy_text)
assert normalized == "This has extra whitespace and characters!"
assert " " not in normalized # No double spaces
assert not normalized.startswith(" ") # No leading space
assert not normalized.endswith(" ") # No trailing space
def test_extract_metadata(self, preprocessor):
"""Test metadata extraction from text."""
news_text = """
[서울=뉴스1] 김기자 기자 = 2024년 1월 15일 -
삼성전자가 새로운 스마트폰을 출시했다고 발표했다.
"""
metadata = preprocessor.extract_metadata(news_text)
assert "date" in metadata
assert "reporter" in metadata
assert "company" in metadata
# Check extracted values
assert "김기자" in metadata["reporter"]
assert "삼성전자" in metadata["company"]
def test_remove_ads_and_promotions(self, preprocessor):
"""Test removal of advertisements and promotional content."""
content_with_ads = """
실제 뉴스 내용입니다.
[광고] 이 제품을 지금 구매하세요!
뉴스가 계속됩니다.
※ 이 기사는 광고입니다.
"""
cleaned = preprocessor.remove_ads_and_promotions(content_with_ads)
assert "실제 뉴스 내용" in cleaned
assert "뉴스가 계속됩니다" in cleaned
assert "광고" not in cleaned
assert "구매하세요" not in cleaned
def test_standardize_quotes(self, preprocessor):
"""Test quote standardization."""
text_with_various_quotes = '"Hello" and \'World\' and „German" and «French»'
standardized = preprocessor.standardize_quotes(text_with_various_quotes)
# Check that at least the fancy quotes are standardized
assert '"German"' in standardized
assert '"French"' in standardized
# Single quotes might remain as single quotes
assert 'Hello' in standardized
assert 'World' in standardized
def test_process_pipeline(self, preprocessor):
"""Test the full preprocessing pipeline."""
raw_html = """
<html>
<body>
<h1>테스트 뉴스</h1>
<p>이것은 <strong>중요한</strong> 뉴스입니다.</p>
[광고] 제품 홍보
<script>alert('bad');</script>
</body>
</html>
"""
processed = preprocessor.process(raw_html)
assert "테스트 뉴스" in processed
assert "중요한 뉴스" in processed
assert "script" not in processed
assert "광고" not in processed
assert "alert" not in processed