"""Safety manager for secure file operations."""
import logging
import shutil
from datetime import datetime
from pathlib import Path
from typing import List, Optional
import os
from .config import Config
logger = logging.getLogger(__name__)
class SafetyManager:
"""Manages backup operations and security checks for file operations."""
def __init__(self, config: Config):
"""Initialize SafetyManager with configuration."""
self.config = config
self.backup_dir = Path(config.backup_directory).resolve()
self._ensure_backup_directory()
def _ensure_backup_directory(self) -> None:
"""Ensure backup directory exists."""
try:
self.backup_dir.mkdir(parents=True, exist_ok=True)
logger.debug(f"Backup directory ready: {self.backup_dir}")
except Exception as e:
logger.error(f"Failed to create backup directory {self.backup_dir}: {e}")
raise
def validate_path(self, path: str) -> bool:
"""Validate if a path is safe to operate on."""
try:
resolved_path = Path(path).resolve()
# Check if path is in protected directories
for protected in self.config.protected_paths:
protected_path = Path(protected).resolve()
try:
resolved_path.relative_to(protected_path)
logger.warning(f"Access denied to protected path: {resolved_path}")
return False
except ValueError:
# Path is not under protected directory, continue checking
continue
return True
except Exception as e:
logger.error(f"Path validation failed for {path}: {e}")
return False
def is_safe_operation(self, operation: str, path: str) -> bool:
"""Check if an operation is safe to perform on the given path."""
if not self.validate_path(path):
return False
resolved_path = Path(path).resolve()
# Additional checks based on operation type
if operation in ["write", "update", "delete"]:
# Check if we're trying to modify system files
if self._is_system_file(resolved_path):
logger.warning(f"Blocked {operation} on system file: {resolved_path}")
return False
if operation == "read":
# Check file size for read operations
if resolved_path.exists() and resolved_path.is_file():
try:
file_size = resolved_path.stat().st_size
if file_size > self.config.max_file_size:
logger.warning(f"File too large for read: {file_size} bytes")
return False
except OSError:
# If we can't get file size, it's probably not safe
return False
return True
def _is_system_file(self, path: Path) -> bool:
"""Check if a path points to a system file."""
# Common system file patterns
system_patterns = [
"/etc/",
"/usr/bin/",
"/usr/sbin/",
"/bin/",
"/sbin/",
"/System/", # macOS
"C:\\Windows\\", # Windows
"C:\\Program Files\\", # Windows
]
path_str = str(path)
return any(path_str.startswith(pattern) for pattern in system_patterns)
def create_backup(self, file_path: str) -> Optional[str]:
"""Create a backup of the specified file."""
if not self.config.enable_backups:
logger.debug("Backups disabled, skipping backup creation")
return None
try:
source_path = Path(file_path).resolve()
if not source_path.exists():
logger.debug(f"Source file doesn't exist, no backup needed: {source_path}")
return None
# Generate backup filename with timestamp
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f")[:-3] # microseconds to milliseconds
backup_name = f"{source_path.name}.{timestamp}.backup"
backup_path = self.backup_dir / backup_name
# Create backup
if source_path.is_file():
shutil.copy2(source_path, backup_path)
elif source_path.is_dir():
shutil.copytree(source_path, backup_path)
else:
logger.warning(f"Cannot backup non-regular file: {source_path}")
return None
logger.info(f"Backup created: {backup_path}")
return str(backup_path)
except Exception as e:
logger.error(f"Failed to create backup for {file_path}: {e}")
return None
def restore_from_backup(self, backup_path: str, original_path: str) -> bool:
"""Restore a file from backup."""
try:
backup_file = Path(backup_path)
original_file = Path(original_path)
if not backup_file.exists():
logger.error(f"Backup file doesn't exist: {backup_path}")
return False
# Remove the corrupted original if it exists
if original_file.exists():
if original_file.is_file():
original_file.unlink()
elif original_file.is_dir():
shutil.rmtree(original_file)
# Restore from backup
if backup_file.is_file():
shutil.copy2(backup_file, original_file)
elif backup_file.is_dir():
shutil.copytree(backup_file, original_file)
logger.info(f"Restored {original_path} from backup {backup_path}")
return True
except Exception as e:
logger.error(f"Failed to restore from backup {backup_path}: {e}")
return False
def cleanup_old_backups(self, max_age_days: int = 30) -> int:
"""Clean up old backup files."""
if not self.backup_dir.exists():
return 0
try:
cutoff_time = datetime.now().timestamp() - (max_age_days * 24 * 3600)
cleaned_count = 0
for backup_file in self.backup_dir.iterdir():
if backup_file.is_file() and backup_file.name.endswith('.backup'):
try:
if backup_file.stat().st_mtime < cutoff_time:
backup_file.unlink()
cleaned_count += 1
logger.debug(f"Cleaned up old backup: {backup_file}")
except OSError as e:
logger.warning(f"Failed to clean backup {backup_file}: {e}")
logger.info(f"Cleaned up {cleaned_count} old backup files")
return cleaned_count
except Exception as e:
logger.error(f"Failed to cleanup old backups: {e}")
return 0
def get_backup_info(self, original_path: str) -> List[dict]:
"""Get information about available backups for a file."""
try:
original_file = Path(original_path)
backups = []
if not self.backup_dir.exists():
return backups
# Look for backup files matching the original filename
pattern = f"{original_file.name}.*.backup"
for backup_file in self.backup_dir.glob(pattern):
try:
stat_info = backup_file.stat()
backups.append({
"path": str(backup_file),
"created": datetime.fromtimestamp(stat_info.st_ctime).isoformat(),
"size": stat_info.st_size,
"modified": datetime.fromtimestamp(stat_info.st_mtime).isoformat()
})
except OSError:
continue
# Sort by creation time, newest first
backups.sort(key=lambda x: x["created"], reverse=True)
return backups
except Exception as e:
logger.error(f"Failed to get backup info for {original_path}: {e}")
return []
def validate_file_extension(self, path: str) -> bool:
"""Validate if file extension is allowed."""
if self.config.allowed_extensions is None:
return True # All extensions allowed
file_path = Path(path)
extension = file_path.suffix.lower()
return extension in [ext.lower() for ext in self.config.allowed_extensions]