validator.py•2.73 kB
"""
Security validation utilities.
"""
import re
from typing import Any, Dict, List
from mcp.types import CallToolRequest
import logging
logger = logging.getLogger(__name__)
class SecurityValidator:
"""Security validation for tool calls."""
def __init__(self):
self.dangerous_patterns = [
r'javascript:',
r'data:',
r'vbscript:',
r'file://',
r'ftp://',
]
self.max_content_length = 10 * 1024 * 1024 # 10MB
async def validate_tool_call(self, request: CallToolRequest) -> bool:
"""Validate a tool call request."""
tool_name = request.params.name
arguments = request.params.arguments or {}
# Basic validation
if not tool_name or not isinstance(tool_name, str):
raise ValueError("Invalid tool name")
# Tool-specific validation
if tool_name == "web_search":
await self._validate_web_search(arguments)
elif tool_name == "web_scrape":
await self._validate_web_scrape(arguments)
return True
async def _validate_web_search(self, arguments: Dict[str, Any]):
"""Validate web search arguments."""
query = arguments.get('query', '')
max_results = arguments.get('max_results', 5)
if not query or not isinstance(query, str):
raise ValueError("Query is required and must be a string")
if len(query) > 1000:
raise ValueError("Query too long (max 1000 characters)")
if not isinstance(max_results, int) or max_results < 1 or max_results > 50:
raise ValueError("max_results must be between 1 and 50")
async def _validate_web_scrape(self, arguments: Dict[str, Any]):
"""Validate web scraping arguments."""
url = arguments.get('url', '')
format = arguments.get('format', 'markdown')
if not url or not isinstance(url, str):
raise ValueError("URL is required and must be a string")
# Check for dangerous URLs
for pattern in self.dangerous_patterns:
if re.search(pattern, url, re.IGNORECASE):
raise ValueError(f"URL contains dangerous pattern: {pattern}")
if format not in ['text', 'markdown']:
raise ValueError("Format must be 'text' or 'markdown'")
def _is_safe_url(self, url: str) -> bool:
"""Check if URL is safe."""
try:
from urllib.parse import urlparse
parsed = urlparse(url)
return parsed.scheme in ['http', 'https']
except Exception:
return False