"""Tests for URL validator."""
import pytest
from mcp_server_builder.utils.url_validator import (
DEFAULT_ALLOWED_DOMAINS,
URLValidationError,
URLValidator,
validate_urls,
)
class TestURLValidator:
"""Tests for URLValidator class."""
def test_init_with_frozenset(self) -> None:
"""Test initialization with frozenset."""
domains = frozenset({"https://example.com/"})
validator = URLValidator(domains)
assert validator._allowed == domains
def test_is_url_allowed_valid(self) -> None:
"""Test allowed URL returns True."""
validator = URLValidator(DEFAULT_ALLOWED_DOMAINS)
assert validator.is_url_allowed("https://modelcontextprotocol.io/docs")
assert validator.is_url_allowed("https://gofastmcp.com/tutorials")
def test_is_url_allowed_invalid(self) -> None:
"""Test disallowed URL returns False."""
validator = URLValidator(DEFAULT_ALLOWED_DOMAINS)
assert not validator.is_url_allowed("https://evil.com/malware")
assert not validator.is_url_allowed("http://modelcontextprotocol.io/docs")
def test_is_url_allowed_empty(self) -> None:
"""Test empty/None URL returns False."""
validator = URLValidator(DEFAULT_ALLOWED_DOMAINS)
assert not validator.is_url_allowed("")
assert not validator.is_url_allowed(None) # type: ignore
def test_validate_urls_single_valid(self) -> None:
"""Test validating a single valid URL."""
validator = URLValidator(DEFAULT_ALLOWED_DOMAINS)
result = validator.validate_urls("https://modelcontextprotocol.io/docs")
assert result == ["https://modelcontextprotocol.io/docs"]
def test_validate_urls_list_valid(self) -> None:
"""Test validating a list of valid URLs."""
validator = URLValidator(DEFAULT_ALLOWED_DOMAINS)
urls = [
"https://modelcontextprotocol.io/docs",
"https://gofastmcp.com/tutorials",
]
result = validator.validate_urls(urls)
assert result == urls
def test_validate_urls_raises_on_invalid(self) -> None:
"""Test that invalid URLs raise URLValidationError."""
validator = URLValidator(DEFAULT_ALLOWED_DOMAINS)
with pytest.raises(URLValidationError) as exc_info:
validator.validate_urls("https://evil.com/malware")
assert "evil.com" in str(exc_info.value)
class TestValidateUrlsFunction:
"""Tests for validate_urls function."""
def test_relative_url_conversion(self) -> None:
"""Test that relative URLs are converted to absolute."""
result = validate_urls("/docs/concepts/tools")
assert result == ["https://modelcontextprotocol.io/docs/concepts/tools"]
def test_custom_domains(self) -> None:
"""Test with custom allowed domains."""
custom = frozenset({"https://custom.com/"})
result = validate_urls("https://custom.com/page", allowed_domains=custom)
assert result == ["https://custom.com/page"]
def test_custom_domains_rejects_default(self) -> None:
"""Test custom domains reject default allowed URLs."""
custom = frozenset({"https://custom.com/"})
with pytest.raises(URLValidationError):
validate_urls("https://modelcontextprotocol.io/docs", allowed_domains=custom)