"""Tests for WebSocket real-time communication."""
import pytest
import asyncio
import json
from unittest.mock import Mock, AsyncMock, patch, MagicMock
from datetime import datetime, timezone
from typing import Dict, List, Any
from fastapi.testclient import TestClient
from fastapi.websockets import WebSocketDisconnect
import websockets
from websockets.exceptions import ConnectionClosed
from src.server.websocket_server import (
WebSocketManager,
WebSocketConnection,
WebSocketMessage,
WebSocketResponse,
NewsWebSocketHandler,
AnalysisWebSocketHandler,
MonitoringWebSocketHandler,
ConnectionManager,
RealTimeNotifier,
MessageBroadcaster,
SubscriptionManager,
WebSocketError
)
class TestWebSocketManager:
"""Test cases for WebSocket Manager."""
@pytest.fixture
def ws_manager(self):
"""Create WebSocketManager instance for testing."""
return WebSocketManager()
@pytest.fixture
def mock_websocket(self):
"""Create mock WebSocket connection."""
websocket = AsyncMock()
websocket.client = Mock()
websocket.client.host = "127.0.0.1"
websocket.client.port = 8000
return websocket
@pytest.mark.asyncio
async def test_connection_manager_initialization(self, ws_manager):
"""Test WebSocket manager initialization."""
assert ws_manager is not None
assert hasattr(ws_manager, 'active_connections')
assert hasattr(ws_manager, 'connect')
assert hasattr(ws_manager, 'disconnect')
assert len(ws_manager.active_connections) == 0
@pytest.mark.asyncio
async def test_websocket_connection_lifecycle(self, ws_manager, mock_websocket):
"""Test WebSocket connection and disconnection."""
connection_id = "test_connection_1"
# Test connection
await ws_manager.connect(mock_websocket, connection_id)
assert connection_id in ws_manager.active_connections
assert len(ws_manager.active_connections) == 1
# Test disconnection
await ws_manager.disconnect(connection_id)
assert connection_id not in ws_manager.active_connections
assert len(ws_manager.active_connections) == 0
@pytest.mark.asyncio
async def test_multiple_connections(self, ws_manager):
"""Test handling multiple WebSocket connections."""
connections = []
for i in range(5):
websocket = AsyncMock()
websocket.client = Mock()
websocket.client.host = f"127.0.0.{i+1}"
connection_id = f"connection_{i}"
await ws_manager.connect(websocket, connection_id)
connections.append(connection_id)
assert len(ws_manager.active_connections) == 5
# Disconnect all
for connection_id in connections:
await ws_manager.disconnect(connection_id)
assert len(ws_manager.active_connections) == 0
@pytest.mark.asyncio
async def test_send_personal_message(self, ws_manager, mock_websocket):
"""Test sending personal message to specific connection."""
connection_id = "test_connection"
message = {"type": "news_update", "data": {"title": "Test News"}}
await ws_manager.connect(mock_websocket, connection_id)
await ws_manager.send_personal_message(message, connection_id)
mock_websocket.send_text.assert_called_once_with(json.dumps(message))
@pytest.mark.asyncio
async def test_broadcast_message(self, ws_manager):
"""Test broadcasting message to all connections."""
# Create multiple connections
websockets = []
for i in range(3):
websocket = AsyncMock()
websocket.client = Mock()
websocket.client.host = f"127.0.0.{i+1}"
await ws_manager.connect(websocket, f"connection_{i}")
websockets.append(websocket)
message = {"type": "broadcast", "data": "Hello everyone"}
await ws_manager.broadcast(message)
# Verify all connections received the message
for websocket in websockets:
websocket.send_text.assert_called_once_with(json.dumps(message))
@pytest.mark.asyncio
async def test_connection_error_handling(self, ws_manager, mock_websocket):
"""Test handling connection errors."""
connection_id = "error_connection"
# Simulate connection error
mock_websocket.send_text.side_effect = ConnectionClosed(None, None)
await ws_manager.connect(mock_websocket, connection_id)
# Should handle error gracefully and remove connection
with pytest.raises(WebSocketError):
await ws_manager.send_personal_message({"test": "message"}, connection_id)
@pytest.mark.asyncio
async def test_connection_heartbeat(self, ws_manager, mock_websocket):
"""Test WebSocket connection heartbeat."""
connection_id = "heartbeat_connection"
await ws_manager.connect(mock_websocket, connection_id)
# Test ping/pong
await ws_manager.ping_connection(connection_id)
mock_websocket.ping.assert_called_once()
@pytest.mark.asyncio
async def test_connection_metadata(self, ws_manager, mock_websocket):
"""Test storing and retrieving connection metadata."""
connection_id = "metadata_connection"
metadata = {
"user_id": "user_123",
"subscriptions": ["news", "analysis"],
"connected_at": datetime.now(timezone.utc).isoformat()
}
await ws_manager.connect(mock_websocket, connection_id, metadata)
stored_metadata = ws_manager.get_connection_metadata(connection_id)
assert stored_metadata["user_id"] == "user_123"
assert "news" in stored_metadata["subscriptions"]
class TestNewsWebSocketHandler:
"""Test cases for News WebSocket Handler."""
@pytest.fixture
def news_handler(self):
"""Create NewsWebSocketHandler instance."""
return NewsWebSocketHandler()
@pytest.fixture
def mock_connection(self):
"""Create mock WebSocket connection."""
connection = Mock()
connection.websocket = AsyncMock()
connection.connection_id = "test_connection"
connection.metadata = {"subscriptions": ["news_updates"]}
return connection
@pytest.mark.asyncio
async def test_news_subscription(self, news_handler, mock_connection):
"""Test subscribing to news updates."""
subscription_data = {
"action": "subscribe",
"topics": ["삼성전자", "LG전자"],
"filters": {
"sentiment": ["positive", "negative"],
"sources": ["naver", "daum"]
}
}
await news_handler.handle_subscription(mock_connection, subscription_data)
# Verify subscription was processed
assert mock_connection.connection_id in news_handler.subscriptions
subscription = news_handler.subscriptions[mock_connection.connection_id]
assert "삼성전자" in subscription["topics"]
assert "positive" in subscription["filters"]["sentiment"]
@pytest.mark.asyncio
async def test_news_unsubscription(self, news_handler, mock_connection):
"""Test unsubscribing from news updates."""
# First subscribe
subscription_data = {
"action": "subscribe",
"topics": ["삼성전자"]
}
await news_handler.handle_subscription(mock_connection, subscription_data)
# Then unsubscribe
unsubscribe_data = {
"action": "unsubscribe",
"topics": ["삼성전자"]
}
await news_handler.handle_subscription(mock_connection, unsubscribe_data)
# Verify unsubscription
if mock_connection.connection_id in news_handler.subscriptions:
subscription = news_handler.subscriptions[mock_connection.connection_id]
assert "삼성전자" not in subscription.get("topics", [])
@pytest.mark.asyncio
async def test_real_time_news_delivery(self, news_handler, mock_connection):
"""Test real-time news delivery to subscribers."""
# Subscribe to news
subscription_data = {
"action": "subscribe",
"topics": ["삼성전자"]
}
await news_handler.handle_subscription(mock_connection, subscription_data)
# Simulate new news article
news_article = {
"id": "news_123",
"title": "삼성전자 실적 발표",
"content": "삼성전자가 좋은 실적을 발표했습니다.",
"entity": "삼성전자",
"sentiment": {"score": 0.8, "label": "positive"},
"published_at": datetime.now(timezone.utc).isoformat()
}
await news_handler.broadcast_news_update(news_article)
# Verify news was sent to subscriber
mock_connection.websocket.send_text.assert_called()
call_args = mock_connection.websocket.send_text.call_args[0][0]
sent_data = json.loads(call_args)
assert sent_data["type"] == "news_update"
assert sent_data["data"]["id"] == "news_123"
@pytest.mark.asyncio
async def test_news_filtering(self, news_handler, mock_connection):
"""Test news filtering based on subscription criteria."""
# Subscribe with filters
subscription_data = {
"action": "subscribe",
"topics": ["삼성전자"],
"filters": {
"sentiment": ["positive"],
"min_sentiment_score": 0.7
}
}
await news_handler.handle_subscription(mock_connection, subscription_data)
# Test positive news (should be delivered)
positive_news = {
"id": "news_positive",
"entity": "삼성전자",
"sentiment": {"score": 0.8, "label": "positive"}
}
should_deliver = news_handler.should_deliver_news(
mock_connection.connection_id, positive_news
)
assert should_deliver is True
# Test negative news (should be filtered out)
negative_news = {
"id": "news_negative",
"entity": "삼성전자",
"sentiment": {"score": 0.3, "label": "negative"}
}
should_deliver = news_handler.should_deliver_news(
mock_connection.connection_id, negative_news
)
assert should_deliver is False
@pytest.mark.asyncio
async def test_news_rate_limiting(self, news_handler, mock_connection):
"""Test rate limiting for news updates."""
# Subscribe with rate limit
subscription_data = {
"action": "subscribe",
"topics": ["삼성전자"],
"rate_limit": {"max_per_minute": 10}
}
await news_handler.handle_subscription(mock_connection, subscription_data)
# Send multiple news updates rapidly
for i in range(15):
news_article = {
"id": f"news_{i}",
"entity": "삼성전자",
"sentiment": {"score": 0.8, "label": "positive"}
}
await news_handler.broadcast_news_update(news_article)
# Should respect rate limit
call_count = mock_connection.websocket.send_text.call_count
assert call_count <= 10
@pytest.mark.asyncio
async def test_news_aggregation(self, news_handler, mock_connection):
"""Test news aggregation for batch delivery."""
# Subscribe with aggregation
subscription_data = {
"action": "subscribe",
"topics": ["삼성전자"],
"aggregation": {
"enabled": True,
"window_seconds": 5,
"max_items": 5
}
}
await news_handler.handle_subscription(mock_connection, subscription_data)
# Send multiple news items
for i in range(3):
news_article = {
"id": f"news_{i}",
"entity": "삼성전자",
"title": f"뉴스 {i}"
}
await news_handler.add_to_aggregation_buffer(
mock_connection.connection_id, news_article
)
# Trigger aggregation
await news_handler.flush_aggregation_buffer(mock_connection.connection_id)
# Should send aggregated message
mock_connection.websocket.send_text.assert_called()
call_args = mock_connection.websocket.send_text.call_args[0][0]
sent_data = json.loads(call_args)
assert sent_data["type"] == "news_batch"
assert len(sent_data["data"]["articles"]) == 3
class TestAnalysisWebSocketHandler:
"""Test cases for Analysis WebSocket Handler."""
@pytest.fixture
def analysis_handler(self):
"""Create AnalysisWebSocketHandler instance."""
return AnalysisWebSocketHandler()
@pytest.fixture
def mock_connection(self):
"""Create mock WebSocket connection."""
connection = Mock()
connection.websocket = AsyncMock()
connection.connection_id = "analysis_connection"
return connection
@pytest.mark.asyncio
async def test_analysis_request_handling(self, analysis_handler, mock_connection):
"""Test handling real-time analysis requests."""
analysis_request = {
"type": "analysis_request",
"analysis_type": "sentiment",
"news_ids": ["news_1", "news_2"],
"request_id": "req_123"
}
with patch.object(analysis_handler, '_process_analysis') as mock_process:
mock_process.return_value = {
"request_id": "req_123",
"results": {"sentiment": "positive", "score": 0.8}
}
await analysis_handler.handle_analysis_request(mock_connection, analysis_request)
mock_process.assert_called_once_with(analysis_request)
@pytest.mark.asyncio
async def test_analysis_progress_updates(self, analysis_handler, mock_connection):
"""Test sending analysis progress updates."""
progress_data = {
"request_id": "req_123",
"status": "processing",
"progress": 45,
"estimated_completion": 30 # seconds
}
await analysis_handler.send_progress_update(mock_connection, progress_data)
mock_connection.websocket.send_text.assert_called()
call_args = mock_connection.websocket.send_text.call_args[0][0]
sent_data = json.loads(call_args)
assert sent_data["type"] == "analysis_progress"
assert sent_data["data"]["progress"] == 45
@pytest.mark.asyncio
async def test_batch_analysis_streaming(self, analysis_handler, mock_connection):
"""Test streaming results for batch analysis."""
batch_request = {
"type": "batch_analysis",
"analyses": [
{"type": "sentiment", "news_ids": ["news_1"]},
{"type": "market_impact", "news_ids": ["news_2"]},
{"type": "summarization", "news_ids": ["news_3"]}
],
"request_id": "batch_123"
}
await analysis_handler.handle_batch_analysis(mock_connection, batch_request)
# Should send multiple progress updates
assert mock_connection.websocket.send_text.call_count >= 3
@pytest.mark.asyncio
async def test_analysis_caching(self, analysis_handler, mock_connection):
"""Test caching of analysis results."""
# First request
analysis_request = {
"type": "analysis_request",
"analysis_type": "sentiment",
"news_ids": ["news_1"],
"request_id": "req_1"
}
with patch.object(analysis_handler, '_process_analysis') as mock_process:
mock_process.return_value = {"sentiment": "positive"}
await analysis_handler.handle_analysis_request(mock_connection, analysis_request)
# Second identical request (should use cache)
analysis_request["request_id"] = "req_2"
await analysis_handler.handle_analysis_request(mock_connection, analysis_request)
# Process should be called only once due to caching
assert mock_process.call_count == 1
@pytest.mark.asyncio
async def test_analysis_error_handling(self, analysis_handler, mock_connection):
"""Test error handling in analysis requests."""
invalid_request = {
"type": "analysis_request",
"analysis_type": "invalid_type",
"news_ids": [],
"request_id": "req_error"
}
await analysis_handler.handle_analysis_request(mock_connection, invalid_request)
# Should send error response
mock_connection.websocket.send_text.assert_called()
call_args = mock_connection.websocket.send_text.call_args[0][0]
sent_data = json.loads(call_args)
assert sent_data["type"] == "error"
assert "invalid" in sent_data["message"].lower()
class TestMonitoringWebSocketHandler:
"""Test cases for Monitoring WebSocket Handler."""
@pytest.fixture
def monitoring_handler(self):
"""Create MonitoringWebSocketHandler instance."""
return MonitoringWebSocketHandler()
@pytest.fixture
def mock_connection(self):
"""Create mock WebSocket connection."""
connection = Mock()
connection.websocket = AsyncMock()
connection.connection_id = "monitoring_connection"
return connection
@pytest.mark.asyncio
async def test_system_metrics_streaming(self, monitoring_handler, mock_connection):
"""Test streaming system metrics."""
subscription_data = {
"action": "subscribe",
"metrics": ["cpu_usage", "memory_usage", "response_time"],
"interval": 5 # seconds
}
await monitoring_handler.handle_subscription(mock_connection, subscription_data)
# Simulate metrics update
metrics_data = {
"cpu_usage": 45.2,
"memory_usage": 67.8,
"response_time": 120,
"timestamp": datetime.now(timezone.utc).isoformat()
}
await monitoring_handler.broadcast_metrics(metrics_data)
mock_connection.websocket.send_text.assert_called()
call_args = mock_connection.websocket.send_text.call_args[0][0]
sent_data = json.loads(call_args)
assert sent_data["type"] == "metrics_update"
assert sent_data["data"]["cpu_usage"] == 45.2
@pytest.mark.asyncio
async def test_alert_notifications(self, monitoring_handler, mock_connection):
"""Test real-time alert notifications."""
alert_data = {
"alert_id": "alert_123",
"type": "high_error_rate",
"severity": "high",
"message": "Error rate exceeded 5% threshold",
"triggered_at": datetime.now(timezone.utc).isoformat(),
"metrics": {"error_rate": 0.06}
}
await monitoring_handler.send_alert(mock_connection, alert_data)
mock_connection.websocket.send_text.assert_called()
call_args = mock_connection.websocket.send_text.call_args[0][0]
sent_data = json.loads(call_args)
assert sent_data["type"] == "alert"
assert sent_data["data"]["severity"] == "high"
@pytest.mark.asyncio
async def test_dashboard_data_updates(self, monitoring_handler, mock_connection):
"""Test dashboard data updates."""
dashboard_data = {
"overview": {
"total_news": 25000,
"active_users": 150,
"system_health": "healthy"
},
"charts": {
"news_volume": [120, 135, 145],
"sentiment_distribution": {"positive": 65, "negative": 15, "neutral": 20}
},
"last_updated": datetime.now(timezone.utc).isoformat()
}
await monitoring_handler.update_dashboard(mock_connection, dashboard_data)
mock_connection.websocket.send_text.assert_called()
call_args = mock_connection.websocket.send_text.call_args[0][0]
sent_data = json.loads(call_args)
assert sent_data["type"] == "dashboard_update"
assert sent_data["data"]["overview"]["total_news"] == 25000
@pytest.mark.asyncio
async def test_monitoring_thresholds(self, monitoring_handler, mock_connection):
"""Test monitoring with custom thresholds."""
threshold_config = {
"action": "configure_thresholds",
"thresholds": {
"cpu_usage": {"warning": 70, "critical": 85},
"memory_usage": {"warning": 80, "critical": 90},
"response_time": {"warning": 200, "critical": 500}
}
}
await monitoring_handler.handle_threshold_config(mock_connection, threshold_config)
# Test threshold violation
metrics_data = {
"cpu_usage": 75, # Above warning threshold
"memory_usage": 85, # Above warning threshold
"response_time": 180 # Normal
}
violations = monitoring_handler.check_thresholds(
mock_connection.connection_id, metrics_data
)
assert len(violations) == 2
assert violations[0]["metric"] == "cpu_usage"
assert violations[0]["level"] == "warning"
@pytest.mark.asyncio
async def test_log_streaming(self, monitoring_handler, mock_connection):
"""Test real-time log streaming."""
log_subscription = {
"action": "subscribe_logs",
"filters": {
"level": ["ERROR", "WARNING"],
"components": ["news_collector", "analysis_engine"],
"tail_lines": 100
}
}
await monitoring_handler.handle_log_subscription(mock_connection, log_subscription)
# Simulate log entry
log_entry = {
"timestamp": datetime.now(timezone.utc).isoformat(),
"level": "ERROR",
"component": "news_collector",
"message": "Failed to collect news from source",
"context": {"source": "naver", "error": "timeout"}
}
await monitoring_handler.broadcast_log_entry(log_entry)
mock_connection.websocket.send_text.assert_called()
call_args = mock_connection.websocket.send_text.call_args[0][0]
sent_data = json.loads(call_args)
assert sent_data["type"] == "log_entry"
assert sent_data["data"]["level"] == "ERROR"
class TestWebSocketIntegration:
"""Integration tests for WebSocket functionality."""
@pytest.fixture
def ws_manager(self):
"""Create WebSocketManager for integration tests."""
return WebSocketManager()
@pytest.mark.asyncio
async def test_end_to_end_news_flow(self, ws_manager):
"""Test complete news flow from subscription to delivery."""
# Setup mock connections
websocket1 = AsyncMock()
websocket2 = AsyncMock()
# Connect clients
await ws_manager.connect(websocket1, "client1")
await ws_manager.connect(websocket2, "client2")
# Client1 subscribes to Samsung news
subscription_msg = {
"type": "subscribe",
"topics": ["삼성전자"],
"filters": {"sentiment": ["positive"]}
}
# Simulate news update
news_update = {
"type": "news_update",
"data": {
"id": "news_123",
"title": "삼성전자 실적 호조",
"entity": "삼성전자",
"sentiment": {"score": 0.8, "label": "positive"}
}
}
await ws_manager.broadcast(news_update)
# Both clients should receive the update
websocket1.send_text.assert_called()
websocket2.send_text.assert_called()
@pytest.mark.asyncio
async def test_concurrent_connections_performance(self, ws_manager):
"""Test performance with many concurrent connections."""
# Create 100 concurrent connections
connections = []
for i in range(100):
websocket = AsyncMock()
connection_id = f"perf_test_{i}"
await ws_manager.connect(websocket, connection_id)
connections.append((websocket, connection_id))
# Broadcast message to all
start_time = asyncio.get_event_loop().time()
broadcast_msg = {
"type": "performance_test",
"data": "High volume message"
}
await ws_manager.broadcast(broadcast_msg)
end_time = asyncio.get_event_loop().time()
# Should complete within reasonable time (< 1 second)
assert (end_time - start_time) < 1.0
# All connections should receive message
for websocket, _ in connections:
websocket.send_text.assert_called_once()
@pytest.mark.asyncio
async def test_connection_recovery(self, ws_manager):
"""Test connection recovery after network issues."""
websocket = AsyncMock()
connection_id = "recovery_test"
await ws_manager.connect(websocket, connection_id)
# Simulate connection error
websocket.send_text.side_effect = ConnectionClosed(None, None)
# Try to send message (should fail and trigger cleanup)
with pytest.raises(WebSocketError):
await ws_manager.send_personal_message(
{"test": "message"}, connection_id
)
# Connection should be removed from active connections
assert connection_id not in ws_manager.active_connections
@pytest.mark.asyncio
async def test_message_ordering(self, ws_manager):
"""Test that messages are delivered in correct order."""
websocket = AsyncMock()
connection_id = "ordering_test"
await ws_manager.connect(websocket, connection_id)
# Send multiple messages rapidly
messages = []
for i in range(10):
msg = {"sequence": i, "data": f"message_{i}"}
messages.append(msg)
await ws_manager.send_personal_message(msg, connection_id)
# Verify messages were sent in order
assert websocket.send_text.call_count == 10
call_args_list = websocket.send_text.call_args_list
for i, call_args in enumerate(call_args_list):
sent_message = json.loads(call_args[0][0])
assert sent_message["sequence"] == i
@pytest.mark.asyncio
async def test_memory_usage_with_long_running_connections(self, ws_manager):
"""Test memory usage with long-running connections."""
# Create connections with message history
for i in range(50):
websocket = AsyncMock()
connection_id = f"long_running_{i}"
metadata = {
"connected_at": datetime.now(timezone.utc).isoformat(),
"message_count": 0
}
await ws_manager.connect(websocket, connection_id, metadata)
# Send many messages to build up history
for j in range(100):
msg = {"message_id": f"{i}_{j}", "data": "test"}
await ws_manager.send_personal_message(msg, connection_id)
# Memory usage should remain reasonable
# (This is more of a monitoring test in real scenarios)
assert len(ws_manager.active_connections) == 50
@pytest.mark.asyncio
async def test_websocket_authentication(self, ws_manager):
"""Test WebSocket authentication and authorization."""
websocket = AsyncMock()
connection_id = "auth_test"
# Test with valid auth token
auth_metadata = {
"auth_token": "valid_token_123",
"user_id": "user_456",
"permissions": ["read_news", "receive_alerts"]
}
with patch('src.server.websocket_server.verify_auth_token') as mock_verify:
mock_verify.return_value = True
await ws_manager.connect(websocket, connection_id, auth_metadata)
assert connection_id in ws_manager.active_connections
mock_verify.assert_called_once_with("valid_token_123")
@pytest.mark.asyncio
async def test_rate_limiting_across_connections(self, ws_manager):
"""Test rate limiting across multiple connections from same client."""
# Create multiple connections from same IP
websockets = []
for i in range(5):
websocket = AsyncMock()
websocket.client = Mock()
websocket.client.host = "127.0.0.1" # Same IP
connection_id = f"rate_limit_{i}"
await ws_manager.connect(websocket, connection_id)
websockets.append((websocket, connection_id))
# Send many messages rapidly
for i in range(100):
for websocket, connection_id in websockets:
try:
await ws_manager.send_personal_message(
{"rapid_message": i}, connection_id
)
except WebSocketError:
# Rate limiting may trigger errors
pass
# Rate limiting should prevent excessive messages
total_calls = sum(ws.send_text.call_count for ws, _ in websockets)
assert total_calls < 500 # Should be limited