validation.py•17.4 kB
"""Comprehensive validation and error handling for GCP MCP server."""
import re
from typing import Any, Dict, List, Optional, Union
from datetime import datetime
from pydantic import BaseModel, Field, validator
from google.cloud.exceptions import GoogleCloudError
from .exceptions import ValidationError, GCPServiceError
class ProjectValidator:
"""Validator for GCP project IDs and configurations."""
@staticmethod
def validate_project_id(project_id: str) -> str:
"""Validate GCP project ID format."""
if not project_id:
raise ValidationError("Project ID cannot be empty")
# GCP project ID rules
if not re.match(r'^[a-z][a-z0-9\-]{4,28}[a-z0-9]$', project_id):
raise ValidationError(
f"Invalid project ID format: {project_id}. "
"Project IDs must be 6-30 characters, start with lowercase letter, "
"contain only lowercase letters, numbers, and hyphens, "
"and end with a letter or number."
)
return project_id
@staticmethod
def validate_project_list(projects: List[str]) -> List[str]:
"""Validate a list of project IDs."""
if not projects:
raise ValidationError("Project list cannot be empty")
if len(projects) > 50:
raise ValidationError("Too many projects specified. Maximum is 50.")
validated_projects = []
for project in projects:
validated_projects.append(ProjectValidator.validate_project_id(project))
return validated_projects
class TimeRangeValidator:
"""Validator for time ranges and durations."""
@staticmethod
def validate_duration(duration: str) -> str:
"""Validate duration string format."""
if not duration:
raise ValidationError("Duration cannot be empty")
pattern = r'^(\d+)(s|m|h|d|w)$'
if not re.match(pattern, duration):
raise ValidationError(
f"Invalid duration format: {duration}. "
"Use format like '30s', '5m', '2h', '7d', '1w'"
)
# Extract number and unit
match = re.match(pattern, duration)
number = int(match.group(1))
unit = match.group(2)
# Validate reasonable limits
if unit == 's' and number > 3600: # Max 1 hour in seconds
raise ValidationError("Duration in seconds cannot exceed 3600 (1 hour)")
elif unit == 'm' and number > 10080: # Max 1 week in minutes
raise ValidationError("Duration in minutes cannot exceed 10080 (1 week)")
elif unit == 'h' and number > 8760: # Max 1 year in hours
raise ValidationError("Duration in hours cannot exceed 8760 (1 year)")
elif unit == 'd' and number > 365: # Max 1 year in days
raise ValidationError("Duration in days cannot exceed 365 (1 year)")
elif unit == 'w' and number > 52: # Max 1 year in weeks
raise ValidationError("Duration in weeks cannot exceed 52 (1 year)")
return duration
@staticmethod
def validate_time_range(start: Optional[str], end: Optional[str] = None) -> Dict[str, str]:
"""Validate time range parameters."""
if not start:
raise ValidationError("Start time is required")
validated_range = {"start": start}
# Validate start time
if start.endswith(('s', 'm', 'h', 'd', 'w')):
TimeRangeValidator.validate_duration(start)
else:
try:
datetime.fromisoformat(start.replace('Z', '+00:00'))
except ValueError:
raise ValidationError(f"Invalid start time format: {start}")
# Validate end time if provided
if end:
if end.endswith(('s', 'm', 'h', 'd', 'w')):
TimeRangeValidator.validate_duration(end)
else:
try:
datetime.fromisoformat(end.replace('Z', '+00:00'))
except ValueError:
raise ValidationError(f"Invalid end time format: {end}")
validated_range["end"] = end
return validated_range
class MetricValidator:
"""Validator for metric queries and configurations."""
@staticmethod
def validate_metric_type(metric_type: str) -> str:
"""Validate GCP metric type format."""
if not metric_type:
raise ValidationError("Metric type cannot be empty")
# Basic metric type format validation
if not re.match(r'^[a-z][a-z0-9_]*\.[a-z][a-z0-9_]*\.[a-z][a-z0-9_/]*$', metric_type):
raise ValidationError(
f"Invalid metric type format: {metric_type}. "
"Should follow pattern like 'service.googleapis.com/metric/name'"
)
return metric_type
@staticmethod
def validate_aggregation_config(config: Dict[str, Any]) -> Dict[str, Any]:
"""Validate metric aggregation configuration."""
if not config:
return config
# Validate alignment period
if "alignment_period" in config:
period = config["alignment_period"]
if not re.match(r'^\d+s$', period):
raise ValidationError(
f"Invalid alignment period: {period}. Must be in format like '60s'"
)
period_seconds = int(period[:-1])
if period_seconds < 1 or period_seconds > 3600:
raise ValidationError(
"Alignment period must be between 1s and 3600s (1 hour)"
)
# Validate aggregation methods
valid_aligners = {
'ALIGN_NONE', 'ALIGN_DELTA', 'ALIGN_RATE', 'ALIGN_INTERPOLATE',
'ALIGN_NEXT_OLDER', 'ALIGN_MIN', 'ALIGN_MAX', 'ALIGN_MEAN',
'ALIGN_COUNT', 'ALIGN_SUM', 'ALIGN_STDDEV', 'ALIGN_COUNT_TRUE',
'ALIGN_COUNT_FALSE', 'ALIGN_FRACTION_TRUE', 'ALIGN_PERCENTILE_99',
'ALIGN_PERCENTILE_95', 'ALIGN_PERCENTILE_50', 'ALIGN_PERCENTILE_05'
}
if "per_series_aligner" in config:
aligner = config["per_series_aligner"]
if aligner not in valid_aligners:
raise ValidationError(f"Invalid per_series_aligner: {aligner}")
valid_reducers = {
'REDUCE_NONE', 'REDUCE_MEAN', 'REDUCE_MIN', 'REDUCE_MAX',
'REDUCE_SUM', 'REDUCE_STDDEV', 'REDUCE_COUNT', 'REDUCE_COUNT_TRUE',
'REDUCE_COUNT_FALSE', 'REDUCE_FRACTION_TRUE', 'REDUCE_PERCENTILE_99',
'REDUCE_PERCENTILE_95', 'REDUCE_PERCENTILE_50', 'REDUCE_PERCENTILE_05'
}
if "cross_series_reducer" in config:
reducer = config["cross_series_reducer"]
if reducer not in valid_reducers:
raise ValidationError(f"Invalid cross_series_reducer: {reducer}")
return config
class FilterValidator:
"""Validator for log and metric filters."""
@staticmethod
def validate_log_filter(filter_expr: str) -> str:
"""Validate Cloud Logging filter expression."""
if not filter_expr:
return filter_expr
# Basic validation for common filter mistakes
dangerous_patterns = [
r';\s*(DROP|DELETE|UPDATE|INSERT)', # SQL injection patterns
r'<script.*?>', # XSS patterns
r'javascript:', # JavaScript injection
]
for pattern in dangerous_patterns:
if re.search(pattern, filter_expr, re.IGNORECASE):
raise ValidationError(f"Potentially dangerous filter expression detected")
# Validate filter length
if len(filter_expr) > 20000:
raise ValidationError("Filter expression too long. Maximum length is 20,000 characters")
return filter_expr
@staticmethod
def validate_resource_types(resource_types: List[str]) -> List[str]:
"""Validate GCP resource types."""
if not resource_types:
return resource_types
# Common GCP resource types
valid_resource_types = {
'gce_instance', 'k8s_container', 'k8s_pod', 'k8s_node', 'k8s_cluster',
'gae_app', 'cloud_function', 'cloud_run_revision', 'cloud_sql_database',
'bigquery_dataset', 'bigquery_table', 'bigquery_project',
'gcs_bucket', 'pubsub_topic', 'pubsub_subscription',
'cloudsql_database', 'redis_instance', 'memcache_instance',
'dataflow_job', 'dataproc_cluster', 'compute_firewall_rule'
}
validated_types = []
for resource_type in resource_types:
if resource_type not in valid_resource_types:
# Allow custom resource types but log a warning
import structlog
logger = structlog.get_logger(__name__)
logger.warning("Unknown resource type", resource_type=resource_type)
validated_types.append(resource_type)
return validated_types
class SecurityValidator:
"""Validator for security-related configurations."""
@staticmethod
def validate_search_terms(search_terms: List[str]) -> List[str]:
"""Validate search terms for security concerns."""
if not search_terms:
return search_terms
validated_terms = []
for term in search_terms:
# Check for potential credential patterns
credential_patterns = [
r'(?i)(password|passwd|pwd)\s*[:=]\s*["\']?[a-zA-Z0-9!@#$%^&*()]+',
r'(?i)(api[_-]?key|apikey)\s*[:=]\s*["\']?[a-zA-Z0-9]+',
r'(?i)(secret|token)\s*[:=]\s*["\']?[a-zA-Z0-9]+',
r'(?i)bearer\s+[a-zA-Z0-9\-_]+',
]
contains_credentials = False
for pattern in credential_patterns:
if re.search(pattern, term):
contains_credentials = True
break
if contains_credentials:
raise ValidationError(
f"Search term appears to contain credentials: {term[:50]}..."
)
# Validate length
if len(term) > 1000:
raise ValidationError("Search term too long. Maximum length is 1000 characters")
validated_terms.append(term)
return validated_terms
@staticmethod
def validate_compliance_framework(framework: str) -> str:
"""Validate compliance framework."""
if not framework:
return framework
valid_frameworks = {'SOC2', 'PCI-DSS', 'HIPAA', 'GDPR', 'CUSTOM', 'ISO27001', 'NIST'}
if framework.upper() not in valid_frameworks:
raise ValidationError(
f"Invalid compliance framework: {framework}. "
f"Valid options: {', '.join(valid_frameworks)}"
)
return framework.upper()
class RateLimitValidator:
"""Validator for rate limiting and resource constraints."""
@staticmethod
def validate_query_limits(
max_results: int,
projects_count: int,
time_range_hours: float
) -> int:
"""Validate query limits to prevent resource exhaustion."""
# Base limits
max_results_per_project = 10000
max_total_results = 50000
# Adjust limits based on time range
if time_range_hours > 24:
max_results_per_project = 5000
max_total_results = 25000
elif time_range_hours > 168: # 1 week
max_results_per_project = 2000
max_total_results = 10000
# Calculate effective limits
effective_max = min(
max_results,
max_results_per_project * projects_count,
max_total_results
)
if max_results > effective_max:
raise ValidationError(
f"Requested results ({max_results}) exceed limits. "
f"Maximum allowed: {effective_max} "
f"(based on {projects_count} projects and {time_range_hours:.1f}h time range)"
)
return effective_max
def validate_tool_arguments(tool_name: str, arguments: Dict[str, Any]) -> Dict[str, Any]:
"""Comprehensive validation for tool arguments."""
validated_args = arguments.copy()
try:
# Common validations
if "projects" in arguments:
if arguments["projects"]:
validated_args["projects"] = ProjectValidator.validate_project_list(
arguments["projects"]
)
if "time_range" in arguments:
validated_args["time_range"] = TimeRangeValidator.validate_time_range(
arguments["time_range"].get("start"),
arguments["time_range"].get("end")
)
if "filter_expression" in arguments:
validated_args["filter_expression"] = FilterValidator.validate_log_filter(
arguments["filter_expression"]
)
if "resource_types" in arguments:
validated_args["resource_types"] = FilterValidator.validate_resource_types(
arguments.get("resource_types", [])
)
if "search_term" in arguments:
SecurityValidator.validate_search_terms([arguments["search_term"]])
if "threat_indicators" in arguments:
validated_args["threat_indicators"] = SecurityValidator.validate_search_terms(
arguments.get("threat_indicators", [])
)
if "compliance_framework" in arguments:
validated_args["compliance_framework"] = SecurityValidator.validate_compliance_framework(
arguments["compliance_framework"]
)
# Tool-specific validations
if tool_name.startswith("advanced_") and "metric_queries" in arguments:
for query in arguments["metric_queries"]:
MetricValidator.validate_metric_type(query["metric_type"])
if "aggregation" in query:
MetricValidator.validate_aggregation_config(query["aggregation"])
# Rate limiting validation
if "max_results" in arguments:
projects_count = len(arguments.get("projects", ["default-project"]))
time_range = arguments.get("time_range", {})
# Estimate time range in hours
time_range_hours = 1.0 # Default
if isinstance(time_range, dict) and "start" in time_range:
start = time_range["start"]
if start.endswith('h'):
time_range_hours = float(start[:-1])
elif start.endswith('d'):
time_range_hours = float(start[:-1]) * 24
elif start.endswith('w'):
time_range_hours = float(start[:-1]) * 24 * 7
validated_args["max_results"] = RateLimitValidator.validate_query_limits(
arguments["max_results"],
projects_count,
time_range_hours
)
return validated_args
except Exception as e:
if isinstance(e, ValidationError):
raise
else:
raise ValidationError(f"Validation failed for {tool_name}: {str(e)}")
def handle_gcp_errors(func):
"""Decorator to handle GCP API errors gracefully."""
async def wrapper(*args, **kwargs):
try:
return await func(*args, **kwargs)
except GoogleCloudError as e:
import structlog
logger = structlog.get_logger(__name__)
logger.error("GCP API error", error=str(e), error_code=getattr(e, 'code', None))
# Convert to user-friendly error messages
if hasattr(e, 'code'):
if e.code == 403:
raise GCPServiceError(
"Access denied. Please check your GCP permissions and ensure "
"the required APIs are enabled."
)
elif e.code == 404:
raise GCPServiceError(
"Resource not found. Please verify project IDs and resource names."
)
elif e.code == 429:
raise GCPServiceError(
"Rate limit exceeded. Please reduce query frequency or scope."
)
elif e.code == 500:
raise GCPServiceError(
"GCP service temporarily unavailable. Please try again later."
)
raise GCPServiceError(f"GCP API error: {str(e)}")
except Exception as e:
import structlog
logger = structlog.get_logger(__name__)
logger.error("Unexpected error", error=str(e))
raise GCPServiceError(f"Unexpected error: {str(e)}")
return wrapper