Skip to main content
Glama
checkpoint.py18.8 kB
"""Checkpoint manager implementation""" import json import asyncio from pathlib import Path from typing import Optional, Dict, Any, List, Set, Tuple from datetime import datetime from dataclasses import dataclass, field import uuid from .cas import ContentAddressableStorage from ..utils.logging import get_logger from ..utils.errors import ValidationError, StorageError from ..utils.notifications import NotificationCenter, NotificationType logger = get_logger(__name__) @dataclass class CheckpointMetadata: """Metadata for a checkpoint""" checkpoint_id: str parent_id: Optional[str] created_at: datetime message: str author: str tags: List[str] = field(default_factory=list) stats: Dict[str, Any] = field(default_factory=dict) def to_dict(self) -> Dict[str, Any]: """Convert to dictionary""" return { "checkpoint_id": self.checkpoint_id, "parent_id": self.parent_id, "created_at": self.created_at.isoformat(), "message": self.message, "author": self.author, "tags": self.tags, "stats": self.stats } @classmethod def from_dict(cls, data: Dict[str, Any]) -> 'CheckpointMetadata': """Create from dictionary""" return cls( checkpoint_id=data["checkpoint_id"], parent_id=data.get("parent_id"), created_at=datetime.fromisoformat(data["created_at"]), message=data["message"], author=data["author"], tags=data.get("tags", []), stats=data.get("stats", {}) ) @dataclass class Checkpoint: """A checkpoint containing file snapshots""" metadata: CheckpointMetadata files: Dict[str, str] # path -> content hash def to_dict(self) -> Dict[str, Any]: """Convert to dictionary""" return { "metadata": self.metadata.to_dict(), "files": self.files } @classmethod def from_dict(cls, data: Dict[str, Any]) -> 'Checkpoint': """Create from dictionary""" return cls( metadata=CheckpointMetadata.from_dict(data["metadata"]), files=data["files"] ) class CheckpointManager: """Manages checkpoints with CAS backend Provides Git-like checkpoint functionality: - Create checkpoints of file states - Track changes between checkpoints - Support branching (multiple children) - Efficient storage with deduplication """ def __init__(self, storage_path: Path, notification_center: Optional[NotificationCenter] = None): """Initialize checkpoint manager Args: storage_path: Base path for checkpoint storage notification_center: Optional notification center """ self.storage_path = Path(storage_path) self.checkpoints_path = self.storage_path / "checkpoints" self.refs_path = self.storage_path / "refs" self.head_path = self.storage_path / "HEAD" # CAS for file content self.cas = ContentAddressableStorage(self.storage_path / "cas") # Notification center self.notification_center = notification_center or NotificationCenter() # In-memory caches self._checkpoints: Dict[str, Checkpoint] = {} self._checkpoint_lock = asyncio.Lock() self._refs: Dict[str, str] = {} # ref name -> checkpoint id self._refs_lock = asyncio.Lock() # Current HEAD self._head: Optional[str] = None async def initialize(self) -> None: """Initialize checkpoint manager""" # Create directories await asyncio.gather( asyncio.create_task(asyncio.to_thread(self.checkpoints_path.mkdir, parents=True, exist_ok=True)), asyncio.create_task(asyncio.to_thread(self.refs_path.mkdir, parents=True, exist_ok=True)) ) # Initialize CAS await self.cas.initialize() # Load existing checkpoints and refs await self._load_checkpoints() await self._load_refs() await self._load_head() logger.info( "checkpoint_manager_initialized", checkpoints=len(self._checkpoints), refs=len(self._refs), head=self._head ) async def create_checkpoint( self, files: Dict[str, bytes], message: str, author: str = "system", parent_id: Optional[str] = None, tags: Optional[List[str]] = None ) -> Checkpoint: """Create a new checkpoint Args: files: Dictionary of file paths to content message: Checkpoint message author: Author name parent_id: Parent checkpoint ID (if None, uses HEAD) tags: Optional tags Returns: Created checkpoint """ # Use HEAD as parent if not specified if parent_id is None and self._head: parent_id = self._head # Store file contents in CAS file_hashes = {} for path, content in files.items(): content_hash = await self.cas.store(content, {"path": path}) file_hashes[path] = content_hash # Create checkpoint checkpoint_id = self._generate_checkpoint_id() metadata = CheckpointMetadata( checkpoint_id=checkpoint_id, parent_id=parent_id, created_at=datetime.utcnow(), message=message, author=author, tags=tags or [], stats={ "file_count": len(files), "total_size": sum(len(content) for content in files.values()) } ) checkpoint = Checkpoint( metadata=metadata, files=file_hashes ) # Save checkpoint await self._save_checkpoint(checkpoint) # Update HEAD await self.update_head(checkpoint_id) # Send notification await self.notification_center.notify( NotificationType.CHECKPOINT, f"Checkpoint created: {message}", { "checkpoint_id": checkpoint_id, "file_count": len(files), "author": author } ) logger.info( "checkpoint_created", checkpoint_id=checkpoint_id, parent_id=parent_id, files=len(files), message=message ) return checkpoint async def get_checkpoint(self, checkpoint_id: str) -> Optional[Checkpoint]: """Get checkpoint by ID""" async with self._checkpoint_lock: return self._checkpoints.get(checkpoint_id) async def list_checkpoints( self, limit: Optional[int] = None, since: Optional[datetime] = None, until: Optional[datetime] = None, tags: Optional[List[str]] = None ) -> List[Checkpoint]: """List checkpoints with filtering Args: limit: Maximum number of checkpoints since: Filter by created after until: Filter by created before tags: Filter by tags Returns: List of checkpoints """ checkpoints = list(self._checkpoints.values()) # Apply filters if since: checkpoints = [ cp for cp in checkpoints if cp.metadata.created_at >= since ] if until: checkpoints = [ cp for cp in checkpoints if cp.metadata.created_at <= until ] if tags: tag_set = set(tags) checkpoints = [ cp for cp in checkpoints if tag_set.intersection(cp.metadata.tags) ] # Sort by created date (newest first) checkpoints.sort(key=lambda cp: cp.metadata.created_at, reverse=True) # Apply limit if limit: checkpoints = checkpoints[:limit] return checkpoints async def get_checkpoint_files( self, checkpoint_id: str, paths: Optional[List[str]] = None ) -> Dict[str, bytes]: """Get files from a checkpoint Args: checkpoint_id: Checkpoint ID paths: Specific paths to retrieve (if None, gets all) Returns: Dictionary of path -> content """ checkpoint = await self.get_checkpoint(checkpoint_id) if not checkpoint: raise ValidationError("checkpoint_id", checkpoint_id, "Checkpoint not found") files = {} # Get requested paths or all paths paths_to_get = paths if paths else list(checkpoint.files.keys()) for path in paths_to_get: if path in checkpoint.files: content_hash = checkpoint.files[path] content = await self.cas.retrieve(content_hash) if content: files[path] = content else: logger.warning(f"Content missing for {path} in checkpoint {checkpoint_id}") return files async def diff_checkpoints( self, from_id: Optional[str], to_id: str ) -> Dict[str, Any]: """Get differences between checkpoints Args: from_id: Source checkpoint (if None, compares with empty) to_id: Target checkpoint Returns: Diff information """ to_checkpoint = await self.get_checkpoint(to_id) if not to_checkpoint: raise ValidationError("to_id", to_id, "Checkpoint not found") from_checkpoint = None if from_id: from_checkpoint = await self.get_checkpoint(from_id) if not from_checkpoint: raise ValidationError("from_id", from_id, "Checkpoint not found") # Get file sets from_files = from_checkpoint.files if from_checkpoint else {} to_files = to_checkpoint.files # Calculate diff added = set(to_files.keys()) - set(from_files.keys()) removed = set(from_files.keys()) - set(to_files.keys()) # Check for modified files modified = [] for path in set(from_files.keys()) & set(to_files.keys()): if from_files[path] != to_files[path]: modified.append(path) return { "from_id": from_id, "to_id": to_id, "added": list(added), "removed": list(removed), "modified": modified, "stats": { "total_changes": len(added) + len(removed) + len(modified) } } async def restore_checkpoint(self, checkpoint_id: str) -> Dict[str, bytes]: """Restore files from a checkpoint Args: checkpoint_id: Checkpoint to restore Returns: All files from the checkpoint """ files = await self.get_checkpoint_files(checkpoint_id) # Update HEAD await self.update_head(checkpoint_id) # Send notification await self.notification_center.notify( NotificationType.CHECKPOINT, f"Checkpoint restored: {checkpoint_id}", { "checkpoint_id": checkpoint_id, "file_count": len(files) } ) logger.info( "checkpoint_restored", checkpoint_id=checkpoint_id, files=len(files) ) return files async def delete_checkpoint(self, checkpoint_id: str) -> bool: """Delete a checkpoint Args: checkpoint_id: Checkpoint to delete Returns: True if deleted """ async with self._checkpoint_lock: checkpoint = self._checkpoints.pop(checkpoint_id, None) if not checkpoint: return False # Delete checkpoint file checkpoint_path = self.checkpoints_path / f"{checkpoint_id}.json" try: checkpoint_path.unlink() except: pass # Update HEAD if necessary if self._head == checkpoint_id: # Find a new HEAD (parent or any other checkpoint) new_head = checkpoint.metadata.parent_id if not new_head and self._checkpoints: new_head = next(iter(self._checkpoints.keys())) await self.update_head(new_head) logger.info("checkpoint_deleted", checkpoint_id=checkpoint_id) return True async def gc(self) -> Tuple[int, int]: """Garbage collect unreferenced content Returns: Tuple of (objects_removed, bytes_freed) """ # Collect all referenced content hashes referenced_hashes = set() async with self._checkpoint_lock: for checkpoint in self._checkpoints.values(): referenced_hashes.update(checkpoint.files.values()) # Run CAS garbage collection return await self.cas.gc(list(referenced_hashes)) async def create_ref(self, name: str, checkpoint_id: str) -> None: """Create or update a reference Args: name: Reference name (e.g., "main", "stable") checkpoint_id: Checkpoint ID """ checkpoint = await self.get_checkpoint(checkpoint_id) if not checkpoint: raise ValidationError("checkpoint_id", checkpoint_id, "Checkpoint not found") async with self._refs_lock: self._refs[name] = checkpoint_id await self._save_ref(name, checkpoint_id) logger.info( "ref_created", name=name, checkpoint_id=checkpoint_id ) async def get_ref(self, name: str) -> Optional[str]: """Get checkpoint ID for a reference""" async with self._refs_lock: return self._refs.get(name) async def delete_ref(self, name: str) -> bool: """Delete a reference""" async with self._refs_lock: if name not in self._refs: return False del self._refs[name] # Delete ref file ref_path = self.refs_path / name try: ref_path.unlink() except: pass logger.info("ref_deleted", name=name) return True async def list_refs(self) -> Dict[str, str]: """List all references""" async with self._refs_lock: return dict(self._refs) async def update_head(self, checkpoint_id: Optional[str]) -> None: """Update HEAD reference""" self._head = checkpoint_id if checkpoint_id: await self._save_head(checkpoint_id) else: # Remove HEAD file try: self.head_path.unlink() except: pass async def get_head(self) -> Optional[str]: """Get current HEAD checkpoint""" return self._head def get_stats(self) -> Dict[str, Any]: """Get checkpoint system statistics""" return { "checkpoint_count": len(self._checkpoints), "ref_count": len(self._refs), "head": self._head, "cas_stats": self.cas.get_stats() } async def _load_checkpoints(self) -> None: """Load all checkpoints from disk""" if not self.checkpoints_path.exists(): return for checkpoint_file in self.checkpoints_path.glob("*.json"): try: with open(checkpoint_file, 'r') as f: data = json.load(f) checkpoint = Checkpoint.from_dict(data) self._checkpoints[checkpoint.metadata.checkpoint_id] = checkpoint except Exception as e: logger.error(f"Failed to load checkpoint {checkpoint_file}: {e}") async def _save_checkpoint(self, checkpoint: Checkpoint) -> None: """Save checkpoint to disk""" async with self._checkpoint_lock: self._checkpoints[checkpoint.metadata.checkpoint_id] = checkpoint checkpoint_path = self.checkpoints_path / f"{checkpoint.metadata.checkpoint_id}.json" with open(checkpoint_path, 'w') as f: json.dump(checkpoint.to_dict(), f, indent=2) async def _load_refs(self) -> None: """Load all refs from disk""" if not self.refs_path.exists(): return for ref_file in self.refs_path.glob("*"): if ref_file.is_file(): try: with open(ref_file, 'r') as f: checkpoint_id = f.read().strip() self._refs[ref_file.name] = checkpoint_id except Exception as e: logger.error(f"Failed to load ref {ref_file}: {e}") async def _save_ref(self, name: str, checkpoint_id: str) -> None: """Save ref to disk""" ref_path = self.refs_path / name with open(ref_path, 'w') as f: f.write(checkpoint_id) async def _load_head(self) -> None: """Load HEAD reference""" if not self.head_path.exists(): return try: with open(self.head_path, 'r') as f: self._head = f.read().strip() except Exception as e: logger.error(f"Failed to load HEAD: {e}") async def _save_head(self, checkpoint_id: str) -> None: """Save HEAD reference""" with open(self.head_path, 'w') as f: f.write(checkpoint_id) def _generate_checkpoint_id(self) -> str: """Generate unique checkpoint ID""" return uuid.uuid4().hex

Latest Blog Posts

MCP directory API

We provide all the information about MCP servers via our MCP API.

curl -X GET 'https://glama.ai/api/mcp/v1/servers/krzemienski/shannon-mcp'

If you have feedback or need assistance with the MCP directory API, please join our Discord server