# -*- coding: utf-8 -*-
# -*- coding: utf-8 -*-
"""Location: ./tests/unit/mcpgateway/test_main.py
Copyright 2025
SPDX-License-Identifier: Apache-2.0
Authors: Mihai Criveti
Comprehensive tests for the main API endpoints with full coverage.
"""
# Standard
import asyncio
from copy import deepcopy
import datetime
import json
import os
from unittest.mock import ANY, AsyncMock, MagicMock, patch
# Third-Party
from fastapi import HTTPException
from fastapi.testclient import TestClient
import jwt
from pydantic import BaseModel, SecretStr, ValidationError
import pytest
import sqlalchemy as sa
from starlette.requests import Request
from starlette.websockets import WebSocketDisconnect
# First-Party
from mcpgateway.config import settings
from mcpgateway.common.models import InitializeResult, ResourceContent, ServerCapabilities
import mcpgateway.db as db_mod
from mcpgateway.schemas import (
GatewayRead,
PromptRead,
ResourceRead,
ServerRead,
ToolRead,
)
# --------------------------------------------------------------------------- #
# Constants #
# --------------------------------------------------------------------------- #
PROTOCOL_VERSION = os.getenv("PROTOCOL_VERSION", "2025-11-25")
TEST_JWT_SECRET = "unit-test-jwt-secret-key-with-minimum-32-bytes"
# Mock data templates with complete field structures
MOCK_METRICS = {
"total_executions": 10,
"successful_executions": 8,
"failed_executions": 2,
"failure_rate": 0.2,
"min_response_time": 0.1,
"max_response_time": 2.5,
"avg_response_time": 1.2,
"last_execution_time": "2023-01-01T00:00:00+00:00",
}
MOCK_SERVER_READ = {
"id": "1",
"name": "test_server",
"description": "A test server",
"icon": "server-icon",
"created_at": "2023-01-01T00:00:00+00:00",
"updated_at": "2023-01-01T00:00:00+00:00",
"enabled": True,
"associated_tools": ["101"],
"associated_resources": ["201"],
"associated_prompts": ["301"],
"metrics": MOCK_METRICS,
}
MOCK_TOOL_READ = {
"id": "1",
"name": "test_tool",
"originalName": "test_tool",
"customName": "test_tool",
"url": "http://example.com/tools/test",
"description": "A test tool",
"original_description": "A test tool original",
"requestType": "POST",
"integrationType": "MCP",
"headers": {"Content-Type": "application/json"},
"inputSchema": {"type": "object", "properties": {"param": {"type": "string"}}},
"annotations": {},
"jsonpathFilter": None,
"auth": {"auth_type": "none"},
"createdAt": "2023-01-01T00:00:00+00:00",
"updatedAt": "2023-01-01T00:00:00+00:00",
"enabled": True,
"reachable": True,
"gatewayId": "gateway-1",
"executionCount": 5,
"metrics": MOCK_METRICS,
"gatewaySlug": "gateway-1",
"customNameSlug": "test-tool",
}
# camelCase → snake_case key map for the fields that differ
_TOOL_KEY_MAP = {
"originalName": "original_name",
"requestType": "request_type",
"integrationType": "integration_type",
"inputSchema": "input_schema",
"jsonpathFilter": "jsonpath_filter",
"createdAt": "created_at",
"updatedAt": "updated_at",
"gatewayId": "gateway_id",
"gatewaySlug": "gateway_slug",
"originalNameSlug": "original_name_slug",
"customNameSlug": "custom_name_slug",
}
def camel_to_snake_tool(d: dict) -> dict:
out = deepcopy(d)
# id must be str
out["id"] = str(out["id"])
for camel, snake in _TOOL_KEY_MAP.items():
if camel in out:
out[snake] = out.pop(camel)
return out
MOCK_TOOL_READ_SNAKE = camel_to_snake_tool(MOCK_TOOL_READ)
MOCK_RESOURCE_READ = {
"id": "39334ce0ed2644d79ede8913a66930c9",
"uri": "test/resource",
"name": "Test Resource",
"description": "A test resource",
"mime_type": "text/plain",
"size": 12,
"created_at": "2023-01-01T00:00:00+00:00",
"updated_at": "2023-01-01T00:00:00+00:00",
"enabled": True,
"metrics": MOCK_METRICS,
}
MOCK_PROMPT_READ = {
"id": "ca627760127d409080fdefc309147e08",
"name": "test_prompt",
"original_name": "test_prompt",
"custom_name": "test_prompt",
"custom_name_slug": "test-prompt",
"display_name": "Test Prompt",
"description": "A test prompt",
"template": "Hello {name}",
"arguments": [],
"created_at": "2023-01-01T00:00:00+00:00",
"updated_at": "2023-01-01T00:00:00+00:00",
"enabled": True,
"metrics": MOCK_METRICS,
}
MOCK_GATEWAY_READ = {
"id": "1",
"name": "test_gateway",
"url": "http://example.com",
"description": "A test gateway",
"transport": "SSE",
"created_at": "2023-01-01T00:00:00+00:00",
"updated_at": "2023-01-01T00:00:00+00:00",
"enabled": True,
"reachable": True,
"auth_type": None,
}
MOCK_ROOT = {
"uri": "/test",
"name": "Test Root",
}
class _ValidationModel(BaseModel):
value: int
def _make_validation_error() -> ValidationError:
try:
_ValidationModel(value="bad")
except ValidationError as exc:
return exc
raise AssertionError("Expected validation error")
VALIDATION_ERROR = _make_validation_error()
INTEGRITY_ERROR = sa.exc.IntegrityError("stmt", {}, Exception("orig"))
def _make_a2a_agent_read(**overrides):
now = datetime.datetime.now(datetime.timezone.utc)
data = {
"id": "agent-1",
"name": "agent-1",
"slug": "agent-1",
"description": "Test agent",
"endpoint_url": "https://example.com/agent",
"agent_type": "generic",
"protocol_version": PROTOCOL_VERSION,
"capabilities": {},
"config": {},
"enabled": True,
"reachable": True,
"created_at": now,
"updated_at": now,
"last_interaction": None,
"tags": [],
"metrics": None,
"passthrough_headers": None,
"auth_type": None,
"auth_value": None,
"oauth_config": None,
}
data.update(overrides)
return data
def _tool_create_payload():
return {
"tool": {"name": "test_tool", "url": "http://example.com", "description": "A test tool"},
"team_id": None,
"visibility": "private",
}
def _server_create_payload():
return {"server": {"name": "test-server"}, "team_id": None, "visibility": "public"}
def _a2a_create_payload():
return {
"agent": {"name": "agent-1", "endpoint_url": "https://example.com/agent", "agent_type": "generic"},
"team_id": None,
"visibility": "public",
}
def _make_request(path: str = "/", headers: dict | None = None) -> Request:
header_list = []
for key, value in (headers or {}).items():
header_list.append((key.lower().encode(), str(value).encode()))
scope = {
"type": "http",
"scheme": "http",
"server": ("testserver", 80),
"path": path,
"headers": header_list,
}
return Request(scope)
# --------------------------------------------------------------------------- #
# Fixtures #
# --------------------------------------------------------------------------- #
@pytest.fixture(autouse=True)
def reset_db(app_with_temp_db):
"""Clear the temp DB between tests when using the module-scoped app."""
engine = db_mod.engine
if engine is None:
yield
return
with engine.begin() as conn:
if engine.dialect.name == "sqlite":
conn.exec_driver_sql("PRAGMA foreign_keys=OFF")
for table in reversed(db_mod.Base.metadata.sorted_tables):
conn.execute(table.delete())
if engine.dialect.name == "sqlite":
try:
conn.exec_driver_sql("DELETE FROM sqlite_sequence")
except sa.exc.DatabaseError:
pass
conn.exec_driver_sql("PRAGMA foreign_keys=ON")
yield
@pytest.fixture
def test_client(app_with_temp_db):
"""
Return a TestClient whose dependency graph bypasses real authentication.
Every FastAPI dependency on ``require_auth`` is overridden to return the
static user name ``"test_user"``. This keeps the protected endpoints
accessible without needing to furnish JWTs in every request.
Also overrides RBAC dependencies to bypass permission checks for tests.
"""
# First-Party
# Mock user object for RBAC system
from mcpgateway.db import EmailUser
from mcpgateway.middleware.rbac import get_current_user_with_permissions
from mcpgateway.utils.verify_credentials import require_auth
mock_user = EmailUser(
email="test_user@example.com",
full_name="Test User",
is_admin=True, # Give admin privileges for tests
is_active=True,
auth_provider="test",
)
# Override old auth system
app_with_temp_db.dependency_overrides[require_auth] = lambda: "test_user"
# Use a strong JWT secret during tests to avoid short-key warnings.
original_jwt_secret = settings.jwt_secret_key
if hasattr(original_jwt_secret, "get_secret_value") and callable(getattr(original_jwt_secret, "get_secret_value", None)):
settings.jwt_secret_key = SecretStr(TEST_JWT_SECRET)
else:
settings.jwt_secret_key = TEST_JWT_SECRET
# Patch the auth function used by DocsAuthMiddleware
# Standard
from unittest.mock import MagicMock, patch
# Third-Party
from fastapi import HTTPException, status
# First-Party
# Mock security_logger to prevent database access
mock_sec_logger = MagicMock()
mock_sec_logger.log_authentication_attempt = MagicMock(return_value=None)
mock_sec_logger.log_security_event = MagicMock(return_value=None)
sec_patcher = patch("mcpgateway.middleware.auth_middleware.security_logger", mock_sec_logger)
sec_patcher.start()
# Create a mock that validates JWT tokens properly
async def mock_require_auth_override(auth_header=None, jwt_token=None):
# Third-Party
import jwt as jwt_lib
# First-Party
from mcpgateway.config import settings
# Try to get token from auth_header or jwt_token
token = jwt_token
if not token and auth_header and auth_header.startswith("Bearer "):
token = auth_header[7:] # Remove "Bearer " prefix
if not token:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Authorization required")
try:
# Always coerce key to str in case SecretStr leaks through
key = settings.jwt_secret_key
# Only call get_secret_value if it exists and is callable (not a string)
if hasattr(key, "get_secret_value") and callable(getattr(key, "get_secret_value", None)):
key = key.get_secret_value()
payload = jwt_lib.decode(token, key, algorithms=[settings.jwt_algorithm], options={"verify_aud": False})
username = payload.get("sub")
if username:
return username
else:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid token")
except jwt_lib.ExpiredSignatureError:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Token expired")
except jwt_lib.InvalidTokenError:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid token")
patcher = patch("mcpgateway.main.require_docs_auth_override", mock_require_auth_override)
patcher.start()
# Override the core auth function used by RBAC system
# First-Party
from mcpgateway.auth import get_current_user
app_with_temp_db.dependency_overrides[get_current_user] = lambda credentials=None, db=None: mock_user
# Override get_current_user_with_permissions for RBAC system
def mock_get_current_user_with_permissions(request=None, credentials=None, jwt_token=None):
return {"email": "test_user@example.com", "full_name": "Test User", "is_admin": True, "ip_address": "127.0.0.1", "user_agent": "test"}
app_with_temp_db.dependency_overrides[get_current_user_with_permissions] = mock_get_current_user_with_permissions
# Mock the permission service to always return True for tests
# First-Party
from mcpgateway.services.permission_service import PermissionService
# Store original method
if not hasattr(PermissionService, "_original_check_permission"):
PermissionService._original_check_permission = PermissionService.check_permission
# Mock with correct async signature matching the real method
async def mock_check_permission(self, user_email: str, permission: str, resource_type=None, resource_id=None, team_id=None, ip_address=None, user_agent=None) -> bool:
return True
PermissionService.check_permission = mock_check_permission
client = TestClient(app_with_temp_db)
yield client
# Clean up overrides and restore original methods
settings.jwt_secret_key = original_jwt_secret
app_with_temp_db.dependency_overrides.pop(require_auth, None)
app_with_temp_db.dependency_overrides.pop(get_current_user, None)
app_with_temp_db.dependency_overrides.pop(get_current_user_with_permissions, None)
patcher.stop() # Stop the require_auth_override patch
sec_patcher.stop() # Stop the security_logger patch
if hasattr(PermissionService, "_original_check_permission"):
PermissionService.check_permission = PermissionService._original_check_permission
@pytest.fixture
def mock_jwt_token():
"""Create a valid JWT token for testing."""
payload = {"sub": "test_user@example.com", "email": "test_user@example.com", "iss": "mcpgateway", "aud": "mcpgateway-api"}
secret = settings.jwt_secret_key
if hasattr(secret, "get_secret_value") and callable(getattr(secret, "get_secret_value", None)):
secret = secret.get_secret_value()
algorithm = settings.jwt_algorithm
return jwt.encode(payload, secret, algorithm=algorithm)
@pytest.fixture
def auth_headers(mock_jwt_token):
"""Default auth header (still accepted by the overridden dependency)."""
return {"Authorization": f"Bearer {mock_jwt_token}"}
# ========================================================================== #
# TEST CLASSES #
# ========================================================================== #
# ----------------------------------------------------- #
# Health & Infrastructure Tests #
# ----------------------------------------------------- #
class TestHealthAndInfrastructure:
"""Tests for health checks, readiness, and basic infrastructure endpoints."""
def test_health_check(self, test_client):
"""Test the basic health check endpoint."""
response = test_client.get("/health")
assert response.status_code == 200
assert response.json()["status"] == "healthy"
def test_ready_check(self, test_client):
"""Test the readiness check endpoint."""
response = test_client.get("/ready")
assert response.status_code == 200
assert response.json()["status"] == "ready"
def test_health_check_db_error(self):
"""Test health check error path with rollback failure."""
# First-Party
from mcpgateway import main as mcpgateway_main
class DummySession:
def __init__(self):
self.invalidate_called = False
def execute(self, *_args, **_kwargs):
raise Exception("boom")
def commit(self):
pass
def rollback(self):
raise Exception("rollback failed")
def invalidate(self):
self.invalidate_called = True
def close(self):
pass
session = DummySession()
with patch("mcpgateway.main.SessionLocal", return_value=session):
response = mcpgateway_main.healthcheck()
assert response["status"] == "unhealthy"
assert session.invalidate_called is True
@pytest.mark.asyncio
async def test_ready_check_db_error(self):
"""Test readiness check error path with rollback failure."""
# First-Party
from mcpgateway import main as mcpgateway_main
class DummySession:
def __init__(self):
self.invalidate_called = False
def execute(self, *_args, **_kwargs):
raise Exception("boom")
def commit(self):
pass
def rollback(self):
raise Exception("rollback failed")
def invalidate(self):
self.invalidate_called = True
def close(self):
pass
session = DummySession()
with (
patch("mcpgateway.main.SessionLocal", return_value=session),
patch("mcpgateway.main.asyncio.to_thread", side_effect=lambda fn, *args, **kwargs: fn(*args, **kwargs)),
):
response = await mcpgateway_main.readiness_check()
assert response.status_code == 503
assert session.invalidate_called is True
def test_root_redirect(self, test_client):
"""Test that root path behavior depends on UI configuration."""
response = test_client.get("/", follow_redirects=False)
# Check if UI is enabled
if settings.mcpgateway_ui_enabled:
# When UI is enabled, should redirect to admin with trailing slash
assert response.status_code == 303
assert response.headers["location"] == f"{settings.app_root_path}/admin/"
else:
# When UI is disabled, should return API info
assert response.status_code == 200
data = response.json()
assert data["name"] == "ContextForge"
assert data["ui_enabled"] is False
def test_static_files(self, test_client):
"""Test static file serving (when files don't exist)."""
with patch("os.path.exists", return_value=True), patch("builtins.open", MagicMock()):
response = test_client.get("/static/test.css")
assert response.status_code == 404 # route registered, file absent
# ----------------------------------------------------- #
# Protocol & MCP Core Tests #
# ----------------------------------------------------- #
class TestProtocolEndpoints:
"""Tests for MCP protocol operations: initialize, ping, notifications, etc."""
# @patch("mcpgateway.main.validate_request")
@patch("mcpgateway.main.session_registry.handle_initialize_logic")
def test_initialize_endpoint(self, mock_handle_initialize, test_client, auth_headers):
"""Test MCP protocol initialization."""
mock_capabilities = ServerCapabilities(
prompts={"listChanged": True},
resources={"subscribe": True, "listChanged": True},
tools={"listChanged": True},
logging={},
roots={"listChanged": True},
sampling={},
)
mock_result = InitializeResult(
protocolVersion=PROTOCOL_VERSION,
capabilities=mock_capabilities,
serverInfo={"name": "ContextForge", "version": "1.0.0"},
instructions="ContextForge providing federated tools, resources and prompts.",
)
mock_handle_initialize.return_value = mock_result
req = {
"protocol_version": PROTOCOL_VERSION,
"capabilities": {},
"client_info": {"name": "Test Client", "version": "1.0.0"},
}
response = test_client.post("/protocol/initialize", json=req, headers=auth_headers)
assert response.status_code == 200
body = response.json()
assert body["protocolVersion"] == PROTOCOL_VERSION
mock_handle_initialize.assert_called_once()
# @patch("mcpgateway.main.validate_request")
def test_ping_endpoint(self, test_client, auth_headers):
"""Test MCP ping endpoint."""
req = {"jsonrpc": "2.0", "method": "ping", "id": "test-id"}
response = test_client.post("/protocol/ping", json=req, headers=auth_headers)
assert response.status_code == 200
body = response.json()
assert body == {"jsonrpc": "2.0", "id": "test-id", "result": {}}
def test_ping_invalid_method(self, test_client, auth_headers):
"""Test ping endpoint with invalid method."""
req = {"jsonrpc": "2.0", "method": "invalid", "id": "test-id"}
response = test_client.post("/protocol/ping", json=req, headers=auth_headers)
# Implementation raises 5xx for unsupported method
assert response.status_code == 500
@patch("mcpgateway.main.logging_service.notify")
def test_handle_notification_initialized(self, mock_notify, test_client, auth_headers):
"""Test handling client initialized notification."""
req = {"method": "notifications/initialized"}
response = test_client.post("/protocol/notifications", json=req, headers=auth_headers)
assert response.status_code == 200
mock_notify.assert_called_once()
@patch("mcpgateway.main._get_rpc_filter_context")
@patch("mcpgateway.main.cancellation_service.get_status", new_callable=AsyncMock)
@patch("mcpgateway.main.cancellation_service.cancel_run", new_callable=AsyncMock)
@patch("mcpgateway.main.logging_service.notify", new_callable=AsyncMock)
def test_handle_notification_cancelled(self, mock_notify, mock_cancel_run, mock_get_status, mock_get_context, test_client, auth_headers):
"""Test handling request cancelled notification."""
mock_get_context.return_value = ("test_user@example.com", [], False)
mock_get_status.return_value = {"owner_email": "test_user@example.com", "owner_team_ids": []}
req = {"method": "notifications/cancelled", "params": {"requestId": "123"}}
response = test_client.post("/protocol/notifications", json=req, headers=auth_headers)
assert response.status_code == 200
mock_cancel_run.assert_awaited_once_with("123", reason=None)
mock_notify.assert_awaited_once()
@patch("mcpgateway.main._get_rpc_filter_context")
@patch("mcpgateway.main.cancellation_service.get_status", new_callable=AsyncMock)
@patch("mcpgateway.main.cancellation_service.cancel_run", new_callable=AsyncMock)
@patch("mcpgateway.main.logging_service.notify", new_callable=AsyncMock)
def test_handle_notification_cancelled_denied_for_non_owner(self, mock_notify, mock_cancel_run, mock_get_status, mock_get_context, test_client, auth_headers):
"""Test cancellation notification denied for non-owner/non-admin users."""
mock_get_context.return_value = ("viewer@example.com", [], False)
mock_get_status.return_value = {"owner_email": "owner@example.com", "owner_team_ids": []}
req = {"method": "notifications/cancelled", "params": {"requestId": "123"}}
response = test_client.post("/protocol/notifications", json=req, headers=auth_headers)
assert response.status_code == 403
assert response.json()["detail"] == "Not authorized to cancel this run"
mock_cancel_run.assert_not_awaited()
mock_notify.assert_not_awaited()
@patch("mcpgateway.main._get_rpc_filter_context")
@patch("mcpgateway.main.cancellation_service.get_status", new_callable=AsyncMock)
@patch("mcpgateway.main.cancellation_service.cancel_run", new_callable=AsyncMock)
@patch("mcpgateway.main.logging_service.notify", new_callable=AsyncMock)
def test_handle_notification_alias_cancelled_enforces_authorization(self, mock_notify, mock_cancel_run, mock_get_status, mock_get_context, test_client, auth_headers):
"""The /notifications alias must enforce the same cancellation authorization rules."""
mock_get_context.return_value = ("viewer@example.com", [], False)
mock_get_status.return_value = {"owner_email": "owner@example.com", "owner_team_ids": []}
req = {"method": "notifications/cancelled", "params": {"requestId": "123"}}
response = test_client.post("/notifications", json=req, headers=auth_headers)
assert response.status_code == 403
assert response.json()["detail"] == "Not authorized to cancel this run"
mock_cancel_run.assert_not_awaited()
mock_notify.assert_not_awaited()
@patch("mcpgateway.main.logging_service.notify")
def test_handle_notification_message(self, mock_notify, test_client, auth_headers):
"""Test handling log message notification."""
req = {
"method": "notifications/message",
"params": {"data": "Test message", "level": "info", "logger": "test"},
}
response = test_client.post("/protocol/notifications", json=req, headers=auth_headers)
assert response.status_code == 200
mock_notify.assert_called_once()
@patch("mcpgateway.main._get_rpc_filter_context")
@patch("mcpgateway.main.completion_service.handle_completion")
def test_handle_completion_endpoint(self, mock_completion, mock_filter_context, test_client, auth_headers):
"""Test completion handling endpoint."""
mock_filter_context.return_value = ("scoped@example.com", ["team-1"], False)
mock_completion.return_value = {"result": "completion_result"}
req = {"ref": {"type": "ref/prompt", "name": "test"}}
response = test_client.post("/protocol/completion/complete", json=req, headers=auth_headers)
assert response.status_code == 200
mock_completion.assert_called_once_with(ANY, req, user_email="scoped@example.com", token_teams=["team-1"])
@patch("mcpgateway.main._get_rpc_filter_context")
@patch("mcpgateway.main.completion_service.handle_completion")
def test_handle_completion_endpoint_admin_bypass(self, mock_completion, mock_filter_context, test_client, auth_headers):
"""Protocol completion should preserve explicit admin bypass context."""
mock_filter_context.return_value = ("admin@example.com", None, True)
mock_completion.return_value = {"result": "completion_result"}
req = {"ref": {"type": "ref/prompt", "name": "test"}}
response = test_client.post("/protocol/completion/complete", json=req, headers=auth_headers)
assert response.status_code == 200
mock_completion.assert_called_once_with(ANY, req, user_email=None, token_teams=None)
@patch("mcpgateway.main._get_rpc_filter_context")
@patch("mcpgateway.main.completion_service.handle_completion")
def test_handle_completion_endpoint_defaults_to_public_scope_when_token_teams_none(self, mock_completion, mock_filter_context, test_client, auth_headers):
"""Protocol completion should treat token_teams=None as public-only for non-admin context."""
mock_filter_context.return_value = ("viewer@example.com", None, False)
mock_completion.return_value = {"result": "completion_result"}
req = {"ref": {"type": "ref/prompt", "name": "test"}}
response = test_client.post("/protocol/completion/complete", json=req, headers=auth_headers)
assert response.status_code == 200
mock_completion.assert_called_once_with(ANY, req, user_email="viewer@example.com", token_teams=[])
@patch("mcpgateway.main.sampling_handler.create_message")
def test_handle_sampling_endpoint(self, mock_sampling, test_client, auth_headers):
"""Test sampling message creation endpoint."""
mock_sampling.return_value = {"messageId": "123"}
req = {"messages": [{"role": "user", "content": {"type": "text", "text": "Hello"}}]}
response = test_client.post("/protocol/sampling/createMessage", json=req, headers=auth_headers)
assert response.status_code == 200
mock_sampling.assert_called_once()
# ----------------------------------------------------- #
# Server Management Tests #
# ----------------------------------------------------- #
class TestServerEndpoints:
@patch("mcpgateway.main.server_service.update_server")
def test_update_server_not_found(self, mock_update, test_client, auth_headers):
"""Test update_server returns 404 if server not found."""
# First-Party
from mcpgateway.services.server_service import ServerNotFoundError
mock_update.side_effect = ServerNotFoundError("Server not found")
req = {"description": "Updated description"}
response = test_client.put("/servers/999", json=req, headers=auth_headers)
assert response.status_code == 404
@patch("mcpgateway.main.server_service.register_server")
def test_create_server_validation_error(self, mock_create, test_client, auth_headers):
"""Test create_server returns 422 for missing required fields."""
mock_create.side_effect = None # Let validation error happen
req = {"description": "Missing name"}
response = test_client.post("/servers/", json=req, headers=auth_headers)
assert response.status_code == 422
@patch("mcpgateway.main.server_service.register_server")
def test_create_server_service_error(self, mock_create, test_client, auth_headers):
"""Test create_server returns 400 for service errors."""
# First-Party
from mcpgateway.services.server_service import ServerError
mock_create.side_effect = ServerError("Bad server")
response = test_client.post("/servers/", json=_server_create_payload(), headers=auth_headers)
assert response.status_code == 400
@pytest.mark.parametrize(
"exc,status_code",
[
(VALIDATION_ERROR, 422),
(INTEGRITY_ERROR, 409),
],
)
@patch("mcpgateway.main.server_service.register_server")
def test_create_server_validation_and_integrity_errors(self, mock_create, exc, status_code, test_client, auth_headers):
"""Test create_server returns correct status for validation/integrity errors."""
mock_create.side_effect = exc
response = test_client.post("/servers/", json=_server_create_payload(), headers=auth_headers)
assert response.status_code == status_code
"""Tests for virtual server management: CRUD operations, status toggles, etc."""
@patch("mcpgateway.main.server_service.list_servers")
def test_list_servers_endpoint(self, mock_list_servers, test_client, auth_headers):
"""Test listing all servers."""
mock_list_servers.return_value = ([ServerRead(**MOCK_SERVER_READ)], None)
response = test_client.get("/servers/", headers=auth_headers)
assert response.status_code == 200
data = response.json()
# Default response is a plain list (include_pagination=False by default)
assert isinstance(data, list)
assert len(data) == 1 and data[0]["name"] == "test_server"
mock_list_servers.assert_called_once()
@patch("mcpgateway.main.server_service.get_server")
def test_get_server_endpoint(self, mock_get, test_client, auth_headers):
"""Test retrieving a specific server."""
mock_get.return_value = ServerRead(**MOCK_SERVER_READ)
response = test_client.get("/servers/1", headers=auth_headers)
assert response.status_code == 200
assert response.json()["name"] == "test_server"
mock_get.assert_called_once()
@patch("mcpgateway.main.server_service.register_server")
def test_create_server_endpoint(self, mock_create, test_client, auth_headers):
"""Test creating a new server."""
mock_create.return_value = ServerRead(**MOCK_SERVER_READ)
req = {"server": {"name": "test_server", "description": "A test server"}, "team_id": None, "visibility": "private"}
response = test_client.post("/servers/", json=req, headers=auth_headers)
assert response.status_code == 201
mock_create.assert_called_once()
@patch("mcpgateway.main.server_service.update_server")
def test_update_server_endpoint(self, mock_update, test_client, auth_headers):
"""Test updating an existing server."""
mock_update.return_value = ServerRead(**MOCK_SERVER_READ)
req = {"description": "Updated description"}
response = test_client.put("/servers/1", json=req, headers=auth_headers)
assert response.status_code == 200
mock_update.assert_called_once()
@pytest.mark.parametrize(
"exc,status_code",
[
(VALIDATION_ERROR, 422),
(INTEGRITY_ERROR, 409),
],
)
@patch("mcpgateway.main.server_service.update_server")
def test_update_server_validation_and_integrity_errors(self, mock_update, exc, status_code, test_client, auth_headers):
"""Test update_server error branches for validation/integrity errors."""
mock_update.side_effect = exc
req = {"description": "Updated description"}
response = test_client.put("/servers/1", json=req, headers=auth_headers)
assert response.status_code == status_code
@patch("mcpgateway.main.server_service.set_server_state")
def test_set_server_state(self, mock_toggle, test_client, auth_headers):
"""Test setting server active/inactive state."""
updated_server = MOCK_SERVER_READ.copy()
updated_server["enabled"] = False
mock_toggle.return_value = ServerRead(**updated_server)
response = test_client.post("/servers/1/state?activate=false", headers=auth_headers)
assert response.status_code == 200
mock_toggle.assert_called_once()
@patch("mcpgateway.main.server_service.set_server_state")
def test_set_server_state_permission_error(self, mock_toggle, test_client, auth_headers):
"""Test server state change forbidden error."""
mock_toggle.side_effect = PermissionError("Forbidden")
response = test_client.post("/servers/1/state?activate=false", headers=auth_headers)
assert response.status_code == 403
@patch("mcpgateway.main.server_service.set_server_state")
def test_set_server_state_not_found(self, mock_toggle, test_client, auth_headers):
"""Test server state change not found error."""
# First-Party
from mcpgateway.services.server_service import ServerNotFoundError
mock_toggle.side_effect = ServerNotFoundError("Missing")
response = test_client.post("/servers/1/state?activate=false", headers=auth_headers)
assert response.status_code == 404
@patch("mcpgateway.main.server_service.delete_server")
@patch("mcpgateway.main.server_service.get_server")
def test_delete_server_endpoint(self, mock_get, mock_delete, test_client, auth_headers):
"""Test permanently deleting a server."""
mock_get.return_value = ServerRead(**MOCK_SERVER_READ)
mock_delete.return_value = None
response = test_client.delete("/servers/1", headers=auth_headers)
assert response.status_code == 200
assert response.json()["status"] == "success"
@patch("mcpgateway.main.server_service.get_server")
def test_delete_server_not_found(self, mock_get, test_client, auth_headers):
"""Test deleting a non-existent server returns 404."""
from mcpgateway.services.server_service import ServerNotFoundError
mock_get.side_effect = ServerNotFoundError("Server not found: nonexistent-id")
response = test_client.delete("/servers/nonexistent-id", headers=auth_headers)
assert response.status_code == 404
assert "Server not found" in response.json()["detail"]
@patch("mcpgateway.main.tool_service.list_server_tools")
def test_server_get_tools(self, mock_list_tools, test_client, auth_headers):
"""Test listing tools associated with a server."""
mock_tool = MagicMock()
mock_tool.model_dump.return_value = MOCK_TOOL_READ
mock_list_tools.return_value = [mock_tool]
response = test_client.get("/servers/1/tools", headers=auth_headers)
assert response.status_code == 200
data = response.json()
assert len(data) == 1
mock_list_tools.assert_called_once()
@patch("mcpgateway.main.resource_service.list_server_resources")
def test_server_get_resources(self, mock_list_resources, test_client, auth_headers):
"""Test listing resources associated with a server."""
mock_resource = MagicMock()
mock_resource.model_dump.return_value = MOCK_RESOURCE_READ
mock_list_resources.return_value = [mock_resource]
response = test_client.get("/servers/1/resources", headers=auth_headers)
assert response.status_code == 200
data = response.json()
assert len(data) == 1
mock_list_resources.assert_called_once()
@patch("mcpgateway.main.prompt_service.list_server_prompts")
def test_server_get_prompts(self, mock_list_prompts, test_client, auth_headers):
"""Test listing prompts associated with a server."""
# First-Party
from mcpgateway.schemas import PromptRead
mock_list_prompts.return_value = [PromptRead(**MOCK_PROMPT_READ)]
response = test_client.get("/servers/1/prompts", headers=auth_headers)
assert response.status_code == 200
data = response.json()
assert len(data) == 1
mock_list_prompts.assert_called_once()
# ----------------------------------------------------- #
# Tool Management Tests #
# ----------------------------------------------------- #
class TestToolEndpoints:
@patch("mcpgateway.main.tool_service.update_tool")
def test_update_tool_not_found(self, mock_update, test_client, auth_headers):
"""Test update_tool returns 404 if tool not found."""
# First-Party
from mcpgateway.services.tool_service import ToolNotFoundError
mock_update.side_effect = ToolNotFoundError("Tool not found")
req = {"description": "Updated description"}
response = test_client.put("/tools/999", json=req, headers=auth_headers)
assert response.status_code == 404
@patch("mcpgateway.main.create_tool")
def test_create_tool_validation_error(self, mock_create, test_client, auth_headers):
"""Test create_tool returns 422 for missing required fields."""
mock_create.side_effect = None # Let validation error happen
req = {"description": "Missing name and url"}
response = test_client.post("/tools/", json=req, headers=auth_headers)
assert response.status_code == 422
@pytest.mark.parametrize(
"exc,status_code",
[
(VALIDATION_ERROR, 422),
(INTEGRITY_ERROR, 409),
],
)
@patch("mcpgateway.main.tool_service.register_tool")
def test_create_tool_service_errors(self, mock_create, exc, status_code, test_client, auth_headers):
"""Test create_tool returns correct status for validation/integrity errors."""
mock_create.side_effect = exc
response = test_client.post("/tools/", json=_tool_create_payload(), headers=auth_headers)
assert response.status_code == status_code
"""Tests for tool management: registration, invocation, updates, etc."""
@patch("mcpgateway.main.tool_service.list_tools")
def test_list_tools_endpoint(self, mock_list_tools, test_client, auth_headers):
"""Test listing all registered tools."""
tool_read = ToolRead(**MOCK_TOOL_READ_SNAKE)
mock_list_tools.return_value = ([tool_read], None)
response = test_client.get("/tools/", headers=auth_headers)
assert response.status_code == 200
data = response.json()
# Default response is a plain list (include_pagination=False by default)
assert isinstance(data, list)
assert len(data) == 1 and data[0]["name"] == "test_tool"
mock_list_tools.assert_called_once()
@patch("mcpgateway.main.tool_service.register_tool")
def test_create_tool_endpoint(self, mock_create, test_client, auth_headers):
mock_create.return_value = MOCK_TOOL_READ_SNAKE
req = {"tool": {"name": "test_tool", "url": "http://example.com", "description": "A test tool"}, "team_id": None, "visibility": "private"}
response = test_client.post("/tools/", json=req, headers=auth_headers)
assert response.status_code == 200
mock_create.assert_called_once()
@patch("mcpgateway.main.tool_service.get_tool")
def test_get_tool_endpoint(self, mock_get, test_client, auth_headers):
mock_get.return_value = MOCK_TOOL_READ_SNAKE
response = test_client.get("/tools/1", headers=auth_headers)
assert response.status_code == 200
mock_get.assert_called_once()
@patch("mcpgateway.main.tool_service.update_tool")
def test_update_tool_endpoint(self, mock_update, test_client, auth_headers):
updated = {**MOCK_TOOL_READ_SNAKE, "description": "Updated description"}
mock_update.return_value = updated
req = {"description": "Updated description"}
response = test_client.put("/tools/1", json=req, headers=auth_headers)
assert response.status_code == 200
mock_update.assert_called_once()
@pytest.mark.parametrize(
"exc,status_code",
[
(VALIDATION_ERROR, 422),
(INTEGRITY_ERROR, 409),
],
)
@patch("mcpgateway.main.tool_service.update_tool")
def test_update_tool_validation_and_integrity_errors(self, mock_update, exc, status_code, test_client, auth_headers):
"""Test update_tool error branches for validation/integrity errors."""
mock_update.side_effect = exc
req = {"description": "Updated description"}
response = test_client.put("/tools/1", json=req, headers=auth_headers)
assert response.status_code == status_code
@patch("mcpgateway.main.tool_service.set_tool_state")
def test_set_tool_state(self, mock_toggle, test_client, auth_headers):
"""Test setting tool active/inactive state."""
mock_tool = MagicMock()
mock_tool.model_dump.return_value = {"id": 1, "name": "test", "is_active": False}
mock_toggle.return_value = mock_tool
response = test_client.post("/tools/1/state?activate=false", headers=auth_headers)
assert response.status_code == 200
assert response.json()["status"] == "success"
@patch("mcpgateway.main.tool_service.set_tool_state")
def test_set_tool_state_permission_error(self, mock_toggle, test_client, auth_headers):
"""Test tool state change forbidden error."""
mock_toggle.side_effect = PermissionError("Forbidden")
response = test_client.post("/tools/1/state?activate=false", headers=auth_headers)
assert response.status_code == 403
@patch("mcpgateway.main.tool_service.set_tool_state")
def test_set_tool_state_not_found(self, mock_toggle, test_client, auth_headers):
"""Test tool state change not found error."""
# First-Party
from mcpgateway.services.tool_service import ToolNotFoundError
mock_toggle.side_effect = ToolNotFoundError("Missing")
response = test_client.post("/tools/1/state?activate=false", headers=auth_headers)
assert response.status_code == 404
@patch("mcpgateway.main.tool_service.delete_tool")
def test_delete_tool_endpoint(self, mock_delete, test_client, auth_headers):
"""Test permanently deleting a tool."""
mock_delete.return_value = None
response = test_client.delete("/tools/1", headers=auth_headers)
assert response.status_code == 200
assert response.json()["status"] == "success"
# ----------------------------------------------------- #
# Resource Management Tests #
# ----------------------------------------------------- #
class TestResourceEndpoints:
@patch("mcpgateway.main.resource_service.update_resource")
def test_update_resource_not_found(self, mock_update, test_client, auth_headers):
"""Test update_resource returns 404 if resource not found."""
# First-Party
from mcpgateway.services.resource_service import ResourceNotFoundError
mock_update.side_effect = ResourceNotFoundError("Resource not found")
req = {"description": "Updated description"}
response = test_client.put("/resources/nonexistent", json=req, headers=auth_headers)
assert response.status_code == 404
@patch("mcpgateway.main.resource_service.register_resource")
def test_create_resource_validation_error(self, mock_create, test_client, auth_headers):
"""Test create_resource returns 422 for missing required fields."""
mock_create.side_effect = None # Let validation error happen
req = {"description": "Missing uri and name"}
response = test_client.post("/resources/", json=req, headers=auth_headers)
assert response.status_code == 422
@pytest.mark.parametrize(
"exc,status_code",
[
(VALIDATION_ERROR, 422),
(INTEGRITY_ERROR, 409),
],
)
@patch("mcpgateway.main.resource_service.update_resource")
def test_update_resource_validation_and_integrity_errors(self, mock_update, exc, status_code, test_client, auth_headers):
"""Test update_resource error branches for validation/integrity errors."""
mock_update.side_effect = exc
req = {"description": "Updated description"}
response = test_client.put("/resources/1", json=req, headers=auth_headers)
assert response.status_code == status_code
"""Tests for resource management: reading, creation, caching, etc."""
@patch("mcpgateway.main.resource_service.list_resources")
def test_list_resources_endpoint(self, mock_list_resources, test_client, auth_headers):
"""Test listing all available resources."""
mock_list_resources.return_value = ([ResourceRead(**MOCK_RESOURCE_READ)], None)
response = test_client.get("/resources/", headers=auth_headers)
assert response.status_code == 200
data = response.json()
# Default response is a plain list (include_pagination=False by default)
assert isinstance(data, list)
assert len(data) == 1 and data[0]["name"] == "Test Resource"
mock_list_resources.assert_called_once()
@patch("mcpgateway.main.resource_service.register_resource")
def test_create_resource_endpoint(self, mock_create, test_client, auth_headers):
"""Test registering a new resource."""
mock_create.return_value = ResourceRead(**MOCK_RESOURCE_READ)
req = {"resource": {"uri": "test/resource", "name": "Test Resource", "description": "A test resource", "content": "Hello world"}, "team_id": None, "visibility": "private"}
response = test_client.post("/resources/", json=req, headers=auth_headers)
assert response.status_code == 200 # route returns 200 on success
mock_create.assert_called_once()
@patch("mcpgateway.main.resource_service.read_resource")
def test_read_resource_endpoint(self, mock_read_resource, test_client, auth_headers):
"""Test reading resource content."""
# Clear the resource cache to avoid stale/cached values
from mcpgateway import main as mcpgateway_main
mcpgateway_main.resource_cache.clear()
mock_read_resource.return_value = ResourceContent(
type="resource",
id="1",
uri="test/resource",
mime_type="text/plain",
text="This is test content",
)
response = test_client.get("/resources/1", headers=auth_headers)
assert response.status_code == 200
body = response.json()
assert body["uri"] == "test/resource" and body["text"] == "This is test content"
mock_read_resource.assert_called_once()
@patch("mcpgateway.main.resource_service.update_resource")
def test_update_resource_endpoint(self, mock_update, test_client, auth_headers):
"""Test updating an existing resource."""
mock_update.return_value = ResourceRead(**MOCK_RESOURCE_READ)
resource_id = mock_update.return_value.id
req = {"description": "Updated description"}
response = test_client.put(f"/resources/{resource_id}", json=req, headers=auth_headers)
assert response.status_code == 200
mock_update.assert_called_once()
@patch("mcpgateway.main.resource_service.delete_resource")
def test_delete_resource_endpoint(self, mock_delete, test_client, auth_headers):
"""Test deleting a resource."""
mock_delete.return_value = None
# Use the same resource_id as in test_update_resource_endpoint
resource_id = MOCK_RESOURCE_READ["id"]
response = test_client.delete(f"/resources/{resource_id}", headers=auth_headers)
assert response.status_code == 200
assert response.json()["status"] == "success"
@patch("mcpgateway.main.resource_service.list_resource_templates")
def test_list_resource_templates(self, mock_list, test_client, auth_headers):
"""Test listing available resource templates."""
mock_list.return_value = []
response = test_client.get("/resources/templates/list", headers=auth_headers)
assert response.status_code == 200
mock_list.assert_called_once()
@patch("mcpgateway.main.resource_service.set_resource_state")
def test_set_resource_state(self, mock_toggle, test_client, auth_headers):
"""Test setting resource active/inactive state."""
mock_resource = MagicMock()
mock_resource.model_dump.return_value = {"id": "1", "enabled": False}
mock_toggle.return_value = mock_resource
response = test_client.post("/resources/1/state?activate=false", headers=auth_headers)
assert response.status_code == 200
@patch("mcpgateway.main.resource_service.set_resource_state")
def test_set_resource_state_permission_error(self, mock_toggle, test_client, auth_headers):
"""Test resource state change forbidden error."""
mock_toggle.side_effect = PermissionError("Forbidden")
response = test_client.post("/resources/1/state?activate=false", headers=auth_headers)
assert response.status_code == 403
@patch("mcpgateway.main.resource_service.set_resource_state")
def test_set_resource_state_not_found(self, mock_toggle, test_client, auth_headers):
"""Test resource state change not found error."""
# First-Party
from mcpgateway.services.resource_service import ResourceNotFoundError
mock_toggle.side_effect = ResourceNotFoundError("Missing")
response = test_client.post("/resources/1/state?activate=false", headers=auth_headers)
assert response.status_code == 404
@patch("mcpgateway.main.resource_service.subscribe_events")
def test_subscribe_resource_events(self, mock_subscribe, test_client, auth_headers):
"""Test subscribing to resource change events via SSE."""
mock_subscribe.return_value = iter(["data: test\n\n"])
response = test_client.post("/resources/subscribe", headers=auth_headers)
assert response.status_code == 200
assert response.headers["content-type"] == "text/event-stream; charset=utf-8"
mock_subscribe.assert_called_once_with(user_email=None, token_teams=None)
# ----------------------------------------------------- #
# Prompt Management Tests #
# ----------------------------------------------------- #
class TestPromptEndpoints:
@patch("mcpgateway.main.prompt_service.delete_prompt")
def test_delete_prompt_not_found(self, mock_delete, test_client, auth_headers):
"""Test delete_prompt returns 404 if prompt not found."""
# First-Party
from mcpgateway.services.prompt_service import PromptNotFoundError
mock_delete.side_effect = PromptNotFoundError("Prompt not found")
response = test_client.delete("/prompts/nonexistent", headers=auth_headers)
assert response.status_code == 404
@patch("mcpgateway.main.prompt_service.update_prompt")
def test_update_prompt_not_found(self, mock_update, test_client, auth_headers):
"""Test update_prompt returns 404 if prompt not found."""
# First-Party
from mcpgateway.services.prompt_service import PromptNotFoundError
mock_update.side_effect = PromptNotFoundError("Prompt not found")
req = {"description": "Updated description"}
response = test_client.put("/prompts/nonexistent", json=req, headers=auth_headers)
assert response.status_code == 404
@patch("mcpgateway.main.prompt_service.register_prompt")
def test_create_prompt_validation_error(self, mock_create, test_client, auth_headers):
"""Test create_prompt returns 422 for missing required fields."""
mock_create.side_effect = None # Let validation error happen
req = {"description": "Missing name and template"}
response = test_client.post("/prompts/", json=req, headers=auth_headers)
assert response.status_code == 422
@patch("mcpgateway.main.prompt_service.get_prompt")
def test_get_prompt_no_args_secondary(self, mock_get, test_client, auth_headers):
"""Test getting a prompt without arguments."""
mock_get.return_value = {"name": "test", "template": "Hello"}
response = test_client.get("/prompts/test", headers=auth_headers)
assert response.status_code == 200
mock_get.assert_called_once_with(ANY, "test", {}, user=None, server_id=None, token_teams=None, plugin_context_table=None, plugin_global_context=ANY)
@patch("mcpgateway.main.prompt_service.update_prompt")
def test_update_prompt_endpoint_secondary(self, mock_update, test_client, auth_headers):
"""Test updating an existing prompt."""
updated = {**MOCK_PROMPT_READ, "description": "Updated description"}
mock_update.return_value = PromptRead(**updated)
req = {"description": "Updated description"}
response = test_client.put("/prompts/test_prompt", json=req, headers=auth_headers)
assert response.status_code == 200
mock_update.assert_called_once()
@pytest.mark.parametrize(
"exc,status_code",
[
(VALIDATION_ERROR, 422),
(INTEGRITY_ERROR, 409),
],
)
@patch("mcpgateway.main.prompt_service.update_prompt")
def test_update_prompt_validation_and_integrity_errors(self, mock_update, exc, status_code, test_client, auth_headers):
"""Test update_prompt error branches for validation/integrity errors."""
mock_update.side_effect = exc
req = {"description": "Updated description"}
response = test_client.put("/prompts/test_prompt", json=req, headers=auth_headers)
assert response.status_code == status_code
@patch("mcpgateway.main.prompt_service.delete_prompt")
def test_delete_prompt_endpoint_secondary(self, mock_delete, test_client, auth_headers):
"""Test deleting a prompt."""
mock_delete.return_value = None
response = test_client.delete("/prompts/test_prompt", headers=auth_headers)
assert response.status_code == 200
assert response.json()["status"] == "success"
mock_delete.assert_called_once()
@patch("mcpgateway.main.prompt_service.set_prompt_state")
def test_set_prompt_state_secondary(self, mock_toggle, test_client, auth_headers):
"""Test setting prompt active/inactive state."""
mock_prompt = MagicMock()
mock_prompt.model_dump.return_value = {"id": 1, "enabled": False}
mock_toggle.return_value = mock_prompt
response = test_client.post("/prompts/1/state?activate=false", headers=auth_headers)
assert response.status_code == 200
assert response.json()["status"] == "success"
mock_toggle.assert_called_once()
@patch("mcpgateway.main.prompt_service.set_prompt_state")
def test_set_prompt_state_permission_error(self, mock_toggle, test_client, auth_headers):
"""Test prompt state change forbidden error."""
mock_toggle.side_effect = PermissionError("Forbidden")
response = test_client.post("/prompts/1/state?activate=false", headers=auth_headers)
assert response.status_code == 403
@patch("mcpgateway.main.prompt_service.set_prompt_state")
def test_set_prompt_state_not_found(self, mock_toggle, test_client, auth_headers):
"""Test prompt state change not found error."""
# First-Party
from mcpgateway.services.prompt_service import PromptNotFoundError
mock_toggle.side_effect = PromptNotFoundError("Missing")
response = test_client.post("/prompts/1/state?activate=false", headers=auth_headers)
assert response.status_code == 404
"""Tests for prompt template management: creation, rendering, arguments, etc."""
@patch("mcpgateway.main.prompt_service.list_prompts")
def test_list_prompts_endpoint(self, mock_list_prompts, test_client, auth_headers):
"""Test listing all available prompts."""
prompt_read = PromptRead(**MOCK_PROMPT_READ)
mock_list_prompts.return_value = ([prompt_read], None)
response = test_client.get("/prompts/", headers=auth_headers)
assert response.status_code == 200
data = response.json()
# Default response is a plain list (include_pagination=False by default)
assert isinstance(data, list)
assert len(data) == 1
mock_list_prompts.assert_called_once()
@patch("mcpgateway.main.prompt_service.register_prompt")
def test_create_prompt_endpoint(self, mock_create, test_client, auth_headers):
"""Test creating a new prompt template."""
# Return an actual model instance
mock_create.return_value = PromptRead(**MOCK_PROMPT_READ)
req = {"prompt": {"name": "test_prompt", "template": "Hello {name}", "description": "A test prompt"}, "team_id": None, "visibility": "private"}
response = test_client.post("/prompts/", json=req, headers=auth_headers)
assert response.status_code == 200
mock_create.assert_called_once()
@patch("mcpgateway.main.prompt_service.get_prompt")
def test_get_prompt_with_args(self, mock_get, test_client, auth_headers):
"""Test getting a prompt with template arguments."""
mock_get.return_value = {
"messages": [{"role": "user", "content": {"type": "text", "text": "Rendered prompt"}}],
"description": "A test prompt",
}
req = {"name": "value"}
response = test_client.post("/prompts/test_prompt", json=req, headers=auth_headers)
assert response.status_code == 200
body = response.json()
assert body["messages"][0]["content"]["text"] == "Rendered prompt"
mock_get.assert_called_once()
@patch("mcpgateway.main.prompt_service.get_prompt")
def test_get_prompt_no_args(self, mock_get, test_client, auth_headers):
"""Test getting a prompt without arguments."""
mock_get.return_value = {"name": "test", "template": "Hello"}
response = test_client.get("/prompts/test", headers=auth_headers)
assert response.status_code == 200
mock_get.assert_called_once_with(ANY, "test", {}, user=None, server_id=None, token_teams=None, plugin_context_table=None, plugin_global_context=ANY)
@patch("mcpgateway.main.prompt_service.update_prompt")
def test_update_prompt_endpoint(self, mock_update, test_client, auth_headers):
"""Test updating an existing prompt."""
updated = {**MOCK_PROMPT_READ, "description": "Updated description"}
mock_update.return_value = PromptRead(**updated) # <- real model
req = {"description": "Updated description"}
response = test_client.put("/prompts/test_prompt", json=req, headers=auth_headers)
assert response.status_code == 200
mock_update.assert_called_once()
@patch("mcpgateway.main.prompt_service.delete_prompt")
def test_delete_prompt_endpoint(self, mock_delete, test_client, auth_headers):
"""Test deleting a prompt."""
mock_delete.return_value = None
response = test_client.delete("/prompts/test_prompt", headers=auth_headers)
assert response.status_code == 200
assert response.json()["status"] == "success"
@patch("mcpgateway.main.prompt_service.set_prompt_state")
def test_set_prompt_state(self, mock_toggle, test_client, auth_headers):
"""Test setting prompt active/inactive state."""
mock_prompt = MagicMock()
mock_prompt.model_dump.return_value = {"id": 1, "enabled": False}
mock_toggle.return_value = mock_prompt
response = test_client.post("/prompts/1/state?activate=false", headers=auth_headers)
assert response.status_code == 200
assert response.json()["status"] == "success"
# ----------------------------------------------------- #
# Gateway Federation Tests #
# ----------------------------------------------------- #
class TestGatewayEndpoints:
@patch("mcpgateway.main.gateway_service.list_gateways")
def test_list_gateways_endpoint_secondary(self, mock_list, test_client, auth_headers):
"""Test listing all registered gateways."""
gateway_read = GatewayRead(**MOCK_GATEWAY_READ)
mock_list.return_value = ([gateway_read], None)
response = test_client.get("/gateways/", headers=auth_headers)
assert response.status_code == 200
data = response.json()
# Default response is a plain list (include_pagination=False by default)
assert isinstance(data, list)
assert len(data) == 1
mock_list.assert_called_once()
@patch("mcpgateway.main.gateway_service.register_gateway")
def test_create_gateway_endpoint_secondary(self, mock_create, test_client, auth_headers):
"""Test registering a new gateway."""
mock_create.return_value = MOCK_GATEWAY_READ
req = {"name": "test_gateway", "url": "http://example.com"}
response = test_client.post("/gateways/", json=req, headers=auth_headers)
assert response.status_code == 200
mock_create.assert_called_once()
@patch("mcpgateway.main.gateway_service.get_gateway")
def test_get_gateway_endpoint_secondary(self, mock_get, test_client, auth_headers):
"""Test retrieving a specific gateway."""
mock_get.return_value = MOCK_GATEWAY_READ
response = test_client.get("/gateways/1", headers=auth_headers)
assert response.status_code == 200
assert response.json()["name"] == "test_gateway"
mock_get.assert_called_once()
@patch("mcpgateway.main.gateway_service.update_gateway")
def test_update_gateway_endpoint_secondary(self, mock_update, test_client, auth_headers):
"""Test updating an existing gateway."""
mock_update.return_value = MOCK_GATEWAY_READ
req = {"description": "Updated description"}
response = test_client.put("/gateways/1", json=req, headers=auth_headers)
assert response.status_code == 200
mock_update.assert_called_once()
@patch("mcpgateway.main.gateway_service.delete_gateway")
@patch("mcpgateway.main.gateway_service.get_gateway")
def test_delete_gateway_endpoint_no_resources(self, mock_get, mock_delete, test_client, auth_headers):
"""Test deleting a gateway that doesn't have resources."""
mock_delete.return_value = None
mock_get.return_value.capabilities = {}
response = test_client.delete("/gateways/1", headers=auth_headers)
assert response.status_code == 200
assert response.json()["status"] == "success"
mock_delete.assert_called_once()
@patch("mcpgateway.main.gateway_service.delete_gateway")
@patch("mcpgateway.main.gateway_service.get_gateway")
@patch("mcpgateway.main.invalidate_resource_cache")
def test_delete_gateway_endpoint_with_resources(self, mock_invalidate_cache, mock_get, mock_delete, test_client, auth_headers):
"""Test deleting a gateway that does have resources."""
mock_delete.return_value = None
mock_get.return_value = MagicMock()
mock_get.return_value.capabilities = {"resources": {"some": "thing"}}
response = test_client.delete("/gateways/1", headers=auth_headers)
assert response.status_code == 200
assert response.json()["status"] == "success"
mock_delete.assert_called_once()
mock_invalidate_cache.assert_called_once()
@patch("mcpgateway.main.gateway_service.set_gateway_state")
def test_set_gateway_state_secondary(self, mock_toggle, test_client, auth_headers):
"""Test setting gateway active/inactive state."""
mock_gateway = MagicMock()
mock_gateway.model_dump.return_value = {"id": "1", "is_active": False}
mock_toggle.return_value = mock_gateway
response = test_client.post("/gateways/1/state?activate=false", headers=auth_headers)
assert response.status_code == 200
assert response.json()["status"] == "success"
mock_toggle.assert_called_once()
@patch("mcpgateway.main.gateway_service.set_gateway_state")
def test_set_gateway_state_permission_error(self, mock_toggle, test_client, auth_headers):
"""Test gateway state change forbidden error."""
mock_toggle.side_effect = PermissionError("Forbidden")
response = test_client.post("/gateways/1/state?activate=false", headers=auth_headers)
assert response.status_code == 403
@patch("mcpgateway.main.gateway_service.set_gateway_state")
def test_set_gateway_state_not_found(self, mock_toggle, test_client, auth_headers):
"""Test gateway state change not found error."""
# First-Party
from mcpgateway.services.gateway_service import GatewayNotFoundError
mock_toggle.side_effect = GatewayNotFoundError("Missing")
response = test_client.post("/gateways/1/state?activate=false", headers=auth_headers)
assert response.status_code == 404
"""Tests for gateway federation: registration, discovery, forwarding, etc."""
@patch("mcpgateway.main.gateway_service.list_gateways")
def test_list_gateways_endpoint(self, mock_list, test_client, auth_headers):
"""Test listing all registered gateways."""
gateway_read = GatewayRead(**MOCK_GATEWAY_READ)
mock_list.return_value = ([gateway_read], None)
response = test_client.get("/gateways/", headers=auth_headers)
assert response.status_code == 200
data = response.json()
# Default response is a plain list (include_pagination=False by default)
assert isinstance(data, list)
assert len(data) == 1
mock_list.assert_called_once()
@patch("mcpgateway.main.gateway_service.register_gateway")
def test_create_gateway_endpoint(self, mock_create, test_client, auth_headers):
"""Test registering a new gateway."""
mock_create.return_value = MOCK_GATEWAY_READ
req = {"name": "test_gateway", "url": "http://example.com"}
response = test_client.post("/gateways/", json=req, headers=auth_headers)
assert response.status_code == 200
mock_create.assert_called_once()
@patch("mcpgateway.main.gateway_service.get_gateway")
def test_get_gateway_endpoint(self, mock_get, test_client, auth_headers):
"""Test retrieving a specific gateway."""
mock_get.return_value = MOCK_GATEWAY_READ
response = test_client.get("/gateways/1", headers=auth_headers)
assert response.status_code == 200
assert response.json()["name"] == "test_gateway"
mock_get.assert_called_once()
@patch("mcpgateway.main.gateway_service.update_gateway")
def test_update_gateway_endpoint(self, mock_update, test_client, auth_headers):
"""Test updating an existing gateway."""
mock_update.return_value = MOCK_GATEWAY_READ
req = {"description": "Updated description"}
response = test_client.put("/gateways/1", json=req, headers=auth_headers)
assert response.status_code == 200
mock_update.assert_called_once()
@patch("mcpgateway.main.gateway_service.delete_gateway")
@patch("mcpgateway.main.gateway_service.get_gateway")
def test_delete_gateway_endpoint(self, mock_get, mock_delete, test_client, auth_headers):
"""Test deleting a gateway."""
mock_delete.return_value = None
mock_get.return_value.capabilities = {}
response = test_client.delete("/gateways/1", headers=auth_headers)
assert response.status_code == 200
assert response.json()["status"] == "success"
@patch("mcpgateway.main.gateway_service.set_gateway_state")
def test_set_gateway_state(self, mock_toggle, test_client, auth_headers):
"""Test setting gateway active/inactive state."""
mock_gateway = MagicMock()
mock_gateway.model_dump.return_value = {"id": "1", "is_active": False}
mock_toggle.return_value = mock_gateway
response = test_client.post("/gateways/1/state?activate=false", headers=auth_headers)
assert response.status_code == 200
assert response.json()["status"] == "success"
# ----------------------------------------------------- #
# Tag Endpoints Tests #
# ----------------------------------------------------- #
class TestTagEndpoints:
@patch("mcpgateway.main._get_rpc_filter_context")
@patch("mcpgateway.main.tag_service.get_all_tags", new_callable=AsyncMock)
def test_list_tags_passes_token_scope(self, mock_get_tags, mock_filter_context, test_client, auth_headers):
"""Tag list endpoint should pass scoped visibility context to service."""
mock_filter_context.return_value = ("scoped@example.com", ["team-1"], False)
mock_get_tags.return_value = []
response = test_client.get("/tags", headers=auth_headers)
assert response.status_code == 200
mock_get_tags.assert_awaited_once_with(
ANY,
entity_types=None,
include_entities=False,
user_email="scoped@example.com",
token_teams=["team-1"],
)
@patch("mcpgateway.main._get_rpc_filter_context")
@patch("mcpgateway.main.tag_service.get_entities_by_tag", new_callable=AsyncMock)
def test_get_entities_by_tag_passes_public_only_scope(self, mock_get_entities, mock_filter_context, test_client, auth_headers):
"""Tag entity lookup should preserve public-only token semantics."""
mock_filter_context.return_value = ("admin@example.com", [], False)
mock_get_entities.return_value = []
response = test_client.get("/tags/test/entities", headers=auth_headers)
assert response.status_code == 200
mock_get_entities.assert_awaited_once_with(
ANY,
tag_name="test",
entity_types=None,
user_email="admin@example.com",
token_teams=[],
)
@patch("mcpgateway.main._get_rpc_filter_context")
@patch("mcpgateway.main.tag_service.get_all_tags", new_callable=AsyncMock)
def test_list_tags_admin_bypass_passes_unrestricted_scope(self, mock_get_tags, mock_filter_context, test_client, auth_headers):
"""Explicit admin bypass token should pass unrestricted scope to tag service."""
mock_filter_context.return_value = ("admin@example.com", None, True)
mock_get_tags.return_value = []
response = test_client.get("/tags", headers=auth_headers)
assert response.status_code == 200
mock_get_tags.assert_awaited_once_with(
ANY,
entity_types=None,
include_entities=False,
user_email=None,
token_teams=None,
)
@patch("mcpgateway.main._get_rpc_filter_context")
@patch("mcpgateway.main.tag_service.get_all_tags", new_callable=AsyncMock)
def test_list_tags_defaults_to_public_scope_when_token_teams_none(self, mock_get_tags, mock_filter_context, test_client, auth_headers):
"""Non-admin token_teams=None should be normalized to public-only scope."""
mock_filter_context.return_value = ("viewer@example.com", None, False)
mock_get_tags.return_value = []
response = test_client.get("/tags", headers=auth_headers)
assert response.status_code == 200
mock_get_tags.assert_awaited_once_with(
ANY,
entity_types=None,
include_entities=False,
user_email="viewer@example.com",
token_teams=[],
)
@patch("mcpgateway.main._get_rpc_filter_context")
@patch("mcpgateway.main.tag_service.get_entities_by_tag", new_callable=AsyncMock)
def test_get_entities_by_tag_admin_bypass_passes_unrestricted_scope(self, mock_get_entities, mock_filter_context, test_client, auth_headers):
"""Admin bypass context should pass unrestricted scope to tag entity lookup."""
mock_filter_context.return_value = ("admin@example.com", None, True)
mock_get_entities.return_value = []
response = test_client.get("/tags/test/entities", headers=auth_headers)
assert response.status_code == 200
mock_get_entities.assert_awaited_once_with(
ANY,
tag_name="test",
entity_types=None,
user_email=None,
token_teams=None,
)
@patch("mcpgateway.main._get_rpc_filter_context")
@patch("mcpgateway.main.tag_service.get_entities_by_tag", new_callable=AsyncMock)
def test_get_entities_by_tag_defaults_to_public_scope_when_token_teams_none(self, mock_get_entities, mock_filter_context, test_client, auth_headers):
"""Non-admin token_teams=None should be normalized to public-only for tag entity lookup."""
mock_filter_context.return_value = ("viewer@example.com", None, False)
mock_get_entities.return_value = []
response = test_client.get("/tags/test/entities", headers=auth_headers)
assert response.status_code == 200
mock_get_entities.assert_awaited_once_with(
ANY,
tag_name="test",
entity_types=None,
user_email="viewer@example.com",
token_teams=[],
)
@patch("mcpgateway.main.tag_service.get_all_tags")
def test_list_tags_error(self, mock_get_tags, test_client, auth_headers):
"""Test tag list error handling."""
mock_get_tags.side_effect = Exception("Tag failure")
response = test_client.get("/tags", headers=auth_headers)
assert response.status_code == 500
@patch("mcpgateway.main.tag_service.get_entities_by_tag")
def test_get_entities_by_tag_error(self, mock_get_entities, test_client, auth_headers):
"""Test tag entity lookup error handling."""
mock_get_entities.side_effect = Exception("Entity lookup failure")
response = test_client.get("/tags/test/entities", headers=auth_headers)
assert response.status_code == 500
# ----------------------------------------------------- #
# Root Management Tests #
# ----------------------------------------------------- #
class TestRootEndpoints:
"""Tests for root directory management: registration, listing, changes, etc."""
@patch("mcpgateway.main.root_service.list_roots")
def test_list_roots_endpoint(self, mock_list, test_client, auth_headers):
"""Test listing all registered roots."""
# First-Party
from mcpgateway.common.models import Root
mock_list.return_value = [Root(uri="file:///test", name="Test Root")] # valid URI
response = test_client.get("/roots/", headers=auth_headers)
assert response.status_code == 200
data = response.json()
assert len(data) == 1
mock_list.assert_called_once()
def test_list_roots_endpoint_requires_admin_permission(self, test_client, auth_headers):
"""Root listing should require admin.system_config permission."""
from mcpgateway.main import app
from mcpgateway.middleware.rbac import get_current_user_with_permissions
def non_admin_user(_request=None, _credentials=None, _jwt_token=None):
return {"email": "non-admin@example.com", "is_admin": False}
app.dependency_overrides[get_current_user_with_permissions] = non_admin_user
try:
client = TestClient(app)
with patch("mcpgateway.middleware.rbac.PermissionService.check_permission", new=AsyncMock(return_value=False)):
response = client.get("/roots/", headers=auth_headers)
finally:
app.dependency_overrides.pop(get_current_user_with_permissions, None)
assert response.status_code == 403
@pytest.mark.parametrize(
("method", "path", "payload"),
[
("get", "/roots/export?uri=file:///test", None),
("get", "/roots/changes", None),
("get", "/roots/file%3A%2F%2F%2Ftest", None),
("post", "/roots/", {"uri": "file:///test", "name": "Test Root"}),
("put", "/roots/file%3A%2F%2F%2Ftest", {"uri": "file:///test", "name": "Updated Root"}),
("delete", "/roots/file%3A%2F%2F%2Ftest", None),
],
)
def test_root_management_endpoints_require_admin_permission(self, method, path, payload, auth_headers):
from mcpgateway.main import app
from mcpgateway.middleware.rbac import get_current_user_with_permissions
def non_admin_user(_request=None, _credentials=None, _jwt_token=None):
return {"email": "non-admin@example.com", "is_admin": False}
app.dependency_overrides[get_current_user_with_permissions] = non_admin_user
try:
client = TestClient(app)
with patch("mcpgateway.middleware.rbac.PermissionService.check_permission", new=AsyncMock(return_value=False)):
if payload is None:
response = getattr(client, method)(path, headers=auth_headers)
else:
response = getattr(client, method)(path, json=payload, headers=auth_headers)
finally:
app.dependency_overrides.pop(get_current_user_with_permissions, None)
assert response.status_code == 403
@patch("mcpgateway.main.root_service.add_root")
def test_add_root_endpoint(self, mock_add, test_client, auth_headers):
"""Test adding a new root directory."""
# First-Party
from mcpgateway.common.models import Root
mock_add.return_value = Root(uri="file:///test", name="Test Root") # valid URI
req = {"uri": "file:///test", "name": "Test Root"} # valid body
response = test_client.post("/roots/", json=req, headers=auth_headers)
assert response.status_code == 200
mock_add.assert_called_once()
@patch("mcpgateway.main.root_service.remove_root")
def test_remove_root_endpoint(self, mock_remove, test_client, auth_headers):
"""Test removing a root directory."""
mock_remove.return_value = None
response = test_client.delete("/roots/%2Ftest", headers=auth_headers)
assert response.status_code == 200
assert response.json()["status"] == "success"
@patch("mcpgateway.main.root_service.subscribe_changes")
def test_subscribe_root_changes(self, mock_subscribe, test_client, auth_headers):
"""Test subscribing to root directory changes via SSE."""
async def mock_async_gen():
yield {"event": "test"}
mock_subscribe.return_value = mock_async_gen()
response = test_client.get("/roots/changes", headers=auth_headers)
assert response.status_code == 200
assert response.headers["content-type"] == "text/event-stream; charset=utf-8"
# ----------------------------------------------------- #
# JSON-RPC & Utility Tests #
# ----------------------------------------------------- #
class TestRPCEndpoints:
"""Tests for JSON-RPC functionality and utility endpoints."""
@patch("mcpgateway.main.tool_service.invoke_tool")
def test_rpc_tool_invocation(self, mock_invoke_tool, test_client, auth_headers):
"""Test tool invocation via JSON-RPC."""
mock_invoke_tool.return_value = {"content": [{"type": "text", "text": "Tool response"}], "is_error": False}
req = {"jsonrpc": "2.0", "id": "test-id", "method": "tools/call", "params": {"name": "test_tool", "arguments": {"param": "value"}}}
response = test_client.post("/rpc/", json=req, headers=auth_headers)
assert response.status_code == 200
body = response.json()
assert body["result"]["content"][0]["text"] == "Tool response"
mock_invoke_tool.assert_called_once_with(
db=ANY,
name="test_tool",
arguments={"param": "value"},
request_headers=ANY,
app_user_email="test_user@example.com", # Updated: now uses email from JWT/RBAC
user_email="test_user@example.com",
token_teams=[],
server_id=None,
plugin_context_table=None,
plugin_global_context=ANY,
meta_data=None,
)
def test_rpc_tool_invocation_requires_tools_execute(self, test_client, auth_headers):
req = {"jsonrpc": "2.0", "id": "test-id-deny", "method": "tools/call", "params": {"name": "test_tool", "arguments": {"param": "value"}}}
async def _has_permission(_self, permission):
return permission != "tools.execute"
with patch("mcpgateway.main.PermissionChecker.has_permission", new=_has_permission):
response = test_client.post("/rpc/", json=req, headers=auth_headers)
assert response.status_code == 200
body = response.json()
assert body["error"]["code"] == -32003
assert "tools.execute" in body["error"]["message"]
def test_rpc_legacy_tool_invocation_requires_tools_execute(self, test_client, auth_headers):
req = {"jsonrpc": "2.0", "id": "test-id-legacy-deny", "method": "legacy_tool", "params": {"param": "value"}}
async def _has_permission(_self, permission):
return permission != "tools.execute"
with patch("mcpgateway.main.PermissionChecker.has_permission", new=_has_permission):
response = test_client.post("/rpc/", json=req, headers=auth_headers)
assert response.status_code == 200
body = response.json()
assert body["error"]["code"] == -32003
assert "tools.execute" in body["error"]["message"]
@patch("mcpgateway.main.prompt_service.get_prompt")
# @patch("mcpgateway.main.validate_request")
def test_rpc_prompt_get(self, mock_get_prompt, test_client, auth_headers):
"""Test prompt retrieval via JSON-RPC."""
mock_get_prompt.return_value = {
"messages": [{"role": "user", "content": {"type": "text", "text": "Rendered prompt"}}],
"description": "A test prompt",
}
req = {
"jsonrpc": "2.0",
"id": "test-id",
"method": "prompts/get",
"params": {"name": "test_prompt", "arguments": {"param": "value"}},
}
response = test_client.post("/rpc/", json=req, headers=auth_headers)
assert response.status_code == 200
body = response.json()
assert body["result"]["messages"][0]["content"]["text"] == "Rendered prompt"
mock_get_prompt.assert_called_once_with(
ANY, # db
"test_prompt", # name
{"param": "value"}, # arguments
user="test_user@example.com",
server_id=None,
token_teams=[],
plugin_context_table=None,
plugin_global_context=ANY,
_meta_data=None,
)
@patch("mcpgateway.main.tool_service.list_tools")
# @patch("mcpgateway.main.validate_request")
def test_rpc_list_tools(self, mock_list_tools, test_client, auth_headers):
"""Test listing tools via JSON-RPC."""
mock_tool = MagicMock()
mock_tool.model_dump.return_value = MOCK_TOOL_READ
mock_list_tools.return_value = ([mock_tool], None)
req = {
"jsonrpc": "2.0",
"id": "test-id",
"method": "tools/list",
"params": {},
}
response = test_client.post("/rpc/", json=req, headers=auth_headers)
assert response.status_code == 200
body = response.json()
assert isinstance(body["result"]["tools"], list)
mock_list_tools.assert_called_once()
@patch("mcpgateway.main.tool_service.list_server_tools", new_callable=AsyncMock)
def test_rpc_list_tools_with_server_id(self, mock_list_tools, test_client, auth_headers):
"""Test listing tools via JSON-RPC for a specific server."""
mock_tool = MagicMock()
mock_tool.model_dump.return_value = MOCK_TOOL_READ
mock_list_tools.return_value = [mock_tool]
req = {
"jsonrpc": "2.0",
"id": "test-id",
"method": "tools/list",
"params": {"server_id": "server-1"},
}
response = test_client.post("/rpc/", json=req, headers=auth_headers)
assert response.status_code == 200
body = response.json()
assert body["result"]["tools"][0]["name"] == "test_tool"
mock_list_tools.assert_called_once()
@patch("mcpgateway.main.tool_service.list_tools")
def test_rpc_legacy_list_tools_next_cursor(self, mock_list_tools, test_client, auth_headers):
"""Test legacy list_tools JSON-RPC method with nextCursor."""
mock_tool = MagicMock()
mock_tool.model_dump.return_value = MOCK_TOOL_READ
mock_list_tools.return_value = ([mock_tool], "next-cursor")
req = {
"jsonrpc": "2.0",
"id": "test-id",
"method": "list_tools",
"params": {},
}
response = test_client.post("/rpc/", json=req, headers=auth_headers)
assert response.status_code == 200
body = response.json()["result"]
assert body["nextCursor"] == "next-cursor"
assert body["tools"][0]["name"] == "test_tool"
@patch("mcpgateway.main.resource_service.list_server_resources", new_callable=AsyncMock)
def test_rpc_resources_list_with_server_id(self, mock_list_resources, test_client, auth_headers):
"""Test listing resources via JSON-RPC for a specific server."""
mock_resource = MagicMock()
mock_resource.model_dump.return_value = {"uri": "res://1"}
mock_list_resources.return_value = [mock_resource]
req = {
"jsonrpc": "2.0",
"id": "test-id",
"method": "resources/list",
"params": {"server_id": "server-1"},
}
response = test_client.post("/rpc/", json=req, headers=auth_headers)
assert response.status_code == 200
body = response.json()["result"]
assert body["resources"][0]["uri"] == "res://1"
def test_rpc_resources_read_missing_uri(self, test_client, auth_headers):
"""Test resources/read error when uri is missing."""
req = {
"jsonrpc": "2.0",
"id": "test-id",
"method": "resources/read",
"params": {},
}
response = test_client.post("/rpc/", json=req, headers=auth_headers)
assert response.status_code == 200
body = response.json()
assert body["error"]["code"] == -32602
assert "Missing resource URI" in body["error"]["message"]
@patch("mcpgateway.main.resource_service.read_resource", new_callable=AsyncMock)
def test_rpc_resources_read_missing_resource_error(self, mock_read, test_client, auth_headers):
"""Test resources/read returns error when local content missing (gateway forwarding removed)."""
mock_read.side_effect = ValueError("no local content")
req = {
"jsonrpc": "2.0",
"id": "test-id",
"method": "resources/read",
"params": {"uri": "res://remote"},
}
response = test_client.post("/rpc/", json=req, headers=auth_headers)
assert response.status_code == 200
body = response.json()
assert "error" in body
assert body["error"]["code"] == -32002
assert "Resource not found" in body["error"]["message"]
@patch("mcpgateway.main.get_user_email", return_value="user_1")
@patch("mcpgateway.main.resource_service.subscribe_resource", new_callable=AsyncMock)
@patch("mcpgateway.main.resource_service.unsubscribe_resource", new_callable=AsyncMock)
def test_rpc_resources_subscribe_unsubscribe(self, mock_unsubscribe, mock_subscribe, _mock_get_user_email, test_client, auth_headers):
"""Test resources/subscribe and resources/unsubscribe JSON-RPC methods."""
subscribe_req = {
"jsonrpc": "2.0",
"id": "test-id",
"method": "resources/subscribe",
"params": {"uri": "res://1"},
}
unsubscribe_req = {
"jsonrpc": "2.0",
"id": "test-id",
"method": "resources/unsubscribe",
"params": {"uri": "res://1"},
}
response = test_client.post("/rpc/", json=subscribe_req, headers=auth_headers)
assert response.status_code == 200
assert response.json()["result"] == {}
response = test_client.post("/rpc/", json=unsubscribe_req, headers=auth_headers)
assert response.status_code == 200
assert response.json()["result"] == {}
mock_subscribe.assert_called_once()
mock_unsubscribe.assert_called_once()
@patch("mcpgateway.main.resource_service.list_resource_templates", new_callable=AsyncMock)
def test_rpc_resource_templates_list(self, mock_list_templates, test_client, auth_headers):
"""Test resources/templates/list JSON-RPC method."""
mock_template = MagicMock()
mock_template.model_dump.return_value = {"uri": "tpl://1"}
mock_list_templates.return_value = [mock_template]
req = {
"jsonrpc": "2.0",
"id": "test-id",
"method": "resources/templates/list",
"params": {},
}
response = test_client.post("/rpc/", json=req, headers=auth_headers)
assert response.status_code == 200
body = response.json()["result"]
assert body["resourceTemplates"][0]["uri"] == "tpl://1"
@patch("mcpgateway.main.prompt_service.list_prompts", new_callable=AsyncMock)
def test_rpc_prompts_list_next_cursor(self, mock_list_prompts, test_client, auth_headers):
"""Test prompts/list JSON-RPC method with nextCursor."""
mock_prompt = MagicMock()
mock_prompt.model_dump.return_value = {"name": "prompt-1"}
mock_list_prompts.return_value = ([mock_prompt], "next-cursor")
req = {
"jsonrpc": "2.0",
"id": "test-id",
"method": "prompts/list",
"params": {},
}
response = test_client.post("/rpc/", json=req, headers=auth_headers)
assert response.status_code == 200
body = response.json()["result"]
assert body["nextCursor"] == "next-cursor"
assert body["prompts"][0]["name"] == "prompt-1"
@patch("mcpgateway.main.gateway_service.list_gateways", new_callable=AsyncMock)
def test_rpc_list_gateways(self, mock_list_gateways, test_client, auth_headers):
"""Test list_gateways JSON-RPC method."""
mock_gateway = MagicMock()
mock_gateway.model_dump.return_value = {"id": "gateway-1"}
mock_list_gateways.return_value = ([mock_gateway], None)
req = {
"jsonrpc": "2.0",
"id": "test-id",
"method": "list_gateways",
"params": {},
}
response = test_client.post("/rpc/", json=req, headers=auth_headers)
assert response.status_code == 200
body = response.json()["result"]
assert body["gateways"][0]["id"] == "gateway-1"
@patch("mcpgateway.main.root_service.list_roots", new_callable=AsyncMock)
def test_rpc_list_roots(self, mock_list_roots, test_client, auth_headers):
"""Test list_roots JSON-RPC method."""
mock_root = MagicMock()
mock_root.model_dump.return_value = {"uri": "root://1"}
mock_list_roots.return_value = [mock_root]
req = {
"jsonrpc": "2.0",
"id": "test-id",
"method": "list_roots",
"params": {},
}
response = test_client.post("/rpc/", json=req, headers=auth_headers)
assert response.status_code == 200
body = response.json()["result"]
assert body["roots"][0]["uri"] == "root://1"
@patch("mcpgateway.main.PermissionChecker.has_permission", new_callable=AsyncMock, return_value=False)
def test_rpc_list_roots_requires_admin_permission(self, _mock_permission, test_client, auth_headers):
"""list_roots RPC must enforce admin.system_config permission."""
req = {
"jsonrpc": "2.0",
"id": "test-id",
"method": "list_roots",
"params": {},
}
response = test_client.post("/rpc/", json=req, headers=auth_headers)
assert response.status_code == 200
body = response.json()
assert body["error"]["code"] == -32003
assert "admin.system_config" in body["error"]["message"]
@patch("mcpgateway.main.logging_service.notify", new_callable=AsyncMock)
@patch("mcpgateway.main.cancellation_service.get_status", new_callable=AsyncMock)
@patch("mcpgateway.main.cancellation_service.cancel_run", new_callable=AsyncMock)
def test_rpc_notification_cancelled(self, mock_cancel_run, mock_get_status, mock_notify, test_client, auth_headers):
"""Test notifications/cancelled JSON-RPC method."""
mock_get_status.return_value = {"owner_email": "test_user@example.com", "owner_team_ids": []}
req = {
"jsonrpc": "2.0",
"id": "test-id",
"method": "notifications/cancelled",
"params": {"requestId": "123", "reason": "user"},
}
response = test_client.post("/rpc/", json=req, headers=auth_headers)
assert response.status_code == 200
assert response.json()["result"] == {}
mock_cancel_run.assert_called_once_with("123", reason="user")
mock_notify.assert_called_once()
@patch("mcpgateway.main.tool_service.invoke_tool", new_callable=AsyncMock)
def test_rpc_tools_call_missing_tool_error(self, mock_invoke, test_client, auth_headers):
"""Test tools/call raises error when tool not found (gateway forwarding removed)."""
mock_invoke.side_effect = ValueError("no local tool")
req = {
"jsonrpc": "2.0",
"id": "test-id",
"method": "tools/call",
"params": {"name": "missing_tool", "arguments": {}},
}
response = test_client.post("/rpc/", json=req, headers=auth_headers)
assert response.status_code == 200
body = response.json()
assert "error" in body
assert body["error"]["code"] == -32601 # Method not found (Tool not found)
def test_rpc_elicitation_disabled(self, test_client, auth_headers, monkeypatch):
"""Test elicitation/create JSON-RPC when feature disabled."""
monkeypatch.setattr(settings, "mcpgateway_elicitation_enabled", False)
req = {
"jsonrpc": "2.0",
"id": "test-id",
"method": "elicitation/create",
"params": {"message": "Need input", "requestedSchema": {"type": "object"}},
}
response = test_client.post("/rpc/", json=req, headers=auth_headers)
assert response.status_code == 200
body = response.json()
assert body["error"]["code"] == -32601
@patch("mcpgateway.main.logging_service.notify", new_callable=AsyncMock)
def test_rpc_notifications_initialized(self, mock_notify, test_client, auth_headers):
"""Test notifications/initialized JSON-RPC method."""
req = {"jsonrpc": "2.0", "id": "test-id", "method": "notifications/initialized"}
response = test_client.post("/rpc/", json=req, headers=auth_headers)
assert response.status_code == 200
assert response.json()["result"] == {}
mock_notify.assert_called_once()
@patch("mcpgateway.main.logging_service.notify", new_callable=AsyncMock)
def test_rpc_notifications_message(self, mock_notify, test_client, auth_headers):
"""Test notifications/message JSON-RPC method."""
req = {
"jsonrpc": "2.0",
"id": "test-id",
"method": "notifications/message",
"params": {"data": "hello", "level": "info", "logger": "test"},
}
response = test_client.post("/rpc/", json=req, headers=auth_headers)
assert response.status_code == 200
assert response.json()["result"] == {}
mock_notify.assert_called_once()
@patch("mcpgateway.main.sampling_handler.create_message", new_callable=AsyncMock)
def test_rpc_sampling_create_message(self, mock_sampling, test_client, auth_headers):
"""Test sampling/createMessage JSON-RPC method."""
mock_sampling.return_value = {"messageId": "abc"}
req = {
"jsonrpc": "2.0",
"id": "test-id",
"method": "sampling/createMessage",
"params": {"messages": [{"role": "user", "content": {"type": "text", "text": "hi"}}]},
}
response = test_client.post("/rpc/", json=req, headers=auth_headers)
assert response.status_code == 200
assert response.json()["result"]["messageId"] == "abc"
mock_sampling.assert_called_once()
def test_rpc_sampling_other_method(self, test_client, auth_headers):
"""Test sampling/* catch-all JSON-RPC method."""
req = {"jsonrpc": "2.0", "id": "test-id", "method": "sampling/unknown", "params": {}}
response = test_client.post("/rpc/", json=req, headers=auth_headers)
assert response.status_code == 200
assert response.json()["result"] == {}
@patch("mcpgateway.main._get_rpc_filter_context")
@patch("mcpgateway.main.completion_service.handle_completion", new_callable=AsyncMock)
def test_rpc_completion_complete(self, mock_completion, mock_filter_context, test_client, auth_headers):
"""Test completion/complete JSON-RPC method."""
mock_filter_context.return_value = ("rpc-user@example.com", ["team-2"], False)
mock_completion.return_value = {"result": "done"}
req = {"jsonrpc": "2.0", "id": "test-id", "method": "completion/complete", "params": {"ref": {"type": "ref/prompt", "name": "p1"}}}
response = test_client.post("/rpc/", json=req, headers=auth_headers)
assert response.status_code == 200
assert response.json()["result"]["result"] == "done"
mock_completion.assert_awaited_once_with(ANY, req["params"], user_email="rpc-user@example.com", token_teams=["team-2"])
@patch("mcpgateway.main._get_rpc_filter_context")
@patch("mcpgateway.main.completion_service.handle_completion", new_callable=AsyncMock)
def test_rpc_completion_complete_admin_bypass(self, mock_completion, mock_filter_context, test_client, auth_headers):
"""RPC completion should preserve explicit admin bypass context."""
mock_filter_context.return_value = ("admin@example.com", None, True)
mock_completion.return_value = {"result": "done"}
req = {"jsonrpc": "2.0", "id": "test-id", "method": "completion/complete", "params": {"ref": {"type": "ref/prompt", "name": "p1"}}}
response = test_client.post("/rpc/", json=req, headers=auth_headers)
assert response.status_code == 200
assert response.json()["result"]["result"] == "done"
mock_completion.assert_awaited_once_with(ANY, req["params"], user_email=None, token_teams=None)
@patch("mcpgateway.main._get_rpc_filter_context")
@patch("mcpgateway.main.completion_service.handle_completion", new_callable=AsyncMock)
def test_rpc_completion_complete_defaults_to_public_scope_when_token_teams_none(self, mock_completion, mock_filter_context, test_client, auth_headers):
"""RPC completion should normalize non-admin token_teams=None to public-only."""
mock_filter_context.return_value = ("viewer@example.com", None, False)
mock_completion.return_value = {"result": "done"}
req = {"jsonrpc": "2.0", "id": "test-id", "method": "completion/complete", "params": {"ref": {"type": "ref/prompt", "name": "p1"}}}
response = test_client.post("/rpc/", json=req, headers=auth_headers)
assert response.status_code == 200
assert response.json()["result"]["result"] == "done"
mock_completion.assert_awaited_once_with(ANY, req["params"], user_email="viewer@example.com", token_teams=[])
def test_rpc_completion_other_method(self, test_client, auth_headers):
"""Test completion/* catch-all JSON-RPC method."""
req = {"jsonrpc": "2.0", "id": "test-id", "method": "completion/unknown", "params": {}}
response = test_client.post("/rpc/", json=req, headers=auth_headers)
assert response.status_code == 200
assert response.json()["result"] == {}
@patch("mcpgateway.main.logging_service.set_level", new_callable=AsyncMock)
def test_rpc_logging_set_level(self, mock_set_level, test_client, auth_headers):
"""Test logging/setLevel JSON-RPC method."""
req = {"jsonrpc": "2.0", "id": "test-id", "method": "logging/setLevel", "params": {"level": "debug"}}
response = test_client.post("/rpc/", json=req, headers=auth_headers)
assert response.status_code == 200
assert response.json()["result"] == {}
mock_set_level.assert_called_once()
def test_rpc_logging_other_method(self, test_client, auth_headers):
"""Test logging/* catch-all JSON-RPC method."""
req = {"jsonrpc": "2.0", "id": "test-id", "method": "logging/other", "params": {}}
response = test_client.post("/rpc/", json=req, headers=auth_headers)
assert response.status_code == 200
assert response.json()["result"] == {}
@patch("mcpgateway.main.root_service.list_roots", new_callable=AsyncMock)
def test_rpc_roots_list_method(self, mock_list_roots, test_client, auth_headers):
"""Test roots/list JSON-RPC method."""
mock_root = MagicMock()
mock_root.model_dump.return_value = {"uri": "root://2"}
mock_list_roots.return_value = [mock_root]
req = {"jsonrpc": "2.0", "id": "test-id", "method": "roots/list", "params": {}}
response = test_client.post("/rpc/", json=req, headers=auth_headers)
assert response.status_code == 200
assert response.json()["result"]["roots"][0]["uri"] == "root://2"
@patch("mcpgateway.main.PermissionChecker.has_permission", new_callable=AsyncMock, return_value=False)
def test_rpc_roots_list_requires_admin_permission(self, _mock_permission, test_client, auth_headers):
"""roots/list RPC must enforce admin.system_config permission."""
req = {"jsonrpc": "2.0", "id": "test-id", "method": "roots/list", "params": {}}
response = test_client.post("/rpc/", json=req, headers=auth_headers)
assert response.status_code == 200
body = response.json()
assert body["error"]["code"] == -32003
def test_rpc_roots_other_method(self, test_client, auth_headers):
"""Test roots/* catch-all JSON-RPC method."""
req = {"jsonrpc": "2.0", "id": "test-id", "method": "roots/remove", "params": {}}
response = test_client.post("/rpc/", json=req, headers=auth_headers)
assert response.status_code == 200
assert response.json()["result"] == {}
@patch("mcpgateway.main.RPCRequest")
def test_rpc_invalid_request(self, mock_rpc_request, test_client, auth_headers):
"""Test RPC error handling for invalid requests."""
mock_rpc_request.side_effect = ValueError("Invalid method")
req = {"jsonrpc": "1.0", "id": "test-id", "method": "invalid_method"}
response = test_client.post("/rpc/", json=req, headers=auth_headers)
assert response.status_code == 422
body = response.json()
assert "Method invalid" in body.get("message")
def test_rpc_invalid_json(self, test_client, auth_headers):
"""Test RPC error handling for malformed JSON."""
headers = auth_headers
headers["content-type"] = "application/json"
response = test_client.post("/rpc/", content="invalid json", headers=headers)
assert response.status_code == 400
body = response.json()
assert body["error"]["code"] == -32700
assert body["error"]["message"] == "Parse error"
@patch("mcpgateway.main.logging_service.set_level")
def test_set_log_level_endpoint(self, mock_set_level, test_client, auth_headers):
"""Test setting the application log level."""
req = {"level": "debug"} # lowercase to match enum
response = test_client.post("/logging/setLevel", json=req, headers=auth_headers)
assert response.status_code == 200
mock_set_level.assert_called_once()
# ----------------------------------------------------- #
# WebSocket & SSE Tests #
# ----------------------------------------------------- #
class TestRealtimeEndpoints:
"""Tests for real-time communication: WebSocket, SSE, message handling, etc."""
@pytest.fixture(autouse=True)
def enable_ws_relay(self, monkeypatch):
"""Enable WebSocket relay feature for realtime endpoint tests."""
# First-Party
from mcpgateway import main as mcpgateway_main
monkeypatch.setattr(mcpgateway_main.settings, "mcpgateway_ws_relay_enabled", True)
@patch("mcpgateway.main.settings")
@patch("mcpgateway.main.ResilientHttpClient") # stub network calls
def test_websocket_endpoint(self, mock_client, mock_settings, test_client):
# Standard
from types import SimpleNamespace
"""Test WebSocket connection and message handling."""
# Configure mock settings for auth disabled
mock_settings.mcp_client_auth_enabled = False
mock_settings.auth_required = False
mock_settings.federation_timeout = 30
mock_settings.skip_ssl_verify = False
mock_settings.port = 4444
# ----- set up async context-manager dummy -----
mock_instance = mock_client.return_value
mock_instance.__aenter__.return_value = mock_instance
mock_instance.__aexit__.return_value = False
async def dummy_post(*_args, **_kwargs):
# minimal object that looks like an httpx.Response
return SimpleNamespace(text='{"jsonrpc":"2.0","id":1,"result":{}}')
mock_instance.post = dummy_post
# ---------------------------------------------
with test_client.websocket_connect("/ws") as websocket:
websocket.send_text('{"jsonrpc":"2.0","method":"ping","id":1}')
data = websocket.receive_text()
response = json.loads(data)
assert response == {"jsonrpc": "2.0", "id": 1, "result": {}}
@patch("mcpgateway.main.update_url_protocol", new=lambda url: url)
@patch("mcpgateway.main.session_registry.add_session")
@patch("mcpgateway.main.session_registry.respond")
@patch("mcpgateway.main.SSETransport")
def test_sse_endpoint(self, mock_transport_class, mock_respond, mock_add_session, test_client, auth_headers):
"""Test SSE connection establishment."""
mock_transport = MagicMock()
mock_transport.session_id = "test-session"
mock_transport.create_sse_response.return_value = MagicMock()
mock_transport_class.return_value = mock_transport
test_client.get("/sse", headers=auth_headers)
# Note: This test may need adjustment based on actual SSE implementation
# The exact assertion will depend on how SSE responses are structured
mock_transport_class.assert_called_once()
@patch("mcpgateway.main.session_registry.broadcast")
def test_message_endpoint(self, mock_broadcast, test_client, auth_headers):
"""Test message broadcasting to SSE sessions."""
message = {"type": "test", "data": "hello"}
with patch("mcpgateway.main.session_registry.get_session_owner", new=AsyncMock(return_value="test_user@example.com")):
response = test_client.post("/message?session_id=test-session", json=message, headers=auth_headers)
assert response.status_code == 202
mock_broadcast.assert_called_once()
@patch("mcpgateway.main._read_request_json")
def test_message_endpoint_invalid_payload(self, mock_read, test_client, auth_headers):
"""Test message endpoint invalid JSON handling."""
mock_read.side_effect = ValueError("Invalid payload")
with patch("mcpgateway.main.session_registry.get_session_owner", new=AsyncMock(return_value="test_user@example.com")):
response = test_client.post("/message?session_id=test-session", json={"bad": True}, headers=auth_headers)
assert response.status_code == 400
def test_message_endpoint_rejects_non_owner(self, test_client, auth_headers):
"""Message endpoint should reject writes to sessions owned by another user."""
message = {"type": "test", "data": "hello"}
with (
patch("mcpgateway.main.session_registry.get_session_owner", new=AsyncMock(return_value="other@example.com")),
patch("mcpgateway.main._get_request_identity", return_value=("test_user@example.com", False)),
):
response = test_client.post("/message?session_id=test-session", json=message, headers=auth_headers)
assert response.status_code == 403
def test_message_endpoint_rejects_unknown_owner_metadata(self, test_client, auth_headers):
"""Message endpoint should fail closed when owner metadata cannot be verified."""
message = {"type": "test", "data": "hello"}
with (
patch("mcpgateway.main.session_registry.get_session_owner", new=AsyncMock(return_value=None)),
patch("mcpgateway.main.session_registry.session_exists", new=AsyncMock(return_value=True)),
):
response = test_client.post("/message?session_id=test-session", json=message, headers=auth_headers)
assert response.status_code == 403
assert response.json()["detail"] == "Session owner metadata unavailable"
def test_server_message_endpoint_rejects_unknown_owner_metadata(self, test_client, auth_headers):
"""Server message endpoint should fail closed when owner metadata is unknown."""
message = {"type": "test", "data": "hello"}
with (
patch("mcpgateway.main.session_registry.get_session_owner", new=AsyncMock(return_value=None)),
patch("mcpgateway.main.session_registry.session_exists", new=AsyncMock(return_value=True)),
):
response = test_client.post("/servers/test-server/message?session_id=test-session", json=message, headers=auth_headers)
assert response.status_code == 403
assert response.json()["detail"] == "Session owner metadata unavailable"
@pytest.mark.asyncio
async def test_websocket_forwards_auth_token_to_rpc(self, monkeypatch):
"""Test that WebSocket forwards JWT token to /rpc endpoint.
This ensures auth credentials are propagated so /rpc doesn't reject
with 401 when AUTH_REQUIRED=true.
"""
# First-Party
from mcpgateway import main as mcpgateway_main
# Track headers passed to the RPC call
captured_headers = {}
class MockResponse:
text = '{"jsonrpc":"2.0","id":1,"result":{}}'
class MockClient:
async def __aenter__(self):
return self
async def __aexit__(self, *args):
pass
async def post(self, url, json, headers):
captured_headers.update(headers)
return MockResponse()
monkeypatch.setattr(mcpgateway_main.settings, "auth_required", True)
monkeypatch.setattr(mcpgateway_main.settings, "mcp_client_auth_enabled", True)
monkeypatch.setattr(mcpgateway_main.settings, "federation_timeout", 30)
monkeypatch.setattr(mcpgateway_main.settings, "skip_ssl_verify", False)
monkeypatch.setattr(mcpgateway_main.settings, "port", 4444)
monkeypatch.setattr(mcpgateway_main.settings, "app_root_path", "")
monkeypatch.setattr(mcpgateway_main.settings, "mcpgateway_ws_relay_enabled", True)
monkeypatch.setattr(mcpgateway_main, "ResilientHttpClient", lambda **kwargs: MockClient())
monkeypatch.setattr(mcpgateway_main, "_authenticate_websocket_user", AsyncMock(return_value=("test-jwt-token", None)))
# Create mock websocket with token in query params
websocket = AsyncMock()
websocket.query_params = {"token": "test-jwt-token"}
websocket.headers = {}
# Track messages
messages_received = []
websocket.receive_text = AsyncMock(side_effect=[
'{"jsonrpc":"2.0","method":"test","id":1}',
WebSocketDisconnect(),
])
websocket.send_text = AsyncMock(side_effect=lambda msg: messages_received.append(msg))
await mcpgateway_main.websocket_endpoint(websocket)
# Verify auth token was forwarded to /rpc
assert "Authorization" in captured_headers, "Authorization header should be forwarded to /rpc"
assert captured_headers["Authorization"] == "Bearer test-jwt-token"
@pytest.mark.asyncio
async def test_websocket_forwards_proxy_user_to_rpc(self, monkeypatch):
"""Test that WebSocket forwards proxy user header to /rpc endpoint."""
# First-Party
from mcpgateway import main as mcpgateway_main
# Track headers passed to the RPC call
captured_headers = {}
class MockResponse:
text = '{"jsonrpc":"2.0","id":1,"result":{}}'
class MockClient:
async def __aenter__(self):
return self
async def __aexit__(self, *args):
pass
async def post(self, url, json, headers):
captured_headers.update(headers)
return MockResponse()
monkeypatch.setattr(mcpgateway_main.settings, "auth_required", True)
monkeypatch.setattr(mcpgateway_main.settings, "mcp_client_auth_enabled", False)
monkeypatch.setattr(mcpgateway_main.settings, "trust_proxy_auth", True)
monkeypatch.setattr(mcpgateway_main.settings, "proxy_user_header", "X-Forwarded-User")
monkeypatch.setattr(mcpgateway_main.settings, "federation_timeout", 30)
monkeypatch.setattr(mcpgateway_main.settings, "skip_ssl_verify", False)
monkeypatch.setattr(mcpgateway_main.settings, "port", 4444)
monkeypatch.setattr(mcpgateway_main.settings, "app_root_path", "")
monkeypatch.setattr(mcpgateway_main.settings, "mcpgateway_ws_relay_enabled", True)
monkeypatch.setattr(mcpgateway_main, "ResilientHttpClient", lambda **kwargs: MockClient())
monkeypatch.setattr(mcpgateway_main, "_authenticate_websocket_user", AsyncMock(return_value=(None, "proxy-user@example.com")))
# Create mock websocket with proxy user header
# Note: Use exact case matching settings.proxy_user_header since we're using a plain dict
websocket = AsyncMock()
websocket.query_params = {}
websocket.headers = {"X-Forwarded-User": "proxy-user@example.com"}
# Track messages
websocket.receive_text = AsyncMock(side_effect=[
'{"jsonrpc":"2.0","method":"test","id":1}',
WebSocketDisconnect(),
])
websocket.send_text = AsyncMock()
await mcpgateway_main.websocket_endpoint(websocket)
# Verify proxy user header was forwarded to /rpc
assert "X-Forwarded-User" in captured_headers, "Proxy user header should be forwarded to /rpc"
assert captured_headers["X-Forwarded-User"] == "proxy-user@example.com"
@pytest.mark.asyncio
async def test_websocket_disconnect_on_accept(self, monkeypatch):
"""Test WebSocket disconnect handling."""
# First-Party
from mcpgateway import main as mcpgateway_main
monkeypatch.setattr(mcpgateway_main.settings, "auth_required", False)
monkeypatch.setattr(mcpgateway_main.settings, "mcp_client_auth_enabled", False)
websocket = AsyncMock()
websocket.query_params = {}
websocket.headers = {}
websocket.accept.side_effect = WebSocketDisconnect()
await mcpgateway_main.websocket_endpoint(websocket)
@pytest.mark.asyncio
async def test_server_sse_cancelled_cleanup(self):
"""Test server SSE cancellation cleanup path."""
# First-Party
from mcpgateway import main as mcpgateway_main
class DummyTransport:
def __init__(self, *_args, **_kwargs):
self.session_id = "sess-cancel"
async def connect(self):
return None
async def create_sse_response(self, *_args, **_kwargs):
raise asyncio.CancelledError()
request = _make_request("/servers/1/sse")
with (
patch("mcpgateway.main.SSETransport", DummyTransport),
patch("mcpgateway.main.server_service.get_server", new_callable=AsyncMock),
patch("mcpgateway.main._enforce_scoped_resource_access"),
patch("mcpgateway.main.session_registry.add_session", new_callable=AsyncMock),
patch("mcpgateway.main.session_registry.respond", new_callable=AsyncMock),
patch("mcpgateway.main.session_registry.register_respond_task"),
patch("mcpgateway.main.session_registry.remove_session", new_callable=AsyncMock, side_effect=Exception("cleanup")),
patch("mcpgateway.middleware.rbac.PermissionService.check_permission", new_callable=AsyncMock, return_value=True),
):
with pytest.raises(asyncio.CancelledError):
await mcpgateway_main.sse_endpoint(request, "1", db=MagicMock(), user={"email": "user@example.com"})
@pytest.mark.asyncio
async def test_server_sse_failure_cleanup(self):
"""Test server SSE failure cleanup path."""
# First-Party
from mcpgateway import main as mcpgateway_main
class DummyTransport:
def __init__(self, *_args, **_kwargs):
self.session_id = "sess-fail"
async def connect(self):
return None
async def create_sse_response(self, *_args, **_kwargs):
raise RuntimeError("boom")
request = _make_request("/servers/1/sse")
with (
patch("mcpgateway.main.SSETransport", DummyTransport),
patch("mcpgateway.main.server_service.get_server", new_callable=AsyncMock),
patch("mcpgateway.main._enforce_scoped_resource_access"),
patch("mcpgateway.main.session_registry.add_session", new_callable=AsyncMock),
patch("mcpgateway.main.session_registry.respond", new_callable=AsyncMock),
patch("mcpgateway.main.session_registry.register_respond_task"),
patch("mcpgateway.main.session_registry.remove_session", new_callable=AsyncMock, side_effect=Exception("cleanup")),
patch("mcpgateway.middleware.rbac.PermissionService.check_permission", new_callable=AsyncMock, return_value=True),
):
with pytest.raises(HTTPException) as excinfo:
await mcpgateway_main.sse_endpoint(request, "1", db=MagicMock(), user={"email": "user@example.com"})
assert excinfo.value.status_code == 500
@pytest.mark.asyncio
async def test_utility_sse_cancelled_cleanup(self):
"""Test utility SSE cancellation cleanup path."""
# First-Party
from mcpgateway import main as mcpgateway_main
class DummyTransport:
def __init__(self, *_args, **_kwargs):
self.session_id = "util-cancel"
async def connect(self):
return None
async def create_sse_response(self, *_args, **_kwargs):
raise asyncio.CancelledError()
request = _make_request("/sse")
with (
patch("mcpgateway.main.SSETransport", DummyTransport),
patch("mcpgateway.main.session_registry.add_session", new_callable=AsyncMock),
patch("mcpgateway.main.session_registry.respond", new_callable=AsyncMock),
patch("mcpgateway.main.session_registry.register_respond_task"),
patch("mcpgateway.main.session_registry.remove_session", new_callable=AsyncMock, side_effect=Exception("cleanup")),
patch("mcpgateway.middleware.rbac.PermissionService.check_permission", new_callable=AsyncMock, return_value=True),
):
with pytest.raises(asyncio.CancelledError):
await mcpgateway_main.utility_sse_endpoint(request, user={"email": "user@example.com"})
@pytest.mark.asyncio
async def test_utility_sse_failure_cleanup(self):
"""Test utility SSE failure cleanup path."""
# First-Party
from mcpgateway import main as mcpgateway_main
class DummyTransport:
def __init__(self, *_args, **_kwargs):
self.session_id = "util-fail"
async def connect(self):
return None
async def create_sse_response(self, *_args, **_kwargs):
raise RuntimeError("boom")
request = _make_request("/sse")
with (
patch("mcpgateway.main.SSETransport", DummyTransport),
patch("mcpgateway.main.session_registry.add_session", new_callable=AsyncMock),
patch("mcpgateway.main.session_registry.respond", new_callable=AsyncMock),
patch("mcpgateway.main.session_registry.register_respond_task"),
patch("mcpgateway.main.session_registry.remove_session", new_callable=AsyncMock, side_effect=Exception("cleanup")),
patch("mcpgateway.middleware.rbac.PermissionService.check_permission", new_callable=AsyncMock, return_value=True),
):
with pytest.raises(HTTPException) as excinfo:
await mcpgateway_main.utility_sse_endpoint(request, user={"email": "user@example.com"})
assert excinfo.value.status_code == 500
# ----------------------------------------------------- #
# Metrics & Monitoring Tests #
# ----------------------------------------------------- #
class TestMetricsEndpoints:
"""Tests for metrics collection, aggregation, and reset functionality."""
@patch("mcpgateway.main.prompt_service.aggregate_metrics")
@patch("mcpgateway.main.server_service.aggregate_metrics")
@patch("mcpgateway.main.resource_service.aggregate_metrics")
@patch("mcpgateway.main.tool_service.aggregate_metrics")
def test_get_metrics(self, mock_tool, mock_resource, mock_server, mock_prompt, test_client, auth_headers):
"""Test retrieving aggregated metrics for all entity types."""
mock_tool.return_value = {"total": 5}
mock_resource.return_value = {"total": 3}
mock_server.return_value = {"total": 2}
mock_prompt.return_value = {"total": 1}
response = test_client.get("/metrics", headers=auth_headers)
assert response.status_code == 200
data = response.json()
assert "tools" in data and "resources" in data
assert "servers" in data and "prompts" in data
# A2A agents may or may not be present based on configuration
# @patch("mcpgateway.main.a2a_service")
# @patch("mcpgateway.main.prompt_service.reset_metrics")
# @patch("mcpgateway.main.server_service.reset_metrics")
# @patch("mcpgateway.main.resource_service.reset_metrics")
# @patch("mcpgateway.main.tool_service.reset_metrics")
# def test_reset_all_metrics(self, mock_tool_reset, mock_resource_reset, mock_server_reset, mock_prompt_reset, mock_a2a_service, test_client, auth_headers):
# """Test resetting metrics for all entity types."""
# # Mock A2A service with reset_metrics method
# mock_a2a_service.reset_metrics = MagicMock()
#
# response = test_client.post("/metrics/reset", headers=auth_headers)
# assert response.status_code == 200
#
# # Verify all services had their metrics reset
# mock_tool_reset.assert_called_once()
# mock_resource_reset.assert_called_once()
# mock_server_reset.assert_called_once()
# mock_prompt_reset.assert_called_once()
# mock_a2a_service.reset_metrics.assert_called_once()
@patch("mcpgateway.main.tool_service.reset_metrics")
def test_reset_specific_entity_metrics(self, mock_tool_reset, test_client, auth_headers):
"""Test resetting metrics for a specific entity type."""
response = test_client.post("/metrics/reset?entity=tool&entity_id=1", headers=auth_headers)
assert response.status_code == 200
mock_tool_reset.assert_called_once_with(ANY, 1)
def test_reset_invalid_entity_metrics(self, test_client, auth_headers):
"""Test error handling for invalid entity type in metrics reset."""
response = test_client.post("/metrics/reset?entity=invalid", headers=auth_headers)
assert response.status_code == 400
# ----------------------------------------------------- #
# A2A Agent API Tests #
# ----------------------------------------------------- #
class TestA2AAgentEndpoints:
"""Test A2A agent API endpoints."""
@patch("mcpgateway.main.a2a_service")
def test_list_a2a_agents(self, mock_service, test_client, auth_headers):
"""Test listing A2A agents."""
mock_service.list_agents = AsyncMock(return_value=([_make_a2a_agent_read()], None))
response = test_client.get("/a2a", headers=auth_headers)
assert response.status_code == 200
assert response.json()[0]["name"] == "agent-1"
mock_service.list_agents.assert_called_once()
@patch("mcpgateway.main.a2a_service")
def test_get_a2a_agent(self, mock_service, test_client, auth_headers):
"""Test getting specific A2A agent."""
mock_service.get_agent = AsyncMock(return_value=_make_a2a_agent_read())
response = test_client.get("/a2a/agent-1", headers=auth_headers)
assert response.status_code == 200
assert response.json()["name"] == "agent-1"
mock_service.get_agent.assert_called_once()
@patch("mcpgateway.main.a2a_service")
def test_create_a2a_agent(self, mock_service, test_client, auth_headers):
"""Test creating A2A agent."""
mock_service.register_agent = AsyncMock(return_value=_make_a2a_agent_read())
response = test_client.post("/a2a", json=_a2a_create_payload(), headers=auth_headers)
assert response.status_code == 201
mock_service.register_agent.assert_called_once()
@pytest.mark.parametrize(
"exc,status_code",
[
("name_conflict", 409),
("agent_error", 400),
],
)
@patch("mcpgateway.main.a2a_service")
def test_create_a2a_agent_error_branches(self, mock_service, exc, status_code, test_client, auth_headers):
"""Test create A2A agent error handling."""
# First-Party
from mcpgateway.services.a2a_service import A2AAgentError, A2AAgentNameConflictError
error = A2AAgentNameConflictError("conflict") if exc == "name_conflict" else A2AAgentError("bad")
mock_service.register_agent = AsyncMock(side_effect=error)
response = test_client.post("/a2a", json=_a2a_create_payload(), headers=auth_headers)
assert response.status_code == status_code
@patch("mcpgateway.main.a2a_service")
def test_update_a2a_agent(self, mock_service, test_client, auth_headers):
"""Test updating A2A agent."""
mock_service.update_agent = AsyncMock(return_value=_make_a2a_agent_read(name="agent-1", description="updated"))
response = test_client.put("/a2a/agent-1", json={"description": "Updated description"}, headers=auth_headers)
assert response.status_code == 200
mock_service.update_agent.assert_called_once()
@patch("mcpgateway.main.a2a_service")
def test_set_a2a_agent_state(self, mock_service, test_client, auth_headers):
"""Test toggling A2A agent status."""
mock_service.set_agent_state = AsyncMock(return_value=_make_a2a_agent_read(enabled=False))
response = test_client.post("/a2a/agent-1/state?activate=false", headers=auth_headers)
assert response.status_code == 200
mock_service.set_agent_state.assert_called_once()
@patch("mcpgateway.main.a2a_service")
def test_delete_a2a_agent(self, mock_service, test_client, auth_headers):
"""Test deleting A2A agent."""
mock_service.delete_agent = AsyncMock(return_value=None)
response = test_client.delete("/a2a/agent-1", headers=auth_headers)
assert response.status_code == 200
mock_service.delete_agent.assert_called_once()
@patch("mcpgateway.main.a2a_service")
def test_invoke_a2a_agent(self, mock_service, test_client, auth_headers):
"""Test invoking A2A agent."""
mock_service.invoke_agent = AsyncMock(return_value={"response": "Agent response", "status": "success"})
response = test_client.post(
"/a2a/agent-1/invoke",
json={"parameters": {"query": "test"}, "interaction_type": "query"},
headers=auth_headers,
)
assert response.status_code == 200
mock_service.invoke_agent.assert_called_once()
# ----------------------------------------------------- #
# Middleware & Security Tests #
# ----------------------------------------------------- #
class TestMiddlewareAndSecurity:
"""Tests for middleware functionality, authentication, CORS, path rewriting, etc."""
def test_docs_auth_middleware_protected_path(self, test_client):
"""Test that documentation paths require authentication."""
response = test_client.get("/docs", follow_redirects=False)
assert response.status_code == 401
def test_docs_auth_middleware_unprotected_path(self, test_client):
"""Test that non-documentation paths bypass docs auth middleware."""
response = test_client.get("/health")
assert response.status_code == 200
def test_openapi_protected(self, test_client):
"""Test that OpenAPI spec endpoint requires authentication."""
response = test_client.get("/openapi.json")
assert response.status_code == 401
def test_redoc_protected(self, test_client):
"""Test that ReDoc endpoint requires authentication."""
response = test_client.get("/redoc")
assert response.status_code == 401
def test_cors_headers(self, test_client, auth_headers):
"""Test that CORS headers are properly set."""
response = test_client.options("/tools/", headers=auth_headers)
# CORS is handled by FastAPI middleware, exact behavior depends on configuration
assert response.status_code in [200, 405] # Either handled or method not allowed
# ----------------------------------------------------- #
# Error Handling & Edge Cases #
# ----------------------------------------------------- #
class TestErrorHandling:
def test_docs_with_invalid_jwt(self, test_client):
"""Test /docs with an invalid JWT returns 401."""
headers = {"Authorization": "Bearer invalid.token.value"}
response = test_client.get("/docs", headers=headers)
assert response.status_code == 401
def test_docs_with_expired_jwt(self, test_client):
"""Test /docs with an expired JWT returns 401."""
expired_payload = {"sub": "test_user", "exp": datetime.datetime.now(tz=datetime.timezone.utc) - datetime.timedelta(hours=1)}
# First-Party
from mcpgateway.config import settings
key = settings.jwt_secret_key
print(f"[DEBUG] settings.jwt_secret_key type: {type(key)}, value: {key}")
if hasattr(key, "get_secret_value") and callable(getattr(key, "get_secret_value", None)):
key = key.get_secret_value()
print(f"[DEBUG] settings.jwt_secret_key after possible unwrap: {type(key)}, value: {key}")
expired_token = jwt.encode(expired_payload, key, algorithm=settings.jwt_algorithm)
headers = {"Authorization": f"Bearer {expired_token}"}
response = test_client.get("/docs", headers=headers)
assert response.status_code == 401
def test_post_on_get_only_endpoint(self, test_client, auth_headers):
"""Test POST on a GET-only endpoint returns 405."""
response = test_client.post("/health", headers=auth_headers)
assert response.status_code == 405
def test_delete_on_docs(self, test_client, auth_headers):
"""Test DELETE on /docs returns 405."""
response = test_client.delete("/docs", headers=auth_headers)
assert response.status_code == 405
def test_missing_query_param(self, test_client, auth_headers):
"""Test endpoint requiring query param returns 422 if missing."""
# /message?session_id=... requires session_id
message = {"type": "test", "data": "hello"}
response = test_client.post("/message", json=message, headers=auth_headers)
assert response.status_code == 400
def test_invalid_json_body(self, test_client, auth_headers):
"""Test handling of malformed JSON in request bodies."""
headers = auth_headers
headers["content-type"] = "application/json"
response = test_client.post("/protocol/initialize", content="invalid json", headers=headers)
assert response.status_code == 400 # body cannot be parsed, so 400
@patch("mcpgateway.main.server_service.get_server")
def test_server_not_found(self, mock_get, test_client, auth_headers):
"""Test proper error response when server is not found."""
# First-Party
from mcpgateway.services.server_service import ServerNotFoundError
mock_get.side_effect = ServerNotFoundError("Server not found")
response = test_client.get("/servers/999", headers=auth_headers)
assert response.status_code == 404
@patch("mcpgateway.main.resource_service.read_resource")
def test_resource_not_found(self, mock_read, test_client, auth_headers):
"""Test proper error response when resource is not found."""
# First-Party
from mcpgateway.services.resource_service import ResourceNotFoundError
mock_read.side_effect = ResourceNotFoundError("Resource not found")
response = test_client.get("/resources/nonexistent", headers=auth_headers)
assert response.status_code == 404
@patch("mcpgateway.main.tool_service.register_tool")
def test_tool_name_conflict(self, mock_register, test_client, auth_headers):
"""Test handling of tool name conflicts during registration."""
# First-Party
from mcpgateway.services.tool_service import ToolNameConflictError
mock_register.side_effect = ToolNameConflictError("Tool name already exists")
req = {"tool": {"name": "existing_tool", "url": "http://example.com"}, "team_id": None, "visibility": "private"}
response = test_client.post("/tools/", json=req, headers=auth_headers)
assert response.status_code == 409
def test_missing_required_fields(self, test_client, auth_headers):
"""Test validation errors for missing required fields."""
req = {"description": "Missing required name field"}
response = test_client.post("/tools/", json=req, headers=auth_headers)
assert response.status_code == 422 # Validation error
def test_openapi_json_with_auth(self, test_client, auth_headers):
"""Test GET /openapi.json with authentication returns 200 and OpenAPI spec."""
response = test_client.get("/openapi.json", headers=auth_headers)
assert response.status_code == 200
assert "openapi" in response.json()
def test_docs_with_auth(self, test_client, auth_headers):
"""Test GET /docs with authentication returns 200 or redirect."""
response = test_client.get("/docs", headers=auth_headers)
assert response.status_code == 200
def test_redoc_with_auth(self, test_client, auth_headers):
"""Test GET /redoc with authentication returns 200 or redirect."""
response = test_client.get("/redoc", headers=auth_headers)
assert response.status_code == 200
# --------------------------------------------------------------------------- #
# jsonpath_modifier #
# --------------------------------------------------------------------------- #
@pytest.fixture(scope="module")
def sample_people():
return [
{"name": "Ada", "id": 1},
{"name": "Bob", "id": 2},
]
def test_jsonpath_modifier_basic_match(sample_people):
# First-Party
from mcpgateway.main import jsonpath_modifier
# Pull out names directly
names = jsonpath_modifier(sample_people, "$[*].name")
assert names == ["Ada", "Bob"]
# Same query but with a mapping
mapped = jsonpath_modifier(sample_people, "$[*]", mappings={"n": "$.name"})
assert mapped == [{"n": "Ada"}, {"n": "Bob"}]
def test_jsonpath_modifier_single_dict_collapse():
# First-Party
from mcpgateway.main import jsonpath_modifier
person = {"name": "Zoe", "id": 10}
out = jsonpath_modifier(person, "$")
assert out == person # single-item dict collapses to dict, not list
def test_jsonpath_modifier_invalid_expressions(sample_people):
# First-Party
from mcpgateway.main import jsonpath_modifier
with pytest.raises(HTTPException):
jsonpath_modifier(sample_people, "$[") # invalid main expr
with pytest.raises(HTTPException):
jsonpath_modifier(sample_people, "$[*]", mappings={"bad": "$["}) # invalid mapping expr
# ----------------------------------------------------- #
# Transform data with mappings
# ----------------------------------------------------- #
class TestTransformDataWithMappings:
def test_transform_data_with_mappings_valid_mapping(self, sample_people):
from mcpgateway.main import transform_data_with_mappings
mapping = {"n": "$.name"}
result = transform_data_with_mappings(sample_people, mapping)
assert result == [{"n": "Ada"}, {"n": "Bob"}]
def test_transform_data_with_mappings_invalid_mapping(self, sample_people):
from mcpgateway.main import transform_data_with_mappings
with pytest.raises(HTTPException):
transform_data_with_mappings(sample_people, {"bad": "$["})
# ----------------------------------------------------- #
# Plugin Exception Handler Tests #
# ----------------------------------------------------- #
class TestPluginExceptionHandlers:
"""Tests for plugin exception handlers: PluginViolationError and PluginError."""
def test_plugin_violation_exception_handler_with_full_violation(self):
"""Test plugin_violation_exception_handler with complete violation details."""
# Standard
import asyncio
# First-Party
from mcpgateway.main import plugin_violation_exception_handler
from mcpgateway.plugins.framework.errors import PluginViolationError
from mcpgateway.plugins.framework.models import PluginViolation
violation = PluginViolation(
reason="Invalid input",
description="The input contains prohibited content",
code="PROHIBITED_CONTENT",
details={"field": "message", "value": "sensitive_data"},
)
violation._plugin_name = "content_filter"
exc = PluginViolationError(message="Policy violation detected", violation=violation)
result = asyncio.run(plugin_violation_exception_handler(None, exc))
assert result.status_code == 200
content = json.loads(result.body.decode())
assert "error" in content
assert content["error"]["code"] == -32602
assert "Plugin Violation:" in content["error"]["message"]
assert "The input contains prohibited content" in content["error"]["message"]
assert content["error"]["data"]["description"] == "The input contains prohibited content"
assert content["error"]["data"]["details"] == {"field": "message", "value": "sensitive_data"}
assert content["error"]["data"]["plugin_error_code"] == "PROHIBITED_CONTENT"
assert content["error"]["data"]["plugin_name"] == "content_filter"
def test_plugin_violation_exception_handler_with_custom_mcp_error_code(self):
"""Test plugin_violation_exception_handler with custom MCP error code."""
# Standard
import asyncio
# First-Party
from mcpgateway.main import plugin_violation_exception_handler
from mcpgateway.plugins.framework.errors import PluginViolationError
from mcpgateway.plugins.framework.models import PluginViolation
violation = PluginViolation(
reason="Rate limit exceeded",
description="Too many requests from this client",
code="RATE_LIMIT",
details={"requests": 100, "limit": 50},
mcp_error_code=-32000, # Custom error code
)
violation._plugin_name = "rate_limiter"
exc = PluginViolationError(message="Rate limit violation", violation=violation)
result = asyncio.run(plugin_violation_exception_handler(None, exc))
assert result.status_code == 200
content = json.loads(result.body.decode())
assert content["error"]["code"] == -32000
assert "Too many requests from this client" in content["error"]["message"]
assert content["error"]["data"]["plugin_error_code"] == "RATE_LIMIT"
assert content["error"]["data"]["plugin_name"] == "rate_limiter"
def test_plugin_violation_exception_handler_with_minimal_violation(self):
"""Test plugin_violation_exception_handler with minimal violation details."""
# Standard
import asyncio
# First-Party
from mcpgateway.main import plugin_violation_exception_handler
from mcpgateway.plugins.framework.errors import PluginViolationError
from mcpgateway.plugins.framework.models import PluginViolation
violation = PluginViolation(
reason="Violation occurred",
description="Minimal violation",
code="MIN_VIOLATION",
details={},
)
exc = PluginViolationError(message="Minimal violation", violation=violation)
result = asyncio.run(plugin_violation_exception_handler(None, exc))
assert result.status_code == 200
content = json.loads(result.body.decode())
assert content["error"]["code"] == -32602
assert "Minimal violation" in content["error"]["message"]
assert content["error"]["data"]["plugin_error_code"] == "MIN_VIOLATION"
def test_plugin_violation_exception_handler_without_violation_object(self):
"""Test plugin_violation_exception_handler when violation object is None."""
# Standard
import asyncio
# First-Party
from mcpgateway.main import plugin_violation_exception_handler
from mcpgateway.plugins.framework.errors import PluginViolationError
exc = PluginViolationError(message="Generic plugin violation", violation=None)
result = asyncio.run(plugin_violation_exception_handler(None, exc))
assert result.status_code == 200
content = json.loads(result.body.decode())
assert content["error"]["code"] == -32602
assert "A plugin violation occurred" in content["error"]["message"]
assert content["error"]["data"] == {}
def test_plugin_exception_handler_with_full_error(self):
"""Test plugin_exception_handler with complete error details."""
# Standard
import asyncio
# First-Party
from mcpgateway.main import plugin_exception_handler
from mcpgateway.plugins.framework.errors import PluginError
from mcpgateway.plugins.framework.models import PluginErrorModel
error = PluginErrorModel(
message="Plugin execution failed",
code="EXECUTION_ERROR",
plugin_name="data_processor",
details={"error_type": "timeout", "duration": 30},
)
exc = PluginError(error=error)
result = asyncio.run(plugin_exception_handler(None, exc))
assert result.status_code == 200
content = json.loads(result.body.decode())
assert "error" in content
assert content["error"]["code"] == -32603
assert "Plugin Error:" in content["error"]["message"]
assert "Plugin execution failed" in content["error"]["message"]
assert content["error"]["data"]["details"] == {"error_type": "timeout", "duration": 30}
assert content["error"]["data"]["plugin_error_code"] == "EXECUTION_ERROR"
assert content["error"]["data"]["plugin_name"] == "data_processor"
def test_plugin_exception_handler_with_custom_mcp_error_code(self):
"""Test plugin_exception_handler with custom MCP error code."""
# Standard
import asyncio
# First-Party
from mcpgateway.main import plugin_exception_handler
from mcpgateway.plugins.framework.errors import PluginError
from mcpgateway.plugins.framework.models import PluginErrorModel
error = PluginErrorModel(
message="Custom error occurred",
code="CUSTOM_ERROR",
plugin_name="custom_plugin",
details={"context": "test"},
mcp_error_code=-32001, # Custom MCP error code
)
exc = PluginError(error=error)
result = asyncio.run(plugin_exception_handler(None, exc))
assert result.status_code == 200
content = json.loads(result.body.decode())
assert content["error"]["code"] == -32001
assert "Custom error occurred" in content["error"]["message"]
assert content["error"]["data"]["plugin_error_code"] == "CUSTOM_ERROR"
def test_plugin_exception_handler_with_minimal_error(self):
"""Test plugin_exception_handler with minimal error details."""
# Standard
import asyncio
# First-Party
from mcpgateway.main import plugin_exception_handler
from mcpgateway.plugins.framework.errors import PluginError
from mcpgateway.plugins.framework.models import PluginErrorModel
error = PluginErrorModel(message="Minimal error", plugin_name="minimal_plugin")
exc = PluginError(error=error)
result = asyncio.run(plugin_exception_handler(None, exc))
assert result.status_code == 200
content = json.loads(result.body.decode())
assert content["error"]["code"] == -32603
assert "Minimal error" in content["error"]["message"]
assert content["error"]["data"]["plugin_name"] == "minimal_plugin"
def test_plugin_exception_handler_with_empty_code(self):
"""Test plugin_exception_handler when error has empty code field."""
# Standard
import asyncio
# First-Party
from mcpgateway.main import plugin_exception_handler
from mcpgateway.plugins.framework.errors import PluginError
from mcpgateway.plugins.framework.models import PluginErrorModel
error = PluginErrorModel(
message="Error without code",
code="",
plugin_name="test_plugin",
details={"info": "test"},
)
exc = PluginError(error=error)
result = asyncio.run(plugin_exception_handler(None, exc))
assert result.status_code == 200
content = json.loads(result.body.decode())
assert content["error"]["code"] == -32603
assert "Error without code" in content["error"]["message"]
# Empty code should not be included in data
assert "plugin_error_code" not in content["error"]["data"] or content["error"]["data"]["plugin_error_code"] == ""
# --------------------------------------------------------------------------- #
# Cache Behavior Tests #
# --------------------------------------------------------------------------- #
class TestJsonPathCaching:
"""Tests for JSONPath caching (#1812)."""
def test_jsonpath_caching_works(self):
"""Verify JSONPath parsing is cached."""
from mcpgateway.main import jsonpath_modifier, _parse_jsonpath
_parse_jsonpath.cache_clear()
result1 = jsonpath_modifier([{"a": 1}, {"a": 2}], "$[*].a")
assert result1 == [1, 2]
result2 = jsonpath_modifier([{"a": 3}], "$[*].a")
assert result2 == [3]
info = _parse_jsonpath.cache_info()
assert info.hits == 1
def test_mappings_parsed_once_per_request(self):
"""Verify mappings are parsed once per request, not per item."""
from mcpgateway.main import transform_data_with_mappings, _parse_jsonpath
_parse_jsonpath.cache_clear()
data = [{"x": 1}, {"x": 2}, {"x": 3}]
mappings = {"y": "$.x"}
result = transform_data_with_mappings(data, mappings)
assert result == [{"y": 1}, {"y": 2}, {"y": 3}]
info = _parse_jsonpath.cache_info()
assert info.misses == 1 # Only one parse for "$.x"
def test_different_jsonpath_cached_separately(self):
"""Verify different JSONPath expressions get separate cache entries."""
from mcpgateway.main import jsonpath_modifier, _parse_jsonpath
_parse_jsonpath.cache_clear()
result1 = jsonpath_modifier({"a": 1, "b": 2}, "$.a")
result2 = jsonpath_modifier({"a": 1, "b": 2}, "$.b")
assert result1 == [1]
assert result2 == [2]
info = _parse_jsonpath.cache_info()
assert info.misses == 2
# ----------------------------------------------------- #
# Token Teams Helper Function Tests (Issue #1915) #
# ----------------------------------------------------- #
class TestNormalizeTokenTeams:
"""Tests for _normalize_token_teams helper function."""
def test_normalize_token_teams_none(self):
"""Test that None input returns empty list."""
from mcpgateway.main import _normalize_token_teams
assert _normalize_token_teams(None) == []
def test_normalize_token_teams_empty_list(self):
"""Test that empty list input returns empty list."""
from mcpgateway.main import _normalize_token_teams
assert _normalize_token_teams([]) == []
def test_normalize_token_teams_string_ids(self):
"""Test that string team IDs are passed through unchanged."""
from mcpgateway.main import _normalize_token_teams
result = _normalize_token_teams(["team_a", "team_b", "team_c"])
assert result == ["team_a", "team_b", "team_c"]
def test_normalize_token_teams_dict_format(self):
"""Test that dict format with id key extracts the ID."""
from mcpgateway.main import _normalize_token_teams
result = _normalize_token_teams([{"id": "team_a", "name": "Team A"}, {"id": "team_b", "name": "Team B"}])
assert result == ["team_a", "team_b"]
def test_normalize_token_teams_mixed_format(self):
"""Test that mixed string and dict formats are handled correctly."""
from mcpgateway.main import _normalize_token_teams
result = _normalize_token_teams([{"id": "t1", "name": "Team 1"}, "t2", {"id": "t3"}])
assert result == ["t1", "t2", "t3"]
def test_normalize_token_teams_dict_without_id(self):
"""Test that dicts without id key are skipped."""
from mcpgateway.main import _normalize_token_teams
result = _normalize_token_teams([{"name": "No ID Team"}, {"id": "valid_team"}])
assert result == ["valid_team"]
def test_normalize_token_teams_dict_with_empty_id(self):
"""Test that dicts with empty id value are skipped."""
from mcpgateway.main import _normalize_token_teams
result = _normalize_token_teams([{"id": "", "name": "Empty ID"}, {"id": "valid"}])
assert result == ["valid"]
def test_normalize_token_teams_preserves_order(self):
"""Test that team order is preserved."""
from mcpgateway.main import _normalize_token_teams
result = _normalize_token_teams(["z_team", "a_team", "m_team"])
assert result == ["z_team", "a_team", "m_team"]
class TestGetTokenTeamsFromRequest:
"""Tests for _get_token_teams_from_request helper function."""
def test_get_token_teams_with_valid_cached_payload(self):
"""Test extraction of teams from cached JWT payload."""
from mcpgateway.main import _get_token_teams_from_request
mock_request = MagicMock()
mock_request.state._jwt_verified_payload = ("token_string", {"sub": "user@example.com", "teams": ["team_a", "team_b"]})
result = _get_token_teams_from_request(mock_request)
assert result == ["team_a", "team_b"]
def test_get_token_teams_with_dict_teams_payload(self):
"""Test extraction and normalization of dict format teams."""
from mcpgateway.main import _get_token_teams_from_request
mock_request = MagicMock()
mock_request.state._jwt_verified_payload = ("token", {"teams": [{"id": "t1", "name": "Team 1"}]})
result = _get_token_teams_from_request(mock_request)
assert result == ["t1"]
def test_get_token_teams_no_cached_payload_returns_empty_list(self):
"""Test that missing cached payload returns [] (public-only, secure default)."""
from mcpgateway.main import _get_token_teams_from_request
mock_request = MagicMock()
mock_request.state._jwt_verified_payload = None
result = _get_token_teams_from_request(mock_request)
assert result == [] # SECURITY: No JWT = public-only (secure default)
def test_get_token_teams_no_teams_in_payload_returns_empty_list(self):
"""Test that payload without teams key returns [] (public-only, secure default)."""
from mcpgateway.main import _get_token_teams_from_request
mock_request = MagicMock()
mock_request.state._jwt_verified_payload = ("token", {"sub": "user@example.com"})
result = _get_token_teams_from_request(mock_request)
assert result == [] # SECURITY: Missing teams = public-only (secure default)
def test_get_token_teams_empty_teams_returns_empty_list(self):
"""Test that payload with empty teams returns empty list (not None)."""
from mcpgateway.main import _get_token_teams_from_request
mock_request = MagicMock()
mock_request.state._jwt_verified_payload = ("token", {"sub": "user@example.com", "teams": []})
result = _get_token_teams_from_request(mock_request)
assert result == [] # Empty list = JWT exists but no teams
def test_get_token_teams_null_teams_non_admin_returns_empty_list(self):
"""Test that payload with teams: null (non-admin) returns [] (public-only)."""
from mcpgateway.main import _get_token_teams_from_request
mock_request = MagicMock()
mock_request.state._jwt_verified_payload = ("token", {"sub": "user@example.com", "teams": None})
result = _get_token_teams_from_request(mock_request)
assert result == [] # SECURITY: Null teams + non-admin = public-only
def test_get_token_teams_null_teams_admin_returns_none(self):
"""Test that payload with teams: null + is_admin=true returns None (admin bypass)."""
from mcpgateway.main import _get_token_teams_from_request
mock_request = MagicMock()
mock_request.state._jwt_verified_payload = ("token", {"sub": "admin@example.com", "teams": None, "is_admin": True})
result = _get_token_teams_from_request(mock_request)
assert result is None # Admin with explicit null teams = admin bypass
def test_get_token_teams_invalid_tuple_format_returns_empty_list(self):
"""Test that non-tuple cached payload returns [] (public-only, secure default)."""
from mcpgateway.main import _get_token_teams_from_request
mock_request = MagicMock()
mock_request.state._jwt_verified_payload = "not_a_tuple"
result = _get_token_teams_from_request(mock_request)
assert result == [] # SECURITY: Invalid format = public-only (secure default)
def test_get_token_teams_short_tuple_returns_empty_list(self):
"""Test that tuple with wrong length returns [] (public-only, secure default)."""
from mcpgateway.main import _get_token_teams_from_request
mock_request = MagicMock()
mock_request.state._jwt_verified_payload = ("only_one_element",)
result = _get_token_teams_from_request(mock_request)
assert result == [] # SECURITY: Invalid format = public-only (secure default)
def test_get_token_teams_none_payload_in_tuple_returns_empty_list(self):
"""Test that None payload in tuple returns [] (public-only, secure default)."""
from mcpgateway.main import _get_token_teams_from_request
mock_request = MagicMock()
mock_request.state._jwt_verified_payload = ("token", None)
result = _get_token_teams_from_request(mock_request)
assert result == [] # SECURITY: No payload = public-only (secure default)
class TestGetRpcFilterContext:
"""Tests for _get_rpc_filter_context helper function."""
def test_get_rpc_filter_context_dict_user(self):
"""Test with dict user containing email and is_admin."""
from mcpgateway.main import _get_rpc_filter_context
mock_request = MagicMock()
# is_admin must be in the token payload, not the user dict (security fix)
mock_request.state._jwt_verified_payload = ("token", {"teams": ["t1", "t2"], "is_admin": True})
user = {"email": "test@example.com", "is_admin": True} # User's is_admin is ignored
email, teams, is_admin = _get_rpc_filter_context(mock_request, user)
assert email == "test@example.com"
assert teams == ["t1", "t2"]
assert is_admin is True # From token payload, not user dict
def test_get_rpc_filter_context_dict_user_sub_field(self):
"""Test that sub field is used if email is not present."""
from mcpgateway.main import _get_rpc_filter_context
mock_request = MagicMock()
mock_request.state._jwt_verified_payload = ("token", {"teams": []})
user = {"sub": "user@sub.com"}
email, teams, is_admin = _get_rpc_filter_context(mock_request, user)
assert email == "user@sub.com"
assert teams == []
assert is_admin is False
def test_get_rpc_filter_context_object_user(self):
"""Test with user object having email and is_admin attributes."""
from mcpgateway.main import _get_rpc_filter_context
mock_request = MagicMock()
mock_request.state._jwt_verified_payload = ("token", {"teams": ["team_x"]})
class UserObject:
email = "obj@example.com"
is_admin = False
email, teams, is_admin = _get_rpc_filter_context(mock_request, UserObject())
assert email == "obj@example.com"
assert teams == ["team_x"]
assert is_admin is False
def test_get_rpc_filter_context_nested_is_admin(self):
"""Test that nested user.is_admin is extracted from token payload."""
from mcpgateway.main import _get_rpc_filter_context
mock_request = MagicMock()
# is_admin must be in token payload - use non-empty teams to allow admin bypass
mock_request.state._jwt_verified_payload = ("token", {"teams": ["team_x"], "user": {"is_admin": True}})
user = {"email": "nested@example.com", "user": {"is_admin": True}}
email, teams, is_admin = _get_rpc_filter_context(mock_request, user)
assert email == "nested@example.com"
assert is_admin is True # From token payload's nested user.is_admin
def test_get_rpc_filter_context_empty_teams_disables_admin(self):
"""Test that empty teams array disables admin bypass even when is_admin is true."""
from mcpgateway.main import _get_rpc_filter_context
mock_request = MagicMock()
# Token has is_admin but empty teams - admin bypass should be disabled
mock_request.state._jwt_verified_payload = ("token", {"teams": [], "is_admin": True})
user = {"email": "admin@example.com", "is_admin": True}
email, teams, is_admin = _get_rpc_filter_context(mock_request, user)
assert email == "admin@example.com"
assert teams == []
assert is_admin is False # Disabled for empty-team tokens (public-only access)
def test_get_rpc_filter_context_string_user(self):
"""Test with string user (fallback to str conversion)."""
from mcpgateway.main import _get_rpc_filter_context
mock_request = MagicMock()
mock_request.state._jwt_verified_payload = ("token", {"teams": ["t1"]})
user = "plain_username"
email, teams, is_admin = _get_rpc_filter_context(mock_request, user)
assert email == "plain_username"
assert teams == ["t1"]
assert is_admin is False
def test_get_rpc_filter_context_none_user(self):
"""Test with None user."""
from mcpgateway.main import _get_rpc_filter_context
mock_request = MagicMock()
mock_request.state._jwt_verified_payload = ("token", {"teams": []})
email, teams, is_admin = _get_rpc_filter_context(mock_request, None)
assert email is None
assert teams == []
assert is_admin is False
def test_get_rpc_filter_context_admin_not_in_dict(self):
"""Test that is_admin defaults to False if not present."""
from mcpgateway.main import _get_rpc_filter_context
mock_request = MagicMock()
mock_request.state._jwt_verified_payload = ("token", {"teams": ["t1"]})
user = {"email": "user@example.com"}
email, teams, is_admin = _get_rpc_filter_context(mock_request, user)
assert email == "user@example.com"
assert is_admin is False
def test_get_rpc_filter_context_no_jwt_returns_empty_teams(self):
"""Test that missing JWT payload returns [] for teams (public-only, secure default)."""
from mcpgateway.main import _get_rpc_filter_context
mock_request = MagicMock()
mock_request.state._jwt_verified_payload = None # No JWT - e.g., plugin auth
user = {"email": "plugin_user@example.com", "is_admin": False}
email, teams, is_admin = _get_rpc_filter_context(mock_request, user)
assert email == "plugin_user@example.com"
assert teams == [] # SECURITY: No JWT = public-only (secure default)
assert is_admin is False