"""Tests for websocket manager."""
import asyncio
import json
import uuid
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from preloop.services.websocket_manager import (
WebSocketManager,
nats_consumer,
persist_execution_log,
)
pytestmark = pytest.mark.asyncio
class TestPersistExecutionLog:
"""Test persist_execution_log function."""
@patch("preloop.services.websocket_manager.get_db")
async def test_persist_execution_log_success(self, mock_get_db):
"""Test persisting execution log successfully."""
execution_id = "exec_123"
log_data = {"message": "Step completed", "level": "INFO"}
# Mock database session
mock_db = MagicMock()
mock_get_db.return_value = iter([mock_db])
await persist_execution_log(execution_id, log_data)
# Verify database operations
assert mock_db.execute.called
assert mock_db.commit.called
assert mock_db.close.called
@patch("preloop.services.websocket_manager.get_db")
async def test_persist_execution_log_with_complex_data(self, mock_get_db):
"""Test persisting execution log with complex data."""
execution_id = "exec_456"
log_data = {
"message": "Complex step",
"level": "DEBUG",
"metadata": {"key1": "value1", "key2": [1, 2, 3]},
}
mock_db = MagicMock()
mock_get_db.return_value = iter([mock_db])
await persist_execution_log(execution_id, log_data)
assert mock_db.execute.called
# Verify JSON serialization happened
call_args = mock_db.execute.call_args
params = call_args[0][1]
assert params["execution_id"] == execution_id
# The log_entry should be a JSON string
assert isinstance(params["log_entry"], str)
@patch("preloop.services.websocket_manager.get_db")
@patch("preloop.services.websocket_manager.logger")
async def test_persist_execution_log_database_error(self, mock_logger, mock_get_db):
"""Test handling database error when persisting log."""
execution_id = "exec_789"
log_data = {"message": "Test"}
# Mock database to raise exception
mock_db = MagicMock()
mock_db.execute.side_effect = Exception("Database error")
mock_get_db.return_value = iter([mock_db])
await persist_execution_log(execution_id, log_data)
# Verify error was logged
assert mock_logger.error.called
# Database should still close
assert mock_db.close.called
@patch("preloop.services.websocket_manager.get_db")
async def test_persist_execution_log_closes_db_on_success(self, mock_get_db):
"""Test that database is closed even on success."""
mock_db = MagicMock()
mock_get_db.return_value = iter([mock_db])
await persist_execution_log("exec_id", {"message": "test"})
assert mock_db.close.called
class TestWebSocketManager:
"""Test WebSocketManager class."""
@pytest.fixture
def manager(self):
"""Create a WebSocketManager instance."""
return WebSocketManager()
@pytest.fixture
def mock_websocket(self):
"""Create a mock WebSocket."""
ws = AsyncMock()
ws.accept = AsyncMock()
ws.send_text = AsyncMock()
return ws
async def test_connect_websocket(self, manager, mock_websocket):
"""Test connecting a WebSocket."""
connection_id = await manager.connect(mock_websocket)
# Verify connection was accepted
assert mock_websocket.accept.called
# Verify connection ID was returned
assert isinstance(connection_id, str)
# Verify connection was stored
assert connection_id in manager.active_connections
assert manager.active_connections[connection_id] == mock_websocket
async def test_connect_multiple_websockets(self, manager, mock_websocket):
"""Test connecting multiple WebSockets."""
mock_ws1 = AsyncMock()
mock_ws1.accept = AsyncMock()
mock_ws2 = AsyncMock()
mock_ws2.accept = AsyncMock()
connection_id1 = await manager.connect(mock_ws1)
connection_id2 = await manager.connect(mock_ws2)
# Verify both connections are stored
assert len(manager.active_connections) == 2
assert connection_id1 != connection_id2
assert manager.active_connections[connection_id1] == mock_ws1
assert manager.active_connections[connection_id2] == mock_ws2
async def test_disconnect_websocket(self, manager, mock_websocket):
"""Test disconnecting a WebSocket."""
connection_id = await manager.connect(mock_websocket)
assert connection_id in manager.active_connections
manager.disconnect(connection_id)
# Verify connection was removed
assert connection_id not in manager.active_connections
async def test_disconnect_nonexistent_connection(self, manager):
"""Test disconnecting a connection that doesn't exist."""
fake_id = str(uuid.uuid4())
# Should not raise an error
manager.disconnect(fake_id)
assert fake_id not in manager.active_connections
async def test_broadcast_to_single_client(self, manager, mock_websocket):
"""Test broadcasting message to single client."""
connection_id = await manager.connect(mock_websocket)
message = "Test message"
await manager.broadcast(message)
# Verify message was sent
mock_websocket.send_text.assert_called_once_with(message)
async def test_broadcast_to_multiple_clients(self, manager):
"""Test broadcasting message to multiple clients."""
# Connect multiple websockets
mock_ws1 = AsyncMock()
mock_ws1.accept = AsyncMock()
mock_ws1.send_text = AsyncMock()
mock_ws2 = AsyncMock()
mock_ws2.accept = AsyncMock()
mock_ws2.send_text = AsyncMock()
await manager.connect(mock_ws1)
await manager.connect(mock_ws2)
message = "Broadcast to all"
await manager.broadcast(message)
# Verify message was sent to both
mock_ws1.send_text.assert_called_once_with(message)
mock_ws2.send_text.assert_called_once_with(message)
@patch("preloop.services.websocket_manager.logger")
async def test_broadcast_with_failed_connection(self, mock_logger, manager):
"""Test broadcasting when one connection fails."""
# Connect two websockets, one that will fail
mock_ws1 = AsyncMock()
mock_ws1.accept = AsyncMock()
mock_ws1.send_text = AsyncMock(side_effect=Exception("Connection closed"))
mock_ws2 = AsyncMock()
mock_ws2.accept = AsyncMock()
mock_ws2.send_text = AsyncMock()
await manager.connect(mock_ws1)
await manager.connect(mock_ws2)
await manager.broadcast("Test message")
# Verify warning was logged for failed connection
assert mock_logger.warning.called
# Verify second connection still received message
assert mock_ws2.send_text.called
async def test_broadcast_json(self, manager, mock_websocket):
"""Test broadcasting JSON data."""
await manager.connect(mock_websocket)
data = {"type": "update", "value": 42}
await manager.broadcast_json(data)
# Verify JSON was sent as string
mock_websocket.send_text.assert_called_once()
sent_message = mock_websocket.send_text.call_args[0][0]
assert isinstance(sent_message, str)
# Verify it's valid JSON
parsed = json.loads(sent_message)
assert parsed == data
async def test_broadcast_json_with_complex_data(self, manager, mock_websocket):
"""Test broadcasting complex JSON data."""
await manager.connect(mock_websocket)
data = {
"type": "execution_update",
"execution_id": "exec_123",
"status": "running",
"logs": [{"message": "Step 1 complete"}, {"message": "Step 2 started"}],
}
await manager.broadcast_json(data)
mock_websocket.send_text.assert_called_once()
sent_message = mock_websocket.send_text.call_args[0][0]
parsed = json.loads(sent_message)
assert parsed == data
async def test_manager_initially_empty(self):
"""Test that manager starts with no connections."""
manager = WebSocketManager()
assert len(manager.active_connections) == 0
async def test_connection_count_after_operations(self, manager):
"""Test connection count after various operations."""
mock_ws1 = AsyncMock()
mock_ws1.accept = AsyncMock()
mock_ws2 = AsyncMock()
mock_ws2.accept = AsyncMock()
# Initially empty
assert len(manager.active_connections) == 0
# After first connection
conn1_id = await manager.connect(mock_ws1)
assert len(manager.active_connections) == 1
# After second connection
conn2_id = await manager.connect(mock_ws2)
assert len(manager.active_connections) == 2
# After disconnect
manager.disconnect(conn1_id)
assert len(manager.active_connections) == 1
# After disconnect all
manager.disconnect(conn2_id)
assert len(manager.active_connections) == 0
class TestNatsConsumer:
"""Test nats_consumer function."""
@patch("preloop.services.websocket_manager.get_task_publisher")
@patch("preloop.services.websocket_manager.persist_execution_log")
async def test_nats_consumer_processes_message(
self, mock_persist, mock_get_publisher
):
"""Test that NATS consumer processes messages correctly."""
manager = WebSocketManager()
# Mock NATS client and publisher
mock_nc = MagicMock()
mock_nc.is_connected = True
mock_publisher = MagicMock()
mock_publisher.nc = mock_nc
mock_get_publisher.return_value = mock_publisher
# Mock subscribe to capture the message handler
mock_sub = AsyncMock()
captured_handler = None
async def mock_subscribe(subject, cb):
nonlocal captured_handler
captured_handler = cb
return mock_sub
mock_nc.subscribe = mock_subscribe
# Start consumer in background (will run briefly)
consumer_task = asyncio.create_task(nats_consumer(manager))
# Give it time to subscribe
await asyncio.sleep(0.1)
# Verify handler was captured
assert captured_handler is not None
# Test the message handler
test_message = {
"execution_id": "exec_123",
"message": "Test update",
"level": "INFO",
}
mock_msg = MagicMock()
mock_msg.data.decode.return_value = json.dumps(test_message)
# Call the handler
await captured_handler(mock_msg)
# Verify persist_execution_log was called
assert mock_persist.called
call_args = mock_persist.call_args
assert call_args[0][0] == "exec_123"
assert call_args[0][1] == test_message
# Clean up
consumer_task.cancel()
try:
await consumer_task
except asyncio.CancelledError:
pass
@patch("preloop.services.websocket_manager.get_task_publisher")
@patch("preloop.services.websocket_manager.logger")
async def test_nats_consumer_no_connection(self, mock_logger, mock_get_publisher):
"""Test NATS consumer when NATS is not connected."""
manager = WebSocketManager()
# Mock NATS client as not connected
mock_nc = MagicMock()
mock_nc.is_connected = False
mock_publisher = MagicMock()
mock_publisher.nc = mock_nc
mock_get_publisher.return_value = mock_publisher
await nats_consumer(manager)
# Verify error was logged
assert mock_logger.error.called
@patch("preloop.services.websocket_manager.get_task_publisher")
async def test_nats_consumer_handles_invalid_json(self, mock_get_publisher):
"""Test NATS consumer handles invalid JSON messages."""
manager = WebSocketManager()
# Mock NATS client
mock_nc = MagicMock()
mock_nc.is_connected = True
mock_publisher = MagicMock()
mock_publisher.nc = mock_nc
mock_get_publisher.return_value = mock_publisher
captured_handler = None
async def mock_subscribe(subject, cb):
nonlocal captured_handler
captured_handler = cb
return AsyncMock()
mock_nc.subscribe = mock_subscribe
# Start consumer briefly
consumer_task = asyncio.create_task(nats_consumer(manager))
await asyncio.sleep(0.1)
# Test with invalid JSON
mock_msg = MagicMock()
mock_msg.data.decode.return_value = "invalid json {{"
# Should handle the error gracefully
with patch("preloop.services.websocket_manager.logger") as mock_logger:
await captured_handler(mock_msg)
assert mock_logger.warning.called
# Clean up
consumer_task.cancel()
try:
await consumer_task
except asyncio.CancelledError:
pass
@patch("preloop.services.websocket_manager.get_task_publisher")
async def test_nats_consumer_message_without_execution_id(self, mock_get_publisher):
"""Test NATS consumer handles messages without execution_id."""
manager = WebSocketManager()
# Add a mock websocket to receive broadcasts
mock_ws = AsyncMock()
mock_ws.accept = AsyncMock()
mock_ws.send_text = AsyncMock()
await manager.connect(mock_ws)
# Mock NATS client
mock_nc = MagicMock()
mock_nc.is_connected = True
mock_publisher = MagicMock()
mock_publisher.nc = mock_nc
mock_get_publisher.return_value = mock_publisher
captured_handler = None
async def mock_subscribe(subject, cb):
nonlocal captured_handler
captured_handler = cb
return AsyncMock()
mock_nc.subscribe = mock_subscribe
# Start consumer
consumer_task = asyncio.create_task(nats_consumer(manager))
await asyncio.sleep(0.1)
# Message without execution_id
test_message = {"message": "No execution ID", "level": "INFO"}
mock_msg = MagicMock()
mock_msg.data.decode.return_value = json.dumps(test_message)
with patch(
"preloop.services.websocket_manager.persist_execution_log"
) as mock_persist:
await captured_handler(mock_msg)
# persist should not be called (no execution_id)
assert not mock_persist.called
# But broadcast should still happen
assert mock_ws.send_text.called
# Clean up
consumer_task.cancel()
try:
await consumer_task
except asyncio.CancelledError:
pass