# -*- coding: utf-8 -*-
# -*- coding: utf-8 -*-
"""Location: ./tests/unit/mcpgateway/test_main.py
Copyright 2025
SPDX-License-Identifier: Apache-2.0
Authors: Mihai Criveti
Comprehensive tests for the main API endpoints with full coverage.
"""
# Standard
import asyncio
from copy import deepcopy
import datetime
import json
import os
from unittest.mock import ANY, AsyncMock, MagicMock, patch
# Third-Party
from fastapi import HTTPException
from fastapi.testclient import TestClient
import jwt
from pydantic import BaseModel, ValidationError
import pytest
import sqlalchemy as sa
from starlette.requests import Request
from starlette.websockets import WebSocketDisconnect
# First-Party
from mcpgateway.config import settings
from mcpgateway.common.models import InitializeResult, ResourceContent, ServerCapabilities
import mcpgateway.db as db_mod
from mcpgateway.schemas import (
GatewayRead,
PromptRead,
ResourceRead,
ServerRead,
ToolRead,
)
# --------------------------------------------------------------------------- #
# Constants #
# --------------------------------------------------------------------------- #
PROTOCOL_VERSION = os.getenv("PROTOCOL_VERSION", "2025-03-26")
# Mock data templates with complete field structures
MOCK_METRICS = {
"total_executions": 10,
"successful_executions": 8,
"failed_executions": 2,
"failure_rate": 0.2,
"min_response_time": 0.1,
"max_response_time": 2.5,
"avg_response_time": 1.2,
"last_execution_time": "2023-01-01T00:00:00+00:00",
}
MOCK_SERVER_READ = {
"id": "1",
"name": "test_server",
"description": "A test server",
"icon": "server-icon",
"created_at": "2023-01-01T00:00:00+00:00",
"updated_at": "2023-01-01T00:00:00+00:00",
"enabled": True,
"associated_tools": ["101"],
"associated_resources": ["201"],
"associated_prompts": ["301"],
"metrics": MOCK_METRICS,
}
MOCK_TOOL_READ = {
"id": "1",
"name": "test_tool",
"originalName": "test_tool",
"customName": "test_tool",
"url": "http://example.com/tools/test",
"description": "A test tool",
"requestType": "POST",
"integrationType": "MCP",
"headers": {"Content-Type": "application/json"},
"inputSchema": {"type": "object", "properties": {"param": {"type": "string"}}},
"annotations": {},
"jsonpathFilter": None,
"auth": {"auth_type": "none"},
"createdAt": "2023-01-01T00:00:00+00:00",
"updatedAt": "2023-01-01T00:00:00+00:00",
"enabled": True,
"reachable": True,
"gatewayId": "gateway-1",
"executionCount": 5,
"metrics": MOCK_METRICS,
"gatewaySlug": "gateway-1",
"customNameSlug": "test-tool",
}
# camelCase → snake_case key map for the fields that differ
_TOOL_KEY_MAP = {
"originalName": "original_name",
"requestType": "request_type",
"integrationType": "integration_type",
"inputSchema": "input_schema",
"jsonpathFilter": "jsonpath_filter",
"createdAt": "created_at",
"updatedAt": "updated_at",
"gatewayId": "gateway_id",
"gatewaySlug": "gateway_slug",
"originalNameSlug": "original_name_slug",
"customNameSlug": "custom_name_slug",
}
def camel_to_snake_tool(d: dict) -> dict:
out = deepcopy(d)
# id must be str
out["id"] = str(out["id"])
for camel, snake in _TOOL_KEY_MAP.items():
if camel in out:
out[snake] = out.pop(camel)
return out
MOCK_TOOL_READ_SNAKE = camel_to_snake_tool(MOCK_TOOL_READ)
MOCK_RESOURCE_READ = {
"id": "39334ce0ed2644d79ede8913a66930c9",
"uri": "test/resource",
"name": "Test Resource",
"description": "A test resource",
"mime_type": "text/plain",
"size": 12,
"created_at": "2023-01-01T00:00:00+00:00",
"updated_at": "2023-01-01T00:00:00+00:00",
"enabled": True,
"metrics": MOCK_METRICS,
}
MOCK_PROMPT_READ = {
"id": "ca627760127d409080fdefc309147e08",
"name": "test_prompt",
"original_name": "test_prompt",
"custom_name": "test_prompt",
"custom_name_slug": "test-prompt",
"display_name": "Test Prompt",
"description": "A test prompt",
"template": "Hello {name}",
"arguments": [],
"created_at": "2023-01-01T00:00:00+00:00",
"updated_at": "2023-01-01T00:00:00+00:00",
"enabled": True,
"metrics": MOCK_METRICS,
}
MOCK_GATEWAY_READ = {
"id": "1",
"name": "test_gateway",
"url": "http://example.com",
"description": "A test gateway",
"transport": "SSE",
"created_at": "2023-01-01T00:00:00+00:00",
"updated_at": "2023-01-01T00:00:00+00:00",
"enabled": True,
"reachable": True,
"auth_type": None,
}
MOCK_ROOT = {
"uri": "/test",
"name": "Test Root",
}
class _ValidationModel(BaseModel):
value: int
def _make_validation_error() -> ValidationError:
try:
_ValidationModel(value="bad")
except ValidationError as exc:
return exc
raise AssertionError("Expected validation error")
VALIDATION_ERROR = _make_validation_error()
INTEGRITY_ERROR = sa.exc.IntegrityError("stmt", {}, Exception("orig"))
def _make_a2a_agent_read(**overrides):
now = datetime.datetime.now(datetime.timezone.utc)
data = {
"id": "agent-1",
"name": "agent-1",
"slug": "agent-1",
"description": "Test agent",
"endpoint_url": "https://example.com/agent",
"agent_type": "generic",
"protocol_version": PROTOCOL_VERSION,
"capabilities": {},
"config": {},
"enabled": True,
"reachable": True,
"created_at": now,
"updated_at": now,
"last_interaction": None,
"tags": [],
"metrics": None,
"passthrough_headers": None,
"auth_type": None,
"auth_value": None,
"oauth_config": None,
}
data.update(overrides)
return data
def _tool_create_payload():
return {
"tool": {"name": "test_tool", "url": "http://example.com", "description": "A test tool"},
"team_id": None,
"visibility": "private",
}
def _server_create_payload():
return {"server": {"name": "test-server"}, "team_id": None, "visibility": "public"}
def _a2a_create_payload():
return {
"agent": {"name": "agent-1", "endpoint_url": "https://example.com/agent", "agent_type": "generic"},
"team_id": None,
"visibility": "public",
}
def _make_request(path: str = "/", headers: dict | None = None) -> Request:
header_list = []
for key, value in (headers or {}).items():
header_list.append((key.lower().encode(), str(value).encode()))
scope = {
"type": "http",
"scheme": "http",
"server": ("testserver", 80),
"path": path,
"headers": header_list,
}
return Request(scope)
# --------------------------------------------------------------------------- #
# Fixtures #
# --------------------------------------------------------------------------- #
@pytest.fixture(autouse=True)
def reset_db(app_with_temp_db):
"""Clear the temp DB between tests when using the module-scoped app."""
engine = db_mod.engine
if engine is None:
yield
return
with engine.begin() as conn:
if engine.dialect.name == "sqlite":
conn.exec_driver_sql("PRAGMA foreign_keys=OFF")
for table in reversed(db_mod.Base.metadata.sorted_tables):
conn.execute(table.delete())
if engine.dialect.name == "sqlite":
try:
conn.exec_driver_sql("DELETE FROM sqlite_sequence")
except sa.exc.DatabaseError:
pass
conn.exec_driver_sql("PRAGMA foreign_keys=ON")
yield
@pytest.fixture
def test_client(app_with_temp_db):
"""
Return a TestClient whose dependency graph bypasses real authentication.
Every FastAPI dependency on ``require_auth`` is overridden to return the
static user name ``"test_user"``. This keeps the protected endpoints
accessible without needing to furnish JWTs in every request.
Also overrides RBAC dependencies to bypass permission checks for tests.
"""
# First-Party
# Mock user object for RBAC system
from mcpgateway.db import EmailUser
from mcpgateway.middleware.rbac import get_current_user_with_permissions
from mcpgateway.utils.verify_credentials import require_auth
mock_user = EmailUser(
email="test_user@example.com",
full_name="Test User",
is_admin=True, # Give admin privileges for tests
is_active=True,
auth_provider="test",
)
# Override old auth system
app_with_temp_db.dependency_overrides[require_auth] = lambda: "test_user"
# Patch the auth function used by DocsAuthMiddleware
# Standard
from unittest.mock import MagicMock, patch
# Third-Party
from fastapi import HTTPException, status
# First-Party
# Mock security_logger to prevent database access
mock_sec_logger = MagicMock()
mock_sec_logger.log_authentication_attempt = MagicMock(return_value=None)
mock_sec_logger.log_security_event = MagicMock(return_value=None)
sec_patcher = patch("mcpgateway.middleware.auth_middleware.security_logger", mock_sec_logger)
sec_patcher.start()
# Create a mock that validates JWT tokens properly
async def mock_require_auth_override(auth_header=None, jwt_token=None):
# Third-Party
import jwt as jwt_lib
# First-Party
from mcpgateway.config import settings
# Try to get token from auth_header or jwt_token
token = jwt_token
if not token and auth_header and auth_header.startswith("Bearer "):
token = auth_header[7:] # Remove "Bearer " prefix
if not token:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Authorization required")
try:
# Always coerce key to str in case SecretStr leaks through
key = settings.jwt_secret_key
# Only call get_secret_value if it exists and is callable (not a string)
if hasattr(key, "get_secret_value") and callable(getattr(key, "get_secret_value", None)):
key = key.get_secret_value()
payload = jwt_lib.decode(token, key, algorithms=[settings.jwt_algorithm], options={"verify_aud": False})
username = payload.get("sub")
if username:
return username
else:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid token")
except jwt_lib.ExpiredSignatureError:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Token expired")
except jwt_lib.InvalidTokenError:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid token")
patcher = patch("mcpgateway.main.require_docs_auth_override", mock_require_auth_override)
patcher.start()
# Override the core auth function used by RBAC system
# First-Party
from mcpgateway.auth import get_current_user
app_with_temp_db.dependency_overrides[get_current_user] = lambda credentials=None, db=None: mock_user
# Override get_current_user_with_permissions for RBAC system
def mock_get_current_user_with_permissions(request=None, credentials=None, jwt_token=None):
return {"email": "test_user@example.com", "full_name": "Test User", "is_admin": True, "ip_address": "127.0.0.1", "user_agent": "test"}
app_with_temp_db.dependency_overrides[get_current_user_with_permissions] = mock_get_current_user_with_permissions
# Mock the permission service to always return True for tests
# First-Party
from mcpgateway.services.permission_service import PermissionService
# Store original method
if not hasattr(PermissionService, "_original_check_permission"):
PermissionService._original_check_permission = PermissionService.check_permission
# Mock with correct async signature matching the real method
async def mock_check_permission(self, user_email: str, permission: str, resource_type=None, resource_id=None, team_id=None, ip_address=None, user_agent=None) -> bool:
return True
PermissionService.check_permission = mock_check_permission
client = TestClient(app_with_temp_db)
yield client
# Clean up overrides and restore original methods
app_with_temp_db.dependency_overrides.pop(require_auth, None)
app_with_temp_db.dependency_overrides.pop(get_current_user, None)
app_with_temp_db.dependency_overrides.pop(get_current_user_with_permissions, None)
patcher.stop() # Stop the require_auth_override patch
sec_patcher.stop() # Stop the security_logger patch
if hasattr(PermissionService, "_original_check_permission"):
PermissionService.check_permission = PermissionService._original_check_permission
@pytest.fixture
def mock_jwt_token():
"""Create a valid JWT token for testing."""
payload = {"sub": "test_user@example.com", "email": "test_user@example.com", "iss": "mcpgateway", "aud": "mcpgateway-api"}
secret = settings.jwt_secret_key
if hasattr(secret, "get_secret_value") and callable(getattr(secret, "get_secret_value", None)):
secret = secret.get_secret_value()
algorithm = settings.jwt_algorithm
return jwt.encode(payload, secret, algorithm=algorithm)
@pytest.fixture
def auth_headers(mock_jwt_token):
"""Default auth header (still accepted by the overridden dependency)."""
return {"Authorization": f"Bearer {mock_jwt_token}"}
# ========================================================================== #
# TEST CLASSES #
# ========================================================================== #
# ----------------------------------------------------- #
# Health & Infrastructure Tests #
# ----------------------------------------------------- #
class TestHealthAndInfrastructure:
"""Tests for health checks, readiness, and basic infrastructure endpoints."""
def test_health_check(self, test_client):
"""Test the basic health check endpoint."""
response = test_client.get("/health")
assert response.status_code == 200
assert response.json()["status"] == "healthy"
def test_ready_check(self, test_client):
"""Test the readiness check endpoint."""
response = test_client.get("/ready")
assert response.status_code == 200
assert response.json()["status"] == "ready"
def test_health_check_db_error(self):
"""Test health check error path with rollback failure."""
# First-Party
from mcpgateway import main as mcpgateway_main
class DummySession:
def __init__(self):
self.invalidate_called = False
def execute(self, *_args, **_kwargs):
raise Exception("boom")
def commit(self):
pass
def rollback(self):
raise Exception("rollback failed")
def invalidate(self):
self.invalidate_called = True
def close(self):
pass
session = DummySession()
with patch("mcpgateway.main.SessionLocal", return_value=session):
response = mcpgateway_main.healthcheck()
assert response["status"] == "unhealthy"
assert session.invalidate_called is True
@pytest.mark.asyncio
async def test_ready_check_db_error(self):
"""Test readiness check error path with rollback failure."""
# First-Party
from mcpgateway import main as mcpgateway_main
class DummySession:
def __init__(self):
self.invalidate_called = False
def execute(self, *_args, **_kwargs):
raise Exception("boom")
def commit(self):
pass
def rollback(self):
raise Exception("rollback failed")
def invalidate(self):
self.invalidate_called = True
def close(self):
pass
session = DummySession()
with (
patch("mcpgateway.main.SessionLocal", return_value=session),
patch("mcpgateway.main.asyncio.to_thread", side_effect=lambda fn, *args, **kwargs: fn(*args, **kwargs)),
):
response = await mcpgateway_main.readiness_check()
assert response.status_code == 503
assert session.invalidate_called is True
def test_root_redirect(self, test_client):
"""Test that root path behavior depends on UI configuration."""
response = test_client.get("/", follow_redirects=False)
# Check if UI is enabled
if settings.mcpgateway_ui_enabled:
# When UI is enabled, should redirect to admin with trailing slash
assert response.status_code == 303
assert response.headers["location"] == f"{settings.app_root_path}/admin/"
else:
# When UI is disabled, should return API info
assert response.status_code == 200
data = response.json()
assert data["name"] == "MCP_Gateway"
assert data["ui_enabled"] is False
def test_static_files(self, test_client):
"""Test static file serving (when files don't exist)."""
with patch("os.path.exists", return_value=True), patch("builtins.open", MagicMock()):
response = test_client.get("/static/test.css")
assert response.status_code == 404 # route registered, file absent
# ----------------------------------------------------- #
# Protocol & MCP Core Tests #
# ----------------------------------------------------- #
class TestProtocolEndpoints:
"""Tests for MCP protocol operations: initialize, ping, notifications, etc."""
# @patch("mcpgateway.main.validate_request")
@patch("mcpgateway.main.session_registry.handle_initialize_logic")
def test_initialize_endpoint(self, mock_handle_initialize, test_client, auth_headers):
"""Test MCP protocol initialization."""
mock_capabilities = ServerCapabilities(
prompts={"listChanged": True},
resources={"subscribe": True, "listChanged": True},
tools={"listChanged": True},
logging={},
roots={"listChanged": True},
sampling={},
)
mock_result = InitializeResult(
protocolVersion=PROTOCOL_VERSION,
capabilities=mock_capabilities,
serverInfo={"name": "MCP Gateway", "version": "1.0.0"},
instructions="MCP Gateway providing federated tools, resources and prompts.",
)
mock_handle_initialize.return_value = mock_result
req = {
"protocol_version": PROTOCOL_VERSION,
"capabilities": {},
"client_info": {"name": "Test Client", "version": "1.0.0"},
}
response = test_client.post("/protocol/initialize", json=req, headers=auth_headers)
assert response.status_code == 200
body = response.json()
assert body["protocolVersion"] == PROTOCOL_VERSION
mock_handle_initialize.assert_called_once()
# @patch("mcpgateway.main.validate_request")
def test_ping_endpoint(self, test_client, auth_headers):
"""Test MCP ping endpoint."""
req = {"jsonrpc": "2.0", "method": "ping", "id": "test-id"}
response = test_client.post("/protocol/ping", json=req, headers=auth_headers)
assert response.status_code == 200
body = response.json()
assert body == {"jsonrpc": "2.0", "id": "test-id", "result": {}}
def test_ping_invalid_method(self, test_client, auth_headers):
"""Test ping endpoint with invalid method."""
req = {"jsonrpc": "2.0", "method": "invalid", "id": "test-id"}
response = test_client.post("/protocol/ping", json=req, headers=auth_headers)
# Implementation raises 5xx for unsupported method
assert response.status_code == 500
@patch("mcpgateway.main.logging_service.notify")
def test_handle_notification_initialized(self, mock_notify, test_client, auth_headers):
"""Test handling client initialized notification."""
req = {"method": "notifications/initialized"}
response = test_client.post("/protocol/notifications", json=req, headers=auth_headers)
assert response.status_code == 200
mock_notify.assert_called_once()
@patch("mcpgateway.main.logging_service.notify")
def test_handle_notification_cancelled(self, mock_notify, test_client, auth_headers):
"""Test handling request cancelled notification."""
req = {"method": "notifications/cancelled", "params": {"requestId": "123"}}
response = test_client.post("/protocol/notifications", json=req, headers=auth_headers)
assert response.status_code == 200
mock_notify.assert_called_once()
@patch("mcpgateway.main.logging_service.notify")
def test_handle_notification_message(self, mock_notify, test_client, auth_headers):
"""Test handling log message notification."""
req = {
"method": "notifications/message",
"params": {"data": "Test message", "level": "info", "logger": "test"},
}
response = test_client.post("/protocol/notifications", json=req, headers=auth_headers)
assert response.status_code == 200
mock_notify.assert_called_once()
@patch("mcpgateway.main.completion_service.handle_completion")
def test_handle_completion_endpoint(self, mock_completion, test_client, auth_headers):
"""Test completion handling endpoint."""
mock_completion.return_value = {"result": "completion_result"}
req = {"ref": {"type": "ref/prompt", "name": "test"}}
response = test_client.post("/protocol/completion/complete", json=req, headers=auth_headers)
assert response.status_code == 200
mock_completion.assert_called_once()
@patch("mcpgateway.main.sampling_handler.create_message")
def test_handle_sampling_endpoint(self, mock_sampling, test_client, auth_headers):
"""Test sampling message creation endpoint."""
mock_sampling.return_value = {"messageId": "123"}
req = {"messages": [{"role": "user", "content": {"type": "text", "text": "Hello"}}]}
response = test_client.post("/protocol/sampling/createMessage", json=req, headers=auth_headers)
assert response.status_code == 200
mock_sampling.assert_called_once()
# ----------------------------------------------------- #
# Server Management Tests #
# ----------------------------------------------------- #
class TestServerEndpoints:
@patch("mcpgateway.main.server_service.update_server")
def test_update_server_not_found(self, mock_update, test_client, auth_headers):
"""Test update_server returns 404 if server not found."""
# First-Party
from mcpgateway.services.server_service import ServerNotFoundError
mock_update.side_effect = ServerNotFoundError("Server not found")
req = {"description": "Updated description"}
response = test_client.put("/servers/999", json=req, headers=auth_headers)
assert response.status_code == 404
@patch("mcpgateway.main.server_service.register_server")
def test_create_server_validation_error(self, mock_create, test_client, auth_headers):
"""Test create_server returns 422 for missing required fields."""
mock_create.side_effect = None # Let validation error happen
req = {"description": "Missing name"}
response = test_client.post("/servers/", json=req, headers=auth_headers)
assert response.status_code == 422
@patch("mcpgateway.main.server_service.register_server")
def test_create_server_service_error(self, mock_create, test_client, auth_headers):
"""Test create_server returns 400 for service errors."""
# First-Party
from mcpgateway.services.server_service import ServerError
mock_create.side_effect = ServerError("Bad server")
response = test_client.post("/servers/", json=_server_create_payload(), headers=auth_headers)
assert response.status_code == 400
@pytest.mark.parametrize(
"exc,status_code",
[
(VALIDATION_ERROR, 422),
(INTEGRITY_ERROR, 409),
],
)
@patch("mcpgateway.main.server_service.register_server")
def test_create_server_validation_and_integrity_errors(self, mock_create, exc, status_code, test_client, auth_headers):
"""Test create_server returns correct status for validation/integrity errors."""
mock_create.side_effect = exc
response = test_client.post("/servers/", json=_server_create_payload(), headers=auth_headers)
assert response.status_code == status_code
"""Tests for virtual server management: CRUD operations, status toggles, etc."""
@patch("mcpgateway.main.server_service.list_servers")
def test_list_servers_endpoint(self, mock_list_servers, test_client, auth_headers):
"""Test listing all servers."""
mock_list_servers.return_value = ([ServerRead(**MOCK_SERVER_READ)], None)
response = test_client.get("/servers/", headers=auth_headers)
assert response.status_code == 200
data = response.json()
# Default response is a plain list (include_pagination=False by default)
assert isinstance(data, list)
assert len(data) == 1 and data[0]["name"] == "test_server"
mock_list_servers.assert_called_once()
@patch("mcpgateway.main.server_service.get_server")
def test_get_server_endpoint(self, mock_get, test_client, auth_headers):
"""Test retrieving a specific server."""
mock_get.return_value = ServerRead(**MOCK_SERVER_READ)
response = test_client.get("/servers/1", headers=auth_headers)
assert response.status_code == 200
assert response.json()["name"] == "test_server"
mock_get.assert_called_once()
@patch("mcpgateway.main.server_service.register_server")
def test_create_server_endpoint(self, mock_create, test_client, auth_headers):
"""Test creating a new server."""
mock_create.return_value = ServerRead(**MOCK_SERVER_READ)
req = {"server": {"name": "test_server", "description": "A test server"}, "team_id": None, "visibility": "private"}
response = test_client.post("/servers/", json=req, headers=auth_headers)
assert response.status_code == 201
mock_create.assert_called_once()
@patch("mcpgateway.main.server_service.update_server")
def test_update_server_endpoint(self, mock_update, test_client, auth_headers):
"""Test updating an existing server."""
mock_update.return_value = ServerRead(**MOCK_SERVER_READ)
req = {"description": "Updated description"}
response = test_client.put("/servers/1", json=req, headers=auth_headers)
assert response.status_code == 200
mock_update.assert_called_once()
@pytest.mark.parametrize(
"exc,status_code",
[
(VALIDATION_ERROR, 422),
(INTEGRITY_ERROR, 409),
],
)
@patch("mcpgateway.main.server_service.update_server")
def test_update_server_validation_and_integrity_errors(self, mock_update, exc, status_code, test_client, auth_headers):
"""Test update_server error branches for validation/integrity errors."""
mock_update.side_effect = exc
req = {"description": "Updated description"}
response = test_client.put("/servers/1", json=req, headers=auth_headers)
assert response.status_code == status_code
@patch("mcpgateway.main.server_service.set_server_state")
def test_set_server_state(self, mock_toggle, test_client, auth_headers):
"""Test setting server active/inactive state."""
updated_server = MOCK_SERVER_READ.copy()
updated_server["enabled"] = False
mock_toggle.return_value = ServerRead(**updated_server)
response = test_client.post("/servers/1/state?activate=false", headers=auth_headers)
assert response.status_code == 200
mock_toggle.assert_called_once()
@patch("mcpgateway.main.server_service.set_server_state")
def test_set_server_state_permission_error(self, mock_toggle, test_client, auth_headers):
"""Test server state change forbidden error."""
mock_toggle.side_effect = PermissionError("Forbidden")
response = test_client.post("/servers/1/state?activate=false", headers=auth_headers)
assert response.status_code == 403
@patch("mcpgateway.main.server_service.set_server_state")
def test_set_server_state_not_found(self, mock_toggle, test_client, auth_headers):
"""Test server state change not found error."""
# First-Party
from mcpgateway.services.server_service import ServerNotFoundError
mock_toggle.side_effect = ServerNotFoundError("Missing")
response = test_client.post("/servers/1/state?activate=false", headers=auth_headers)
assert response.status_code == 404
@patch("mcpgateway.main.server_service.delete_server")
@patch("mcpgateway.main.server_service.get_server")
def test_delete_server_endpoint(self, mock_get, mock_delete, test_client, auth_headers):
"""Test permanently deleting a server."""
mock_get.return_value = ServerRead(**MOCK_SERVER_READ)
mock_delete.return_value = None
response = test_client.delete("/servers/1", headers=auth_headers)
assert response.status_code == 200
assert response.json()["status"] == "success"
@patch("mcpgateway.main.server_service.get_server")
def test_delete_server_not_found(self, mock_get, test_client, auth_headers):
"""Test deleting a non-existent server returns 404."""
from mcpgateway.services.server_service import ServerNotFoundError
mock_get.side_effect = ServerNotFoundError("Server not found: nonexistent-id")
response = test_client.delete("/servers/nonexistent-id", headers=auth_headers)
assert response.status_code == 404
assert "Server not found" in response.json()["detail"]
@patch("mcpgateway.main.tool_service.list_server_tools")
def test_server_get_tools(self, mock_list_tools, test_client, auth_headers):
"""Test listing tools associated with a server."""
mock_tool = MagicMock()
mock_tool.model_dump.return_value = MOCK_TOOL_READ
mock_list_tools.return_value = [mock_tool]
response = test_client.get("/servers/1/tools", headers=auth_headers)
assert response.status_code == 200
data = response.json()
assert len(data) == 1
mock_list_tools.assert_called_once()
@patch("mcpgateway.main.resource_service.list_server_resources")
def test_server_get_resources(self, mock_list_resources, test_client, auth_headers):
"""Test listing resources associated with a server."""
mock_resource = MagicMock()
mock_resource.model_dump.return_value = MOCK_RESOURCE_READ
mock_list_resources.return_value = [mock_resource]
response = test_client.get("/servers/1/resources", headers=auth_headers)
assert response.status_code == 200
data = response.json()
assert len(data) == 1
mock_list_resources.assert_called_once()
@patch("mcpgateway.main.prompt_service.list_server_prompts")
def test_server_get_prompts(self, mock_list_prompts, test_client, auth_headers):
"""Test listing prompts associated with a server."""
# First-Party
from mcpgateway.schemas import PromptRead
mock_list_prompts.return_value = [PromptRead(**MOCK_PROMPT_READ)]
response = test_client.get("/servers/1/prompts", headers=auth_headers)
assert response.status_code == 200
data = response.json()
assert len(data) == 1
mock_list_prompts.assert_called_once()
# ----------------------------------------------------- #
# Tool Management Tests #
# ----------------------------------------------------- #
class TestToolEndpoints:
@patch("mcpgateway.main.tool_service.update_tool")
def test_update_tool_not_found(self, mock_update, test_client, auth_headers):
"""Test update_tool returns 404 if tool not found."""
# First-Party
from mcpgateway.services.tool_service import ToolNotFoundError
mock_update.side_effect = ToolNotFoundError("Tool not found")
req = {"description": "Updated description"}
response = test_client.put("/tools/999", json=req, headers=auth_headers)
assert response.status_code == 404
@patch("mcpgateway.main.create_tool")
def test_create_tool_validation_error(self, mock_create, test_client, auth_headers):
"""Test create_tool returns 422 for missing required fields."""
mock_create.side_effect = None # Let validation error happen
req = {"description": "Missing name and url"}
response = test_client.post("/tools/", json=req, headers=auth_headers)
assert response.status_code == 422
@pytest.mark.parametrize(
"exc,status_code",
[
(VALIDATION_ERROR, 422),
(INTEGRITY_ERROR, 409),
],
)
@patch("mcpgateway.main.tool_service.register_tool")
def test_create_tool_service_errors(self, mock_create, exc, status_code, test_client, auth_headers):
"""Test create_tool returns correct status for validation/integrity errors."""
mock_create.side_effect = exc
response = test_client.post("/tools/", json=_tool_create_payload(), headers=auth_headers)
assert response.status_code == status_code
"""Tests for tool management: registration, invocation, updates, etc."""
@patch("mcpgateway.main.tool_service.list_tools")
def test_list_tools_endpoint(self, mock_list_tools, test_client, auth_headers):
"""Test listing all registered tools."""
tool_read = ToolRead(**MOCK_TOOL_READ_SNAKE)
mock_list_tools.return_value = ([tool_read], None)
response = test_client.get("/tools/", headers=auth_headers)
assert response.status_code == 200
data = response.json()
# Default response is a plain list (include_pagination=False by default)
assert isinstance(data, list)
assert len(data) == 1 and data[0]["name"] == "test_tool"
mock_list_tools.assert_called_once()
@patch("mcpgateway.main.tool_service.register_tool")
def test_create_tool_endpoint(self, mock_create, test_client, auth_headers):
mock_create.return_value = MOCK_TOOL_READ_SNAKE
req = {"tool": {"name": "test_tool", "url": "http://example.com", "description": "A test tool"}, "team_id": None, "visibility": "private"}
response = test_client.post("/tools/", json=req, headers=auth_headers)
assert response.status_code == 200
mock_create.assert_called_once()
@patch("mcpgateway.main.tool_service.get_tool")
def test_get_tool_endpoint(self, mock_get, test_client, auth_headers):
mock_get.return_value = MOCK_TOOL_READ_SNAKE
response = test_client.get("/tools/1", headers=auth_headers)
assert response.status_code == 200
mock_get.assert_called_once()
@patch("mcpgateway.main.tool_service.update_tool")
def test_update_tool_endpoint(self, mock_update, test_client, auth_headers):
updated = {**MOCK_TOOL_READ_SNAKE, "description": "Updated description"}
mock_update.return_value = updated
req = {"description": "Updated description"}
response = test_client.put("/tools/1", json=req, headers=auth_headers)
assert response.status_code == 200
mock_update.assert_called_once()
@pytest.mark.parametrize(
"exc,status_code",
[
(VALIDATION_ERROR, 422),
(INTEGRITY_ERROR, 409),
],
)
@patch("mcpgateway.main.tool_service.update_tool")
def test_update_tool_validation_and_integrity_errors(self, mock_update, exc, status_code, test_client, auth_headers):
"""Test update_tool error branches for validation/integrity errors."""
mock_update.side_effect = exc
req = {"description": "Updated description"}
response = test_client.put("/tools/1", json=req, headers=auth_headers)
assert response.status_code == status_code
@patch("mcpgateway.main.tool_service.set_tool_state")
def test_set_tool_state(self, mock_toggle, test_client, auth_headers):
"""Test setting tool active/inactive state."""
mock_tool = MagicMock()
mock_tool.model_dump.return_value = {"id": 1, "name": "test", "is_active": False}
mock_toggle.return_value = mock_tool
response = test_client.post("/tools/1/state?activate=false", headers=auth_headers)
assert response.status_code == 200
assert response.json()["status"] == "success"
@patch("mcpgateway.main.tool_service.set_tool_state")
def test_set_tool_state_permission_error(self, mock_toggle, test_client, auth_headers):
"""Test tool state change forbidden error."""
mock_toggle.side_effect = PermissionError("Forbidden")
response = test_client.post("/tools/1/state?activate=false", headers=auth_headers)
assert response.status_code == 403
@patch("mcpgateway.main.tool_service.set_tool_state")
def test_set_tool_state_not_found(self, mock_toggle, test_client, auth_headers):
"""Test tool state change not found error."""
# First-Party
from mcpgateway.services.tool_service import ToolNotFoundError
mock_toggle.side_effect = ToolNotFoundError("Missing")
response = test_client.post("/tools/1/state?activate=false", headers=auth_headers)
assert response.status_code == 404
@patch("mcpgateway.main.tool_service.delete_tool")
def test_delete_tool_endpoint(self, mock_delete, test_client, auth_headers):
"""Test permanently deleting a tool."""
mock_delete.return_value = None
response = test_client.delete("/tools/1", headers=auth_headers)
assert response.status_code == 200
assert response.json()["status"] == "success"
# ----------------------------------------------------- #
# Resource Management Tests #
# ----------------------------------------------------- #
class TestResourceEndpoints:
@patch("mcpgateway.main.resource_service.update_resource")
def test_update_resource_not_found(self, mock_update, test_client, auth_headers):
"""Test update_resource returns 404 if resource not found."""
# First-Party
from mcpgateway.services.resource_service import ResourceNotFoundError
mock_update.side_effect = ResourceNotFoundError("Resource not found")
req = {"description": "Updated description"}
response = test_client.put("/resources/nonexistent", json=req, headers=auth_headers)
assert response.status_code == 404
@patch("mcpgateway.main.resource_service.register_resource")
def test_create_resource_validation_error(self, mock_create, test_client, auth_headers):
"""Test create_resource returns 422 for missing required fields."""
mock_create.side_effect = None # Let validation error happen
req = {"description": "Missing uri and name"}
response = test_client.post("/resources/", json=req, headers=auth_headers)
assert response.status_code == 422
@pytest.mark.parametrize(
"exc,status_code",
[
(VALIDATION_ERROR, 422),
(INTEGRITY_ERROR, 409),
],
)
@patch("mcpgateway.main.resource_service.update_resource")
def test_update_resource_validation_and_integrity_errors(self, mock_update, exc, status_code, test_client, auth_headers):
"""Test update_resource error branches for validation/integrity errors."""
mock_update.side_effect = exc
req = {"description": "Updated description"}
response = test_client.put("/resources/1", json=req, headers=auth_headers)
assert response.status_code == status_code
"""Tests for resource management: reading, creation, caching, etc."""
@patch("mcpgateway.main.resource_service.list_resources")
def test_list_resources_endpoint(self, mock_list_resources, test_client, auth_headers):
"""Test listing all available resources."""
mock_list_resources.return_value = ([ResourceRead(**MOCK_RESOURCE_READ)], None)
response = test_client.get("/resources/", headers=auth_headers)
assert response.status_code == 200
data = response.json()
# Default response is a plain list (include_pagination=False by default)
assert isinstance(data, list)
assert len(data) == 1 and data[0]["name"] == "Test Resource"
mock_list_resources.assert_called_once()
@patch("mcpgateway.main.resource_service.register_resource")
def test_create_resource_endpoint(self, mock_create, test_client, auth_headers):
"""Test registering a new resource."""
mock_create.return_value = ResourceRead(**MOCK_RESOURCE_READ)
req = {"resource": {"uri": "test/resource", "name": "Test Resource", "description": "A test resource", "content": "Hello world"}, "team_id": None, "visibility": "private"}
response = test_client.post("/resources/", json=req, headers=auth_headers)
assert response.status_code == 200 # route returns 200 on success
mock_create.assert_called_once()
@patch("mcpgateway.main.resource_service.read_resource")
def test_read_resource_endpoint(self, mock_read_resource, test_client, auth_headers):
"""Test reading resource content."""
# Clear the resource cache to avoid stale/cached values
from mcpgateway import main as mcpgateway_main
mcpgateway_main.resource_cache.clear()
mock_read_resource.return_value = ResourceContent(
type="resource",
id="1",
uri="test/resource",
mime_type="text/plain",
text="This is test content",
)
response = test_client.get("/resources/1", headers=auth_headers)
assert response.status_code == 200
body = response.json()
assert body["uri"] == "test/resource" and body["text"] == "This is test content"
mock_read_resource.assert_called_once()
@patch("mcpgateway.main.resource_service.update_resource")
def test_update_resource_endpoint(self, mock_update, test_client, auth_headers):
"""Test updating an existing resource."""
mock_update.return_value = ResourceRead(**MOCK_RESOURCE_READ)
resource_id = mock_update.return_value.id
req = {"description": "Updated description"}
response = test_client.put(f"/resources/{resource_id}", json=req, headers=auth_headers)
assert response.status_code == 200
mock_update.assert_called_once()
@patch("mcpgateway.main.resource_service.delete_resource")
def test_delete_resource_endpoint(self, mock_delete, test_client, auth_headers):
"""Test deleting a resource."""
mock_delete.return_value = None
# Use the same resource_id as in test_update_resource_endpoint
resource_id = MOCK_RESOURCE_READ["id"]
response = test_client.delete(f"/resources/{resource_id}", headers=auth_headers)
assert response.status_code == 200
assert response.json()["status"] == "success"
@patch("mcpgateway.main.resource_service.list_resource_templates")
def test_list_resource_templates(self, mock_list, test_client, auth_headers):
"""Test listing available resource templates."""
mock_list.return_value = []
response = test_client.get("/resources/templates/list", headers=auth_headers)
assert response.status_code == 200
mock_list.assert_called_once()
@patch("mcpgateway.main.resource_service.set_resource_state")
def test_set_resource_state(self, mock_toggle, test_client, auth_headers):
"""Test setting resource active/inactive state."""
mock_resource = MagicMock()
mock_resource.model_dump.return_value = {"id": "1", "enabled": False}
mock_toggle.return_value = mock_resource
response = test_client.post("/resources/1/state?activate=false", headers=auth_headers)
assert response.status_code == 200
@patch("mcpgateway.main.resource_service.set_resource_state")
def test_set_resource_state_permission_error(self, mock_toggle, test_client, auth_headers):
"""Test resource state change forbidden error."""
mock_toggle.side_effect = PermissionError("Forbidden")
response = test_client.post("/resources/1/state?activate=false", headers=auth_headers)
assert response.status_code == 403
@patch("mcpgateway.main.resource_service.set_resource_state")
def test_set_resource_state_not_found(self, mock_toggle, test_client, auth_headers):
"""Test resource state change not found error."""
# First-Party
from mcpgateway.services.resource_service import ResourceNotFoundError
mock_toggle.side_effect = ResourceNotFoundError("Missing")
response = test_client.post("/resources/1/state?activate=false", headers=auth_headers)
assert response.status_code == 404
@patch("mcpgateway.main.resource_service.subscribe_events")
def test_subscribe_resource_events(self, mock_subscribe, test_client, auth_headers):
"""Test subscribing to resource change events via SSE."""
mock_subscribe.return_value = iter(["data: test\n\n"])
resource_id = MOCK_RESOURCE_READ["id"]
response = test_client.post(f"/resources/subscribe", headers=auth_headers)
assert response.status_code == 200
assert response.headers["content-type"] == "text/event-stream; charset=utf-8"
# ----------------------------------------------------- #
# Prompt Management Tests #
# ----------------------------------------------------- #
class TestPromptEndpoints:
@patch("mcpgateway.main.prompt_service.delete_prompt")
def test_delete_prompt_not_found(self, mock_delete, test_client, auth_headers):
"""Test delete_prompt returns 404 if prompt not found."""
# First-Party
from mcpgateway.services.prompt_service import PromptNotFoundError
mock_delete.side_effect = PromptNotFoundError("Prompt not found")
response = test_client.delete("/prompts/nonexistent", headers=auth_headers)
assert response.status_code == 404
@patch("mcpgateway.main.prompt_service.update_prompt")
def test_update_prompt_not_found(self, mock_update, test_client, auth_headers):
"""Test update_prompt returns 404 if prompt not found."""
# First-Party
from mcpgateway.services.prompt_service import PromptNotFoundError
mock_update.side_effect = PromptNotFoundError("Prompt not found")
req = {"description": "Updated description"}
response = test_client.put("/prompts/nonexistent", json=req, headers=auth_headers)
assert response.status_code == 404
@patch("mcpgateway.main.prompt_service.register_prompt")
def test_create_prompt_validation_error(self, mock_create, test_client, auth_headers):
"""Test create_prompt returns 422 for missing required fields."""
mock_create.side_effect = None # Let validation error happen
req = {"description": "Missing name and template"}
response = test_client.post("/prompts/", json=req, headers=auth_headers)
assert response.status_code == 422
@patch("mcpgateway.main.prompt_service.get_prompt")
def test_get_prompt_no_args(self, mock_get, test_client, auth_headers):
"""Test getting a prompt without arguments."""
mock_get.return_value = {"name": "test", "template": "Hello"}
response = test_client.get("/prompts/test", headers=auth_headers)
assert response.status_code == 200
mock_get.assert_called_once_with(ANY, "test", {}, user=None, server_id=None, token_teams=None, plugin_context_table=None, plugin_global_context=ANY)
@patch("mcpgateway.main.prompt_service.update_prompt")
def test_update_prompt_endpoint(self, mock_update, test_client, auth_headers):
"""Test updating an existing prompt."""
updated = {**MOCK_PROMPT_READ, "description": "Updated description"}
mock_update.return_value = PromptRead(**updated)
req = {"description": "Updated description"}
response = test_client.put("/prompts/test_prompt", json=req, headers=auth_headers)
assert response.status_code == 200
mock_update.assert_called_once()
@pytest.mark.parametrize(
"exc,status_code",
[
(VALIDATION_ERROR, 422),
(INTEGRITY_ERROR, 409),
],
)
@patch("mcpgateway.main.prompt_service.update_prompt")
def test_update_prompt_validation_and_integrity_errors(self, mock_update, exc, status_code, test_client, auth_headers):
"""Test update_prompt error branches for validation/integrity errors."""
mock_update.side_effect = exc
req = {"description": "Updated description"}
response = test_client.put("/prompts/test_prompt", json=req, headers=auth_headers)
assert response.status_code == status_code
@patch("mcpgateway.main.prompt_service.delete_prompt")
def test_delete_prompt_endpoint(self, mock_delete, test_client, auth_headers):
"""Test deleting a prompt."""
mock_delete.return_value = None
response = test_client.delete("/prompts/test_prompt", headers=auth_headers)
assert response.status_code == 200
assert response.json()["status"] == "success"
mock_delete.assert_called_once()
@patch("mcpgateway.main.prompt_service.set_prompt_state")
def test_set_prompt_state(self, mock_toggle, test_client, auth_headers):
"""Test setting prompt active/inactive state."""
mock_prompt = MagicMock()
mock_prompt.model_dump.return_value = {"id": 1, "enabled": False}
mock_toggle.return_value = mock_prompt
response = test_client.post("/prompts/1/state?activate=false", headers=auth_headers)
assert response.status_code == 200
assert response.json()["status"] == "success"
mock_toggle.assert_called_once()
@patch("mcpgateway.main.prompt_service.set_prompt_state")
def test_set_prompt_state_permission_error(self, mock_toggle, test_client, auth_headers):
"""Test prompt state change forbidden error."""
mock_toggle.side_effect = PermissionError("Forbidden")
response = test_client.post("/prompts/1/state?activate=false", headers=auth_headers)
assert response.status_code == 403
@patch("mcpgateway.main.prompt_service.set_prompt_state")
def test_set_prompt_state_not_found(self, mock_toggle, test_client, auth_headers):
"""Test prompt state change not found error."""
# First-Party
from mcpgateway.services.prompt_service import PromptNotFoundError
mock_toggle.side_effect = PromptNotFoundError("Missing")
response = test_client.post("/prompts/1/state?activate=false", headers=auth_headers)
assert response.status_code == 404
"""Tests for prompt template management: creation, rendering, arguments, etc."""
@patch("mcpgateway.main.prompt_service.list_prompts")
def test_list_prompts_endpoint(self, mock_list_prompts, test_client, auth_headers):
"""Test listing all available prompts."""
prompt_read = PromptRead(**MOCK_PROMPT_READ)
mock_list_prompts.return_value = ([prompt_read], None)
response = test_client.get("/prompts/", headers=auth_headers)
assert response.status_code == 200
data = response.json()
# Default response is a plain list (include_pagination=False by default)
assert isinstance(data, list)
assert len(data) == 1
mock_list_prompts.assert_called_once()
@patch("mcpgateway.main.prompt_service.register_prompt")
def test_create_prompt_endpoint(self, mock_create, test_client, auth_headers):
"""Test creating a new prompt template."""
# Return an actual model instance
mock_create.return_value = PromptRead(**MOCK_PROMPT_READ)
req = {"prompt": {"name": "test_prompt", "template": "Hello {name}", "description": "A test prompt"}, "team_id": None, "visibility": "private"}
response = test_client.post("/prompts/", json=req, headers=auth_headers)
assert response.status_code == 200
mock_create.assert_called_once()
@patch("mcpgateway.main.prompt_service.get_prompt")
def test_get_prompt_with_args(self, mock_get, test_client, auth_headers):
"""Test getting a prompt with template arguments."""
mock_get.return_value = {
"messages": [{"role": "user", "content": {"type": "text", "text": "Rendered prompt"}}],
"description": "A test prompt",
}
req = {"name": "value"}
response = test_client.post("/prompts/test_prompt", json=req, headers=auth_headers)
assert response.status_code == 200
body = response.json()
assert body["messages"][0]["content"]["text"] == "Rendered prompt"
mock_get.assert_called_once()
@patch("mcpgateway.main.prompt_service.get_prompt")
def test_get_prompt_no_args(self, mock_get, test_client, auth_headers):
"""Test getting a prompt without arguments."""
mock_get.return_value = {"name": "test", "template": "Hello"}
response = test_client.get("/prompts/test", headers=auth_headers)
assert response.status_code == 200
mock_get.assert_called_once_with(ANY, "test", {}, user=None, server_id=None, token_teams=None, plugin_context_table=None, plugin_global_context=ANY)
@patch("mcpgateway.main.prompt_service.update_prompt")
def test_update_prompt_endpoint(self, mock_update, test_client, auth_headers):
"""Test updating an existing prompt."""
updated = {**MOCK_PROMPT_READ, "description": "Updated description"}
mock_update.return_value = PromptRead(**updated) # <- real model
req = {"description": "Updated description"}
response = test_client.put("/prompts/test_prompt", json=req, headers=auth_headers)
assert response.status_code == 200
mock_update.assert_called_once()
@patch("mcpgateway.main.prompt_service.delete_prompt")
def test_delete_prompt_endpoint(self, mock_delete, test_client, auth_headers):
"""Test deleting a prompt."""
mock_delete.return_value = None
response = test_client.delete("/prompts/test_prompt", headers=auth_headers)
assert response.status_code == 200
assert response.json()["status"] == "success"
@patch("mcpgateway.main.prompt_service.set_prompt_state")
def test_set_prompt_state(self, mock_toggle, test_client, auth_headers):
"""Test setting prompt active/inactive state."""
mock_prompt = MagicMock()
mock_prompt.model_dump.return_value = {"id": 1, "enabled": False}
mock_toggle.return_value = mock_prompt
response = test_client.post("/prompts/1/state?activate=false", headers=auth_headers)
assert response.status_code == 200
assert response.json()["status"] == "success"
# ----------------------------------------------------- #
# Gateway Federation Tests #
# ----------------------------------------------------- #
class TestGatewayEndpoints:
@patch("mcpgateway.main.gateway_service.list_gateways")
def test_list_gateways_endpoint(self, mock_list, test_client, auth_headers):
"""Test listing all registered gateways."""
gateway_read = GatewayRead(**MOCK_GATEWAY_READ)
mock_list.return_value = ([gateway_read], None)
response = test_client.get("/gateways/", headers=auth_headers)
assert response.status_code == 200
data = response.json()
# Default response is a plain list (include_pagination=False by default)
assert isinstance(data, list)
assert len(data) == 1
mock_list.assert_called_once()
@patch("mcpgateway.main.gateway_service.register_gateway")
def test_create_gateway_endpoint(self, mock_create, test_client, auth_headers):
"""Test registering a new gateway."""
mock_create.return_value = MOCK_GATEWAY_READ
req = {"name": "test_gateway", "url": "http://example.com"}
response = test_client.post("/gateways/", json=req, headers=auth_headers)
assert response.status_code == 200
mock_create.assert_called_once()
@patch("mcpgateway.main.gateway_service.get_gateway")
def test_get_gateway_endpoint(self, mock_get, test_client, auth_headers):
"""Test retrieving a specific gateway."""
mock_get.return_value = MOCK_GATEWAY_READ
response = test_client.get("/gateways/1", headers=auth_headers)
assert response.status_code == 200
assert response.json()["name"] == "test_gateway"
mock_get.assert_called_once()
@patch("mcpgateway.main.gateway_service.update_gateway")
def test_update_gateway_endpoint(self, mock_update, test_client, auth_headers):
"""Test updating an existing gateway."""
mock_update.return_value = MOCK_GATEWAY_READ
req = {"description": "Updated description"}
response = test_client.put("/gateways/1", json=req, headers=auth_headers)
assert response.status_code == 200
mock_update.assert_called_once()
@patch("mcpgateway.main.gateway_service.delete_gateway")
@patch("mcpgateway.main.gateway_service.get_gateway")
def test_delete_gateway_endpoint_no_resources(self, mock_get, mock_delete, test_client, auth_headers):
"""Test deleting a gateway that doesn't have resources."""
mock_delete.return_value = None
mock_get.return_value.capabilities = {}
response = test_client.delete("/gateways/1", headers=auth_headers)
assert response.status_code == 200
assert response.json()["status"] == "success"
mock_delete.assert_called_once()
@patch("mcpgateway.main.gateway_service.delete_gateway")
@patch("mcpgateway.main.gateway_service.get_gateway")
@patch("mcpgateway.main.invalidate_resource_cache")
def test_delete_gateway_endpoint_with_resources(self, mock_invalidate_cache, mock_get, mock_delete, test_client, auth_headers):
"""Test deleting a gateway that does have resources."""
mock_delete.return_value = None
mock_get.return_value = MagicMock()
mock_get.return_value.capabilities = {"resources": {"some": "thing"}}
response = test_client.delete("/gateways/1", headers=auth_headers)
assert response.status_code == 200
assert response.json()["status"] == "success"
mock_delete.assert_called_once()
mock_invalidate_cache.assert_called_once()
@patch("mcpgateway.main.gateway_service.set_gateway_state")
def test_set_gateway_state(self, mock_toggle, test_client, auth_headers):
"""Test setting gateway active/inactive state."""
mock_gateway = MagicMock()
mock_gateway.model_dump.return_value = {"id": "1", "is_active": False}
mock_toggle.return_value = mock_gateway
response = test_client.post("/gateways/1/state?activate=false", headers=auth_headers)
assert response.status_code == 200
assert response.json()["status"] == "success"
mock_toggle.assert_called_once()
@patch("mcpgateway.main.gateway_service.set_gateway_state")
def test_set_gateway_state_permission_error(self, mock_toggle, test_client, auth_headers):
"""Test gateway state change forbidden error."""
mock_toggle.side_effect = PermissionError("Forbidden")
response = test_client.post("/gateways/1/state?activate=false", headers=auth_headers)
assert response.status_code == 403
@patch("mcpgateway.main.gateway_service.set_gateway_state")
def test_set_gateway_state_not_found(self, mock_toggle, test_client, auth_headers):
"""Test gateway state change not found error."""
# First-Party
from mcpgateway.services.gateway_service import GatewayNotFoundError
mock_toggle.side_effect = GatewayNotFoundError("Missing")
response = test_client.post("/gateways/1/state?activate=false", headers=auth_headers)
assert response.status_code == 404
"""Tests for gateway federation: registration, discovery, forwarding, etc."""
@patch("mcpgateway.main.gateway_service.list_gateways")
def test_list_gateways_endpoint(self, mock_list, test_client, auth_headers):
"""Test listing all registered gateways."""
gateway_read = GatewayRead(**MOCK_GATEWAY_READ)
mock_list.return_value = ([gateway_read], None)
response = test_client.get("/gateways/", headers=auth_headers)
assert response.status_code == 200
data = response.json()
# Default response is a plain list (include_pagination=False by default)
assert isinstance(data, list)
assert len(data) == 1
mock_list.assert_called_once()
@patch("mcpgateway.main.gateway_service.register_gateway")
def test_create_gateway_endpoint(self, mock_create, test_client, auth_headers):
"""Test registering a new gateway."""
mock_create.return_value = MOCK_GATEWAY_READ
req = {"name": "test_gateway", "url": "http://example.com"}
response = test_client.post("/gateways/", json=req, headers=auth_headers)
assert response.status_code == 200
mock_create.assert_called_once()
@patch("mcpgateway.main.gateway_service.get_gateway")
def test_get_gateway_endpoint(self, mock_get, test_client, auth_headers):
"""Test retrieving a specific gateway."""
mock_get.return_value = MOCK_GATEWAY_READ
response = test_client.get("/gateways/1", headers=auth_headers)
assert response.status_code == 200
assert response.json()["name"] == "test_gateway"
mock_get.assert_called_once()
@patch("mcpgateway.main.gateway_service.update_gateway")
def test_update_gateway_endpoint(self, mock_update, test_client, auth_headers):
"""Test updating an existing gateway."""
mock_update.return_value = MOCK_GATEWAY_READ
req = {"description": "Updated description"}
response = test_client.put("/gateways/1", json=req, headers=auth_headers)
assert response.status_code == 200
mock_update.assert_called_once()
@patch("mcpgateway.main.gateway_service.delete_gateway")
@patch("mcpgateway.main.gateway_service.get_gateway")
def test_delete_gateway_endpoint(self, mock_get, mock_delete, test_client, auth_headers):
"""Test deleting a gateway."""
mock_delete.return_value = None
mock_get.return_value.capabilities = {}
response = test_client.delete("/gateways/1", headers=auth_headers)
assert response.status_code == 200
assert response.json()["status"] == "success"
@patch("mcpgateway.main.gateway_service.set_gateway_state")
def test_set_gateway_state(self, mock_toggle, test_client, auth_headers):
"""Test setting gateway active/inactive state."""
mock_gateway = MagicMock()
mock_gateway.model_dump.return_value = {"id": "1", "is_active": False}
mock_toggle.return_value = mock_gateway
response = test_client.post("/gateways/1/state?activate=false", headers=auth_headers)
assert response.status_code == 200
assert response.json()["status"] == "success"
# ----------------------------------------------------- #
# Tag Endpoints Tests #
# ----------------------------------------------------- #
class TestTagEndpoints:
@patch("mcpgateway.main.tag_service.get_all_tags")
def test_list_tags_error(self, mock_get_tags, test_client, auth_headers):
"""Test tag list error handling."""
mock_get_tags.side_effect = Exception("Tag failure")
response = test_client.get("/tags", headers=auth_headers)
assert response.status_code == 500
@patch("mcpgateway.main.tag_service.get_entities_by_tag")
def test_get_entities_by_tag_error(self, mock_get_entities, test_client, auth_headers):
"""Test tag entity lookup error handling."""
mock_get_entities.side_effect = Exception("Entity lookup failure")
response = test_client.get("/tags/test/entities", headers=auth_headers)
assert response.status_code == 500
# ----------------------------------------------------- #
# Root Management Tests #
# ----------------------------------------------------- #
class TestRootEndpoints:
"""Tests for root directory management: registration, listing, changes, etc."""
@patch("mcpgateway.main.root_service.list_roots")
def test_list_roots_endpoint(self, mock_list, test_client, auth_headers):
"""Test listing all registered roots."""
# First-Party
from mcpgateway.common.models import Root
mock_list.return_value = [Root(uri="file:///test", name="Test Root")] # valid URI
response = test_client.get("/roots/", headers=auth_headers)
assert response.status_code == 200
data = response.json()
assert len(data) == 1
mock_list.assert_called_once()
@patch("mcpgateway.main.root_service.add_root")
def test_add_root_endpoint(self, mock_add, test_client, auth_headers):
"""Test adding a new root directory."""
# First-Party
from mcpgateway.common.models import Root
mock_add.return_value = Root(uri="file:///test", name="Test Root") # valid URI
req = {"uri": "file:///test", "name": "Test Root"} # valid body
response = test_client.post("/roots/", json=req, headers=auth_headers)
assert response.status_code == 200
mock_add.assert_called_once()
@patch("mcpgateway.main.root_service.remove_root")
def test_remove_root_endpoint(self, mock_remove, test_client, auth_headers):
"""Test removing a root directory."""
mock_remove.return_value = None
response = test_client.delete("/roots/%2Ftest", headers=auth_headers)
assert response.status_code == 200
assert response.json()["status"] == "success"
@patch("mcpgateway.main.root_service.subscribe_changes")
def test_subscribe_root_changes(self, mock_subscribe, test_client, auth_headers):
"""Test subscribing to root directory changes via SSE."""
async def mock_async_gen():
yield {"event": "test"}
mock_subscribe.return_value = mock_async_gen()
response = test_client.get("/roots/changes", headers=auth_headers)
assert response.status_code == 200
assert response.headers["content-type"] == "text/event-stream; charset=utf-8"
# ----------------------------------------------------- #
# JSON-RPC & Utility Tests #
# ----------------------------------------------------- #
class TestRPCEndpoints:
"""Tests for JSON-RPC functionality and utility endpoints."""
@patch("mcpgateway.main.tool_service.invoke_tool")
def test_rpc_tool_invocation(self, mock_invoke_tool, test_client, auth_headers):
"""Test tool invocation via JSON-RPC."""
mock_invoke_tool.return_value = {"content": [{"type": "text", "text": "Tool response"}], "is_error": False}
req = {"jsonrpc": "2.0", "id": "test-id", "method": "tools/call", "params": {"name": "test_tool", "arguments": {"param": "value"}}}
response = test_client.post("/rpc/", json=req, headers=auth_headers)
assert response.status_code == 200
body = response.json()
assert body["result"]["content"][0]["text"] == "Tool response"
mock_invoke_tool.assert_called_once_with(
db=ANY,
name="test_tool",
arguments={"param": "value"},
request_headers=ANY,
app_user_email="test_user@example.com", # Updated: now uses email from JWT/RBAC
user_email="test_user@example.com",
token_teams=[],
server_id=None,
plugin_context_table=None,
plugin_global_context=ANY,
meta_data=None,
)
@patch("mcpgateway.main.prompt_service.get_prompt")
# @patch("mcpgateway.main.validate_request")
def test_rpc_prompt_get(self, mock_get_prompt, test_client, auth_headers):
"""Test prompt retrieval via JSON-RPC."""
mock_get_prompt.return_value = {
"messages": [{"role": "user", "content": {"type": "text", "text": "Rendered prompt"}}],
"description": "A test prompt",
}
req = {
"jsonrpc": "2.0",
"id": "test-id",
"method": "prompts/get",
"params": {"name": "test_prompt", "arguments": {"param": "value"}},
}
response = test_client.post("/rpc/", json=req, headers=auth_headers)
assert response.status_code == 200
body = response.json()
assert body["result"]["messages"][0]["content"]["text"] == "Rendered prompt"
mock_get_prompt.assert_called_once_with(
ANY, # db
"test_prompt", # name
{"param": "value"}, # arguments
user="test_user@example.com",
server_id=None,
token_teams=[],
plugin_context_table=None,
plugin_global_context=ANY,
_meta_data=None,
)
@patch("mcpgateway.main.tool_service.list_tools")
# @patch("mcpgateway.main.validate_request")
def test_rpc_list_tools(self, mock_list_tools, test_client, auth_headers):
"""Test listing tools via JSON-RPC."""
mock_tool = MagicMock()
mock_tool.model_dump.return_value = MOCK_TOOL_READ
mock_list_tools.return_value = ([mock_tool], None)
req = {
"jsonrpc": "2.0",
"id": "test-id",
"method": "tools/list",
"params": {},
}
response = test_client.post("/rpc/", json=req, headers=auth_headers)
assert response.status_code == 200
body = response.json()
assert isinstance(body["result"]["tools"], list)
mock_list_tools.assert_called_once()
@patch("mcpgateway.main.tool_service.list_server_tools", new_callable=AsyncMock)
def test_rpc_list_tools_with_server_id(self, mock_list_tools, test_client, auth_headers):
"""Test listing tools via JSON-RPC for a specific server."""
mock_tool = MagicMock()
mock_tool.model_dump.return_value = MOCK_TOOL_READ
mock_list_tools.return_value = [mock_tool]
req = {
"jsonrpc": "2.0",
"id": "test-id",
"method": "tools/list",
"params": {"server_id": "server-1"},
}
response = test_client.post("/rpc/", json=req, headers=auth_headers)
assert response.status_code == 200
body = response.json()
assert body["result"]["tools"][0]["name"] == "test_tool"
mock_list_tools.assert_called_once()
@patch("mcpgateway.main.tool_service.list_tools")
def test_rpc_legacy_list_tools_next_cursor(self, mock_list_tools, test_client, auth_headers):
"""Test legacy list_tools JSON-RPC method with nextCursor."""
mock_tool = MagicMock()
mock_tool.model_dump.return_value = MOCK_TOOL_READ
mock_list_tools.return_value = ([mock_tool], "next-cursor")
req = {
"jsonrpc": "2.0",
"id": "test-id",
"method": "list_tools",
"params": {},
}
response = test_client.post("/rpc/", json=req, headers=auth_headers)
assert response.status_code == 200
body = response.json()["result"]
assert body["nextCursor"] == "next-cursor"
assert body["tools"][0]["name"] == "test_tool"
@patch("mcpgateway.main.resource_service.list_server_resources", new_callable=AsyncMock)
def test_rpc_resources_list_with_server_id(self, mock_list_resources, test_client, auth_headers):
"""Test listing resources via JSON-RPC for a specific server."""
mock_resource = MagicMock()
mock_resource.model_dump.return_value = {"uri": "res://1"}
mock_list_resources.return_value = [mock_resource]
req = {
"jsonrpc": "2.0",
"id": "test-id",
"method": "resources/list",
"params": {"server_id": "server-1"},
}
response = test_client.post("/rpc/", json=req, headers=auth_headers)
assert response.status_code == 200
body = response.json()["result"]
assert body["resources"][0]["uri"] == "res://1"
def test_rpc_resources_read_missing_uri(self, test_client, auth_headers):
"""Test resources/read error when uri is missing."""
req = {
"jsonrpc": "2.0",
"id": "test-id",
"method": "resources/read",
"params": {},
}
response = test_client.post("/rpc/", json=req, headers=auth_headers)
assert response.status_code == 200
body = response.json()
assert body["error"]["code"] == -32602
assert "Missing resource URI" in body["error"]["message"]
@patch("mcpgateway.main.resource_service.read_resource", new_callable=AsyncMock)
@patch("mcpgateway.main.gateway_service.forward_request", new_callable=AsyncMock)
def test_rpc_resources_read_fallback(self, mock_forward, mock_read, test_client, auth_headers):
"""Test resources/read fallback to gateway when local content missing."""
mock_read.side_effect = ValueError("no local content")
mock_forward.return_value = {"contents": [{"uri": "res://remote", "text": "ok"}]}
req = {
"jsonrpc": "2.0",
"id": "test-id",
"method": "resources/read",
"params": {"uri": "res://remote"},
}
response = test_client.post("/rpc/", json=req, headers=auth_headers)
assert response.status_code == 200
body = response.json()["result"]
assert body["contents"][0]["uri"] == "res://remote"
mock_forward.assert_called_once()
@patch("mcpgateway.main.get_user_email", return_value="user_1")
@patch("mcpgateway.main.resource_service.subscribe_resource", new_callable=AsyncMock)
@patch("mcpgateway.main.resource_service.unsubscribe_resource", new_callable=AsyncMock)
def test_rpc_resources_subscribe_unsubscribe(self, mock_unsubscribe, mock_subscribe, _mock_get_user_email, test_client, auth_headers):
"""Test resources/subscribe and resources/unsubscribe JSON-RPC methods."""
subscribe_req = {
"jsonrpc": "2.0",
"id": "test-id",
"method": "resources/subscribe",
"params": {"uri": "res://1"},
}
unsubscribe_req = {
"jsonrpc": "2.0",
"id": "test-id",
"method": "resources/unsubscribe",
"params": {"uri": "res://1"},
}
response = test_client.post("/rpc/", json=subscribe_req, headers=auth_headers)
assert response.status_code == 200
assert response.json()["result"] == {}
response = test_client.post("/rpc/", json=unsubscribe_req, headers=auth_headers)
assert response.status_code == 200
assert response.json()["result"] == {}
mock_subscribe.assert_called_once()
mock_unsubscribe.assert_called_once()
@patch("mcpgateway.main.resource_service.list_resource_templates", new_callable=AsyncMock)
def test_rpc_resource_templates_list(self, mock_list_templates, test_client, auth_headers):
"""Test resources/templates/list JSON-RPC method."""
mock_template = MagicMock()
mock_template.model_dump.return_value = {"uri": "tpl://1"}
mock_list_templates.return_value = [mock_template]
req = {
"jsonrpc": "2.0",
"id": "test-id",
"method": "resources/templates/list",
"params": {},
}
response = test_client.post("/rpc/", json=req, headers=auth_headers)
assert response.status_code == 200
body = response.json()["result"]
assert body["resourceTemplates"][0]["uri"] == "tpl://1"
@patch("mcpgateway.main.prompt_service.list_prompts", new_callable=AsyncMock)
def test_rpc_prompts_list_next_cursor(self, mock_list_prompts, test_client, auth_headers):
"""Test prompts/list JSON-RPC method with nextCursor."""
mock_prompt = MagicMock()
mock_prompt.model_dump.return_value = {"name": "prompt-1"}
mock_list_prompts.return_value = ([mock_prompt], "next-cursor")
req = {
"jsonrpc": "2.0",
"id": "test-id",
"method": "prompts/list",
"params": {},
}
response = test_client.post("/rpc/", json=req, headers=auth_headers)
assert response.status_code == 200
body = response.json()["result"]
assert body["nextCursor"] == "next-cursor"
assert body["prompts"][0]["name"] == "prompt-1"
@patch("mcpgateway.main.gateway_service.list_gateways", new_callable=AsyncMock)
def test_rpc_list_gateways(self, mock_list_gateways, test_client, auth_headers):
"""Test list_gateways JSON-RPC method."""
mock_gateway = MagicMock()
mock_gateway.model_dump.return_value = {"id": "gateway-1"}
mock_list_gateways.return_value = ([mock_gateway], None)
req = {
"jsonrpc": "2.0",
"id": "test-id",
"method": "list_gateways",
"params": {},
}
response = test_client.post("/rpc/", json=req, headers=auth_headers)
assert response.status_code == 200
body = response.json()["result"]
assert body["gateways"][0]["id"] == "gateway-1"
@patch("mcpgateway.main.root_service.list_roots", new_callable=AsyncMock)
def test_rpc_list_roots(self, mock_list_roots, test_client, auth_headers):
"""Test list_roots JSON-RPC method."""
mock_root = MagicMock()
mock_root.model_dump.return_value = {"uri": "root://1"}
mock_list_roots.return_value = [mock_root]
req = {
"jsonrpc": "2.0",
"id": "test-id",
"method": "list_roots",
"params": {},
}
response = test_client.post("/rpc/", json=req, headers=auth_headers)
assert response.status_code == 200
body = response.json()["result"]
assert body["roots"][0]["uri"] == "root://1"
@patch("mcpgateway.main.logging_service.notify", new_callable=AsyncMock)
@patch("mcpgateway.main.cancellation_service.cancel_run", new_callable=AsyncMock)
def test_rpc_notification_cancelled(self, mock_cancel_run, mock_notify, test_client, auth_headers):
"""Test notifications/cancelled JSON-RPC method."""
req = {
"jsonrpc": "2.0",
"id": "test-id",
"method": "notifications/cancelled",
"params": {"requestId": "123", "reason": "user"},
}
response = test_client.post("/rpc/", json=req, headers=auth_headers)
assert response.status_code == 200
assert response.json()["result"] == {}
mock_cancel_run.assert_called_once_with("123", reason="user")
mock_notify.assert_called_once()
@patch("mcpgateway.main.gateway_service.forward_request", new_callable=AsyncMock)
@patch("mcpgateway.main.tool_service.invoke_tool", new_callable=AsyncMock)
def test_rpc_tools_call_fallback_to_gateway(self, mock_invoke, mock_forward, test_client, auth_headers):
"""Test tools/call fallback to gateway when tool invocation fails."""
mock_invoke.side_effect = ValueError("no local tool")
mock_forward.return_value = {"content": [{"type": "text", "text": "fallback"}]}
req = {
"jsonrpc": "2.0",
"id": "test-id",
"method": "tools/call",
"params": {"name": "missing_tool", "arguments": {}},
}
response = test_client.post("/rpc/", json=req, headers=auth_headers)
assert response.status_code == 200
body = response.json()["result"]
assert body["content"][0]["text"] == "fallback"
mock_forward.assert_called_once()
def test_rpc_elicitation_disabled(self, test_client, auth_headers, monkeypatch):
"""Test elicitation/create JSON-RPC when feature disabled."""
monkeypatch.setattr(settings, "mcpgateway_elicitation_enabled", False)
req = {
"jsonrpc": "2.0",
"id": "test-id",
"method": "elicitation/create",
"params": {"message": "Need input", "requestedSchema": {"type": "object"}},
}
response = test_client.post("/rpc/", json=req, headers=auth_headers)
assert response.status_code == 200
body = response.json()
assert body["error"]["code"] == -32601
@patch("mcpgateway.main.logging_service.notify", new_callable=AsyncMock)
def test_rpc_notifications_initialized(self, mock_notify, test_client, auth_headers):
"""Test notifications/initialized JSON-RPC method."""
req = {"jsonrpc": "2.0", "id": "test-id", "method": "notifications/initialized"}
response = test_client.post("/rpc/", json=req, headers=auth_headers)
assert response.status_code == 200
assert response.json()["result"] == {}
mock_notify.assert_called_once()
@patch("mcpgateway.main.logging_service.notify", new_callable=AsyncMock)
def test_rpc_notifications_message(self, mock_notify, test_client, auth_headers):
"""Test notifications/message JSON-RPC method."""
req = {
"jsonrpc": "2.0",
"id": "test-id",
"method": "notifications/message",
"params": {"data": "hello", "level": "info", "logger": "test"},
}
response = test_client.post("/rpc/", json=req, headers=auth_headers)
assert response.status_code == 200
assert response.json()["result"] == {}
mock_notify.assert_called_once()
@patch("mcpgateway.main.sampling_handler.create_message", new_callable=AsyncMock)
def test_rpc_sampling_create_message(self, mock_sampling, test_client, auth_headers):
"""Test sampling/createMessage JSON-RPC method."""
mock_sampling.return_value = {"messageId": "abc"}
req = {
"jsonrpc": "2.0",
"id": "test-id",
"method": "sampling/createMessage",
"params": {"messages": [{"role": "user", "content": {"type": "text", "text": "hi"}}]},
}
response = test_client.post("/rpc/", json=req, headers=auth_headers)
assert response.status_code == 200
assert response.json()["result"]["messageId"] == "abc"
mock_sampling.assert_called_once()
def test_rpc_sampling_other_method(self, test_client, auth_headers):
"""Test sampling/* catch-all JSON-RPC method."""
req = {"jsonrpc": "2.0", "id": "test-id", "method": "sampling/unknown", "params": {}}
response = test_client.post("/rpc/", json=req, headers=auth_headers)
assert response.status_code == 200
assert response.json()["result"] == {}
@patch("mcpgateway.main.completion_service.handle_completion", new_callable=AsyncMock)
def test_rpc_completion_complete(self, mock_completion, test_client, auth_headers):
"""Test completion/complete JSON-RPC method."""
mock_completion.return_value = {"result": "done"}
req = {"jsonrpc": "2.0", "id": "test-id", "method": "completion/complete", "params": {"ref": {"type": "ref/prompt", "name": "p1"}}}
response = test_client.post("/rpc/", json=req, headers=auth_headers)
assert response.status_code == 200
assert response.json()["result"]["result"] == "done"
mock_completion.assert_called_once()
def test_rpc_completion_other_method(self, test_client, auth_headers):
"""Test completion/* catch-all JSON-RPC method."""
req = {"jsonrpc": "2.0", "id": "test-id", "method": "completion/unknown", "params": {}}
response = test_client.post("/rpc/", json=req, headers=auth_headers)
assert response.status_code == 200
assert response.json()["result"] == {}
@patch("mcpgateway.main.logging_service.set_level", new_callable=AsyncMock)
def test_rpc_logging_set_level(self, mock_set_level, test_client, auth_headers):
"""Test logging/setLevel JSON-RPC method."""
req = {"jsonrpc": "2.0", "id": "test-id", "method": "logging/setLevel", "params": {"level": "debug"}}
response = test_client.post("/rpc/", json=req, headers=auth_headers)
assert response.status_code == 200
assert response.json()["result"] == {}
mock_set_level.assert_called_once()
def test_rpc_logging_other_method(self, test_client, auth_headers):
"""Test logging/* catch-all JSON-RPC method."""
req = {"jsonrpc": "2.0", "id": "test-id", "method": "logging/other", "params": {}}
response = test_client.post("/rpc/", json=req, headers=auth_headers)
assert response.status_code == 200
assert response.json()["result"] == {}
@patch("mcpgateway.main.root_service.list_roots", new_callable=AsyncMock)
def test_rpc_roots_list_method(self, mock_list_roots, test_client, auth_headers):
"""Test roots/list JSON-RPC method."""
mock_root = MagicMock()
mock_root.model_dump.return_value = {"uri": "root://2"}
mock_list_roots.return_value = [mock_root]
req = {"jsonrpc": "2.0", "id": "test-id", "method": "roots/list", "params": {}}
response = test_client.post("/rpc/", json=req, headers=auth_headers)
assert response.status_code == 200
assert response.json()["result"]["roots"][0]["uri"] == "root://2"
def test_rpc_roots_other_method(self, test_client, auth_headers):
"""Test roots/* catch-all JSON-RPC method."""
req = {"jsonrpc": "2.0", "id": "test-id", "method": "roots/remove", "params": {}}
response = test_client.post("/rpc/", json=req, headers=auth_headers)
assert response.status_code == 200
assert response.json()["result"] == {}
@patch("mcpgateway.main.RPCRequest")
def test_rpc_invalid_request(self, mock_rpc_request, test_client, auth_headers):
"""Test RPC error handling for invalid requests."""
mock_rpc_request.side_effect = ValueError("Invalid method")
req = {"jsonrpc": "1.0", "id": "test-id", "method": "invalid_method"}
response = test_client.post("/rpc/", json=req, headers=auth_headers)
assert response.status_code == 422
body = response.json()
assert "Method invalid" in body.get("message")
def test_rpc_invalid_json(self, test_client, auth_headers):
"""Test RPC error handling for malformed JSON."""
headers = auth_headers
headers["content-type"] = "application/json"
response = test_client.post("/rpc/", content="invalid json", headers=headers)
assert response.status_code == 400
body = response.json()
assert body["error"]["code"] == -32700
assert body["error"]["message"] == "Parse error"
@patch("mcpgateway.main.logging_service.set_level")
def test_set_log_level_endpoint(self, mock_set_level, test_client, auth_headers):
"""Test setting the application log level."""
req = {"level": "debug"} # lowercase to match enum
response = test_client.post("/logging/setLevel", json=req, headers=auth_headers)
assert response.status_code == 200
mock_set_level.assert_called_once()
# ----------------------------------------------------- #
# WebSocket & SSE Tests #
# ----------------------------------------------------- #
class TestRealtimeEndpoints:
"""Tests for real-time communication: WebSocket, SSE, message handling, etc."""
@patch("mcpgateway.main.settings")
@patch("mcpgateway.main.ResilientHttpClient") # stub network calls
def test_websocket_endpoint(self, mock_client, mock_settings, test_client):
# Standard
from types import SimpleNamespace
"""Test WebSocket connection and message handling."""
# Configure mock settings for auth disabled
mock_settings.mcp_client_auth_enabled = False
mock_settings.auth_required = False
mock_settings.federation_timeout = 30
mock_settings.skip_ssl_verify = False
mock_settings.port = 4444
# ----- set up async context-manager dummy -----
mock_instance = mock_client.return_value
mock_instance.__aenter__.return_value = mock_instance
mock_instance.__aexit__.return_value = False
async def dummy_post(*_args, **_kwargs):
# minimal object that looks like an httpx.Response
return SimpleNamespace(text='{"jsonrpc":"2.0","id":1,"result":{}}')
mock_instance.post = dummy_post
# ---------------------------------------------
with test_client.websocket_connect("/ws") as websocket:
websocket.send_text('{"jsonrpc":"2.0","method":"ping","id":1}')
data = websocket.receive_text()
response = json.loads(data)
assert response == {"jsonrpc": "2.0", "id": 1, "result": {}}
@patch("mcpgateway.main.update_url_protocol", new=lambda url: url)
@patch("mcpgateway.main.session_registry.add_session")
@patch("mcpgateway.main.session_registry.respond")
@patch("mcpgateway.main.SSETransport")
def test_sse_endpoint(self, mock_transport_class, mock_respond, mock_add_session, test_client, auth_headers):
"""Test SSE connection establishment."""
mock_transport = MagicMock()
mock_transport.session_id = "test-session"
mock_transport.create_sse_response.return_value = MagicMock()
mock_transport_class.return_value = mock_transport
test_client.get("/sse", headers=auth_headers)
# Note: This test may need adjustment based on actual SSE implementation
# The exact assertion will depend on how SSE responses are structured
mock_transport_class.assert_called_once()
@patch("mcpgateway.main.session_registry.broadcast")
def test_message_endpoint(self, mock_broadcast, test_client, auth_headers):
"""Test message broadcasting to SSE sessions."""
message = {"type": "test", "data": "hello"}
response = test_client.post("/message?session_id=test-session", json=message, headers=auth_headers)
assert response.status_code == 202
mock_broadcast.assert_called_once()
@patch("mcpgateway.main._read_request_json")
def test_message_endpoint_invalid_payload(self, mock_read, test_client, auth_headers):
"""Test message endpoint invalid JSON handling."""
mock_read.side_effect = ValueError("Invalid payload")
response = test_client.post("/message?session_id=test-session", json={"bad": True}, headers=auth_headers)
assert response.status_code == 400
@pytest.mark.asyncio
async def test_websocket_forwards_auth_token_to_rpc(self, monkeypatch):
"""Test that WebSocket forwards JWT token to /rpc endpoint.
This ensures auth credentials are propagated so /rpc doesn't reject
with 401 when AUTH_REQUIRED=true.
"""
# First-Party
from mcpgateway import main as mcpgateway_main
# Track headers passed to the RPC call
captured_headers = {}
class MockResponse:
text = '{"jsonrpc":"2.0","id":1,"result":{}}'
class MockClient:
async def __aenter__(self):
return self
async def __aexit__(self, *args):
pass
async def post(self, url, json, headers):
captured_headers.update(headers)
return MockResponse()
monkeypatch.setattr(mcpgateway_main.settings, "auth_required", True)
monkeypatch.setattr(mcpgateway_main.settings, "mcp_client_auth_enabled", True)
monkeypatch.setattr(mcpgateway_main.settings, "federation_timeout", 30)
monkeypatch.setattr(mcpgateway_main.settings, "skip_ssl_verify", False)
monkeypatch.setattr(mcpgateway_main.settings, "port", 4444)
monkeypatch.setattr(mcpgateway_main.settings, "app_root_path", "")
monkeypatch.setattr(mcpgateway_main, "ResilientHttpClient", lambda **kwargs: MockClient())
monkeypatch.setattr(mcpgateway_main, "verify_jwt_token", AsyncMock(return_value=None))
# Create mock websocket with token in query params
websocket = AsyncMock()
websocket.query_params = {"token": "test-jwt-token"}
websocket.headers = {}
# Track messages
messages_received = []
websocket.receive_text = AsyncMock(side_effect=[
'{"jsonrpc":"2.0","method":"test","id":1}',
WebSocketDisconnect(),
])
websocket.send_text = AsyncMock(side_effect=lambda msg: messages_received.append(msg))
await mcpgateway_main.websocket_endpoint(websocket)
# Verify auth token was forwarded to /rpc
assert "Authorization" in captured_headers, "Authorization header should be forwarded to /rpc"
assert captured_headers["Authorization"] == "Bearer test-jwt-token"
@pytest.mark.asyncio
async def test_websocket_forwards_proxy_user_to_rpc(self, monkeypatch):
"""Test that WebSocket forwards proxy user header to /rpc endpoint."""
# First-Party
from mcpgateway import main as mcpgateway_main
# Track headers passed to the RPC call
captured_headers = {}
class MockResponse:
text = '{"jsonrpc":"2.0","id":1,"result":{}}'
class MockClient:
async def __aenter__(self):
return self
async def __aexit__(self, *args):
pass
async def post(self, url, json, headers):
captured_headers.update(headers)
return MockResponse()
monkeypatch.setattr(mcpgateway_main.settings, "auth_required", True)
monkeypatch.setattr(mcpgateway_main.settings, "mcp_client_auth_enabled", False)
monkeypatch.setattr(mcpgateway_main.settings, "trust_proxy_auth", True)
monkeypatch.setattr(mcpgateway_main.settings, "proxy_user_header", "X-Forwarded-User")
monkeypatch.setattr(mcpgateway_main.settings, "federation_timeout", 30)
monkeypatch.setattr(mcpgateway_main.settings, "skip_ssl_verify", False)
monkeypatch.setattr(mcpgateway_main.settings, "port", 4444)
monkeypatch.setattr(mcpgateway_main.settings, "app_root_path", "")
monkeypatch.setattr(mcpgateway_main, "ResilientHttpClient", lambda **kwargs: MockClient())
# Create mock websocket with proxy user header
# Note: Use exact case matching settings.proxy_user_header since we're using a plain dict
websocket = AsyncMock()
websocket.query_params = {}
websocket.headers = {"X-Forwarded-User": "proxy-user@example.com"}
# Track messages
websocket.receive_text = AsyncMock(side_effect=[
'{"jsonrpc":"2.0","method":"test","id":1}',
WebSocketDisconnect(),
])
websocket.send_text = AsyncMock()
await mcpgateway_main.websocket_endpoint(websocket)
# Verify proxy user header was forwarded to /rpc
assert "X-Forwarded-User" in captured_headers, "Proxy user header should be forwarded to /rpc"
assert captured_headers["X-Forwarded-User"] == "proxy-user@example.com"
@pytest.mark.asyncio
async def test_websocket_disconnect_on_accept(self, monkeypatch):
"""Test WebSocket disconnect handling."""
# First-Party
from mcpgateway import main as mcpgateway_main
monkeypatch.setattr(mcpgateway_main.settings, "auth_required", False)
monkeypatch.setattr(mcpgateway_main.settings, "mcp_client_auth_enabled", False)
websocket = AsyncMock()
websocket.query_params = {}
websocket.headers = {}
websocket.accept.side_effect = WebSocketDisconnect()
await mcpgateway_main.websocket_endpoint(websocket)
@pytest.mark.asyncio
async def test_server_sse_cancelled_cleanup(self):
"""Test server SSE cancellation cleanup path."""
# First-Party
from mcpgateway import main as mcpgateway_main
class DummyTransport:
def __init__(self, *_args, **_kwargs):
self.session_id = "sess-cancel"
async def connect(self):
return None
async def create_sse_response(self, *_args, **_kwargs):
raise asyncio.CancelledError()
request = _make_request("/servers/1/sse")
with (
patch("mcpgateway.main.SSETransport", DummyTransport),
patch("mcpgateway.main.session_registry.add_session", new_callable=AsyncMock),
patch("mcpgateway.main.session_registry.respond", new_callable=AsyncMock),
patch("mcpgateway.main.session_registry.register_respond_task"),
patch("mcpgateway.main.session_registry.remove_session", new_callable=AsyncMock, side_effect=Exception("cleanup")),
patch("mcpgateway.middleware.rbac.PermissionService.check_permission", new_callable=AsyncMock, return_value=True),
):
with pytest.raises(asyncio.CancelledError):
await mcpgateway_main.sse_endpoint(request, "1", user={"email": "user@example.com"})
@pytest.mark.asyncio
async def test_server_sse_failure_cleanup(self):
"""Test server SSE failure cleanup path."""
# First-Party
from mcpgateway import main as mcpgateway_main
class DummyTransport:
def __init__(self, *_args, **_kwargs):
self.session_id = "sess-fail"
async def connect(self):
return None
async def create_sse_response(self, *_args, **_kwargs):
raise RuntimeError("boom")
request = _make_request("/servers/1/sse")
with (
patch("mcpgateway.main.SSETransport", DummyTransport),
patch("mcpgateway.main.session_registry.add_session", new_callable=AsyncMock),
patch("mcpgateway.main.session_registry.respond", new_callable=AsyncMock),
patch("mcpgateway.main.session_registry.register_respond_task"),
patch("mcpgateway.main.session_registry.remove_session", new_callable=AsyncMock, side_effect=Exception("cleanup")),
patch("mcpgateway.middleware.rbac.PermissionService.check_permission", new_callable=AsyncMock, return_value=True),
):
with pytest.raises(HTTPException) as excinfo:
await mcpgateway_main.sse_endpoint(request, "1", user={"email": "user@example.com"})
assert excinfo.value.status_code == 500
@pytest.mark.asyncio
async def test_utility_sse_cancelled_cleanup(self):
"""Test utility SSE cancellation cleanup path."""
# First-Party
from mcpgateway import main as mcpgateway_main
class DummyTransport:
def __init__(self, *_args, **_kwargs):
self.session_id = "util-cancel"
async def connect(self):
return None
async def create_sse_response(self, *_args, **_kwargs):
raise asyncio.CancelledError()
request = _make_request("/sse")
with (
patch("mcpgateway.main.SSETransport", DummyTransport),
patch("mcpgateway.main.session_registry.add_session", new_callable=AsyncMock),
patch("mcpgateway.main.session_registry.respond", new_callable=AsyncMock),
patch("mcpgateway.main.session_registry.register_respond_task"),
patch("mcpgateway.main.session_registry.remove_session", new_callable=AsyncMock, side_effect=Exception("cleanup")),
patch("mcpgateway.middleware.rbac.PermissionService.check_permission", new_callable=AsyncMock, return_value=True),
):
with pytest.raises(asyncio.CancelledError):
await mcpgateway_main.utility_sse_endpoint(request, user={"email": "user@example.com"})
@pytest.mark.asyncio
async def test_utility_sse_failure_cleanup(self):
"""Test utility SSE failure cleanup path."""
# First-Party
from mcpgateway import main as mcpgateway_main
class DummyTransport:
def __init__(self, *_args, **_kwargs):
self.session_id = "util-fail"
async def connect(self):
return None
async def create_sse_response(self, *_args, **_kwargs):
raise RuntimeError("boom")
request = _make_request("/sse")
with (
patch("mcpgateway.main.SSETransport", DummyTransport),
patch("mcpgateway.main.session_registry.add_session", new_callable=AsyncMock),
patch("mcpgateway.main.session_registry.respond", new_callable=AsyncMock),
patch("mcpgateway.main.session_registry.register_respond_task"),
patch("mcpgateway.main.session_registry.remove_session", new_callable=AsyncMock, side_effect=Exception("cleanup")),
patch("mcpgateway.middleware.rbac.PermissionService.check_permission", new_callable=AsyncMock, return_value=True),
):
with pytest.raises(HTTPException) as excinfo:
await mcpgateway_main.utility_sse_endpoint(request, user={"email": "user@example.com"})
assert excinfo.value.status_code == 500
# ----------------------------------------------------- #
# Metrics & Monitoring Tests #
# ----------------------------------------------------- #
class TestMetricsEndpoints:
"""Tests for metrics collection, aggregation, and reset functionality."""
@patch("mcpgateway.main.prompt_service.aggregate_metrics")
@patch("mcpgateway.main.server_service.aggregate_metrics")
@patch("mcpgateway.main.resource_service.aggregate_metrics")
@patch("mcpgateway.main.tool_service.aggregate_metrics")
def test_get_metrics(self, mock_tool, mock_resource, mock_server, mock_prompt, test_client, auth_headers):
"""Test retrieving aggregated metrics for all entity types."""
mock_tool.return_value = {"total": 5}
mock_resource.return_value = {"total": 3}
mock_server.return_value = {"total": 2}
mock_prompt.return_value = {"total": 1}
response = test_client.get("/metrics", headers=auth_headers)
assert response.status_code == 200
data = response.json()
assert "tools" in data and "resources" in data
assert "servers" in data and "prompts" in data
# A2A agents may or may not be present based on configuration
# @patch("mcpgateway.main.a2a_service")
# @patch("mcpgateway.main.prompt_service.reset_metrics")
# @patch("mcpgateway.main.server_service.reset_metrics")
# @patch("mcpgateway.main.resource_service.reset_metrics")
# @patch("mcpgateway.main.tool_service.reset_metrics")
# def test_reset_all_metrics(self, mock_tool_reset, mock_resource_reset, mock_server_reset, mock_prompt_reset, mock_a2a_service, test_client, auth_headers):
# """Test resetting metrics for all entity types."""
# # Mock A2A service with reset_metrics method
# mock_a2a_service.reset_metrics = MagicMock()
#
# response = test_client.post("/metrics/reset", headers=auth_headers)
# assert response.status_code == 200
#
# # Verify all services had their metrics reset
# mock_tool_reset.assert_called_once()
# mock_resource_reset.assert_called_once()
# mock_server_reset.assert_called_once()
# mock_prompt_reset.assert_called_once()
# mock_a2a_service.reset_metrics.assert_called_once()
@patch("mcpgateway.main.tool_service.reset_metrics")
def test_reset_specific_entity_metrics(self, mock_tool_reset, test_client, auth_headers):
"""Test resetting metrics for a specific entity type."""
response = test_client.post("/metrics/reset?entity=tool&entity_id=1", headers=auth_headers)
assert response.status_code == 200
mock_tool_reset.assert_called_once_with(ANY, 1)
def test_reset_invalid_entity_metrics(self, test_client, auth_headers):
"""Test error handling for invalid entity type in metrics reset."""
response = test_client.post("/metrics/reset?entity=invalid", headers=auth_headers)
assert response.status_code == 400
# ----------------------------------------------------- #
# A2A Agent API Tests #
# ----------------------------------------------------- #
class TestA2AAgentEndpoints:
"""Test A2A agent API endpoints."""
@patch("mcpgateway.main.a2a_service")
def test_list_a2a_agents(self, mock_service, test_client, auth_headers):
"""Test listing A2A agents."""
mock_service.list_agents = AsyncMock(return_value=([_make_a2a_agent_read()], None))
response = test_client.get("/a2a", headers=auth_headers)
assert response.status_code == 200
assert response.json()[0]["name"] == "agent-1"
mock_service.list_agents.assert_called_once()
@patch("mcpgateway.main.a2a_service")
def test_get_a2a_agent(self, mock_service, test_client, auth_headers):
"""Test getting specific A2A agent."""
mock_service.get_agent = AsyncMock(return_value=_make_a2a_agent_read())
response = test_client.get("/a2a/agent-1", headers=auth_headers)
assert response.status_code == 200
assert response.json()["name"] == "agent-1"
mock_service.get_agent.assert_called_once()
@patch("mcpgateway.main.a2a_service")
def test_create_a2a_agent(self, mock_service, test_client, auth_headers):
"""Test creating A2A agent."""
mock_service.register_agent = AsyncMock(return_value=_make_a2a_agent_read())
response = test_client.post("/a2a", json=_a2a_create_payload(), headers=auth_headers)
assert response.status_code == 201
mock_service.register_agent.assert_called_once()
@pytest.mark.parametrize(
"exc,status_code",
[
("name_conflict", 409),
("agent_error", 400),
],
)
@patch("mcpgateway.main.a2a_service")
def test_create_a2a_agent_error_branches(self, mock_service, exc, status_code, test_client, auth_headers):
"""Test create A2A agent error handling."""
# First-Party
from mcpgateway.services.a2a_service import A2AAgentError, A2AAgentNameConflictError
error = A2AAgentNameConflictError("conflict") if exc == "name_conflict" else A2AAgentError("bad")
mock_service.register_agent = AsyncMock(side_effect=error)
response = test_client.post("/a2a", json=_a2a_create_payload(), headers=auth_headers)
assert response.status_code == status_code
@patch("mcpgateway.main.a2a_service")
def test_update_a2a_agent(self, mock_service, test_client, auth_headers):
"""Test updating A2A agent."""
mock_service.update_agent = AsyncMock(return_value=_make_a2a_agent_read(name="agent-1", description="updated"))
response = test_client.put("/a2a/agent-1", json={"description": "Updated description"}, headers=auth_headers)
assert response.status_code == 200
mock_service.update_agent.assert_called_once()
@patch("mcpgateway.main.a2a_service")
def test_set_a2a_agent_state(self, mock_service, test_client, auth_headers):
"""Test toggling A2A agent status."""
mock_service.set_agent_state = AsyncMock(return_value=_make_a2a_agent_read(enabled=False))
response = test_client.post("/a2a/agent-1/state?activate=false", headers=auth_headers)
assert response.status_code == 200
mock_service.set_agent_state.assert_called_once()
@patch("mcpgateway.main.a2a_service")
def test_delete_a2a_agent(self, mock_service, test_client, auth_headers):
"""Test deleting A2A agent."""
mock_service.delete_agent = AsyncMock(return_value=None)
response = test_client.delete("/a2a/agent-1", headers=auth_headers)
assert response.status_code == 200
mock_service.delete_agent.assert_called_once()
@patch("mcpgateway.main.a2a_service")
def test_invoke_a2a_agent(self, mock_service, test_client, auth_headers):
"""Test invoking A2A agent."""
mock_service.invoke_agent = AsyncMock(return_value={"response": "Agent response", "status": "success"})
response = test_client.post(
"/a2a/agent-1/invoke",
json={"parameters": {"query": "test"}, "interaction_type": "query"},
headers=auth_headers,
)
assert response.status_code == 200
mock_service.invoke_agent.assert_called_once()
# ----------------------------------------------------- #
# Middleware & Security Tests #
# ----------------------------------------------------- #
class TestMiddlewareAndSecurity:
"""Tests for middleware functionality, authentication, CORS, path rewriting, etc."""
def test_docs_auth_middleware_protected_path(self, test_client):
"""Test that documentation paths require authentication."""
response = test_client.get("/docs", follow_redirects=False)
assert response.status_code == 401
def test_docs_auth_middleware_unprotected_path(self, test_client):
"""Test that non-documentation paths bypass docs auth middleware."""
response = test_client.get("/health")
assert response.status_code == 200
def test_openapi_protected(self, test_client):
"""Test that OpenAPI spec endpoint requires authentication."""
response = test_client.get("/openapi.json")
assert response.status_code == 401
def test_redoc_protected(self, test_client):
"""Test that ReDoc endpoint requires authentication."""
response = test_client.get("/redoc")
assert response.status_code == 401
def test_cors_headers(self, test_client, auth_headers):
"""Test that CORS headers are properly set."""
response = test_client.options("/tools/", headers=auth_headers)
# CORS is handled by FastAPI middleware, exact behavior depends on configuration
assert response.status_code in [200, 405] # Either handled or method not allowed
# ----------------------------------------------------- #
# Error Handling & Edge Cases #
# ----------------------------------------------------- #
class TestErrorHandling:
def test_docs_with_invalid_jwt(self, test_client):
"""Test /docs with an invalid JWT returns 401."""
headers = {"Authorization": "Bearer invalid.token.value"}
response = test_client.get("/docs", headers=headers)
assert response.status_code == 401
def test_docs_with_expired_jwt(self, test_client):
"""Test /docs with an expired JWT returns 401."""
expired_payload = {"sub": "test_user", "exp": datetime.datetime.now(tz=datetime.timezone.utc) - datetime.timedelta(hours=1)}
# First-Party
from mcpgateway.config import settings
key = settings.jwt_secret_key
print(f"[DEBUG] settings.jwt_secret_key type: {type(key)}, value: {key}")
if hasattr(key, "get_secret_value") and callable(getattr(key, "get_secret_value", None)):
key = key.get_secret_value()
print(f"[DEBUG] settings.jwt_secret_key after possible unwrap: {type(key)}, value: {key}")
expired_token = jwt.encode(expired_payload, key, algorithm=settings.jwt_algorithm)
headers = {"Authorization": f"Bearer {expired_token}"}
response = test_client.get("/docs", headers=headers)
assert response.status_code == 401
def test_post_on_get_only_endpoint(self, test_client, auth_headers):
"""Test POST on a GET-only endpoint returns 405."""
response = test_client.post("/health", headers=auth_headers)
assert response.status_code == 405
def test_delete_on_docs(self, test_client, auth_headers):
"""Test DELETE on /docs returns 405."""
response = test_client.delete("/docs", headers=auth_headers)
assert response.status_code == 405
def test_missing_query_param(self, test_client, auth_headers):
"""Test endpoint requiring query param returns 422 if missing."""
# /message?session_id=... requires session_id
message = {"type": "test", "data": "hello"}
response = test_client.post("/message", json=message, headers=auth_headers)
assert response.status_code == 400
def test_invalid_json_body(self, test_client, auth_headers):
"""Test handling of malformed JSON in request bodies."""
headers = auth_headers
headers["content-type"] = "application/json"
response = test_client.post("/protocol/initialize", content="invalid json", headers=headers)
assert response.status_code == 400 # body cannot be parsed, so 400
@patch("mcpgateway.main.server_service.get_server")
def test_server_not_found(self, mock_get, test_client, auth_headers):
"""Test proper error response when server is not found."""
# First-Party
from mcpgateway.services.server_service import ServerNotFoundError
mock_get.side_effect = ServerNotFoundError("Server not found")
response = test_client.get("/servers/999", headers=auth_headers)
assert response.status_code == 404
@patch("mcpgateway.main.resource_service.read_resource")
def test_resource_not_found(self, mock_read, test_client, auth_headers):
"""Test proper error response when resource is not found."""
# First-Party
from mcpgateway.services.resource_service import ResourceNotFoundError
mock_read.side_effect = ResourceNotFoundError("Resource not found")
response = test_client.get("/resources/nonexistent", headers=auth_headers)
assert response.status_code == 404
@patch("mcpgateway.main.tool_service.register_tool")
def test_tool_name_conflict(self, mock_register, test_client, auth_headers):
"""Test handling of tool name conflicts during registration."""
# First-Party
from mcpgateway.services.tool_service import ToolNameConflictError
mock_register.side_effect = ToolNameConflictError("Tool name already exists")
req = {"tool": {"name": "existing_tool", "url": "http://example.com"}, "team_id": None, "visibility": "private"}
response = test_client.post("/tools/", json=req, headers=auth_headers)
assert response.status_code == 409
def test_missing_required_fields(self, test_client, auth_headers):
"""Test validation errors for missing required fields."""
req = {"description": "Missing required name field"}
response = test_client.post("/tools/", json=req, headers=auth_headers)
assert response.status_code == 422 # Validation error
def test_openapi_json_with_auth(self, test_client, auth_headers):
"""Test GET /openapi.json with authentication returns 200 and OpenAPI spec."""
response = test_client.get("/openapi.json", headers=auth_headers)
assert response.status_code == 200
assert "openapi" in response.json()
def test_docs_with_auth(self, test_client, auth_headers):
"""Test GET /docs with authentication returns 200 or redirect."""
response = test_client.get("/docs", headers=auth_headers)
assert response.status_code == 200
def test_redoc_with_auth(self, test_client, auth_headers):
"""Test GET /redoc with authentication returns 200 or redirect."""
response = test_client.get("/redoc", headers=auth_headers)
assert response.status_code == 200
# --------------------------------------------------------------------------- #
# jsonpath_modifier #
# --------------------------------------------------------------------------- #
@pytest.fixture(scope="module")
def sample_people():
return [
{"name": "Ada", "id": 1},
{"name": "Bob", "id": 2},
]
def test_jsonpath_modifier_basic_match(sample_people):
# First-Party
from mcpgateway.main import jsonpath_modifier
# Pull out names directly
names = jsonpath_modifier(sample_people, "$[*].name")
assert names == ["Ada", "Bob"]
# Same query but with a mapping
mapped = jsonpath_modifier(sample_people, "$[*]", mappings={"n": "$.name"})
assert mapped == [{"n": "Ada"}, {"n": "Bob"}]
def test_jsonpath_modifier_single_dict_collapse():
# First-Party
from mcpgateway.main import jsonpath_modifier
person = {"name": "Zoe", "id": 10}
out = jsonpath_modifier(person, "$")
assert out == person # single-item dict collapses to dict, not list
def test_jsonpath_modifier_invalid_expressions(sample_people):
# First-Party
from mcpgateway.main import jsonpath_modifier
with pytest.raises(HTTPException):
jsonpath_modifier(sample_people, "$[") # invalid main expr
with pytest.raises(HTTPException):
jsonpath_modifier(sample_people, "$[*]", mappings={"bad": "$["}) # invalid mapping expr
# ----------------------------------------------------- #
# Transform data with mappings
# ----------------------------------------------------- #
class TestTransformDataWithMappings:
def test_transform_data_with_mappings_valid_mapping(self, sample_people):
from mcpgateway.main import transform_data_with_mappings
mapping = {"n": "$.name"}
result = transform_data_with_mappings(sample_people, mapping)
assert result == [{"n": "Ada"}, {"n": "Bob"}]
def test_transform_data_with_mappings_invalid_mapping(self, sample_people):
from mcpgateway.main import transform_data_with_mappings
with pytest.raises(HTTPException):
transform_data_with_mappings(sample_people, {"bad": "$["})
# ----------------------------------------------------- #
# Plugin Exception Handler Tests #
# ----------------------------------------------------- #
class TestPluginExceptionHandlers:
"""Tests for plugin exception handlers: PluginViolationError and PluginError."""
def test_plugin_violation_exception_handler_with_full_violation(self):
"""Test plugin_violation_exception_handler with complete violation details."""
# Standard
import asyncio
# First-Party
from mcpgateway.main import plugin_violation_exception_handler
from mcpgateway.plugins.framework.errors import PluginViolationError
from mcpgateway.plugins.framework.models import PluginViolation
violation = PluginViolation(
reason="Invalid input",
description="The input contains prohibited content",
code="PROHIBITED_CONTENT",
details={"field": "message", "value": "sensitive_data"},
)
violation._plugin_name = "content_filter"
exc = PluginViolationError(message="Policy violation detected", violation=violation)
result = asyncio.run(plugin_violation_exception_handler(None, exc))
assert result.status_code == 200
content = json.loads(result.body.decode())
assert "error" in content
assert content["error"]["code"] == -32602
assert "Plugin Violation:" in content["error"]["message"]
assert "The input contains prohibited content" in content["error"]["message"]
assert content["error"]["data"]["description"] == "The input contains prohibited content"
assert content["error"]["data"]["details"] == {"field": "message", "value": "sensitive_data"}
assert content["error"]["data"]["plugin_error_code"] == "PROHIBITED_CONTENT"
assert content["error"]["data"]["plugin_name"] == "content_filter"
def test_plugin_violation_exception_handler_with_custom_mcp_error_code(self):
"""Test plugin_violation_exception_handler with custom MCP error code."""
# Standard
import asyncio
# First-Party
from mcpgateway.main import plugin_violation_exception_handler
from mcpgateway.plugins.framework.errors import PluginViolationError
from mcpgateway.plugins.framework.models import PluginViolation
violation = PluginViolation(
reason="Rate limit exceeded",
description="Too many requests from this client",
code="RATE_LIMIT",
details={"requests": 100, "limit": 50},
mcp_error_code=-32000, # Custom error code
)
violation._plugin_name = "rate_limiter"
exc = PluginViolationError(message="Rate limit violation", violation=violation)
result = asyncio.run(plugin_violation_exception_handler(None, exc))
assert result.status_code == 200
content = json.loads(result.body.decode())
assert content["error"]["code"] == -32000
assert "Too many requests from this client" in content["error"]["message"]
assert content["error"]["data"]["plugin_error_code"] == "RATE_LIMIT"
assert content["error"]["data"]["plugin_name"] == "rate_limiter"
def test_plugin_violation_exception_handler_with_minimal_violation(self):
"""Test plugin_violation_exception_handler with minimal violation details."""
# Standard
import asyncio
# First-Party
from mcpgateway.main import plugin_violation_exception_handler
from mcpgateway.plugins.framework.errors import PluginViolationError
from mcpgateway.plugins.framework.models import PluginViolation
violation = PluginViolation(
reason="Violation occurred",
description="Minimal violation",
code="MIN_VIOLATION",
details={},
)
exc = PluginViolationError(message="Minimal violation", violation=violation)
result = asyncio.run(plugin_violation_exception_handler(None, exc))
assert result.status_code == 200
content = json.loads(result.body.decode())
assert content["error"]["code"] == -32602
assert "Minimal violation" in content["error"]["message"]
assert content["error"]["data"]["plugin_error_code"] == "MIN_VIOLATION"
def test_plugin_violation_exception_handler_without_violation_object(self):
"""Test plugin_violation_exception_handler when violation object is None."""
# Standard
import asyncio
# First-Party
from mcpgateway.main import plugin_violation_exception_handler
from mcpgateway.plugins.framework.errors import PluginViolationError
exc = PluginViolationError(message="Generic plugin violation", violation=None)
result = asyncio.run(plugin_violation_exception_handler(None, exc))
assert result.status_code == 200
content = json.loads(result.body.decode())
assert content["error"]["code"] == -32602
assert "A plugin violation occurred" in content["error"]["message"]
assert content["error"]["data"] == {}
def test_plugin_exception_handler_with_full_error(self):
"""Test plugin_exception_handler with complete error details."""
# Standard
import asyncio
# First-Party
from mcpgateway.main import plugin_exception_handler
from mcpgateway.plugins.framework.errors import PluginError
from mcpgateway.plugins.framework.models import PluginErrorModel
error = PluginErrorModel(
message="Plugin execution failed",
code="EXECUTION_ERROR",
plugin_name="data_processor",
details={"error_type": "timeout", "duration": 30},
)
exc = PluginError(error=error)
result = asyncio.run(plugin_exception_handler(None, exc))
assert result.status_code == 200
content = json.loads(result.body.decode())
assert "error" in content
assert content["error"]["code"] == -32603
assert "Plugin Error:" in content["error"]["message"]
assert "Plugin execution failed" in content["error"]["message"]
assert content["error"]["data"]["details"] == {"error_type": "timeout", "duration": 30}
assert content["error"]["data"]["plugin_error_code"] == "EXECUTION_ERROR"
assert content["error"]["data"]["plugin_name"] == "data_processor"
def test_plugin_exception_handler_with_custom_mcp_error_code(self):
"""Test plugin_exception_handler with custom MCP error code."""
# Standard
import asyncio
# First-Party
from mcpgateway.main import plugin_exception_handler
from mcpgateway.plugins.framework.errors import PluginError
from mcpgateway.plugins.framework.models import PluginErrorModel
error = PluginErrorModel(
message="Custom error occurred",
code="CUSTOM_ERROR",
plugin_name="custom_plugin",
details={"context": "test"},
mcp_error_code=-32001, # Custom MCP error code
)
exc = PluginError(error=error)
result = asyncio.run(plugin_exception_handler(None, exc))
assert result.status_code == 200
content = json.loads(result.body.decode())
assert content["error"]["code"] == -32001
assert "Custom error occurred" in content["error"]["message"]
assert content["error"]["data"]["plugin_error_code"] == "CUSTOM_ERROR"
def test_plugin_exception_handler_with_minimal_error(self):
"""Test plugin_exception_handler with minimal error details."""
# Standard
import asyncio
# First-Party
from mcpgateway.main import plugin_exception_handler
from mcpgateway.plugins.framework.errors import PluginError
from mcpgateway.plugins.framework.models import PluginErrorModel
error = PluginErrorModel(message="Minimal error", plugin_name="minimal_plugin")
exc = PluginError(error=error)
result = asyncio.run(plugin_exception_handler(None, exc))
assert result.status_code == 200
content = json.loads(result.body.decode())
assert content["error"]["code"] == -32603
assert "Minimal error" in content["error"]["message"]
assert content["error"]["data"]["plugin_name"] == "minimal_plugin"
def test_plugin_exception_handler_with_empty_code(self):
"""Test plugin_exception_handler when error has empty code field."""
# Standard
import asyncio
# First-Party
from mcpgateway.main import plugin_exception_handler
from mcpgateway.plugins.framework.errors import PluginError
from mcpgateway.plugins.framework.models import PluginErrorModel
error = PluginErrorModel(
message="Error without code",
code="",
plugin_name="test_plugin",
details={"info": "test"},
)
exc = PluginError(error=error)
result = asyncio.run(plugin_exception_handler(None, exc))
assert result.status_code == 200
content = json.loads(result.body.decode())
assert content["error"]["code"] == -32603
assert "Error without code" in content["error"]["message"]
# Empty code should not be included in data
assert "plugin_error_code" not in content["error"]["data"] or content["error"]["data"]["plugin_error_code"] == ""
# --------------------------------------------------------------------------- #
# Cache Behavior Tests #
# --------------------------------------------------------------------------- #
class TestJsonPathCaching:
"""Tests for JSONPath caching (#1812)."""
def test_jsonpath_caching_works(self):
"""Verify JSONPath parsing is cached."""
from mcpgateway.main import jsonpath_modifier, _parse_jsonpath
_parse_jsonpath.cache_clear()
result1 = jsonpath_modifier([{"a": 1}, {"a": 2}], "$[*].a")
assert result1 == [1, 2]
result2 = jsonpath_modifier([{"a": 3}], "$[*].a")
assert result2 == [3]
info = _parse_jsonpath.cache_info()
assert info.hits == 1
def test_mappings_parsed_once_per_request(self):
"""Verify mappings are parsed once per request, not per item."""
from mcpgateway.main import transform_data_with_mappings, _parse_jsonpath
_parse_jsonpath.cache_clear()
data = [{"x": 1}, {"x": 2}, {"x": 3}]
mappings = {"y": "$.x"}
result = transform_data_with_mappings(data, mappings)
assert result == [{"y": 1}, {"y": 2}, {"y": 3}]
info = _parse_jsonpath.cache_info()
assert info.misses == 1 # Only one parse for "$.x"
def test_different_jsonpath_cached_separately(self):
"""Verify different JSONPath expressions get separate cache entries."""
from mcpgateway.main import jsonpath_modifier, _parse_jsonpath
_parse_jsonpath.cache_clear()
result1 = jsonpath_modifier({"a": 1, "b": 2}, "$.a")
result2 = jsonpath_modifier({"a": 1, "b": 2}, "$.b")
assert result1 == [1]
assert result2 == [2]
info = _parse_jsonpath.cache_info()
assert info.misses == 2
# ----------------------------------------------------- #
# Token Teams Helper Function Tests (Issue #1915) #
# ----------------------------------------------------- #
class TestNormalizeTokenTeams:
"""Tests for _normalize_token_teams helper function."""
def test_normalize_token_teams_none(self):
"""Test that None input returns empty list."""
from mcpgateway.main import _normalize_token_teams
assert _normalize_token_teams(None) == []
def test_normalize_token_teams_empty_list(self):
"""Test that empty list input returns empty list."""
from mcpgateway.main import _normalize_token_teams
assert _normalize_token_teams([]) == []
def test_normalize_token_teams_string_ids(self):
"""Test that string team IDs are passed through unchanged."""
from mcpgateway.main import _normalize_token_teams
result = _normalize_token_teams(["team_a", "team_b", "team_c"])
assert result == ["team_a", "team_b", "team_c"]
def test_normalize_token_teams_dict_format(self):
"""Test that dict format with id key extracts the ID."""
from mcpgateway.main import _normalize_token_teams
result = _normalize_token_teams([{"id": "team_a", "name": "Team A"}, {"id": "team_b", "name": "Team B"}])
assert result == ["team_a", "team_b"]
def test_normalize_token_teams_mixed_format(self):
"""Test that mixed string and dict formats are handled correctly."""
from mcpgateway.main import _normalize_token_teams
result = _normalize_token_teams([{"id": "t1", "name": "Team 1"}, "t2", {"id": "t3"}])
assert result == ["t1", "t2", "t3"]
def test_normalize_token_teams_dict_without_id(self):
"""Test that dicts without id key are skipped."""
from mcpgateway.main import _normalize_token_teams
result = _normalize_token_teams([{"name": "No ID Team"}, {"id": "valid_team"}])
assert result == ["valid_team"]
def test_normalize_token_teams_dict_with_empty_id(self):
"""Test that dicts with empty id value are skipped."""
from mcpgateway.main import _normalize_token_teams
result = _normalize_token_teams([{"id": "", "name": "Empty ID"}, {"id": "valid"}])
assert result == ["valid"]
def test_normalize_token_teams_preserves_order(self):
"""Test that team order is preserved."""
from mcpgateway.main import _normalize_token_teams
result = _normalize_token_teams(["z_team", "a_team", "m_team"])
assert result == ["z_team", "a_team", "m_team"]
class TestGetTokenTeamsFromRequest:
"""Tests for _get_token_teams_from_request helper function."""
def test_get_token_teams_with_valid_cached_payload(self):
"""Test extraction of teams from cached JWT payload."""
from mcpgateway.main import _get_token_teams_from_request
mock_request = MagicMock()
mock_request.state._jwt_verified_payload = ("token_string", {"sub": "user@example.com", "teams": ["team_a", "team_b"]})
result = _get_token_teams_from_request(mock_request)
assert result == ["team_a", "team_b"]
def test_get_token_teams_with_dict_teams_payload(self):
"""Test extraction and normalization of dict format teams."""
from mcpgateway.main import _get_token_teams_from_request
mock_request = MagicMock()
mock_request.state._jwt_verified_payload = ("token", {"teams": [{"id": "t1", "name": "Team 1"}]})
result = _get_token_teams_from_request(mock_request)
assert result == ["t1"]
def test_get_token_teams_no_cached_payload_returns_empty_list(self):
"""Test that missing cached payload returns [] (public-only, secure default)."""
from mcpgateway.main import _get_token_teams_from_request
mock_request = MagicMock()
mock_request.state._jwt_verified_payload = None
result = _get_token_teams_from_request(mock_request)
assert result == [] # SECURITY: No JWT = public-only (secure default)
def test_get_token_teams_no_teams_in_payload_returns_empty_list(self):
"""Test that payload without teams key returns [] (public-only, secure default)."""
from mcpgateway.main import _get_token_teams_from_request
mock_request = MagicMock()
mock_request.state._jwt_verified_payload = ("token", {"sub": "user@example.com"})
result = _get_token_teams_from_request(mock_request)
assert result == [] # SECURITY: Missing teams = public-only (secure default)
def test_get_token_teams_empty_teams_returns_empty_list(self):
"""Test that payload with empty teams returns empty list (not None)."""
from mcpgateway.main import _get_token_teams_from_request
mock_request = MagicMock()
mock_request.state._jwt_verified_payload = ("token", {"sub": "user@example.com", "teams": []})
result = _get_token_teams_from_request(mock_request)
assert result == [] # Empty list = JWT exists but no teams
def test_get_token_teams_null_teams_non_admin_returns_empty_list(self):
"""Test that payload with teams: null (non-admin) returns [] (public-only)."""
from mcpgateway.main import _get_token_teams_from_request
mock_request = MagicMock()
mock_request.state._jwt_verified_payload = ("token", {"sub": "user@example.com", "teams": None})
result = _get_token_teams_from_request(mock_request)
assert result == [] # SECURITY: Null teams + non-admin = public-only
def test_get_token_teams_null_teams_admin_returns_none(self):
"""Test that payload with teams: null + is_admin=true returns None (admin bypass)."""
from mcpgateway.main import _get_token_teams_from_request
mock_request = MagicMock()
mock_request.state._jwt_verified_payload = ("token", {"sub": "admin@example.com", "teams": None, "is_admin": True})
result = _get_token_teams_from_request(mock_request)
assert result is None # Admin with explicit null teams = admin bypass
def test_get_token_teams_invalid_tuple_format_returns_empty_list(self):
"""Test that non-tuple cached payload returns [] (public-only, secure default)."""
from mcpgateway.main import _get_token_teams_from_request
mock_request = MagicMock()
mock_request.state._jwt_verified_payload = "not_a_tuple"
result = _get_token_teams_from_request(mock_request)
assert result == [] # SECURITY: Invalid format = public-only (secure default)
def test_get_token_teams_short_tuple_returns_empty_list(self):
"""Test that tuple with wrong length returns [] (public-only, secure default)."""
from mcpgateway.main import _get_token_teams_from_request
mock_request = MagicMock()
mock_request.state._jwt_verified_payload = ("only_one_element",)
result = _get_token_teams_from_request(mock_request)
assert result == [] # SECURITY: Invalid format = public-only (secure default)
def test_get_token_teams_none_payload_in_tuple_returns_empty_list(self):
"""Test that None payload in tuple returns [] (public-only, secure default)."""
from mcpgateway.main import _get_token_teams_from_request
mock_request = MagicMock()
mock_request.state._jwt_verified_payload = ("token", None)
result = _get_token_teams_from_request(mock_request)
assert result == [] # SECURITY: No payload = public-only (secure default)
class TestGetRpcFilterContext:
"""Tests for _get_rpc_filter_context helper function."""
def test_get_rpc_filter_context_dict_user(self):
"""Test with dict user containing email and is_admin."""
from mcpgateway.main import _get_rpc_filter_context
mock_request = MagicMock()
# is_admin must be in the token payload, not the user dict (security fix)
mock_request.state._jwt_verified_payload = ("token", {"teams": ["t1", "t2"], "is_admin": True})
user = {"email": "test@example.com", "is_admin": True} # User's is_admin is ignored
email, teams, is_admin = _get_rpc_filter_context(mock_request, user)
assert email == "test@example.com"
assert teams == ["t1", "t2"]
assert is_admin is True # From token payload, not user dict
def test_get_rpc_filter_context_dict_user_sub_field(self):
"""Test that sub field is used if email is not present."""
from mcpgateway.main import _get_rpc_filter_context
mock_request = MagicMock()
mock_request.state._jwt_verified_payload = ("token", {"teams": []})
user = {"sub": "user@sub.com"}
email, teams, is_admin = _get_rpc_filter_context(mock_request, user)
assert email == "user@sub.com"
assert teams == []
assert is_admin is False
def test_get_rpc_filter_context_object_user(self):
"""Test with user object having email and is_admin attributes."""
from mcpgateway.main import _get_rpc_filter_context
mock_request = MagicMock()
mock_request.state._jwt_verified_payload = ("token", {"teams": ["team_x"]})
class UserObject:
email = "obj@example.com"
is_admin = False
email, teams, is_admin = _get_rpc_filter_context(mock_request, UserObject())
assert email == "obj@example.com"
assert teams == ["team_x"]
assert is_admin is False
def test_get_rpc_filter_context_nested_is_admin(self):
"""Test that nested user.is_admin is extracted from token payload."""
from mcpgateway.main import _get_rpc_filter_context
mock_request = MagicMock()
# is_admin must be in token payload - use non-empty teams to allow admin bypass
mock_request.state._jwt_verified_payload = ("token", {"teams": ["team_x"], "user": {"is_admin": True}})
user = {"email": "nested@example.com", "user": {"is_admin": True}}
email, teams, is_admin = _get_rpc_filter_context(mock_request, user)
assert email == "nested@example.com"
assert is_admin is True # From token payload's nested user.is_admin
def test_get_rpc_filter_context_empty_teams_disables_admin(self):
"""Test that empty teams array disables admin bypass even when is_admin is true."""
from mcpgateway.main import _get_rpc_filter_context
mock_request = MagicMock()
# Token has is_admin but empty teams - admin bypass should be disabled
mock_request.state._jwt_verified_payload = ("token", {"teams": [], "is_admin": True})
user = {"email": "admin@example.com", "is_admin": True}
email, teams, is_admin = _get_rpc_filter_context(mock_request, user)
assert email == "admin@example.com"
assert teams == []
assert is_admin is False # Disabled for empty-team tokens (public-only access)
def test_get_rpc_filter_context_string_user(self):
"""Test with string user (fallback to str conversion)."""
from mcpgateway.main import _get_rpc_filter_context
mock_request = MagicMock()
mock_request.state._jwt_verified_payload = ("token", {"teams": ["t1"]})
user = "plain_username"
email, teams, is_admin = _get_rpc_filter_context(mock_request, user)
assert email == "plain_username"
assert teams == ["t1"]
assert is_admin is False
def test_get_rpc_filter_context_none_user(self):
"""Test with None user."""
from mcpgateway.main import _get_rpc_filter_context
mock_request = MagicMock()
mock_request.state._jwt_verified_payload = ("token", {"teams": []})
email, teams, is_admin = _get_rpc_filter_context(mock_request, None)
assert email is None
assert teams == []
assert is_admin is False
def test_get_rpc_filter_context_admin_not_in_dict(self):
"""Test that is_admin defaults to False if not present."""
from mcpgateway.main import _get_rpc_filter_context
mock_request = MagicMock()
mock_request.state._jwt_verified_payload = ("token", {"teams": ["t1"]})
user = {"email": "user@example.com"}
email, teams, is_admin = _get_rpc_filter_context(mock_request, user)
assert email == "user@example.com"
assert is_admin is False
def test_get_rpc_filter_context_no_jwt_returns_empty_teams(self):
"""Test that missing JWT payload returns [] for teams (public-only, secure default)."""
from mcpgateway.main import _get_rpc_filter_context
mock_request = MagicMock()
mock_request.state._jwt_verified_payload = None # No JWT - e.g., plugin auth
user = {"email": "plugin_user@example.com", "is_admin": False}
email, teams, is_admin = _get_rpc_filter_context(mock_request, user)
assert email == "plugin_user@example.com"
assert teams == [] # SECURITY: No JWT = public-only (secure default)
assert is_admin is False