# -*- coding: utf-8 -*-
"""Location: ./tests/unit/mcpgateway/services/test_gateway_service.py
Copyright 2025
SPDX-License-Identifier: Apache-2.0
Authors: Mihai Criveti
Unit-tests for the GatewayService implementation.
These tests use only MagicMock / AsyncMock - no real network access
and no real database needed. Where the service relies on Pydantic
models or SQLAlchemy Result objects, we monkey-patch or fake just
enough behaviour to satisfy the code paths under test.
"""
# Future
from __future__ import annotations
# Standard
import asyncio
from datetime import datetime, timezone, timedelta
from typing import TypeVar
from unittest.mock import AsyncMock, MagicMock, Mock, patch
# Third-Party
import pytest
from url_normalize import url_normalize
# First-Party
# ---------------------------------------------------------------------------
# Application imports
# ---------------------------------------------------------------------------
from mcpgateway.db import Gateway as DbGateway
from mcpgateway.db import Tool as DbTool
from mcpgateway.db import Resource as DbResource
from mcpgateway.db import Prompt as DbPrompt
from mcpgateway.schemas import GatewayCreate, GatewayUpdate
from mcpgateway.services.gateway_service import (
GatewayConnectionError,
GatewayError,
GatewayNameConflictError,
GatewayNotFoundError,
GatewayService,
GatewayDuplicateConflictError,
OAuthToolValidationError,
)
# ---------------------------------------------------------------------------
# Helpers & global monkey-patches
# ---------------------------------------------------------------------------
_R = TypeVar("_R")
def _make_execute_result(*, scalar: _R | None = None, scalars_list: list[_R] | None = None) -> MagicMock:
"""
Return a MagicMock that behaves like the SQLAlchemy Result object the
service expects after ``Session.execute``:
- .scalar_one_or_none() -> *scalar*
- .scalars().all() -> *scalars_list* (defaults to [])
This lets us emulate both the "fetch one" path and the "fetch many"
path with a single helper.
"""
result = MagicMock()
result.scalar_one_or_none.return_value = scalar
scalars_proxy = MagicMock()
scalars_proxy.all.return_value = scalars_list or []
result.scalars.return_value = scalars_proxy
return result
@pytest.fixture(autouse=True)
def mock_logging_services():
"""Mock audit_trail and structured_logger to prevent database writes during tests."""
# Clear SSL context cache before each test for isolation
from mcpgateway.utils.ssl_context_cache import clear_ssl_context_cache
clear_ssl_context_cache()
with patch("mcpgateway.services.gateway_service.audit_trail") as mock_audit, \
patch("mcpgateway.services.gateway_service.structured_logger") as mock_logger:
mock_audit.log_action = MagicMock(return_value=None)
mock_logger.log = MagicMock(return_value=None)
yield {"audit_trail": mock_audit, "structured_logger": mock_logger}
@pytest.fixture(autouse=True)
def _bypass_gatewayread_validation(monkeypatch):
"""
The real GatewayService returns ``GatewayRead.model_validate(db_obj)``.
The DB objects we feed in here are MagicMocks, not real models, and
Pydantic hates that. We therefore stub out `GatewayRead.model_validate`
so it simply returns what it was given.
"""
# First-Party
from mcpgateway.schemas import GatewayRead
monkeypatch.setattr(GatewayRead, "model_validate", staticmethod(lambda x: x))
# ---------------------------------------------------------------------------
# Fixtures
# ---------------------------------------------------------------------------
@pytest.fixture
def gateway_service():
"""
A GatewayService instance with its internal HTTP-client replaced by
an AsyncMock so no real HTTP requests are performed.
"""
service = GatewayService()
service._http_client = AsyncMock()
return service
@pytest.fixture
def mock_gateway():
"""Return a minimal but realistic DbGateway MagicMock."""
gw = MagicMock(spec=DbGateway)
gw.id = 1
gw.name = "test_gateway"
gw.url = "http://example.com/gateway"
gw.description = "A test gateway"
gw.capabilities = {"prompts": {"listChanged": True}, "resources": {"listChanged": True}, "tools": {"listChanged": True}}
gw.created_at = gw.updated_at = gw.last_seen = "2025-01-01T00:00:00Z"
gw.enabled = True
gw.reachable = True
# one dummy tool hanging off the gateway
tool = MagicMock(spec=DbTool, id=101, name="dummy_tool")
gw.tools = [tool]
gw.resources = [] # Empty list for delete tests
gw.prompts = [] # Empty list for delete tests
gw.federated_tools = []
gw.transport = "sse"
gw.auth_value = {}
gw.team_id = 1 # Ensure team_id is a real value, not a MagicMock
return gw
@pytest.fixture
def mock_session():
"""Return a mocked SQLAlchemy session."""
session = MagicMock()
session.query.return_value = MagicMock()
session.commit.return_value = None
session.rollback.return_value = None
return session
# ---------------------------------------------------------------------------
# Test-cases
# ---------------------------------------------------------------------------
class TestGatewayService:
"""All GatewayService happy-path and error-path unit-tests."""
# ────────────────────────────────────────────────────────────────────
# REGISTER
# ────────────────────────────────────────────────────────────────────
@pytest.mark.asyncio
async def test_register_gateway(self, gateway_service, test_db, monkeypatch):
"""Successful gateway registration populates DB and returns data."""
test_db.execute = Mock(
side_effect=[
_make_execute_result(scalar=None), # name-conflict check
_make_execute_result(scalars_list=[]), # tool lookup
]
)
test_db.add = Mock()
test_db.flush = Mock() # Implementation uses flush() not commit()
test_db.refresh = Mock()
# Mock query for _check_gateway_uniqueness
test_db.query = Mock(return_value=Mock(filter=Mock(return_value=Mock(all=Mock(return_value=[])))))
# Internal helpers
gateway_service._initialize_gateway = AsyncMock(
return_value=(
{
"prompts": {"listChanged": True},
"resources": {"listChanged": True},
"tools": {"listChanged": True},
},
[],
[],
[],
)
)
gateway_service._notify_gateway_added = AsyncMock()
url = url_normalize("example.com")
# Patch GatewayRead.model_validate to return a mock with .masked()
mock_model = Mock()
mock_model.masked.return_value = mock_model
mock_model.name = "test_gateway"
mock_model.url = url
mock_model.description = "A test gateway"
monkeypatch.setattr(
"mcpgateway.services.gateway_service.GatewayRead.model_validate",
lambda x: mock_model,
)
gateway_create = GatewayCreate(
name="test_gateway",
url=url,
description="A test gateway",
)
result = await gateway_service.register_gateway(test_db, gateway_create)
test_db.add.assert_called_once()
test_db.flush.assert_called_once() # Implementation uses flush() not commit()
test_db.refresh.assert_called_once()
gateway_service._initialize_gateway.assert_called_once()
gateway_service._notify_gateway_added.assert_called_once()
# `result` is the same GatewayCreate instance because we stubbed
# GatewayRead.model_validate → just check its fields:
assert result.name == "test_gateway"
expected_url = url
assert result.url == expected_url
assert result.description == "A test gateway"
mock_model.url = expected_url
@pytest.mark.asyncio
async def test_register_gateway_name_conflict(self, gateway_service, mock_gateway, test_db):
"""Trying to register a gateway whose *name* already exists raises a conflict error."""
# DB returns an existing gateway with the same name
mock_gateway.name = "test_gateway"
mock_gateway.slug = "test-gateway"
test_db.execute = Mock(return_value=_make_execute_result(scalar=mock_gateway))
gateway_create = GatewayCreate(
name="test_gateway", # same as mock_gateway
slug="test-gateway",
url="http://example.com/other",
description="Another gateway",
visibility="public",
)
with pytest.raises(GatewayNameConflictError) as exc_info:
await gateway_service.register_gateway(test_db, gateway_create)
err = exc_info.value
assert "Public Gateway already exists with name" in str(err)
assert err.name == "test-gateway"
assert err.gateway_id == mock_gateway.id
@pytest.mark.asyncio
async def test_register_gateway_connection_error(self, gateway_service, test_db):
"""Initial connection to the remote gateway fails and the error propagates."""
test_db.execute = Mock(return_value=_make_execute_result(scalar=None))
# _initialize_gateway blows up before any DB work happens
gateway_service._initialize_gateway = AsyncMock(side_effect=GatewayConnectionError("Failed to connect"))
gateway_create = GatewayCreate(
name="test_gateway",
url="http://example.com/gateway",
description="A test gateway",
)
with pytest.raises(GatewayConnectionError) as exc_info:
await gateway_service.register_gateway(test_db, gateway_create)
assert "Failed to connect" in str(exc_info.value)
@pytest.mark.asyncio
async def test_register_gateway_with_auth(self, gateway_service, test_db, monkeypatch):
"""Test registering gateway with authentication credentials."""
test_db.execute = Mock(
side_effect=[
_make_execute_result(scalar=None), # name-conflict check
_make_execute_result(scalars_list=[]), # tool lookup
]
)
test_db.add = Mock()
test_db.flush = Mock() # Implementation uses flush() not commit()
test_db.refresh = Mock()
# Mock query for _check_gateway_uniqueness
test_db.query = Mock(return_value=Mock(filter=Mock(return_value=Mock(all=Mock(return_value=[])))))
url = url_normalize("example.com")
print(f"url:{url}")
gateway_service._initialize_gateway = AsyncMock(
return_value=(
{
"resources": {"listChanged": True},
"tools": {"listChanged": True},
},
[],
[],
[],
)
)
gateway_service._notify_gateway_added = AsyncMock()
mock_model = Mock()
mock_model.masked.return_value = mock_model
mock_model.name = "auth_gateway"
mock_model.url = url
monkeypatch.setattr(
"mcpgateway.services.gateway_service.GatewayRead.model_validate",
lambda x: mock_model,
)
gateway_create = GatewayCreate(name="auth_gateway", url=url, description="Gateway with auth", auth_type="bearer", auth_token="test-token")
await gateway_service.register_gateway(test_db, gateway_create)
test_db.add.assert_called_once()
test_db.flush.assert_called_once() # Implementation uses flush() not commit()
gateway_service._initialize_gateway.assert_called_once()
@pytest.mark.asyncio
async def test_register_gateway_with_tools(self, gateway_service, test_db, monkeypatch):
"""Test registering gateway that returns tools from initialization."""
test_db.execute = Mock(
side_effect=[
_make_execute_result(scalar=None), # name-conflict check
_make_execute_result(scalars_list=[]), # tool lookup
]
)
test_db.add = Mock()
test_db.commit = Mock()
test_db.refresh = Mock()
# Mock tools returned from gateway
# First-Party
from mcpgateway.schemas import ToolCreate
mock_tools = [ToolCreate(name="test_tool", description="A test tool", integration_type="REST", request_type="POST", input_schema={"type": "object"})]
gateway_service._initialize_gateway = AsyncMock(
return_value=(
{
"prompts": {"listChanged": True},
"resources": {"listChanged": True},
"tools": {"listChanged": True},
},
mock_tools,
[],
[],
)
)
gateway_service._notify_gateway_added = AsyncMock()
mock_model = Mock()
mock_model.masked.return_value = mock_model
mock_model.name = "tool_gateway"
monkeypatch.setattr(
"mcpgateway.services.gateway_service.GatewayRead.model_validate",
lambda x: mock_model,
)
gateway_create = GatewayCreate(
name="tool_gateway",
url="http://example.com/gateway",
description="Gateway with tools",
)
await gateway_service.register_gateway(test_db, gateway_create)
test_db.add.assert_called_once()
# Verify that tools were created and added to the gateway
db_gateway_call = test_db.add.call_args[0][0]
assert len(db_gateway_call.tools) == 1
assert db_gateway_call.tools[0].original_name == "test_tool"
@pytest.mark.asyncio
async def test_register_gateway_inactive_name_conflict(self, gateway_service, test_db):
"""Test name conflict with an inactive gateway."""
# Mock an inactive gateway with the same name
inactive_gateway = MagicMock(spec=DbGateway)
inactive_gateway.id = 2
inactive_gateway.name = "test_gateway"
inactive_gateway.slug = "test-gateway"
inactive_gateway.enabled = False
test_db.execute = Mock(return_value=_make_execute_result(scalar=inactive_gateway))
gateway_create = GatewayCreate(name="test_gateway", slug="test-gateway", url="http://example.com/gateway", description="New gateway", visibility="public")
with pytest.raises(GatewayNameConflictError) as exc_info:
await gateway_service.register_gateway(test_db, gateway_create)
err = exc_info.value
assert "Public Gateway already exists with name" in str(err)
assert err.name == "test-gateway"
assert err.enabled is False
assert err.gateway_id == 2
@pytest.mark.asyncio
async def test_register_gateway_database_error(self, gateway_service, test_db):
"""Test database error during gateway registration."""
test_db.execute = Mock(return_value=_make_execute_result(scalar=None))
test_db.add = Mock()
test_db.flush = Mock(side_effect=Exception("Database error")) # Implementation uses flush() not commit()
test_db.rollback = Mock()
# Mock query for _check_gateway_uniqueness
test_db.query = Mock(return_value=Mock(filter=Mock(return_value=Mock(all=Mock(return_value=[])))))
gateway_service._initialize_gateway = AsyncMock(return_value=({"tools": {"listChanged": True}}, [], [], []))
gateway_create = GatewayCreate(
name="test_gateway",
url="http://example.com/gateway",
description="Test gateway",
)
with pytest.raises(Exception) as exc_info:
await gateway_service.register_gateway(test_db, gateway_create)
assert "Database error" in str(exc_info.value)
@pytest.mark.asyncio
async def test_register_gateway_value_error(self, gateway_service, test_db):
"""Test ValueError during gateway registration."""
test_db.execute = Mock(return_value=_make_execute_result(scalar=None))
gateway_service._initialize_gateway = AsyncMock(side_effect=ValueError("Invalid gateway configuration"))
gateway_create = GatewayCreate(
name="test_gateway",
url="http://example.com/gateway",
description="Test gateway",
)
with pytest.raises(ValueError) as exc_info:
await gateway_service.register_gateway(test_db, gateway_create)
assert "Invalid gateway configuration" in str(exc_info.value)
@pytest.mark.asyncio
async def test_register_gateway_runtime_error(self, gateway_service, test_db):
"""Test RuntimeError during gateway registration."""
test_db.execute = Mock(return_value=_make_execute_result(scalar=None))
gateway_service._initialize_gateway = AsyncMock(side_effect=RuntimeError("Runtime error occurred"))
gateway_create = GatewayCreate(
name="test_gateway",
url="http://example.com/gateway",
description="Test gateway",
)
with pytest.raises(RuntimeError) as exc_info:
await gateway_service.register_gateway(test_db, gateway_create)
assert "Runtime error occurred" in str(exc_info.value)
@pytest.mark.asyncio
async def test_register_gateway_integrity_error(self, gateway_service, test_db):
"""Test IntegrityError during gateway registration."""
# Third-Party
from sqlalchemy.exc import IntegrityError as SQLIntegrityError
test_db.execute = Mock(return_value=_make_execute_result(scalar=None))
test_db.add = Mock()
test_db.flush = Mock(side_effect=SQLIntegrityError("statement", "params", BaseException("orig"))) # Implementation uses flush()
# Mock query for _check_gateway_uniqueness
test_db.query = Mock(return_value=Mock(filter=Mock(return_value=Mock(all=Mock(return_value=[])))))
gateway_service._initialize_gateway = AsyncMock(return_value=({"tools": {"listChanged": True}}, [], [], []))
gateway_create = GatewayCreate(
name="test_gateway",
url="http://example.com/gateway",
description="Test gateway",
)
with pytest.raises(SQLIntegrityError):
await gateway_service.register_gateway(test_db, gateway_create)
@pytest.mark.asyncio
async def test_register_gateway_masked_auth_value(self, gateway_service, test_db, monkeypatch):
"""Test registering gateway with masked auth value that should not be updated."""
test_db.execute = Mock(
side_effect=[
_make_execute_result(scalar=None), # name-conflict check
_make_execute_result(scalars_list=[]), # tool lookup
]
)
test_db.add = Mock()
test_db.flush = Mock() # Implementation uses flush() not commit()
test_db.refresh = Mock()
# Mock query for _check_gateway_uniqueness
test_db.query = Mock(return_value=Mock(filter=Mock(return_value=Mock(all=Mock(return_value=[])))))
gateway_service._initialize_gateway = AsyncMock(return_value=({"tools": {"listChanged": True}}, [], [], []))
gateway_service._notify_gateway_added = AsyncMock()
mock_model = Mock()
mock_model.masked.return_value = mock_model
mock_model.name = "auth_gateway"
monkeypatch.setattr(
"mcpgateway.services.gateway_service.GatewayRead.model_validate",
lambda x: mock_model,
)
# Mock settings for masked auth value
with patch("mcpgateway.services.gateway_service.settings.masked_auth_value", "***MASKED***"):
gateway_create = GatewayCreate(
name="auth_gateway",
url="http://example.com/gateway",
description="Gateway with masked auth",
auth_type="bearer",
auth_token="***MASKED***", # This should not update the auth_value
)
await gateway_service.register_gateway(test_db, gateway_create)
test_db.add.assert_called_once()
test_db.flush.assert_called_once() # Implementation uses flush() not commit()
gateway_service._initialize_gateway.assert_called_once()
@pytest.mark.asyncio
async def test_register_gateway_exception_rollback(self, gateway_service, test_db):
"""Test rollback on exception during gateway registration."""
test_db.execute = Mock(return_value=_make_execute_result(scalar=None))
test_db.add = Mock()
test_db.flush = Mock(side_effect=Exception("Flush failed")) # Implementation uses flush() not commit()
test_db.rollback = Mock()
# Mock query for _check_gateway_uniqueness
test_db.query = Mock(return_value=Mock(filter=Mock(return_value=Mock(all=Mock(return_value=[])))))
gateway_service._initialize_gateway = AsyncMock(return_value=({"tools": {"listChanged": True}}, [], [], []))
gateway_create = GatewayCreate(
name="test_gateway",
url="http://example.com/gateway",
description="Test gateway",
)
with pytest.raises(Exception) as exc_info:
await gateway_service.register_gateway(test_db, gateway_create)
assert "Flush failed" in str(exc_info.value) # Error message matches the mocked exception
# The register_gateway method doesn't actually call rollback in the exception handler
# It just re-raises the exception, so we shouldn't expect rollback to be called
@pytest.mark.asyncio
async def test_register_gateway_with_existing_tools(self, gateway_service, test_db, monkeypatch):
"""Test registering gateway with URL/credentials that already exist (duplicate gateway)."""
# Mock existing GATEWAY in database (not tool)
existing_gateway = MagicMock()
existing_gateway.id = 123
existing_gateway.url = "http://example.com/gateway"
existing_gateway.enabled = True
existing_gateway.visibility = "public"
existing_gateway.name = "existing_gateway"
existing_gateway.team_id = None
existing_gateway.owner_email = "test@example.com"
test_db.execute = Mock(
side_effect=[
_make_execute_result(scalar=None), # name-conflict check
# No second call needed - check_gateway_uniqueness uses query().all()
]
)
# Mock check_gateway_uniqueness to return the existing gateway
gateway_service._check_gateway_uniqueness = Mock(return_value=existing_gateway)
test_db.add = Mock()
test_db.commit = Mock()
test_db.refresh = Mock()
gateway_create = GatewayCreate(
name="tool_gateway",
url="http://example.com/gateway", # Same URL as existing
description="Gateway with existing tools",
)
with pytest.raises(GatewayDuplicateConflictError) as exc_info:
await gateway_service.register_gateway(test_db, gateway_create)
# Verify the error details
assert exc_info.value.gateway_id == 123
assert exc_info.value.enabled is True
# ────────────────────────────────────────────────────────────────────
# Validate Gateway URL SSL Verification
# ────────────────────────────────────────────────────────────────────
@pytest.mark.skip("Yet to implement")
async def test_ssl_verification_bypass(self, gateway_service, monkeypatch):
"""
Test case logic to verify settings.skip_ssl_verify
"""
# ────────────────────────────────────────────────────────────────────
# LIST / GET
# ────────────────────────────────────────────────────────────────────
@pytest.mark.asyncio
async def test_list_gateways(self, gateway_service, mock_gateway, test_db, monkeypatch):
"""Listing gateways returns the active ones."""
test_db.execute = Mock(return_value=_make_execute_result(scalars_list=[mock_gateway]))
mock_model = Mock()
mock_model.masked.return_value = mock_model
mock_model.name = "test_gateway"
# Patch using full path string to GatewayRead.model_validate
monkeypatch.setattr("mcpgateway.services.gateway_service.GatewayRead.model_validate", lambda x: mock_model)
result, next_cursor = await gateway_service.list_gateways(test_db)
# Assert that execute was called twice
assert test_db.execute.call_count == 2
# Optionally, print or check call arguments for debugging
# print(test_db.execute.call_args_list)
assert len(result) == 1
assert result[0].name == "test_gateway"
@pytest.mark.asyncio
async def test_get_gateway(self, gateway_service, mock_gateway, test_db):
"""Gateway is fetched and returned by ID."""
mock_gateway.masked = Mock(return_value=mock_gateway)
mock_gateway.team_id = 1 # Ensure team_id is a real value
test_db.get = Mock(return_value=mock_gateway)
result = await gateway_service.get_gateway(test_db, 1)
test_db.get.assert_called_once_with(DbGateway, 1)
assert result.name == "test_gateway"
assert result.capabilities == mock_gateway.capabilities
@pytest.mark.asyncio
async def test_get_gateway_not_found(self, gateway_service, test_db):
"""Missing ID → GatewayNotFoundError."""
test_db.get = Mock(return_value=None)
with pytest.raises(GatewayNotFoundError):
await gateway_service.get_gateway(test_db, 999)
@pytest.mark.asyncio
async def test_get_gateway_inactive(self, gateway_service, mock_gateway, test_db):
"""Inactive gateway is not returned unless explicitly asked for."""
mock_gateway.enabled = False
mock_gateway.id = 1
mock_gateway.team_id = 1 # Ensure team_id is a real value
test_db.get = Mock(return_value=mock_gateway)
# Create a mock for GatewayRead with a masked method
mock_gateway_read = Mock()
mock_gateway_read.id = 1
mock_gateway_read.enabled = False
mock_gateway_read.masked = Mock(return_value=mock_gateway_read)
with patch("mcpgateway.services.gateway_service.GatewayRead.model_validate", return_value=mock_gateway_read):
result = await gateway_service.get_gateway(test_db, 1, include_inactive=True)
assert result.id == 1
assert not result.enabled
# Now test the inactive = False path
test_db.get = Mock(return_value=mock_gateway)
with pytest.raises(GatewayNotFoundError):
await gateway_service.get_gateway(test_db, 1, include_inactive=False)
# ────────────────────────────────────────────────────────────────────
# UPDATE
# ────────────────────────────────────────────────────────────────────
@pytest.mark.asyncio
async def test_update_gateway(self, gateway_service, mock_gateway, test_db):
"""All mutable fields can be updated."""
mock_gateway.team_id = 1 # Ensure team_id is a real value
# Mock execute to return gateway for selectinload query (first call)
# and None for name-conflict check (subsequent calls)
execute_results = [_make_execute_result(scalar=mock_gateway), _make_execute_result(scalar=None)]
test_db.execute = Mock(side_effect=execute_results)
test_db.commit = Mock()
test_db.refresh = Mock()
# Simulate successful gateway initialization
gateway_service._initialize_gateway = AsyncMock(
return_value=(
{
"prompts": {"subscribe": True},
"resources": {"subscribe": True},
"tools": {"subscribe": True},
},
[],
)
)
gateway_service._notify_gateway_updated = AsyncMock()
# Create the update payload
gateway_update = GatewayUpdate(
name="updated_gateway",
url="http://example.com/updated",
description="Updated description",
)
# Create mock return for GatewayRead.model_validate().masked()
mock_gateway_read = MagicMock()
mock_gateway_read.name = "updated_gateway"
mock_gateway_read.masked.return_value = mock_gateway_read # Ensure .masked() returns the same object
# Patch the model_validate call in the service
with patch("mcpgateway.services.gateway_service.GatewayRead.model_validate", return_value=mock_gateway_read):
result = await gateway_service.update_gateway(test_db, 1, gateway_update)
# Assertions
test_db.commit.assert_called_once()
test_db.refresh.assert_called_once()
gateway_service._initialize_gateway.assert_called_once()
gateway_service._notify_gateway_updated.assert_called_once()
assert mock_gateway.name == "updated_gateway"
assert result.name == "updated_gateway"
@pytest.mark.asyncio
async def test_update_gateway_not_found(self, gateway_service, test_db):
"""Updating a non-existent gateway surfaces GatewayError with message."""
test_db.execute = Mock(return_value=_make_execute_result(scalar=None))
gateway_update = GatewayUpdate(name="whatever")
with pytest.raises(GatewayError) as exc_info:
await gateway_service.update_gateway(test_db, 999, gateway_update)
assert "Gateway not found: 999" in str(exc_info.value)
@pytest.mark.asyncio
async def test_update_gateway_name_conflict(self, gateway_service, mock_gateway, test_db):
"""Changing the name to one that already exists raises GatewayError."""
mock_gateway.name = "original_name"
mock_gateway.slug = "original-name"
mock_gateway.visibility = "public"
mock_gateway.team_id = 1 # Ensure team_id is a real value
conflicting = MagicMock(spec=DbGateway, id=2, name="existing_gateway", slug="existing-gateway", visibility="public", is_active=True)
# First call returns the gateway to update (with selectinload), second returns the conflicting one
execute_results = [_make_execute_result(scalar=mock_gateway), _make_execute_result(scalar=conflicting)]
test_db.execute = Mock(side_effect=execute_results)
test_db.rollback = Mock()
# gateway_update = MagicMock(spec=GatewayUpdate, name="existing_gateway")
gateway_update = GatewayUpdate(name="existing_gateway", slug="existing-gateway")
with pytest.raises(GatewayError) as exc_info:
await gateway_service.update_gateway(test_db, 1, gateway_update)
assert "Public Gateway already exists with name" in str(exc_info.value)
@pytest.mark.asyncio
async def test_update_gateway_with_auth_update(self, gateway_service, mock_gateway, test_db):
"""Test updating gateway with new authentication values."""
mock_gateway.auth_type = "bearer"
mock_gateway.auth_value = "old-token-encrypted"
mock_gateway.team_id = 1 # Ensure team_id is a real value
# First call returns gateway (selectinload query), rest are for conflict checks and team lookups
test_db.execute = Mock(return_value=_make_execute_result(scalar=mock_gateway))
test_db.commit = Mock()
test_db.refresh = Mock()
# Mock the query for team name lookup
test_db.query = Mock(return_value=Mock(filter=Mock(return_value=Mock(first=Mock(return_value=None)))))
gateway_service._initialize_gateway = AsyncMock(return_value=({"tools": {"listChanged": True}}, [], [], []))
gateway_service._notify_gateway_updated = AsyncMock()
# Mock settings for auth value checking
with patch("mcpgateway.services.gateway_service.settings.masked_auth_value", "***MASKED***"):
gateway_update = GatewayUpdate(auth_type="bearer", auth_token="new-token")
mock_gateway_read = MagicMock()
mock_gateway_read.masked.return_value = mock_gateway_read
with patch("mcpgateway.services.gateway_service.GatewayRead.model_validate", return_value=mock_gateway_read):
await gateway_service.update_gateway(test_db, 1, gateway_update)
# Check that auth_type was updated
assert mock_gateway.auth_type == "bearer"
test_db.commit.assert_called_once()
@pytest.mark.asyncio
async def test_update_gateway_clear_auth(self, gateway_service, mock_gateway, test_db):
"""Test clearing authentication from gateway."""
mock_gateway.auth_type = "bearer"
mock_gateway.auth_value = {"token": "old-token"}
mock_gateway.team_id = 1 # Ensure team_id is a real value
# Use return_value for all execute calls
test_db.execute = Mock(return_value=_make_execute_result(scalar=mock_gateway))
test_db.commit = Mock()
test_db.refresh = Mock()
# Mock the query for team name lookup
test_db.query = Mock(return_value=Mock(filter=Mock(return_value=Mock(first=Mock(return_value=None)))))
gateway_service._initialize_gateway = AsyncMock(return_value=({"tools": {"listChanged": True}}, [], [], []))
gateway_service._notify_gateway_updated = AsyncMock()
gateway_update = GatewayUpdate(auth_type="")
mock_gateway_read = MagicMock()
mock_gateway_read.masked.return_value = mock_gateway_read
with patch("mcpgateway.services.gateway_service.GatewayRead.model_validate", return_value=mock_gateway_read):
await gateway_service.update_gateway(test_db, 1, gateway_update)
assert mock_gateway.auth_type == ""
assert mock_gateway.auth_value == ""
test_db.commit.assert_called_once()
@pytest.mark.asyncio
async def test_update_gateway_url_change_with_tools(self, gateway_service, mock_gateway, test_db):
"""Test updating gateway URL and tools are refreshed."""
# Setup existing tool
existing_tool = MagicMock()
existing_tool.original_name = "existing_tool"
mock_gateway.tools = [existing_tool]
mock_gateway.team_id = 1 # Ensure team_id is a real value
# First call returns gateway (selectinload), then conflict checks
test_db.execute = Mock(
side_effect=[
_make_execute_result(scalar=mock_gateway), # selectinload gateway
_make_execute_result(scalar=None), # name conflict check
_make_execute_result(scalar=existing_tool), # existing tool check
]
)
test_db.commit = Mock()
test_db.refresh = Mock()
# Mock the query for team name lookup
test_db.query = Mock(return_value=Mock(filter=Mock(return_value=Mock(first=Mock(return_value=None)))))
# Mock new tools from gateway
# First-Party
from mcpgateway.schemas import ToolCreate
new_tools = [
ToolCreate(name="existing_tool", description="Updated tool", integration_type="REST", request_type="POST", input_schema={"type": "object"}),
ToolCreate(name="new_tool", description="Brand new tool", integration_type="REST", request_type="POST", input_schema={"type": "object"}),
]
gateway_service._initialize_gateway = AsyncMock(return_value=({"tools": {"listChanged": True}}, new_tools, [], []))
gateway_service._notify_gateway_updated = AsyncMock()
url = GatewayService.normalize_url("http://example.com/new-url")
gateway_update = GatewayUpdate(url=url)
mock_gateway_read = MagicMock()
mock_gateway_read.masked.return_value = mock_gateway_read
with patch("mcpgateway.services.gateway_service.GatewayRead.model_validate", return_value=mock_gateway_read):
try:
await gateway_service.update_gateway(test_db, 1, gateway_update)
except Exception as e:
print(f"Exception during update_gateway: {e}")
import traceback
traceback.print_exc()
raise
assert mock_gateway.url == url
gateway_service._initialize_gateway.assert_called_once()
test_db.commit.assert_called_once()
@pytest.mark.asyncio
async def test_update_gateway_url_initialization_failure(self, gateway_service, mock_gateway, test_db):
"""Test updating gateway URL when initialization fails."""
# Use return_value for all execute calls
test_db.execute = Mock(return_value=_make_execute_result(scalar=mock_gateway))
test_db.commit = Mock()
test_db.refresh = Mock()
# Mock the query for team name lookup
test_db.query = Mock(return_value=Mock(filter=Mock(return_value=Mock(first=Mock(return_value=None)))))
# Mock initialization failure
gateway_service._initialize_gateway = AsyncMock(side_effect=GatewayConnectionError("Connection failed"))
gateway_service._notify_gateway_updated = AsyncMock()
url = GatewayService.normalize_url("http://example.com/bad-url")
gateway_update = GatewayUpdate(url=url)
mock_gateway_read = MagicMock()
mock_gateway_read.masked.return_value = mock_gateway_read
# Should not raise exception, just log warning
with patch("mcpgateway.services.gateway_service.GatewayRead.model_validate", return_value=mock_gateway_read):
await gateway_service.update_gateway(test_db, 1, gateway_update)
assert mock_gateway.url == url
test_db.commit.assert_called_once()
@pytest.mark.asyncio
async def test_update_gateway_partial_update(self, gateway_service, mock_gateway, test_db):
"""Test updating only some fields."""
# Use return_value for all execute calls
test_db.execute = Mock(return_value=_make_execute_result(scalar=mock_gateway))
test_db.commit = Mock()
test_db.refresh = Mock()
# Mock the query for team name lookup
test_db.query = Mock(return_value=Mock(filter=Mock(return_value=Mock(first=Mock(return_value=None)))))
gateway_service._notify_gateway_updated = AsyncMock()
# Only update description
gateway_update = GatewayUpdate(description="New description only")
mock_gateway_read = MagicMock()
mock_gateway_read.masked.return_value = mock_gateway_read
with patch("mcpgateway.services.gateway_service.GatewayRead.model_validate", return_value=mock_gateway_read):
await gateway_service.update_gateway(test_db, 1, gateway_update)
# Only description should be updated
assert mock_gateway.description == "New description only"
# Name and URL should remain unmodified
assert mock_gateway.name == "test_gateway"
assert mock_gateway.url == "http://example.com/gateway"
test_db.commit.assert_called_once()
@pytest.mark.asyncio
async def test_update_gateway_inactive_excluded(self, gateway_service, mock_gateway, test_db):
"""Test updating inactive gateway when include_inactive=False - should return None."""
mock_gateway.enabled = False
test_db.execute = Mock(return_value=_make_execute_result(scalar=mock_gateway))
gateway_update = GatewayUpdate(description="New description")
# When gateway is inactive and include_inactive=False,
# the method skips the update logic and returns None implicitly
result = await gateway_service.update_gateway(test_db, 1, gateway_update, include_inactive=False)
# The method should return None when the condition fails
assert result is None
# Verify that description was NOT updated (since update was skipped)
assert mock_gateway.description != "New description"
@pytest.mark.asyncio
async def test_update_gateway_database_rollback(self, gateway_service, mock_gateway, test_db):
"""Test database rollback on update failure."""
# Use return_value for all execute calls
test_db.execute = Mock(return_value=_make_execute_result(scalar=mock_gateway))
test_db.commit = Mock(side_effect=Exception("Database error"))
test_db.rollback = Mock()
# Mock the query for team name lookup
test_db.query = Mock(return_value=Mock(filter=Mock(return_value=Mock(first=Mock(return_value=None)))))
gateway_service._notify_gateway_updated = AsyncMock()
gateway_update = GatewayUpdate(description="New description")
with pytest.raises(GatewayError) as exc_info:
await gateway_service.update_gateway(test_db, 1, gateway_update)
assert "Failed to update gateway" in str(exc_info.value)
test_db.rollback.assert_called_once()
@pytest.mark.asyncio
async def test_update_gateway_with_masked_auth(self, gateway_service, mock_gateway, test_db):
"""Test updating gateway with masked auth values that should not be changed."""
mock_gateway.auth_type = "bearer"
mock_gateway.auth_value = "existing-token"
# Use return_value for all execute calls
test_db.execute = Mock(return_value=_make_execute_result(scalar=mock_gateway))
test_db.commit = Mock()
test_db.refresh = Mock()
# Mock the query for team name lookup
test_db.query = Mock(return_value=Mock(filter=Mock(return_value=Mock(first=Mock(return_value=None)))))
gateway_service._notify_gateway_updated = AsyncMock()
# Mock settings for masked auth value
with patch("mcpgateway.services.gateway_service.settings.masked_auth_value", "***MASKED***"):
gateway_update = GatewayUpdate(auth_type="bearer", auth_token="***MASKED***", auth_password="***MASKED***", auth_header_value="***MASKED***") # This should not update the auth_value
mock_gateway_read = MagicMock()
mock_gateway_read.masked.return_value = mock_gateway_read
with patch("mcpgateway.services.gateway_service.GatewayRead.model_validate", return_value=mock_gateway_read):
await gateway_service.update_gateway(test_db, 1, gateway_update)
# Auth value should remain unmodified since all values were masked
assert mock_gateway.auth_value == "existing-token"
test_db.commit.assert_called_once()
@pytest.mark.asyncio
async def test_update_gateway_integrity_error(self, gateway_service, mock_gateway, test_db):
"""Test IntegrityError during gateway update."""
# Third-Party
from sqlalchemy.exc import IntegrityError as SQLIntegrityError
# Use return_value for all execute calls
test_db.execute = Mock(return_value=_make_execute_result(scalar=mock_gateway))
test_db.commit = Mock(side_effect=SQLIntegrityError("statement", "params", BaseException("orig")))
# Mock the query for team name lookup
test_db.query = Mock(return_value=Mock(filter=Mock(return_value=Mock(first=Mock(return_value=None)))))
gateway_service._notify_gateway_updated = AsyncMock()
gateway_update = GatewayUpdate(description="New description")
with pytest.raises(SQLIntegrityError):
await gateway_service.update_gateway(test_db, 1, gateway_update)
def test_normalize_url_preserves_domain(self):
"""Test that normalize_url preserves domain names but normalizes localhost."""
# Test with various domain formats
test_cases = [
# Regular domains should be preserved as-is
("http://example.com", "http://example.com"),
("https://api.example.com:8080/path", "https://api.example.com:8080/path"),
("https://my-app.cloud-provider.region.example.com/sse", "https://my-app.cloud-provider.region.example.com/sse"),
("https://cdn.service.com/api/v1", "https://cdn.service.com/api/v1"),
# localhost should remain localhost
("http://localhost:8000", "http://localhost:8000"),
("https://localhost/api", "https://localhost/api"),
# 127.0.0.1 should be normalized to localhost to prevent duplicates
("http://127.0.0.1:8080/path", "http://localhost:8080/path"),
("https://127.0.0.1/sse", "https://localhost/sse"),
]
for input_url, expected in test_cases:
result = GatewayService.normalize_url(input_url)
assert result == expected, f"normalize_url({input_url}) should return {expected}, got {result}"
def test_normalize_url_prevents_localhost_duplicates(self):
"""Test that normalization prevents localhost/127.0.0.1 duplicates."""
# These URLs should all normalize to the same value
equivalent_urls = [
"http://127.0.0.1:8080/sse",
"http://localhost:8080/sse",
]
normalized = [GatewayService.normalize_url(url) for url in equivalent_urls]
# All should normalize to localhost version
assert all(n == "http://localhost:8080/sse" for n in normalized), f"All localhost variants should normalize to same URL, got: {normalized}"
# They should all be the same (no duplicates possible)
assert len(set(normalized)) == 1, "All localhost variants should produce identical normalized URLs"
@pytest.mark.asyncio
async def test_update_gateway_with_transport_change(self, gateway_service, mock_gateway, test_db):
"""Test updating gateway transport type."""
# Use return_value for all execute calls
test_db.execute = Mock(return_value=_make_execute_result(scalar=mock_gateway))
test_db.commit = Mock()
test_db.refresh = Mock()
# Mock the query for team name lookup
test_db.query = Mock(return_value=Mock(filter=Mock(return_value=Mock(first=Mock(return_value=None)))))
gateway_service._initialize_gateway = AsyncMock(return_value=({"tools": {"listChanged": True}}, [], [], []))
gateway_service._notify_gateway_updated = AsyncMock()
gateway_update = GatewayUpdate(transport="STREAMABLEHTTP")
mock_gateway_read = MagicMock()
mock_gateway_read.masked.return_value = mock_gateway_read
with patch("mcpgateway.services.gateway_service.GatewayRead.model_validate", return_value=mock_gateway_read):
await gateway_service.update_gateway(test_db, 1, gateway_update)
assert mock_gateway.transport == "STREAMABLEHTTP"
test_db.commit.assert_called_once()
@pytest.mark.asyncio
async def test_update_gateway_without_auth_type_attr(self, gateway_service, test_db):
"""Test updating gateway that doesn't have auth_type attribute."""
# Create mock gateway without auth_type attribute
mock_gateway_no_auth = MagicMock(spec=DbGateway)
mock_gateway_no_auth.id = 1
mock_gateway_no_auth.name = "test_gateway"
mock_gateway_no_auth.enabled = True
# Don't set auth_type attribute to test the getattr fallback
# Use return_value for all execute calls
test_db.execute = Mock(return_value=_make_execute_result(scalar=mock_gateway_no_auth))
test_db.commit = Mock()
test_db.refresh = Mock()
# Mock the query for team name lookup
test_db.query = Mock(return_value=Mock(filter=Mock(return_value=Mock(first=Mock(return_value=None)))))
gateway_service._notify_gateway_updated = AsyncMock()
gateway_update = GatewayUpdate(description="New description")
mock_gateway_read = MagicMock()
mock_gateway_read.masked.return_value = mock_gateway_read
with patch("mcpgateway.services.gateway_service.GatewayRead.model_validate", return_value=mock_gateway_read):
await gateway_service.update_gateway(test_db, 1, gateway_update)
assert mock_gateway_no_auth.description == "New description"
test_db.commit.assert_called_once()
# ────────────────────────────────────────────────────────────────────
# TOGGLE ACTIVE / INACTIVE
# ────────────────────────────────────────────────────────────────────
@pytest.mark.asyncio
async def test_toggle_gateway_status(self, gateway_service, mock_gateway, test_db):
"""Deactivating an active gateway triggers tool-status toggle + event."""
test_db.execute = Mock(return_value=_make_execute_result(scalar=mock_gateway))
test_db.commit = Mock()
test_db.refresh = Mock()
# Return one tool so toggle_tool_status gets called
query_proxy = MagicMock()
filter_proxy = MagicMock()
filter_proxy.all.return_value = [MagicMock(id=101)]
query_proxy.filter.return_value = filter_proxy
test_db.query = Mock(return_value=query_proxy)
# Setup gateway service mocks
gateway_service._notify_gateway_activated = AsyncMock()
gateway_service._notify_gateway_deactivated = AsyncMock()
gateway_service._initialize_gateway = AsyncMock(return_value=({"prompts": {}}, [], [], []))
tool_service_stub = MagicMock()
tool_service_stub.toggle_tool_status = AsyncMock()
gateway_service.tool_service = tool_service_stub
# Patch model_validate to return a mock with .masked()
mock_gateway_read = MagicMock()
mock_gateway_read.masked.return_value = mock_gateway_read
with patch("mcpgateway.services.gateway_service.GatewayRead.model_validate", return_value=mock_gateway_read):
result = await gateway_service.toggle_gateway_status(test_db, 1, activate=False)
assert mock_gateway.enabled is False
gateway_service._notify_gateway_deactivated.assert_called_once()
assert tool_service_stub.toggle_tool_status.called
assert result == mock_gateway_read
@pytest.mark.asyncio
async def test_toggle_gateway_status_activate(self, gateway_service, mock_gateway, test_db):
"""Test activating an inactive gateway."""
mock_gateway.enabled = False
test_db.execute = Mock(return_value=_make_execute_result(scalar=mock_gateway))
test_db.commit = Mock()
test_db.refresh = Mock()
# Return one tool so toggle_tool_status gets called
query_proxy = MagicMock()
filter_proxy = MagicMock()
filter_proxy.all.return_value = [MagicMock(id=101)]
query_proxy.filter.return_value = filter_proxy
test_db.query = Mock(return_value=query_proxy)
# Setup gateway service mocks
gateway_service._notify_gateway_activated = AsyncMock()
gateway_service._notify_gateway_deactivated = AsyncMock()
gateway_service._initialize_gateway = AsyncMock(return_value=({"prompts": {}}, [], [], []))
tool_service_stub = MagicMock()
tool_service_stub.toggle_tool_status = AsyncMock()
gateway_service.tool_service = tool_service_stub
# Patch model_validate to return a mock with .masked()
mock_gateway_read = MagicMock()
mock_gateway_read.masked.return_value = mock_gateway_read
with patch("mcpgateway.services.gateway_service.GatewayRead.model_validate", return_value=mock_gateway_read):
result = await gateway_service.toggle_gateway_status(test_db, 1, activate=True)
assert mock_gateway.enabled is True
gateway_service._notify_gateway_activated.assert_called_once()
assert tool_service_stub.toggle_tool_status.called
assert result == mock_gateway_read
@pytest.mark.asyncio
async def test_toggle_gateway_status_not_found(self, gateway_service, test_db):
"""Test toggling status of non-existent gateway."""
test_db.execute = Mock(return_value=_make_execute_result(scalar=None))
with pytest.raises(GatewayError) as exc_info:
await gateway_service.toggle_gateway_status(test_db, 999, activate=True)
assert "Gateway not found: 999" in str(exc_info.value)
@pytest.mark.asyncio
async def test_toggle_gateway_status_with_tools_error(self, gateway_service, mock_gateway, test_db):
"""Test toggling gateway status when tool toggle fails."""
test_db.execute = Mock(return_value=_make_execute_result(scalar=mock_gateway))
test_db.commit = Mock()
test_db.refresh = Mock()
test_db.rollback = Mock()
# Return one tool so toggle_tool_status gets called
query_proxy = MagicMock()
filter_proxy = MagicMock()
filter_proxy.all.return_value = [MagicMock(id=101)]
query_proxy.filter.return_value = filter_proxy
test_db.query = Mock(return_value=query_proxy)
# Setup gateway service mocks
gateway_service._notify_gateway_deactivated = AsyncMock()
gateway_service._initialize_gateway = AsyncMock(return_value=({"prompts": {}}, [], [], []))
# Make tool toggle fail
tool_service_stub = MagicMock()
tool_service_stub.toggle_tool_status = AsyncMock(side_effect=Exception("Tool toggle failed"))
gateway_service.tool_service = tool_service_stub
# The toggle_gateway_status method will catch the exception and raise GatewayError
with pytest.raises(GatewayError) as exc_info:
await gateway_service.toggle_gateway_status(test_db, 1, activate=False)
assert "Failed to toggle gateway status" in str(exc_info.value)
assert "Tool toggle failed" in str(exc_info.value)
test_db.rollback.assert_called_once()
# ────────────────────────────────────────────────────────────────────
# DELETE
# ────────────────────────────────────────────────────────────────────
@pytest.mark.asyncio
async def test_delete_gateway(self, gateway_service, mock_gateway, test_db):
"""Gateway is removed and subscribers are notified."""
# Mock the fetchone result for DELETE ... RETURNING
mock_fetch_result = Mock()
mock_fetch_result.fetchone.return_value = (mock_gateway.id,)
# First execute call returns gateway (selectinload query), rest are for bulk deletes, last is DELETE RETURNING
execute_mock = Mock(
side_effect=[
_make_execute_result(scalar=mock_gateway), # Initial select
Mock(), # Tool metrics delete
Mock(), # Tool association delete
Mock(), # Tool delete
Mock(), # Resource metrics delete
Mock(), # Resource association delete
Mock(), # Resource subscription delete
Mock(), # Resource delete
Mock(), # Prompt metrics delete
Mock(), # Prompt association delete
Mock(), # Prompt delete
mock_fetch_result, # DELETE ... RETURNING
]
)
test_db.execute = execute_mock
test_db.commit = Mock()
test_db.expire = Mock() # For expiring gateway after bulk deletes
gateway_service._notify_gateway_deleted = AsyncMock()
await gateway_service.delete_gateway(test_db, 1)
gateway_service._notify_gateway_deleted.assert_called_once()
# Verify execute was called multiple times (select + bulk deletes + final delete)
assert test_db.execute.call_count >= 2
@pytest.mark.asyncio
async def test_delete_gateway_not_found(self, gateway_service, test_db):
"""Trying to delete a non-existent gateway raises GatewayError."""
test_db.execute = Mock(return_value=_make_execute_result(scalar=None))
with pytest.raises(GatewayError) as exc_info:
await gateway_service.delete_gateway(test_db, 999)
assert "Gateway not found: 999" in str(exc_info.value)
# ────────────────────────────────────────────────────────────────────
# FORWARD
# ────────────────────────────────────────────────────────────────────
@pytest.mark.asyncio
async def test_forward_request(self, gateway_service, mock_gateway):
"""Happy-path RPC forward."""
mock_response = AsyncMock()
mock_response.raise_for_status = Mock()
mock_response.json = Mock(return_value={"jsonrpc": "2.0", "result": {"success": True, "data": "OK"}, "id": 1})
gateway_service._http_client.post.return_value = mock_response
result = await gateway_service.forward_request(mock_gateway, "method", {"p": 1})
assert result == {"success": True, "data": "OK"}
assert mock_gateway.last_seen is not None
@pytest.mark.asyncio
async def test_forward_request_error_response(self, gateway_service, mock_gateway):
"""Gateway returns JSON-RPC error."""
mock_response = AsyncMock()
mock_response.raise_for_status = Mock()
mock_response.json = Mock(return_value={"jsonrpc": "2.0", "error": {"code": -32000, "message": "Boom"}, "id": 1})
gateway_service._http_client.post.return_value = mock_response
with pytest.raises(GatewayError) as exc_info:
await gateway_service.forward_request(mock_gateway, "method", {"p": 1})
assert "Gateway error: Boom" in str(exc_info.value)
@pytest.mark.asyncio
async def test_forward_request_connection_error(self, gateway_service, mock_gateway):
"""HTTP client raises network-level exception."""
gateway_service._http_client.post.side_effect = Exception("Network down")
with pytest.raises(GatewayConnectionError):
await gateway_service.forward_request(mock_gateway, "method", {})
# ────────────────────────────────────────────────────────────────────
# REDIS/INITIALIZATION COVERAGE
# ────────────────────────────────────────────────────────────────────
@pytest.mark.asyncio
async def test_init_with_redis_unavailable(self, monkeypatch):
"""Test initialization when Redis import fails."""
monkeypatch.setattr("mcpgateway.services.gateway_service.REDIS_AVAILABLE", False)
with patch("mcpgateway.services.gateway_service.logging"):
# Import should trigger the ImportError path
# First-Party
from mcpgateway.services.gateway_service import GatewayService
service = GatewayService()
assert service._redis_client is None
@pytest.mark.asyncio
async def test_init_with_redis_enabled(self, monkeypatch):
"""Test initialization with Redis available and enabled."""
monkeypatch.setattr("mcpgateway.services.gateway_service.REDIS_AVAILABLE", True)
mock_redis_client = AsyncMock()
mock_redis_client.ping = AsyncMock()
mock_redis_client.set = AsyncMock(return_value=True)
async def mock_get_redis_client():
return mock_redis_client
with patch("mcpgateway.services.gateway_service.get_redis_client", mock_get_redis_client):
with patch("mcpgateway.services.gateway_service.settings") as mock_settings:
mock_settings.cache_type = "redis"
mock_settings.redis_url = "redis://localhost:6379"
mock_settings.redis_leader_key = "gateway_service_leader"
mock_settings.redis_leader_ttl = 15
mock_settings.redis_leader_heartbeat_interval = 5
# First-Party
from mcpgateway.services.gateway_service import GatewayService
service = GatewayService()
await service.initialize()
assert service._redis_client is mock_redis_client
assert isinstance(service._instance_id, str)
assert service._leader_key == "gateway_service_leader"
assert service._leader_ttl == 15
@pytest.mark.asyncio
async def test_init_file_cache_path_adjustment(self, monkeypatch):
"""Test file cache path adjustment logic."""
monkeypatch.setattr("mcpgateway.services.gateway_service.REDIS_AVAILABLE", False)
with patch("mcpgateway.services.gateway_service.settings") as mock_settings:
mock_settings.cache_type = "file"
with patch("os.path.expanduser") as mock_expanduser, patch("os.path.relpath") as mock_relpath, patch("os.path.splitdrive") as mock_splitdrive:
mock_expanduser.return_value = "/home/user/.mcpgateway/health_checks.lock"
mock_splitdrive.return_value = ("C:", "/home/user/.mcpgateway/health_checks.lock")
mock_relpath.return_value = "home/user/.mcpgateway/health_checks.lock"
# First-Party
from mcpgateway.services.gateway_service import GatewayService
service = GatewayService()
# This triggers the path normalization logic
# But the actual trigger depends on the path being absolute
# Let's check that the service was created properly
assert service is not None
@pytest.mark.asyncio
async def test_init_with_cache_disabled(self, monkeypatch):
"""Test initialization with cache disabled."""
monkeypatch.setattr("mcpgateway.services.gateway_service.REDIS_AVAILABLE", False)
with patch("mcpgateway.services.gateway_service.settings") as mock_settings:
mock_settings.cache_type = "none"
# First-Party
from mcpgateway.services.gateway_service import GatewayService
service = GatewayService()
assert service._redis_client is None
# ────────────────────────────────────────────────────────────────────
# GATEWAY INITIALIZATION AND CONNECTION COVERAGE
# ────────────────────────────────────────────────────────────────────
@pytest.mark.asyncio
async def test_initialize_gateway_with_resources_and_prompts(self, gateway_service):
"""Test _initialize_gateway with full resources and prompts support."""
with (
patch("mcpgateway.services.gateway_service.sse_client") as mock_sse_client,
patch("mcpgateway.services.gateway_service.ClientSession") as mock_session,
patch("mcpgateway.services.gateway_service.decode_auth") as mock_decode,
):
# Setup mocks
mock_decode.return_value = {"Authorization": "Bearer token"}
# Mock SSE client context manager
mock_streams = (MagicMock(), MagicMock())
mock_sse_context = AsyncMock()
mock_sse_context.__aenter__.return_value = mock_streams
mock_sse_context.__aexit__.return_value = None
mock_sse_client.return_value = mock_sse_context
# Mock ClientSession
mock_session_instance = AsyncMock()
mock_session_context = AsyncMock()
mock_session_context.__aenter__.return_value = mock_session_instance
mock_session_context.__aexit__.return_value = None
mock_session.return_value = mock_session_context
# Mock initialization response
mock_init_response = MagicMock()
mock_init_response.capabilities.model_dump.return_value = {"protocolVersion": "0.1.0", "resources": {"listChanged": True}, "prompts": {"listChanged": True}, "tools": {"listChanged": True}}
mock_session_instance.initialize.return_value = mock_init_response
# Mock tools response
mock_tools_response = MagicMock()
mock_tool = MagicMock()
mock_tool.model_dump.return_value = {"name": "test_tool", "description": "Test tool", "inputSchema": {"type": "object"}}
mock_tools_response.tools = [mock_tool]
mock_session_instance.list_tools.return_value = mock_tools_response
# Mock resources response with URI handling
mock_resources_response = MagicMock()
mock_resource = MagicMock()
mock_resource.model_dump.return_value = {"uri": "file://test.txt", "name": "test_resource", "description": "Test resource", "mime_type": "text/plain"}
mock_resources_response.resources = [mock_resource]
mock_session_instance.list_resources.return_value = mock_resources_response
# Mock prompts response
mock_prompts_response = MagicMock()
mock_prompt = MagicMock()
mock_prompt.model_dump.return_value = {"name": "test_prompt", "description": "Test prompt"}
mock_prompts_response.prompts = [mock_prompt]
mock_session_instance.list_prompts.return_value = mock_prompts_response
# Execute
capabilities, tools, resources, prompts = await gateway_service._initialize_gateway("http://test.example.com", {"Authorization": "Bearer token"}, "SSE")
# Verify
assert "resources" in capabilities
assert "prompts" in capabilities
assert len(tools) == 1
assert len(resources) == 1
assert len(prompts) == 1
assert resources[0].uri == "file://test.txt"
assert resources[0].content == "" # Default content added
assert prompts[0].template == "" # Default template added
@pytest.mark.asyncio
async def test_initialize_gateway_resource_validation_error(self, gateway_service):
"""Test _initialize_gateway with resource validation error fallback."""
with (
patch("mcpgateway.services.gateway_service.sse_client") as mock_sse_client,
patch("mcpgateway.services.gateway_service.ClientSession") as mock_session,
patch("mcpgateway.services.gateway_service.decode_auth") as mock_decode,
):
# Setup mocks
mock_decode.return_value = {"Authorization": "Bearer token"}
# Mock SSE client context manager
mock_streams = (MagicMock(), MagicMock())
mock_sse_context = AsyncMock()
mock_sse_context.__aenter__.return_value = mock_streams
mock_sse_context.__aexit__.return_value = None
mock_sse_client.return_value = mock_sse_context
# Mock ClientSession
mock_session_instance = AsyncMock()
mock_session_context = AsyncMock()
mock_session_context.__aenter__.return_value = mock_session_instance
mock_session_context.__aexit__.return_value = None
mock_session.return_value = mock_session_context
# Mock initialization response with resources support
mock_init_response = MagicMock()
mock_init_response.capabilities.model_dump.return_value = {"resources": {"listChanged": True}, "tools": {"listChanged": True}}
mock_session_instance.initialize.return_value = mock_init_response
# Mock tools response
mock_tools_response = MagicMock()
mock_tools_response.tools = []
mock_session_instance.list_tools.return_value = mock_tools_response
# Mock resources response with complex URI object
mock_resources_response = MagicMock()
mock_resource = MagicMock()
# Create a complex URI object that has unicode_string attribute
mock_uri = MagicMock()
mock_uri.unicode_string = "file://complex.txt"
mock_resource.model_dump.return_value = {"uri": mock_uri, "name": "complex_resource", "description": "Complex resource"}
mock_resources_response.resources = [mock_resource]
mock_session_instance.list_resources.return_value = mock_resources_response
# Mock ResourceCreate.model_validate to raise exception first time
with patch("mcpgateway.services.gateway_service.ResourceCreate") as mock_resource_create:
mock_resource_create.model_validate.side_effect = [Exception("Validation error"), MagicMock()]
mock_resource_create.return_value = MagicMock()
# Execute
capabilities, tools, resources, prompts = await gateway_service._initialize_gateway("http://test.example.com", {"Authorization": "Bearer token"}, "SSE")
# Verify fallback resource creation was used
assert len(resources) == 1
assert mock_resource_create.called
@pytest.mark.asyncio
async def test_initialize_gateway_prompt_validation_error(self, gateway_service):
"""Test _initialize_gateway with prompt validation error fallback."""
with (
patch("mcpgateway.services.gateway_service.sse_client") as mock_sse_client,
patch("mcpgateway.services.gateway_service.ClientSession") as mock_session,
patch("mcpgateway.services.gateway_service.decode_auth") as mock_decode,
):
# Setup mocks
mock_decode.return_value = {"Authorization": "Bearer token"}
# Mock SSE client context manager
mock_streams = (MagicMock(), MagicMock())
mock_sse_context = AsyncMock()
mock_sse_context.__aenter__.return_value = mock_streams
mock_sse_context.__aexit__.return_value = None
mock_sse_client.return_value = mock_sse_context
# Mock ClientSession
mock_session_instance = AsyncMock()
mock_session_context = AsyncMock()
mock_session_context.__aenter__.return_value = mock_session_instance
mock_session_context.__aexit__.return_value = None
mock_session.return_value = mock_session_context
# Mock initialization response with prompts support
mock_init_response = MagicMock()
mock_init_response.capabilities.model_dump.return_value = {"prompts": {"listChanged": True}, "tools": {"listChanged": True}}
mock_session_instance.initialize.return_value = mock_init_response
# Mock tools response
mock_tools_response = MagicMock()
mock_tools_response.tools = []
mock_session_instance.list_tools.return_value = mock_tools_response
# Mock prompts response
mock_prompts_response = MagicMock()
mock_prompt = MagicMock()
mock_prompt.model_dump.return_value = {"name": "complex_prompt", "description": "Complex prompt"}
mock_prompts_response.prompts = [mock_prompt]
mock_session_instance.list_prompts.return_value = mock_prompts_response
# Mock PromptCreate.model_validate to raise exception first time
with patch("mcpgateway.services.gateway_service.PromptCreate") as mock_prompt_create:
mock_prompt_create.model_validate.side_effect = [Exception("Validation error"), MagicMock()]
mock_prompt_create.return_value = MagicMock()
# Execute
capabilities, tools, resources, prompts = await gateway_service._initialize_gateway("http://test.example.com", {"Authorization": "Bearer token"}, "SSE")
# Verify fallback prompt creation was used
assert len(prompts) == 1
assert mock_prompt_create.called
@pytest.mark.asyncio
async def test_initialize_gateway_resource_fetch_failure(self, gateway_service):
"""Test _initialize_gateway when resource fetching fails."""
with (
patch("mcpgateway.services.gateway_service.sse_client") as mock_sse_client,
patch("mcpgateway.services.gateway_service.ClientSession") as mock_session,
patch("mcpgateway.services.gateway_service.decode_auth") as mock_decode,
):
# Setup mocks
mock_decode.return_value = {"Authorization": "Bearer token"}
# Mock SSE client context manager
mock_streams = (MagicMock(), MagicMock())
mock_sse_context = AsyncMock()
mock_sse_context.__aenter__.return_value = mock_streams
mock_sse_context.__aexit__.return_value = None
mock_sse_client.return_value = mock_sse_context
# Mock ClientSession
mock_session_instance = AsyncMock()
mock_session_context = AsyncMock()
mock_session_context.__aenter__.return_value = mock_session_instance
mock_session_context.__aexit__.return_value = None
mock_session.return_value = mock_session_context
# Mock initialization response with resources support
mock_init_response = MagicMock()
mock_init_response.capabilities.model_dump.return_value = {"resources": {"listChanged": True}, "tools": {"listChanged": True}}
mock_session_instance.initialize.return_value = mock_init_response
# Mock tools response
mock_tools_response = MagicMock()
mock_tools_response.tools = []
mock_session_instance.list_tools.return_value = mock_tools_response
# Make list_resources fail
mock_session_instance.list_resources.side_effect = Exception("Resource fetch failed")
# Execute
capabilities, tools, resources, prompts = await gateway_service._initialize_gateway("http://test.example.com", {"Authorization": "Bearer token"}, "SSE")
# Verify
assert "resources" in capabilities
assert len(resources) == 0 # Should be empty due to failure
@pytest.mark.asyncio
async def test_initialize_gateway_prompt_fetch_failure(self, gateway_service):
"""Test _initialize_gateway when prompt fetching fails."""
with (
patch("mcpgateway.services.gateway_service.sse_client") as mock_sse_client,
patch("mcpgateway.services.gateway_service.ClientSession") as mock_session,
patch("mcpgateway.services.gateway_service.decode_auth") as mock_decode,
):
# Setup mocks
mock_decode.return_value = {"Authorization": "Bearer token"}
# Mock SSE client context manager
mock_streams = (MagicMock(), MagicMock())
mock_sse_context = AsyncMock()
mock_sse_context.__aenter__.return_value = mock_streams
mock_sse_context.__aexit__.return_value = None
mock_sse_client.return_value = mock_sse_context
# Mock ClientSession
mock_session_instance = AsyncMock()
mock_session_context = AsyncMock()
mock_session_context.__aenter__.return_value = mock_session_instance
mock_session_context.__aexit__.return_value = None
mock_session.return_value = mock_session_context
# Mock initialization response with prompts support
mock_init_response = MagicMock()
mock_init_response.capabilities.model_dump.return_value = {"prompts": {"listChanged": True}, "tools": {"listChanged": True}}
mock_session_instance.initialize.return_value = mock_init_response
# Mock tools response
mock_tools_response = MagicMock()
mock_tools_response.tools = []
mock_session_instance.list_tools.return_value = mock_tools_response
# Make list_prompts fail
mock_session_instance.list_prompts.side_effect = Exception("Prompt fetch failed")
# Execute
capabilities, tools, resources, prompts = await gateway_service._initialize_gateway("http://test.example.com", {"Authorization": "Bearer token"}, "SSE")
# Verify
assert "prompts" in capabilities
assert len(prompts) == 0 # Should be empty due to failure
@pytest.mark.asyncio
async def test_list_gateway_with_tags(self, gateway_service, mock_gateway):
"""Test listing gateways with tag filtering."""
# Third-Party
# Mock query chain - needs to support chaining through order_by, where, limit
mock_query = MagicMock()
mock_query.order_by.return_value = mock_query
mock_query.where.return_value = mock_query
mock_query.limit.return_value = mock_query
session = MagicMock()
session.execute.return_value.scalars.return_value.all.return_value = [mock_gateway]
bind = MagicMock()
bind.dialect = MagicMock()
bind.dialect.name = "sqlite" # or "postgresql" or "mysql"
session.get_bind.return_value = bind
# Mock EmailTeam query for team names
session.query.return_value.filter.return_value.all.return_value = []
session.commit = MagicMock()
mocked_gateway_read = MagicMock()
mocked_gateway_read.model_dump.return_value = {"id": "1", "name": "test"}
# Mock select to return the mock_query that supports chaining
def mock_select(*args):
return mock_query
# Mock convert_gateway_to_read to return the mocked gateway
gateway_service.convert_gateway_to_read = MagicMock(return_value=mocked_gateway_read)
with patch("mcpgateway.services.gateway_service.select", side_effect=mock_select):
with patch("mcpgateway.services.gateway_service.json_contains_expr") as mock_json_contains:
fake_condition = MagicMock()
mock_json_contains.return_value = fake_condition
# Pass include_inactive=True to avoid the enabled filter, so we can test tag filtering in isolation
result, next_cursor = await gateway_service.list_gateways(session, tags=["test", "production"], include_inactive=True)
mock_json_contains.assert_called_once() # called exactly once
called_args = mock_json_contains.call_args[0] # positional args tuple
assert called_args[0] is session # session passed through
# third positional arg is the tags list (signature: session, col, values, match_any=True)
assert called_args[2] == ["test", "production"]
# Verify where() was called and the fake_condition is in one of the calls
assert mock_query.where.called, "where() should have been called"
# Check that fake_condition appears in at least one of the where() calls
where_calls = mock_query.where.call_args_list
assert any(fake_condition in call.args for call in where_calls), f"fake_condition not found in where() calls: {where_calls}"
# finally, your service should return the list produced by mock_db.execute(...)
assert isinstance(result, list)
assert result == [mocked_gateway_read]
class TestGatewayRefresh:
"""Test suite for gateway refresh logic (internal and manual)."""
@pytest.fixture
def mock_db_session(self):
"""Mock database session context manager."""
session = MagicMock()
session.commit = MagicMock()
session.flush = MagicMock()
session.execute.return_value = _make_execute_result(scalar=None)
# Mock dirty objects set
session.dirty = set()
# Mock context manager
ctx = MagicMock()
ctx.__enter__.return_value = session
ctx.__exit__.return_value = None
return ctx
@pytest.fixture
def mock_gateway_with_relations(self):
"""Mock gateway with tools, resources, prompts relations."""
gw = MagicMock(spec=DbGateway)
gw.id = "gw-123"
gw.name = "test_gateway"
gw.url = "http://example.com"
gw.enabled = True
gw.reachable = True
gw.tools = []
gw.resources = []
gw.prompts = []
return gw
@pytest.mark.asyncio
async def test_refresh_gateway_success_all_changed(self, gateway_service, mock_gateway_with_relations, mock_db_session):
"""Test successful refresh where tools, resources, prompts are all updated."""
# Setup mocks
session = mock_db_session.__enter__()
# Mock gateway fetch
session.execute.return_value = _make_execute_result(scalar=mock_gateway_with_relations)
# Mock fresh_db_session to return our mock session
with patch("mcpgateway.services.gateway_service.fresh_db_session", return_value=mock_db_session):
# Mock _initialize_gateway to return new data
new_tools = [MagicMock(name="tool1")]
new_resources = [MagicMock(uri="res1")]
new_prompts = [MagicMock(name="prompt1")]
gateway_service._initialize_gateway = AsyncMock(return_value=({}, new_tools, new_resources, new_prompts)) # capabilities
# Mock update/create helpers
gateway_service._update_or_create_tools = Mock(return_value=[MagicMock()])
gateway_service._update_or_create_resources = Mock(return_value=[MagicMock()])
gateway_service._update_or_create_prompts = Mock(return_value=[MagicMock()])
# Simulate dirty objects for count calculation
session.dirty = {MagicMock(spec=DbTool), MagicMock(spec=DbResource), MagicMock(spec=DbPrompt)} # mock updated objects
result = await gateway_service._refresh_gateway_tools_resources_prompts("gw-123", gateway=mock_gateway_with_relations)
assert result["success"] is True
assert result["tools_added"] == 1
assert result["resources_added"] == 1
assert result["prompts_added"] == 1
# Note: dirty check logic in actual code compares vs snapshot, simplified here
@pytest.mark.asyncio
async def test_refresh_gateway_no_changes(self, gateway_service, mock_gateway_with_relations, mock_db_session):
"""Test refresh with no changes detected."""
# Setup mock session to return gateway when queried
session = mock_db_session.__enter__()
session.execute.return_value = _make_execute_result(scalar=mock_gateway_with_relations)
with patch("mcpgateway.services.gateway_service.fresh_db_session", return_value=mock_db_session):
# Mock empty return from initialize
gateway_service._initialize_gateway = AsyncMock(return_value=({}, [], [], []))
# Mock update methods to avoid real execution errors
gateway_service._update_or_create_tools = Mock(return_value=[])
gateway_service._update_or_create_resources = Mock(return_value=[])
gateway_service._update_or_create_prompts = Mock(return_value=[])
result = await gateway_service._refresh_gateway_tools_resources_prompts("gw-123", gateway=mock_gateway_with_relations)
if not result.get("success", True):
pytest.fail(f"Refresh failed with error: {result.get('error')}")
assert result["success"] is True
assert result["tools_added"] == 0
assert result["resources_added"] == 0
assert result["prompts_added"] == 0
@pytest.mark.asyncio
async def test_refresh_gateway_not_found(self, gateway_service, mock_db_session):
"""Test refresh fails when gateway doesn't exist."""
session = mock_db_session.__enter__()
session.execute.return_value = _make_execute_result(scalar=None)
with patch("mcpgateway.services.gateway_service.fresh_db_session", return_value=mock_db_session):
result = await gateway_service._refresh_gateway_tools_resources_prompts("non-existent-id")
# Depending on implementation, it may return empty result or error
# Code says: logger.warning and return result (which defaults success=True but counts 0)
assert result["success"] is True # Based on code reading: returns default result
assert result["tools_added"] == 0
@pytest.mark.asyncio
async def test_refresh_gateway_inactive(self, gateway_service, mock_gateway_with_relations):
"""Test refresh is skipped for inactive gateway."""
mock_gateway_with_relations.enabled = False
result = await gateway_service._refresh_gateway_tools_resources_prompts("gw-123", gateway=mock_gateway_with_relations)
assert result["tools_added"] == 0
# Should verify no init calls made
assert not hasattr(gateway_service._initialize_gateway, "called") or not gateway_service._initialize_gateway.called
@pytest.mark.asyncio
async def test_refresh_gateway_connection_error(self, gateway_service, mock_gateway_with_relations):
"""Test handling of connection error during refresh."""
gateway_service._initialize_gateway = AsyncMock(side_effect=Exception("Connection failed"))
result = await gateway_service._refresh_gateway_tools_resources_prompts("gw-123", gateway=mock_gateway_with_relations)
assert result["success"] is False
assert "Connection failed" in result["error"]
@pytest.mark.asyncio
async def test_manual_refresh_success(self, gateway_service, mock_gateway_with_relations, mock_db_session):
"""Test successful manual refresh."""
session = mock_db_session.__enter__()
session.execute.return_value = _make_execute_result(scalar=mock_gateway_with_relations)
with patch("mcpgateway.services.gateway_service.fresh_db_session", return_value=mock_db_session):
# Mock the internal refresh method (which handles last_refresh_at update internally)
gateway_service._refresh_gateway_tools_resources_prompts = AsyncMock(
return_value={"success": True, "tools_added": 5, "tools_removed": 0, "resources_added": 0, "resources_removed": 0, "prompts_added": 0, "prompts_removed": 0}
)
result = await gateway_service.refresh_gateway_manually("gw-123")
assert result["success"] is True
assert result["tools_added"] == 5
assert "duration_ms" in result
assert "refreshed_at" in result
gateway_service._refresh_gateway_tools_resources_prompts.assert_called_once()
# Verify internal method was called with correct params
args, kwargs = gateway_service._refresh_gateway_tools_resources_prompts.call_args
assert kwargs["created_via"] == "manual_refresh"
@pytest.mark.asyncio
async def test_manual_refresh_gateway_not_found(self, gateway_service, mock_db_session):
"""Test manual refresh raises error if gateway not found."""
session = mock_db_session.__enter__()
session.execute.return_value = _make_execute_result(scalar=None)
with patch("mcpgateway.services.gateway_service.fresh_db_session", return_value=mock_db_session):
with pytest.raises(GatewayNotFoundError):
await gateway_service.refresh_gateway_manually("non-existent-id")
@pytest.mark.asyncio
async def test_manual_refresh_concurrency(self, gateway_service, mock_gateway_with_relations, mock_db_session):
"""Test error when refresh lock is already held."""
session = mock_db_session.__enter__()
session.execute.return_value = _make_execute_result(scalar=mock_gateway_with_relations)
# Manually acquire the lock first
lock = gateway_service._get_refresh_lock("gw-123")
await lock.acquire()
with patch("mcpgateway.services.gateway_service.fresh_db_session", return_value=mock_db_session):
try:
with pytest.raises(GatewayError) as exc_info:
await gateway_service.refresh_gateway_manually("gw-123")
assert "Refresh already in progress" in str(exc_info.value)
finally:
lock.release()
@pytest.mark.asyncio
async def test_manual_refresh_passthrough_headers(self, gateway_service, mock_gateway_with_relations, mock_db_session):
"""Test manual refresh uses passthrough headers."""
session = mock_db_session.__enter__()
session.execute.return_value = _make_execute_result(scalar=mock_gateway_with_relations)
with patch("mcpgateway.services.gateway_service.fresh_db_session", return_value=mock_db_session):
with patch("mcpgateway.services.gateway_service.get_passthrough_headers") as mock_get_headers:
mock_get_headers.return_value = {"x-custom": "value"}
# Return full dict structure expected by logging
gateway_service._refresh_gateway_tools_resources_prompts = AsyncMock(
return_value={
"success": True,
"tools_added": 0,
"tools_removed": 0,
"tools_updated": 0,
"resources_added": 0,
"resources_removed": 0,
"resources_updated": 0,
"prompts_added": 0,
"prompts_removed": 0,
"prompts_updated": 0,
"duration_ms": 0,
}
)
await gateway_service.refresh_gateway_manually("gw-123", request_headers={"x-foo": "bar"})
mock_get_headers.assert_called_once()
# Verify headers passed to internal method
args, kwargs = gateway_service._refresh_gateway_tools_resources_prompts.call_args
assert kwargs["pre_auth_headers"] == {"x-custom": "value"}
def test_validate_tools_partial_failure(self, gateway_service):
"""Test tool validation logs errors but returns valid tools and validation errors."""
tools = [
{"name": "valid_tool", "description": "valid", "inputSchema": {}},
{"name": "invalid_tool", "integration_type": "INVALID_TYPE"}, # Invalid integration_type, should fail
]
valid_tools, validation_errors = gateway_service._validate_tools(tools)
assert len(valid_tools) == 1
assert valid_tools[0].name == "valid_tool"
assert len(validation_errors) == 1
assert "invalid_tool" in validation_errors[0]
def test_validate_tools_all_invalid(self, gateway_service):
"""Test failure when all tools are invalid."""
tools = [
{"name": "invalid1", "integration_type": "INVALID_TYPE"},
{"name": "invalid2", "integration_type": "INVALID_TYPE"},
]
with pytest.raises(GatewayConnectionError) as exc:
gateway_service._validate_tools(tools)
assert "validation" in str(exc.value)
def test_validate_tools_all_invalid_oauth(self, gateway_service):
"""Test failure when all tools are invalid in oauth context."""
tools = [{"name": "invalid", "integration_type": "INVALID_TYPE"}]
with pytest.raises(OAuthToolValidationError) as exc:
gateway_service._validate_tools(tools, context="oauth")
assert "OAuth tool fetch failed" in str(exc.value)
def test_validate_tools_depth_limit(self, gateway_service):
"""Test handling of recursion depth error in validation."""
# We simulate this by mocking ToolCreate.model_validate to raise ValueError
with patch("mcpgateway.services.gateway_service.ToolCreate.model_validate") as mock_validate:
mock_validate.side_effect = ValueError("JSON structure exceeds maximum depth")
# Should not raise exception, but log error and return empty valid list
# Since all failed, it will raise GatewayConnectionError eventually
with pytest.raises(GatewayConnectionError):
gateway_service._validate_tools([{"name": "deep_tool"}])
@pytest.mark.asyncio
async def test_publish_event(self, gateway_service):
"""Test event publishing."""
# Mock internal event service
gateway_service._event_service = AsyncMock()
event = {"type": "test", "data": "foo"}
await gateway_service._publish_event(event)
gateway_service._event_service.publish_event.assert_awaited_once_with(event)
@pytest.mark.asyncio
async def test_connect_to_sse_server_without_validation_success(self, gateway_service):
"""Test successful connection without URL validation."""
# Mock dependencies
mock_session = AsyncMock()
# Mock responses
mock_init_response = MagicMock()
mock_init_response.capabilities.model_dump.return_value = {"resources": True, "prompts": True}
mock_session.initialize.return_value = mock_init_response
mock_list_tools = MagicMock()
mock_list_tools.tools = [MagicMock(model_dump=MagicMock(return_value={"name": "tool1", "inputSchema": {}}))]
mock_session.list_tools.return_value = mock_list_tools
mock_list_resources = MagicMock()
mock_list_resources.resources = [MagicMock(model_dump=MagicMock(return_value={"uri": "res1", "name": "res1"}))]
mock_session.list_resources.return_value = mock_list_resources
mock_session.list_resource_templates.return_value = MagicMock(resourceTemplates=[])
mock_list_prompts = MagicMock()
mock_list_prompts.prompts = [MagicMock(model_dump=MagicMock(return_value={"name": "prompt1"}))]
mock_session.list_prompts.return_value = mock_list_prompts
# Context managers
mock_sse_cm = AsyncMock()
mock_sse_cm.__aenter__.return_value = (MagicMock(), MagicMock())
mock_sse_cm.__aexit__.return_value = None
mock_client_cm = AsyncMock()
mock_client_cm.__aenter__.return_value = mock_session
mock_client_cm.__aexit__.return_value = None
with patch("mcpgateway.services.gateway_service.sse_client", return_value=mock_sse_cm):
with patch("mcpgateway.services.gateway_service.ClientSession", return_value=mock_client_cm):
# Execute
capabilities, tools, resources, prompts = await gateway_service._connect_to_sse_server_without_validation("http://test.com")
assert len(tools) == 1
assert len(resources) == 1
assert len(prompts) == 1
assert capabilities["resources"] is True
@pytest.mark.asyncio
async def test_connect_to_sse_server_without_validation_fetch_errors(self, gateway_service):
"""Test resilience when resource/prompt fetch fails."""
# Mock dependencies
mock_session = AsyncMock()
# Mock responses
mock_init_response = MagicMock()
mock_init_response.capabilities.model_dump.return_value = {"resources": True, "prompts": True}
mock_session.initialize.return_value = mock_init_response
mock_list_tools = MagicMock()
mock_list_tools.tools = []
mock_session.list_tools.return_value = mock_list_tools
# Simulate failures
mock_session.list_resources.side_effect = Exception("Resource fetch failed")
mock_session.list_prompts.side_effect = Exception("Prompt fetch failed")
# Context managers
mock_sse_cm = AsyncMock()
mock_sse_cm.__aenter__.return_value = (MagicMock(), MagicMock())
mock_sse_cm.__aexit__.return_value = None
mock_client_cm = AsyncMock()
mock_client_cm.__aenter__.return_value = mock_session
mock_client_cm.__aexit__.return_value = None
with patch("mcpgateway.services.gateway_service.sse_client", return_value=mock_sse_cm):
with patch("mcpgateway.services.gateway_service.ClientSession", return_value=mock_client_cm):
# Execute
capabilities, tools, resources, prompts = await gateway_service._connect_to_sse_server_without_validation("http://test.com")
# Should return empty lists for failed parts, not raise exception
assert len(resources) == 0
assert len(prompts) == 0
assert capabilities["resources"] is True
class TestGatewayHealth:
"""Test suite for gateway health checks and auto-refresh logic."""
@pytest.fixture
def mock_db_session(self):
mock_session = MagicMock()
# Allow context manager usage
mock_session.__enter__.return_value = mock_session
mock_session.__exit__.return_value = None
return mock_session
@pytest.fixture
def mock_gateway_health(self):
"""Gateway ready for health checks."""
gw = MagicMock(spec=DbGateway)
gw.id = "gw-health-1"
gw.name = "Health Gateway"
gw.url = "http://health.test"
gw.enabled = True
gw.auth_type = None
gw.last_refresh_at = datetime.now(timezone.utc) - timedelta(hours=1)
gw.refresh_interval_seconds = 300
gw.ca_certificate = None
gw.ca_certificate_sig = None
return gw
@pytest.mark.asyncio
async def test_check_health_batch_success(self, gateway_service, mock_gateway_health):
"""Test batch health check success."""
gateways = [mock_gateway_health]
# Mock single check to succeed
gateway_service._check_single_gateway_health = AsyncMock(return_value=None)
# Mock settings
with patch("mcpgateway.services.gateway_service.settings") as mock_settings:
mock_settings.max_concurrent_health_checks = 5
mock_settings.gateway_health_check_timeout = 5
result = await gateway_service.check_health_of_gateways(gateways)
assert result is True
gateway_service._check_single_gateway_health.assert_awaited_once()
@pytest.mark.asyncio
async def test_check_health_timeout(self, gateway_service, mock_gateway_health):
"""Test handling of health check timeout."""
gateways = [mock_gateway_health]
# Mock single check to sleep forever (simulating timeout)
async def slow_check(*args, **kwargs):
await asyncio.sleep(0.2)
gateway_service._check_single_gateway_health = AsyncMock(side_effect=slow_check)
gateway_service._handle_gateway_failure = AsyncMock()
# Mock settings with very short timeout
with patch("mcpgateway.services.gateway_service.settings") as mock_settings:
mock_settings.max_concurrent_health_checks = 5
mock_settings.gateway_health_check_timeout = 0.01 # Ultra short timeout
result = await gateway_service.check_health_of_gateways(gateways)
assert result is True
# Should have timed out and called failure handler
gateway_service._handle_gateway_failure.assert_awaited_once_with(mock_gateway_health)
@pytest.mark.asyncio
async def test_health_triggers_auto_refresh(self, gateway_service, mock_gateway_health, mock_db_session):
"""Test that health check triggers auto-refresh when due."""
# Setup: Auto-refresh ON, Refresh needed
gateway_service._refresh_gateway_tools_resources_prompts = AsyncMock()
gateway_service.toggle_gateway_status = AsyncMock()
gateway_service._get_refresh_lock = MagicMock()
# Lock needs to be MagicMock for sync .locked(), but behave as AsyncMock for context manager
lock = MagicMock()
lock.locked.return_value = False
lock.__aenter__ = AsyncMock(return_value=None)
lock.__aexit__ = AsyncMock(return_value=None)
gateway_service._get_refresh_lock.return_value = lock
# Mock http client for health ping
mock_client = AsyncMock()
mock_client.__aenter__.return_value = mock_client
mock_client.__aexit__.return_value = None
mock_client.get.return_value = MagicMock(status_code=200)
with patch("mcpgateway.services.gateway_service.settings") as mock_settings:
mock_settings.auto_refresh_servers = True
mock_settings.gateway_auto_refresh_interval = 300
# Ensure Ed25519 signing is disabled to simplify test
mock_settings.enable_ed25519_signing = False
mock_settings.httpx_admin_read_timeout = 5.0
with patch("mcpgateway.services.http_client_service.get_isolated_http_client", return_value=mock_client):
with patch("mcpgateway.services.gateway_service.fresh_db_session", return_value=mock_db_session):
# Mock DB lookup for last_seen update
session = mock_db_session.__enter__()
session.execute.return_value = _make_execute_result(scalar=mock_gateway_health)
await gateway_service._check_single_gateway_health(mock_gateway_health)
# Should call refresh
gateway_service._refresh_gateway_tools_resources_prompts.assert_awaited_once()
args, kwargs = gateway_service._refresh_gateway_tools_resources_prompts.call_args
assert kwargs["created_via"] == "health_check"
@pytest.mark.asyncio
async def test_health_skips_refresh_disabled(self, gateway_service, mock_gateway_health, mock_db_session):
"""Test that health check skips refresh if feature disabled."""
gateway_service._refresh_gateway_tools_resources_prompts = AsyncMock()
# Mock http client
mock_client = AsyncMock()
mock_client.__aenter__.return_value = mock_client
mock_client.__aexit__.return_value = None
mock_client.get.return_value = MagicMock(status_code=200)
with patch("mcpgateway.services.gateway_service.settings") as mock_settings:
mock_settings.auto_refresh_servers = False # Disabled
mock_settings.enable_ed25519_signing = False
mock_settings.httpx_admin_read_timeout = 5.0
with patch("mcpgateway.services.http_client_service.get_isolated_http_client", return_value=mock_client):
with patch("mcpgateway.services.gateway_service.fresh_db_session", return_value=mock_db_session):
session = mock_db_session.__enter__()
session.execute.return_value = _make_execute_result(scalar=mock_gateway_health)
await gateway_service._check_single_gateway_health(mock_gateway_health)
gateway_service._refresh_gateway_tools_resources_prompts.assert_not_called()
@pytest.mark.asyncio
async def test_health_skips_refresh_throttled(self, gateway_service, mock_gateway_health, mock_db_session):
"""Test that health check skips refresh if done recently."""
# Setup: Refreshed just now
mock_gateway_health.last_refresh_at = datetime.now(timezone.utc)
gateway_service._refresh_gateway_tools_resources_prompts = AsyncMock()
# Mock http client
mock_client = AsyncMock()
mock_client.__aenter__.return_value = mock_client
mock_client.__aexit__.return_value = None
mock_client.get.return_value = MagicMock(status_code=200)
with patch("mcpgateway.services.gateway_service.settings") as mock_settings:
mock_settings.auto_refresh_servers = True
mock_settings.gateway_auto_refresh_interval = 300
mock_settings.enable_ed25519_signing = False
mock_settings.httpx_admin_read_timeout = 5.0
with patch("mcpgateway.services.http_client_service.get_isolated_http_client", return_value=mock_client):
with patch("mcpgateway.services.gateway_service.fresh_db_session", return_value=mock_db_session):
session = mock_db_session.__enter__()
session.execute.return_value = _make_execute_result(scalar=mock_gateway_health)
await gateway_service._check_single_gateway_health(mock_gateway_health)
gateway_service._refresh_gateway_tools_resources_prompts.assert_not_called()
@pytest.mark.asyncio
async def test_health_skips_refresh_locked(self, gateway_service, mock_gateway_health, mock_db_session):
"""Test that health check skips refresh if lock is held."""
gateway_service._refresh_gateway_tools_resources_prompts = AsyncMock()
lock = MagicMock()
lock.locked.return_value = True # Lock held!
lock.__aenter__ = AsyncMock(return_value=None)
lock.__aexit__ = AsyncMock(return_value=None)
gateway_service._get_refresh_lock = MagicMock(return_value=lock)
# Mock http client
mock_client = AsyncMock()
mock_client.__aenter__.return_value = mock_client
mock_client.__aexit__.return_value = None
mock_client.get.return_value = MagicMock(status_code=200)
with patch("mcpgateway.services.gateway_service.settings") as mock_settings:
mock_settings.auto_refresh_servers = True
mock_settings.enable_ed25519_signing = False
mock_settings.httpx_admin_read_timeout = 5.0
with patch("mcpgateway.services.http_client_service.get_isolated_http_client", return_value=mock_client):
with patch("mcpgateway.services.gateway_service.fresh_db_session", return_value=mock_db_session):
session = mock_db_session.__enter__()
session.execute.return_value = _make_execute_result(scalar=mock_gateway_health)
await gateway_service._check_single_gateway_health(mock_gateway_health)
gateway_service._refresh_gateway_tools_resources_prompts.assert_not_called()