Skip to main content
Glama

WarpGBM MCP Service

utils.py9.79 kB
""" Utility functions for serialization, ONNX export, etc. """ import base64 import io import time import tempfile import os import gzip from typing import Any, Tuple, Optional import numpy as np import joblib def serialize_model_joblib(model: Any) -> str: """ Serialize a model to base64-encoded, gzip-compressed joblib format. Args: model: Any sklearn-API model (should have .to_cpu() method for GPU models) Returns: Base64-encoded, gzip-compressed string (typically 60-80% smaller) """ # Convert to CPU if needed if hasattr(model, "to_cpu"): model_cpu = model.to_cpu() else: model_cpu = model # Serialize to bytes buf = io.BytesIO() joblib.dump(model_cpu, buf) buf.seek(0) # Compress with gzip compressed = gzip.compress(buf.getvalue()) # Encode to base64 return base64.b64encode(compressed).decode("utf-8") def deserialize_model_joblib(artifact_b64: str) -> Any: """ Deserialize a base64-encoded, gzip-compressed joblib model. WARNING: This should ONLY be used for CPU models (LightGBM). WarpGBM models must be loaded with GPU-specific deserializers! Args: artifact_b64: Base64-encoded, gzip-compressed model string Returns: Deserialized model object """ # Decode base64 compressed_bytes = base64.b64decode(artifact_b64) # Decompress model_bytes = gzip.decompress(compressed_bytes) # Load from bytes buf = io.BytesIO(model_bytes) model = joblib.load(buf) return model def serialize_model_onnx(model: Any, X_sample: np.ndarray) -> str: """ Convert model to ONNX format and serialize to base64. Args: model: WarpGBM model X_sample: Sample input for shape inference Returns: Base64-encoded ONNX model """ try: # Check if model has to_onnx method if hasattr(model, "to_onnx"): onnx_model = model.to_onnx(X_sample) else: # Fallback: use skl2onnx if model is sklearn-compatible from skl2onnx import convert_sklearn from skl2onnx.common.data_types import FloatTensorType initial_type = [("float_input", FloatTensorType([None, X_sample.shape[1]]))] onnx_model = convert_sklearn(model, initial_types=initial_type) # Serialize ONNX model to bytes onnx_bytes = onnx_model.SerializeToString() # Encode to base64 return base64.b64encode(onnx_bytes).decode("utf-8") except Exception as e: # ONNX export might not be supported yet raise NotImplementedError(f"ONNX export not yet implemented for this model: {e}") def deserialize_model_onnx(artifact_b64: str): """ Deserialize a base64-encoded ONNX model. Args: artifact_b64: Base64-encoded ONNX model Returns: ONNX InferenceSession """ import onnxruntime as ort # Decode base64 onnx_bytes = base64.b64decode(artifact_b64) # Create inference session return ort.InferenceSession(onnx_bytes) def predict_onnx(session, X: np.ndarray) -> np.ndarray: """ Run inference with ONNX model. Args: session: ONNX InferenceSession X: Input features Returns: Predictions """ input_name = session.get_inputs()[0].name output_name = session.get_outputs()[0].name result = session.run([output_name], {input_name: X.astype(np.float32)}) return result[0] def check_gpu_availability() -> Tuple[bool, Optional[str]]: """ Check if GPU is available for WarpGBM. Returns: Tuple of (is_available, gpu_name) """ try: import torch if torch.cuda.is_available(): gpu_name = torch.cuda.get_device_name(0) return True, gpu_name return False, None except ImportError: # PyTorch not installed, assume no GPU return False, None class Timer: """Simple context manager for timing operations""" def __enter__(self): self.start = time.time() return self def __exit__(self, *args): self.end = time.time() self.elapsed = self.end - self.start def create_temp_workspace(): """ Create a temporary directory for job isolation. Returns: TemporaryDirectory context manager """ return tempfile.TemporaryDirectory(prefix="warpgbm_job_") def validate_array_size(X: list, y: Optional[list] = None, max_mb: int = 50): """ Validate that input arrays are within size limits. Args: X: Feature matrix y: Target vector (optional) max_mb: Maximum size in megabytes Raises: ValueError if arrays are too large """ import sys X_size = sys.getsizeof(X) / (1024 * 1024) # Convert to MB if y is not None: y_size = sys.getsizeof(y) / (1024 * 1024) total_size = X_size + y_size else: total_size = X_size if total_size > max_mb: raise ValueError( f"Input data size ({total_size:.2f} MB) exceeds limit ({max_mb} MB). " "Consider using dataset URLs or chunking for large datasets." ) class ArtifactCache: """ Cache for model artifacts with TTL. Supports both in-memory (local) and Modal Dict (production) backends. Thread-safe for concurrent requests. """ def __init__(self, default_ttl_seconds: int = 300, backend=None): """ Args: default_ttl_seconds: Default time-to-live in seconds (default: 5 minutes) backend: Optional backend storage (e.g., Modal Dict). If None, uses in-memory dict. """ self.cache = backend if backend is not None else {} self.default_ttl = default_ttl_seconds self.use_modal = backend is not None if not self.use_modal: import threading self.lock = threading.Lock() def set(self, artifact_id: str, data: dict, ttl: Optional[int] = None): """ Store an artifact in cache. Args: artifact_id: Unique identifier for the artifact data: Dict with keys: artifact, format, model_type ttl: Time-to-live in seconds (uses default if None) """ expires_at = time.time() + (ttl or self.default_ttl) entry = { "artifact": data["artifact"], "format": data.get("format", "joblib"), "model_type": data.get("model_type"), "expires_at": expires_at, "created_at": time.time() } if self.use_modal: # Modal Dict is already thread-safe self.cache[artifact_id] = entry else: with self.lock: self.cache[artifact_id] = entry def get(self, artifact_id: str) -> Optional[dict]: """ Retrieve an artifact from cache. Args: artifact_id: Unique identifier for the artifact Returns: Dict with keys: artifact, format, model_type (if found and not expired), None otherwise """ if self.use_modal: # Modal Dict access entry = self.cache.get(artifact_id) if entry is None: return None # Check if expired if time.time() > entry["expires_at"]: try: del self.cache[artifact_id] except KeyError: pass return None # Return dict return { "artifact": entry["artifact"], "format": entry["format"], "model_type": entry.get("model_type") } else: # In-memory access with self.lock: if artifact_id not in self.cache: return None entry = self.cache[artifact_id] # Check if expired if time.time() > entry["expires_at"]: del self.cache[artifact_id] return None # Return dict return { "artifact": entry["artifact"], "format": entry["format"], "model_type": entry.get("model_type") } def cleanup_expired(self): """Remove all expired entries from cache.""" current_time = time.time() if self.use_modal: # Modal Dict cleanup (be careful with concurrent access) expired_keys = [ key for key, entry in self.cache.items() if current_time > entry["expires_at"] ] for key in expired_keys: try: del self.cache[key] except KeyError: pass else: with self.lock: expired_keys = [ key for key, entry in self.cache.items() if current_time > entry["expires_at"] ] for key in expired_keys: del self.cache[key] def size(self) -> int: """Return number of items in cache.""" if self.use_modal: return len(self.cache) else: with self.lock: return len(self.cache) # Global cache instance artifact_cache = ArtifactCache(default_ttl_seconds=300) # 5 minutes

MCP directory API

We provide all the information about MCP servers via our MCP API.

curl -X GET 'https://glama.ai/api/mcp/v1/servers/jefferythewind/warpgbm-mcp-service'

If you have feedback or need assistance with the MCP directory API, please join our Discord server