# -*- coding: utf-8 -*-
"""Location: ./tests/security/test_security_headers.py
Copyright 2025
SPDX-License-Identifier: Apache-2.0
Authors: Mihai Criveti
Security Headers and CORS Testing.
This module contains comprehensive tests for security headers middleware and CORS configuration.
"""
# Standard
from unittest.mock import patch
# Third-Party
from fastapi.testclient import TestClient
import pytest
# First-Party
from mcpgateway.config import settings
class TestSecurityHeaders:
"""Test security headers are properly set on all responses."""
def test_security_headers_present_on_health_endpoint(self, client: TestClient):
"""Test that essential security headers are present on health endpoint."""
response = client.get("/health")
# Essential security headers
assert response.headers["X-Content-Type-Options"] == "nosniff"
assert response.headers["X-Frame-Options"] == "DENY"
assert response.headers["X-XSS-Protection"] == "0"
assert response.headers["Referrer-Policy"] == "strict-origin-when-cross-origin"
assert "Content-Security-Policy" in response.headers
# Verify CSP contains essential directives
csp = response.headers["Content-Security-Policy"]
assert "default-src 'self'" in csp
assert "frame-ancestors 'none'" in csp
def test_security_headers_present_on_api_endpoints(self, client: TestClient):
"""Test security headers on API endpoints."""
# Test with authentication disabled for this test
with patch.object(settings, "auth_required", False):
response = client.get("/tools")
assert response.headers["X-Content-Type-Options"] == "nosniff"
assert response.headers["X-Frame-Options"] == "DENY"
assert response.headers["X-XSS-Protection"] == "0"
assert response.headers["Referrer-Policy"] == "strict-origin-when-cross-origin"
assert "Content-Security-Policy" in response.headers
def test_sensitive_headers_removed(self, client: TestClient):
"""Test that sensitive headers are removed."""
response = client.get("/health")
# These headers should not be present
assert "X-Powered-By" not in response.headers
assert "Server" not in response.headers
def test_hsts_header_on_https_request(self, client: TestClient):
"""Test HSTS header is present when X-Forwarded-Proto indicates HTTPS."""
response = client.get("/health", headers={"X-Forwarded-Proto": "https"})
assert "Strict-Transport-Security" in response.headers
hsts_value = response.headers["Strict-Transport-Security"]
assert "max-age=31536000" in hsts_value
assert "includeSubDomains" in hsts_value
def test_no_hsts_header_on_http_request(self, client: TestClient):
"""Test HSTS header is not present on HTTP requests."""
response = client.get("/health")
# HSTS should not be present for HTTP requests
assert "Strict-Transport-Security" not in response.headers
def test_content_security_policy_structure(self, client: TestClient):
"""Test CSP header has proper structure and directives."""
response = client.get("/health")
csp = response.headers["Content-Security-Policy"]
# Check for essential CSP directives
assert "default-src 'self'" in csp
assert "script-src 'self'" in csp
assert "style-src 'self'" in csp
assert "img-src 'self'" in csp
assert "font-src 'self'" in csp
assert "connect-src 'self'" in csp
assert "frame-ancestors 'none'" in csp
# Verify CSP ends with semicolon
assert csp.endswith(";")
class TestCORSConfiguration:
"""Test CORS configuration and behavior."""
def test_cors_with_development_origins(self, client: TestClient):
"""Test CORS works with development origins."""
with patch.object(settings, "environment", "development"):
with patch.object(settings, "allowed_origins", {"http://localhost:3000", "http://localhost:8080"}):
# Test with actual GET request that includes CORS headers
response = client.get("/health", headers={"Origin": "http://localhost:3000"})
assert response.status_code == 200
# Check that CORS headers are present for allowed origin
assert response.headers.get("Access-Control-Allow-Origin") == "http://localhost:3000"
def test_cors_blocks_unauthorized_origin(self, client: TestClient):
"""Test CORS blocks unauthorized origins."""
with patch.object(settings, "allowed_origins", {"http://localhost:3000"}):
# Test blocked origin with GET request
response = client.get("/health", headers={"Origin": "https://evil.com"})
# For blocked origins, Access-Control-Allow-Origin should not be set to the blocked origin
assert response.headers.get("Access-Control-Allow-Origin") != "https://evil.com"
# The response should still succeed but without CORS headers for the blocked origin
assert response.status_code == 200
def test_cors_credentials_allowed(self, client: TestClient):
"""Test CORS allows credentials when configured."""
with patch.object(settings, "cors_allow_credentials", True):
with patch.object(settings, "allowed_origins", {"http://localhost:3000"}):
response = client.get("/health", headers={"Origin": "http://localhost:3000"})
assert response.headers.get("Access-Control-Allow-Credentials") == "true"
def test_cors_allowed_methods(self, client: TestClient):
"""Test CORS exposes correct allowed methods."""
with patch.object(settings, "allowed_origins", {"http://localhost:3000"}):
# Test with an endpoint that supports OPTIONS for proper CORS preflight
# Use the root endpoint which should support more methods
response = client.get("/health", headers={"Origin": "http://localhost:3000"})
# Check that the response includes CORS origin header indicating CORS is working
assert response.headers.get("Access-Control-Allow-Origin") == "http://localhost:3000"
def test_cors_exposed_headers(self, client: TestClient):
"""Test CORS exposes correct headers."""
with patch.object(settings, "allowed_origins", {"http://localhost:3000"}):
response = client.get("/health", headers={"Origin": "http://localhost:3000"})
# Check that CORS is working with the allowed origin
assert response.headers.get("Access-Control-Allow-Origin") == "http://localhost:3000"
# Check for exposed headers (these may be set by CORS middleware)
exposed_headers = response.headers.get("Access-Control-Expose-Headers", "")
if exposed_headers: # Only check if the header is present
assert "Content-Length" in exposed_headers
assert "X-Request-ID" in exposed_headers
class TestProductionSecurity:
"""Test security configuration in production environment."""
def test_production_cors_requires_explicit_origins(self, client: TestClient):
"""Test that production environment requires explicit CORS origins."""
with patch.object(settings, "environment", "production"):
with patch.object(settings, "allowed_origins", set()):
# Should have empty origins list for production without explicit config
assert len(settings.allowed_origins) == 0
def test_production_uses_https_origins(self, client: TestClient):
"""Test that production environment uses HTTPS origins."""
with patch.object(settings, "environment", "production"):
with patch.object(settings, "app_domain", "example.com"):
# This would be set during initialization
test_origins = {"https://example.com", "https://app.example.com", "https://admin.example.com"}
with patch.object(settings, "allowed_origins", test_origins):
# All origins should be HTTPS
for origin in settings.allowed_origins:
assert origin.startswith("https://")
def test_security_headers_consistent_across_endpoints(self, client: TestClient):
"""Test security headers are consistent across different endpoints."""
endpoints = ["/health", "/ready"]
headers_to_check = ["X-Content-Type-Options", "X-Frame-Options", "X-XSS-Protection", "Referrer-Policy", "Content-Security-Policy"]
responses = {}
for endpoint in endpoints:
responses[endpoint] = client.get(endpoint)
# Check that all endpoints have the same security headers
for header in headers_to_check:
values = [responses[endpoint].headers.get(header) for endpoint in endpoints]
assert all(value == values[0] for value in values), f"Inconsistent {header} across endpoints"
class TestSecurityHeadersEdgeCases:
"""Test edge cases and error conditions for security headers."""
def test_security_headers_on_error_responses(self, client: TestClient):
"""Test security headers are present even on error responses."""
# Make a request to a non-existent endpoint
response = client.get("/nonexistent")
# Even 404 responses should have security headers
assert response.headers["X-Content-Type-Options"] == "nosniff"
assert response.headers["X-Frame-Options"] == "DENY"
assert "Content-Security-Policy" in response.headers
def test_security_headers_on_method_not_allowed(self, client: TestClient):
"""Test security headers on 405 Method Not Allowed responses."""
# Try to POST to a GET-only endpoint
response = client.post("/health")
assert response.status_code == 405
assert response.headers["X-Content-Type-Options"] == "nosniff"
assert response.headers["X-Frame-Options"] == "DENY"
assert "Content-Security-Policy" in response.headers
@pytest.mark.parametrize("forwarded_proto", ["http", "https", "invalid"])
def test_hsts_with_various_forwarded_proto_values(self, client: TestClient, forwarded_proto: str):
"""Test HSTS behavior with various X-Forwarded-Proto values."""
response = client.get("/health", headers={"X-Forwarded-Proto": forwarded_proto})
if forwarded_proto == "https":
assert "Strict-Transport-Security" in response.headers
else:
assert "Strict-Transport-Security" not in response.headers
@pytest.fixture
def client(app):
"""Create a test client for the FastAPI app."""
return TestClient(app)