"""Tests for PluginHub WebSocket API key authentication gate."""
import asyncio
from types import SimpleNamespace
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from core.config import config
from core.constants import API_KEY_HEADER
from services.api_key_service import ApiKeyService, ValidationResult
from transport.plugin_hub import PluginHub
from transport.plugin_registry import PluginRegistry
@pytest.fixture(autouse=True)
def _reset_api_key_singleton():
ApiKeyService._instance = None
yield
ApiKeyService._instance = None
@pytest.fixture(autouse=True)
def _reset_plugin_hub():
"""Ensure PluginHub class-level state doesn't leak between tests."""
old_registry = PluginHub._registry
old_connections = PluginHub._connections.copy()
old_pending = PluginHub._pending.copy()
old_lock = PluginHub._lock
old_loop = PluginHub._loop
yield
PluginHub._registry = old_registry
PluginHub._connections = old_connections
PluginHub._pending = old_pending
PluginHub._lock = old_lock
PluginHub._loop = old_loop
def _make_mock_websocket(headers=None, state_attrs=None):
"""Create a mock WebSocket with configurable headers and state."""
ws = AsyncMock()
ws.headers = headers or {}
ws.state = SimpleNamespace(**(state_attrs or {}))
ws.accept = AsyncMock()
ws.close = AsyncMock()
ws.send_json = AsyncMock()
return ws
def _make_hub():
"""Create a PluginHub instance with a minimal ASGI scope."""
scope = {"type": "websocket"}
return PluginHub(scope, receive=AsyncMock(), send=AsyncMock())
def _init_api_key_service(validate_result=None):
"""Initialize ApiKeyService with a mocked validate method."""
svc = ApiKeyService(validation_url="https://auth.example.com/validate")
if validate_result is not None:
svc.validate = AsyncMock(return_value=validate_result)
return svc
class TestWebSocketAuthGate:
@pytest.mark.asyncio
async def test_no_api_key_remote_hosted_rejected(self, monkeypatch):
"""WebSocket without API key in remote-hosted mode -> close 4401."""
monkeypatch.setattr(config, "http_remote_hosted", True)
_init_api_key_service(ValidationResult(valid=True, user_id="u1"))
ws = _make_mock_websocket(headers={}) # No X-API-Key header
hub = _make_hub()
await hub.on_connect(ws)
ws.close.assert_called_once_with(code=4401, reason="API key required")
ws.accept.assert_not_called()
@pytest.mark.asyncio
async def test_invalid_api_key_rejected(self, monkeypatch):
"""WebSocket with invalid API key -> close 4403."""
monkeypatch.setattr(config, "http_remote_hosted", True)
_init_api_key_service(ValidationResult(
valid=False, error="Invalid API key"))
ws = _make_mock_websocket(headers={API_KEY_HEADER: "sk-bad-key"})
hub = _make_hub()
await hub.on_connect(ws)
ws.close.assert_called_once_with(code=4403, reason="Invalid API key")
ws.accept.assert_not_called()
@pytest.mark.asyncio
async def test_valid_api_key_accepted(self, monkeypatch):
"""WebSocket with valid API key -> accepted, user_id stored in state."""
monkeypatch.setattr(config, "http_remote_hosted", True)
_init_api_key_service(
ValidationResult(valid=True, user_id="user-42",
metadata={"plan": "pro"})
)
ws = _make_mock_websocket(headers={API_KEY_HEADER: "sk-valid-key"})
hub = _make_hub()
await hub.on_connect(ws)
ws.accept.assert_called_once()
ws.close.assert_not_called()
assert ws.state.user_id == "user-42"
assert ws.state.api_key_metadata == {"plan": "pro"}
# Should have sent welcome message
ws.send_json.assert_called_once()
@pytest.mark.asyncio
async def test_auth_service_unavailable_close_1013(self, monkeypatch):
"""Auth service error with 'unavailable' -> close 1013 (try again later)."""
monkeypatch.setattr(config, "http_remote_hosted", True)
_init_api_key_service(
ValidationResult(
valid=False, error="Auth service unavailable", cacheable=False)
)
ws = _make_mock_websocket(headers={API_KEY_HEADER: "sk-some-key"})
hub = _make_hub()
await hub.on_connect(ws)
ws.close.assert_called_once_with(code=1013, reason="Try again later")
ws.accept.assert_not_called()
@pytest.mark.asyncio
async def test_not_remote_hosted_accepts_without_key(self, monkeypatch):
"""When not remote-hosted, WebSocket accepted without API key."""
monkeypatch.setattr(config, "http_remote_hosted", False)
ws = _make_mock_websocket(headers={})
hub = _make_hub()
await hub.on_connect(ws)
ws.accept.assert_called_once()
ws.close.assert_not_called()
class TestUserIdFlowsToRegistration:
@pytest.mark.asyncio
async def test_user_id_passed_to_registry_on_register(self, monkeypatch):
"""After valid auth, the register message should pass user_id to registry."""
monkeypatch.setattr(config, "http_remote_hosted", True)
_init_api_key_service(
ValidationResult(valid=True, user_id="user-99")
)
registry = PluginRegistry()
loop = asyncio.get_running_loop()
PluginHub.configure(registry, loop)
# Simulate full flow: connect, then register
ws = _make_mock_websocket(headers={API_KEY_HEADER: "sk-valid-key"})
hub = _make_hub()
await hub.on_connect(ws)
assert ws.state.user_id == "user-99"
# Simulate register message
register_data = {
"type": "register",
"project_name": "TestProject",
"project_hash": "abc123",
"unity_version": "2022.3",
}
await hub.on_receive(ws, register_data)
# Verify registry has the user_id
sessions = await registry.list_sessions(user_id="user-99")
assert len(sessions) == 1
session = next(iter(sessions.values()))
assert session.user_id == "user-99"
assert session.project_name == "TestProject"
assert session.project_hash == "abc123"