We provide all the information about MCP servers via our MCP API.
curl -X GET 'https://glama.ai/api/mcp/v1/servers/KuudoAI/amazon_ads_mcp'
If you have feedback or need assistance with the MCP directory API, please join our Discord server
"""Tests for security refactoring."""
import asyncio
from datetime import datetime, timedelta, timezone
from unittest.mock import AsyncMock, MagicMock
import pytest
from amazon_ads_mcp.auth.oauth_state_store import OAuthStateStore
from amazon_ads_mcp.auth.secure_token_store import SecureTokenStore
from amazon_ads_mcp.exceptions import (
OAuthStateError,
TimeoutError,
APIError,
ToolExecutionError,
)
from amazon_ads_mcp.utils.async_compat import (
CompatibleEventLoopPolicy,
ensure_event_loop,
run_async_in_sync,
AsyncContextManager,
)
from amazon_ads_mcp.utils.response_wrapper import ResponseWrapper
from amazon_ads_mcp.utils.sampling_wrapper import SamplingHandlerWrapper
class TestOAuthStateStore:
"""Test OAuth state store functionality."""
def test_generate_state(self):
"""Test state generation with HMAC signature."""
store = OAuthStateStore(secret_key="test_secret")
state = store.generate_state(
auth_url="https://example.com/auth",
user_agent="TestAgent/1.0",
ip_address="192.168.1.1"
)
assert state is not None
assert "." in state # Should have signature separator
assert len(state) > 40 # Should be reasonably long
def test_validate_state_success(self):
"""Test successful state validation."""
store = OAuthStateStore(secret_key="test_secret")
state = store.generate_state(
auth_url="https://example.com/auth",
user_agent="TestAgent/1.0"
)
is_valid, error = store.validate_state(state, user_agent="TestAgent/1.0")
assert is_valid is True
assert error is None
def test_validate_state_invalid(self):
"""Test invalid state validation."""
store = OAuthStateStore(secret_key="test_secret")
# Test with completely invalid state
is_valid, error = store.validate_state("invalid_state")
assert is_valid is False
assert error == "Invalid or expired state"
def test_validate_state_tampered(self):
"""Test tampered state detection."""
store = OAuthStateStore(secret_key="test_secret")
state = store.generate_state(auth_url="https://example.com/auth")
# Tamper with the signature
base, sig = state.rsplit(".", 1)
tampered_state = f"{base}.tampered_signature"
is_valid, error = store.validate_state(tampered_state)
assert is_valid is False
# Tampering the signature changes the state token; store lookup fails first
assert error in ("Invalid or expired state", "Invalid state signature")
def test_validate_state_reuse_prevention(self):
"""Test that states cannot be reused."""
store = OAuthStateStore(secret_key="test_secret")
state = store.generate_state(auth_url="https://example.com/auth")
# First validation should succeed
is_valid, error = store.validate_state(state)
assert is_valid is True
# Second validation should fail
is_valid, error = store.validate_state(state)
assert is_valid is False
assert "already used" in error
def test_state_expiration(self):
"""Test state expiration."""
store = OAuthStateStore(secret_key="test_secret")
state = store.generate_state(
auth_url="https://example.com/auth",
ttl_minutes=0 # Expire immediately
)
# Force expiration
entry = store._memory_store[state]
entry.expires_at = datetime.now(timezone.utc) - timedelta(minutes=1)
is_valid, error = store.validate_state(state)
assert is_valid is False
assert "expired" in error.lower()
def test_persistence(self, tmp_path):
"""Test state persistence to file."""
store_path = tmp_path / "oauth_states.json"
store1 = OAuthStateStore(secret_key="test_secret", store_path=store_path)
state = store1.generate_state(auth_url="https://example.com/auth")
# Create new store instance
store2 = OAuthStateStore(secret_key="test_secret", store_path=store_path)
# Should be able to validate state from first store
is_valid, error = store2.validate_state(state)
assert is_valid is True
class TestSecureTokenStore:
"""Test secure token storage."""
def test_store_and_retrieve(self, tmp_path):
"""Test storing and retrieving tokens."""
store = SecureTokenStore(
storage_path=tmp_path / "tokens.enc",
encryption_key="test_password"
)
store.store_token(
token_id="test_token",
token_value="secret_value_123",
token_type="refresh",
expires_at=datetime.now(timezone.utc) + timedelta(hours=1)
)
token = store.get_token("test_token")
assert token is not None
assert token["value"] == "secret_value_123"
assert token["type"] == "refresh"
def test_encryption(self, tmp_path):
"""Test that tokens are encrypted on disk."""
storage_path = tmp_path / "tokens.enc"
store = SecureTokenStore(
storage_path=storage_path,
encryption_key="test_password"
)
store.store_token(
token_id="sensitive_token",
token_value="super_secret_value",
token_type="access"
)
# Read raw file content
with open(storage_path, "rb") as f:
raw_content = f.read()
# Should not contain the plaintext token
assert b"super_secret_value" not in raw_content
assert b"sensitive_token" not in raw_content # ID should also be encrypted
def test_expiration(self, tmp_path):
"""Test token expiration."""
store = SecureTokenStore(
storage_path=tmp_path / "tokens.enc",
encryption_key="test_password"
)
# Store expired token
store.store_token(
token_id="expired_token",
token_value="old_value",
expires_at=datetime.now(timezone.utc) - timedelta(hours=1)
)
# Should not retrieve expired token
token = store.get_token("expired_token")
assert token is None
def test_persistence_across_instances(self, tmp_path):
"""Test token persistence across store instances."""
storage_path = tmp_path / "tokens.enc"
# Store token with first instance
store1 = SecureTokenStore(
storage_path=storage_path,
encryption_key="test_password"
)
store1.store_token(
token_id="persistent_token",
token_value="persistent_value"
)
# Retrieve with second instance
store2 = SecureTokenStore(
storage_path=storage_path,
encryption_key="test_password"
)
token = store2.get_token("persistent_token")
assert token is not None
assert token["value"] == "persistent_value"
def test_wrong_key_fails(self, tmp_path):
"""Test that wrong encryption key fails gracefully."""
storage_path = tmp_path / "tokens.enc"
# Store with one key
store1 = SecureTokenStore(
storage_path=storage_path,
encryption_key="correct_password"
)
store1.store_token(token_id="test", token_value="value")
# Try to load with wrong key
store2 = SecureTokenStore(
storage_path=storage_path,
encryption_key="wrong_password"
)
# Should start fresh, not crash
token = store2.get_token("test")
assert token is None
class TestAsyncCompatibility:
"""Test async compatibility utilities."""
def test_compatible_event_loop_policy(self):
"""Test compatible event loop policy."""
policy = CompatibleEventLoopPolicy()
asyncio.set_event_loop_policy(policy)
# Should create loop when needed
loop = asyncio.get_event_loop()
assert loop is not None
assert not loop.is_closed()
# Clean up
loop.close()
asyncio.set_event_loop(None)
asyncio.set_event_loop_policy(asyncio.DefaultEventLoopPolicy())
def test_ensure_event_loop(self):
"""Test ensure_event_loop function."""
# Clear any existing loop
try:
loop = asyncio.get_event_loop()
loop.close()
except RuntimeError:
pass
asyncio.set_event_loop(None)
# Should create new loop
loop = ensure_event_loop()
assert loop is not None
assert not loop.is_closed()
# Clean up
loop.close()
asyncio.set_event_loop(None)
def test_run_async_in_sync(self):
"""Test running async function from sync context."""
async def async_func(value):
await asyncio.sleep(0.01)
return value * 2
result = run_async_in_sync(async_func, 21)
assert result == 42
def test_async_context_manager(self):
"""Test AsyncContextManager."""
async def async_task():
await asyncio.sleep(0.01)
return "completed"
with AsyncContextManager() as ctx:
result = ctx.run(async_task())
assert result == "completed"
class TestResponseWrapper:
"""Test response wrapper functionality."""
def test_response_wrapper_basic(self):
"""Test basic response wrapper functionality."""
import httpx
# Create mock response
response = httpx.Response(
200,
headers={"content-type": "application/json"},
content=b'{"key": "value"}'
)
wrapper = ResponseWrapper(response)
assert wrapper.status_code == 200
assert wrapper.json() == {"key": "value"}
def test_response_wrapper_modification(self):
"""Test response content modification."""
import httpx
response = httpx.Response(
200,
headers={"content-type": "application/json"},
content=b'{"old": "value"}'
)
wrapper = ResponseWrapper(response)
wrapper.set_json({"new": "value"})
assert wrapper.json() == {"new": "value"}
assert wrapper.content == b'{"new": "value"}'
def test_response_wrapper_modify_json(self):
"""Test JSON modification with function."""
import httpx
response = httpx.Response(
200,
headers={"content-type": "application/json"},
content=b'{"count": 10}'
)
wrapper = ResponseWrapper(response)
wrapper.modify_json(lambda data: {**data, "count": data["count"] * 2})
assert wrapper.json() == {"count": 20}
class TestStructuredExceptions:
"""Test structured exception classes."""
def test_oauth_state_error(self):
"""Test OAuthStateError."""
error = OAuthStateError("Invalid state")
assert error.code == "OAUTH_STATE_ERROR"
assert error.message == "Invalid state"
error_dict = error.to_dict()
assert error_dict["error"] == "OAUTH_STATE_ERROR"
assert error_dict["message"] == "Invalid state"
def test_timeout_error(self):
"""Test TimeoutError."""
error = TimeoutError("Request timed out", operation="list_campaigns")
assert error.code == "TIMEOUT_ERROR"
assert error.details["operation"] == "list_campaigns"
def test_api_error(self):
"""Test APIError."""
error = APIError(
"API request failed",
status_code=404,
response_body="Not found"
)
assert error.code == "API_ERROR"
assert error.status_code == 404
assert error.details["status_code"] == 404
assert error.details["response_body"] == "Not found"
def test_tool_execution_error(self):
"""Test ToolExecutionError."""
original = ValueError("Original error")
error = ToolExecutionError(
"Tool failed",
tool_name="test_tool",
original_error=original
)
assert error.code == "TOOL_EXECUTION_ERROR"
assert error.tool_name == "test_tool"
assert error.details["tool"] == "test_tool"
assert "ValueError" in error.details["error_type"]
class TestSamplingWrapper:
"""Test sampling wrapper functionality."""
@pytest.mark.asyncio
async def test_sampling_wrapper_with_handler(self):
"""Test sampling wrapper with configured handler."""
# Create mock handler
async def mock_handler(messages, params, context):
return MagicMock(content="sampled response")
wrapper = SamplingHandlerWrapper()
wrapper.set_handler(mock_handler)
assert wrapper.has_handler() is True
# Mock context that doesn't support sampling
mock_ctx = MagicMock()
mock_ctx.sample = AsyncMock(side_effect=Exception("does not support sampling"))
mock_ctx.request_context = {}
result = await wrapper.sample(
messages="test message",
ctx=mock_ctx
)
assert result == "sampled response"
@pytest.mark.asyncio
async def test_sampling_wrapper_no_handler(self):
"""Test sampling wrapper without handler."""
wrapper = SamplingHandlerWrapper()
assert wrapper.has_handler() is False
mock_ctx = MagicMock()
mock_ctx.sample = AsyncMock(side_effect=Exception("does not support sampling"))
with pytest.raises(Exception) as exc_info:
await wrapper.sample(
messages="test message",
ctx=mock_ctx
)
assert "no server-side fallback is configured" in str(exc_info.value)