OpenAI MCP Server

import modal import logging import time import uuid import json import asyncio import hashlib import threading import concurrent.futures from pathlib import Path from typing import Dict, List, Optional, Any, Tuple, Union, AsyncIterator from datetime import datetime, timedelta from collections import deque from fastapi import FastAPI, Request, Depends, HTTPException, status, BackgroundTasks from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials from fastapi.responses import JSONResponse, HTMLResponse from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel, Field # Configure logging logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' ) # Create FastAPI app api_app = FastAPI( title="Advanced LLM Inference API", description="Enterprise-grade OpenAI-compatible LLM serving API with multiple model support, streaming, and advanced caching", version="1.1.0" ) # Add CORS middleware api_app.add_middleware( CORSMiddleware, allow_origins=["*"], # For production, specify specific origins instead of wildcard allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # Security setup security = HTTPBearer() # Token bucket rate limiter class TokenBucket: """ Token bucket algorithm for rate limiting. Each user gets a bucket that fills at a constant rate. """ def __init__(self): self.buckets = {} self.lock = threading.Lock() def _get_bucket(self, user_id, rate_limit): """Get or create a bucket for a user""" now = time.time() if user_id not in self.buckets: # Initialize with full bucket self.buckets[user_id] = { "tokens": rate_limit, "last_refill": now, "rate": rate_limit / 60.0 # tokens per second } return self.buckets[user_id] bucket = self.buckets[user_id] # Update rate if it changed bucket["rate"] = rate_limit / 60.0 # Refill tokens based on time elapsed elapsed = now - bucket["last_refill"] new_tokens = elapsed * bucket["rate"] bucket["tokens"] = min(rate_limit, bucket["tokens"] + new_tokens) bucket["last_refill"] = now return bucket def consume(self, user_id, tokens=1, rate_limit=60): """ Consume tokens from a user's bucket. Returns True if tokens were consumed, False otherwise. """ with self.lock: bucket = self._get_bucket(user_id, rate_limit) if bucket["tokens"] >= tokens: bucket["tokens"] -= tokens return True return False # Create rate limiter rate_limiter = TokenBucket() # Define the container image with necessary dependencies vllm_image = ( modal.Image.debian_slim(python_version="3.10") .pip_install( "vllm==0.7.3", # Updated version "huggingface_hub[hf_transfer]==0.26.2", "flashinfer-python==0.2.0.post2", "fastapi>=0.95.0", "uvicorn>=0.15.0", "pydantic>=2.0.0", "tiktoken>=0.5.1", extra_index_url="https://flashinfer.ai/whl/cu124/torch2.5", ) .env({"HF_HUB_ENABLE_HF_TRANSFER": "1"}) # faster model transfers .env({"VLLM_USE_V1": "1"}) # Enable V1 engine for better performance ) # Define llama.cpp image for alternative models llama_cpp_image = ( modal.Image.debian_slim(python_version="3.10") .apt_install("git", "build-essential", "cmake", "curl", "libcurl4-openssl-dev") .pip_install( "huggingface_hub==0.26.2", "hf_transfer>=0.1.4", "fastapi>=0.95.0", "uvicorn>=0.15.0", "pydantic>=2.0.0" ) .env({"HF_HUB_ENABLE_HF_TRANSFER": "1"}) .run_commands( "git clone https://github.com/ggerganov/llama.cpp", "cmake llama.cpp -B llama.cpp/build -DBUILD_SHARED_LIBS=OFF -DLLAMA_CURL=ON", "cmake --build llama.cpp/build --config Release -j --target llama-cli", "cp llama.cpp/build/bin/llama-* /usr/local/bin/" ) ) # Set up model configurations MODELS_DIR = "/models" VLLM_MODELS = { "llama3-8b": { "id": "llama3-8b", "name": "neuralmagic/Meta-Llama-3.1-8B-Instruct-quantized.w4a16", "config": "config.json", # Ensure this file is present in the model directory "revision": "a7c09948d9a632c2c840722f519672cd94af885d", "max_tokens": 4096, "loaded": False }, "mistral-7b": { "id": "mistral-7b", "name": "mistralai/Mistral-7B-Instruct-v0.2", "revision": "main", "max_tokens": 4096, "loaded": False }, # Small model for quick loading "tiny-llama-1.1b": { "id": "tiny-llama-1.1b", "name": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "revision": "main", "max_tokens": 2048, "loaded": False } } LLAMA_CPP_MODELS = { "deepseek-r1": { "id": "deepseek-r1", "name": "unsloth/DeepSeek-R1-GGUF", "quant": "UD-IQ1_S", "pattern": "*UD-IQ1_S*", "revision": "02656f62d2aa9da4d3f0cdb34c341d30dd87c3b6", "gpu": "L40S:4", "max_tokens": 4096, "loaded": False }, "phi-4": { "id": "phi-4", "name": "unsloth/phi-4-GGUF", "quant": "Q2_K", "pattern": "*Q2_K*", "revision": None, "gpu": "L40S:4", # Use GPU for better performance "max_tokens": 4096, "loaded": False }, # Small model for quick loading "phi-2": { "id": "phi-2", "name": "TheBloke/phi-2-GGUF", "quant": "Q4_K_M", "pattern": "*Q4_K_M.gguf", "revision": "main", "gpu": None, # Can run on CPU "max_tokens": 2048, "loaded": False } } DEFAULT_MODEL = "phi-4" # Create volumes for caching hf_cache_vol = modal.Volume.from_name("huggingface-cache", create_if_missing=True) vllm_cache_vol = modal.Volume.from_name("vllm-cache", create_if_missing=True) llama_cpp_cache_vol = modal.Volume.from_name("llama-cpp-cache", create_if_missing=True) results_vol = modal.Volume.from_name("model-results", create_if_missing=True) # Create the Modal app app = modal.App("openai-compatible-llm-server") # Create shared data structures model_stats_dict = modal.Dict.from_name("model-stats", create_if_missing=True) user_usage_dict = modal.Dict.from_name("user-usage", create_if_missing=True) request_queue = modal.Queue.from_name("request-queue", create_if_missing=True) response_dict = modal.Dict.from_name("response-cache", create_if_missing=True) api_keys_dict = modal.Dict.from_name("api-keys", create_if_missing=True) stream_queues = modal.Dict.from_name("stream-queues", create_if_missing=True) # Advanced caching system class AdvancedCache: """ Advanced caching system with TTL and LRU eviction. """ def __init__(self, max_size=1000, default_ttl=3600): self.cache = {} self.ttl_map = {} self.access_times = {} self.max_size = max_size self.default_ttl = default_ttl self.lock = threading.Lock() def get(self, key): """Get a value from the cache""" with self.lock: now = time.time() # Check if key exists and is not expired if key in self.cache: # Check TTL if key in self.ttl_map and self.ttl_map[key] < now: # Expired self._remove(key) return None # Update access time self.access_times[key] = now return self.cache[key] return None def set(self, key, value, ttl=None): """Set a value in the cache with optional TTL""" with self.lock: now = time.time() # Evict if needed if len(self.cache) >= self.max_size and key not in self.cache: self._evict_lru() # Set value self.cache[key] = value self.access_times[key] = now # Set TTL if ttl is not None: self.ttl_map[key] = now + ttl elif self.default_ttl > 0: self.ttl_map[key] = now + self.default_ttl def _remove(self, key): """Remove a key from the cache""" if key in self.cache: del self.cache[key] if key in self.ttl_map: del self.ttl_map[key] if key in self.access_times: del self.access_times[key] def _evict_lru(self): """Evict least recently used item""" if not self.access_times: return # Find oldest access time oldest_key = min(self.access_times.items(), key=lambda x: x[1])[0] self._remove(oldest_key) def clear_expired(self): """Clear all expired entries""" with self.lock: now = time.time() expired_keys = [k for k, v in self.ttl_map.items() if v < now] for key in expired_keys: self._remove(key) # Constants MAX_CACHE_AGE = 3600 # 1 hour in seconds # Create memory cache memory_cache = AdvancedCache(max_size=10000, default_ttl=MAX_CACHE_AGE) # Initialize with default key if empty if "default" not in api_keys_dict: api_keys_dict["default"] = { "key": "sk-modal-llm-api-key", "rate_limit": 60, # requests per minute "quota": 1000000, # tokens per day "created_at": datetime.now().isoformat(), "owner": "default" } # Add a default ADMIN API key if "admin" not in api_keys_dict: api_keys_dict["admin"] = { "key": "sk-modal-admin-api-key", "rate_limit": 1000, # Higher rate limit for admin "quota": 10000000, # Higher quota for admin "created_at": datetime.now().isoformat(), "owner": "admin" } # Constants DEFAULT_API_KEY = api_keys_dict["default"]["key"] MINUTES = 60 # seconds SERVER_PORT = 8000 CACHE_DIR = "/root/.cache" RESULTS_DIR = "/root/results" # Request/response models class GenerationRequest(BaseModel): request_id: str = Field(default_factory=lambda: str(uuid.uuid4())) model_id: str messages: List[Dict[str, str]] temperature: float = 0.7 max_tokens: int = 1024 top_p: float = 1.0 frequency_penalty: float = 0.0 presence_penalty: float = 0.0 user: Optional[str] = None stream: bool = False timestamp: float = Field(default_factory=time.time) api_key: str = DEFAULT_API_KEY class StreamChunk(BaseModel): """Model for streaming response chunks""" id: str object: str = "chat.completion.chunk" created: int model: str choices: List[Dict[str, Any]] class StreamManager: """Manages streaming responses for clients""" def __init__(self): self.streams = {} self.lock = threading.Lock() def create_stream(self, request_id): """Create a new stream for a request""" with self.lock: self.streams[request_id] = { "queue": asyncio.Queue(), "finished": False, "created_at": time.time() } def add_chunk(self, request_id, chunk): """Add a chunk to a stream""" with self.lock: if request_id in self.streams: stream = self.streams[request_id] if not stream["finished"]: stream["queue"].put_nowait(chunk) def finish_stream(self, request_id): """Mark a stream as finished""" with self.lock: if request_id in self.streams: self.streams[request_id]["finished"] = True # Add None to signal end of stream self.streams[request_id]["queue"].put_nowait(None) async def get_chunks(self, request_id): """Get chunks from a stream as an async generator""" if request_id not in self.streams: return stream = self.streams[request_id] queue = stream["queue"] while True: chunk = await queue.get() if chunk is None: # End of stream break yield chunk queue.task_done() # Clean up after streaming is done with self.lock: if request_id in self.streams: del self.streams[request_id] def clean_old_streams(self, max_age=3600): """Clean up old streams""" with self.lock: now = time.time() to_remove = [] for request_id, stream in self.streams.items(): if now - stream["created_at"] > max_age: to_remove.append(request_id) for request_id in to_remove: if request_id in self.streams: # Mark as finished to stop any ongoing processing self.streams[request_id]["finished"] = True # Add None to unblock any waiting consumers self.streams[request_id]["queue"].put_nowait(None) # Remove from streams del self.streams[request_id] # Create stream manager stream_manager = StreamManager() # API Authentication dependency def verify_api_key(credentials: HTTPAuthorizationCredentials = Depends(security)): """Verify that the API key in the authorization header is valid and check rate limits""" if credentials.scheme != "Bearer": raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid authentication scheme. Use Bearer", ) api_key = credentials.credentials valid_key = False key_info = None # Check if this is a known API key for user_id, user_data in api_keys_dict.items(): if user_data.get("key") == api_key: valid_key = True key_info = user_data break if not valid_key: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid API key", ) # Check rate limits user_id = key_info.get("owner", "unknown") rate_limit = key_info.get("rate_limit", 60) # Default: 60 requests per minute # Get or initialize user usage tracking if user_id not in user_usage_dict: user_usage_dict[user_id] = { "requests": [], "tokens": { "input": 0, "output": 0, "last_reset": datetime.now().isoformat() } } usage = user_usage_dict[user_id] # Check if user exceeded rate limit using token bucket algorithm if not rate_limiter.consume(user_id, tokens=1, rate_limit=rate_limit): # Calculate retry-after based on rate retry_after = int(60 / rate_limit) # seconds until at least one token is available raise HTTPException( status_code=status.HTTP_429_TOO_MANY_REQUESTS, detail=f"Rate limit exceeded. Maximum {rate_limit} requests per minute.", headers={"Retry-After": str(retry_after)} ) # Add current request timestamp for analytics now = datetime.now() usage["requests"].append(now.timestamp()) # Clean up old requests (older than 1 day) to prevent unbounded growth day_ago = (now - timedelta(days=1)).timestamp() usage["requests"] = [req for req in usage["requests"] if req > day_ago] # Update usage dict user_usage_dict[user_id] = usage # Return the API key and user ID return {"key": api_key, "user_id": user_id} # API Endpoints @api_app.get("/", response_class=HTMLResponse) async def index(): """Root endpoint that returns HTML with API information""" return """ <html> <head> <title>Modal LLM Inference API</title> <style> body { font-family: system-ui, sans-serif; max-width: 800px; margin: 0 auto; padding: 2rem; } h1 { color: #4a56e2; } code { background: #f4f4f8; padding: 0.2rem 0.4rem; border-radius: 3px; } </style> </head> <body> <h1>Modal LLM Inference API</h1> <p>This is an OpenAI-compatible API for LLM inference powered by Modal.</p> <p>Use the following endpoints:</p> <ul> <li><a href="/docs">/docs</a> - API documentation</li> <li><a href="/v1/models">/v1/models</a> - List available models</li> <li><code>/v1/chat/completions</code> - Chat completions endpoint</li> </ul> </body> </html> """ @api_app.get("/health") async def health_check(): """Health check endpoint""" return {"status": "healthy"} @api_app.get("/v1/models", dependencies=[Depends(verify_api_key)]) async def list_models(): """List all available models in OpenAI-compatible format""" # Combine vLLM and llama.cpp models all_models = [] for model_id, model_info in VLLM_MODELS.items(): all_models.append({ "id": model_info["id"], "object": "model", "created": 1677610602, "owned_by": "modal", "engine": "vllm", "loaded": model_info.get("loaded", False) }) for model_id, model_info in LLAMA_CPP_MODELS.items(): all_models.append({ "id": model_info["id"], "object": "model", "created": 1677610602, "owned_by": "modal", "engine": "llama.cpp", "loaded": model_info.get("loaded", False) }) return {"data": all_models, "object": "list"} # Model management endpoints class ModelLoadRequest(BaseModel): """Request model to load a specific model""" model_id: str force_reload: bool = False class HFModelLoadRequest(BaseModel): """Request to load a model directly from Hugging Face""" repo_id: str model_type: str = "vllm" # "vllm" or "llama.cpp" revision: Optional[str] = None quant: Optional[str] = None # For llama.cpp models max_tokens: int = 4096 gpu: Optional[str] = None # For llama.cpp models @api_app.post("/admin/models/load", dependencies=[Depends(verify_api_key)]) async def load_model(request: ModelLoadRequest, background_tasks: BackgroundTasks): """Load a specific model into memory""" model_id = request.model_id force_reload = request.force_reload # Check if model exists if model_id in VLLM_MODELS: model_type = "vllm" model_info = VLLM_MODELS[model_id] elif model_id in LLAMA_CPP_MODELS: model_type = "llama.cpp" model_info = LLAMA_CPP_MODELS[model_id] else: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail=f"Model {model_id} not found" ) # Check if model is already loaded if model_info.get("loaded", False) and not force_reload: return { "status": "success", "message": f"Model {model_id} is already loaded", "model_id": model_id, "model_type": model_type } # Start loading the model in the background if model_type == "vllm": # Start vLLM server for this model background_tasks.add_task(serve_vllm_model.remote, model_id=model_id) # Update model status VLLM_MODELS[model_id]["loaded"] = True else: # llama.cpp # For llama.cpp models, we'll preload the model background_tasks.add_task(preload_llama_cpp_model, model_id) # Update model status LLAMA_CPP_MODELS[model_id]["loaded"] = True return { "status": "success", "message": f"Started loading model {model_id}", "model_id": model_id, "model_type": model_type } @api_app.post("/admin/models/load-from-hf", dependencies=[Depends(verify_api_key)]) async def load_model_from_hf(request: HFModelLoadRequest, background_tasks: BackgroundTasks): """Load a model directly from Hugging Face""" repo_id = request.repo_id model_type = request.model_type revision = request.revision # Generate a unique model_id based on the repo name repo_name = repo_id.split("/")[-1] if "/" in repo_id else repo_id model_id = f"hf-{repo_name}-{uuid.uuid4().hex[:6]}" # Create model info based on type if model_type.lower() == "vllm": # Add to VLLM_MODELS VLLM_MODELS[model_id] = { "id": model_id, "name": repo_id, "revision": revision or "main", "max_tokens": request.max_tokens, "loaded": False, "hf_direct": True # Mark as directly loaded from HF } # Start vLLM server for this model background_tasks.add_task(serve_vllm_model.remote, model_id=model_id) # Update model status VLLM_MODELS[model_id]["loaded"] = True elif model_type.lower() == "llama.cpp": # For llama.cpp we need quant info quant = request.quant or "Q4_K_M" # Default quantization pattern = f"*{quant}*" # Add to LLAMA_CPP_MODELS LLAMA_CPP_MODELS[model_id] = { "id": model_id, "name": repo_id, "quant": quant, "pattern": pattern, "revision": revision, "gpu": request.gpu, # Can be None for CPU "max_tokens": request.max_tokens, "loaded": False, "hf_direct": True # Mark as directly loaded from HF } # Preload the model background_tasks.add_task(preload_llama_cpp_model, model_id) # Update model status LLAMA_CPP_MODELS[model_id]["loaded"] = True else: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail=f"Invalid model type: {model_type}. Must be 'vllm' or 'llama.cpp'" ) return { "status": "success", "message": f"Started loading model {repo_id} as {model_id}", "model_id": model_id, "model_type": model_type, "repo_id": repo_id } @api_app.post("/admin/models/unload", dependencies=[Depends(verify_api_key)]) async def unload_model(request: ModelLoadRequest): """Unload a specific model from memory""" model_id = request.model_id # Check if model exists if model_id in VLLM_MODELS: model_type = "vllm" model_info = VLLM_MODELS[model_id] elif model_id in LLAMA_CPP_MODELS: model_type = "llama.cpp" model_info = LLAMA_CPP_MODELS[model_id] else: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail=f"Model {model_id} not found" ) # Check if model is loaded if not model_info.get("loaded", False): return { "status": "success", "message": f"Model {model_id} is not loaded", "model_id": model_id, "model_type": model_type } # Update model status if model_type == "vllm": VLLM_MODELS[model_id]["loaded"] = False else: # llama.cpp LLAMA_CPP_MODELS[model_id]["loaded"] = False return { "status": "success", "message": f"Unloaded model {model_id}", "model_id": model_id, "model_type": model_type } @api_app.get("/admin/models/status/{model_id}", dependencies=[Depends(verify_api_key)]) async def get_model_status(model_id: str): """Get the status of a specific model""" # Check if model exists if model_id in VLLM_MODELS: model_type = "vllm" model_info = VLLM_MODELS[model_id] elif model_id in LLAMA_CPP_MODELS: model_type = "llama.cpp" model_info = LLAMA_CPP_MODELS[model_id] else: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail=f"Model {model_id} not found" ) # Get model stats if available model_stats = model_stats_dict.get(model_id, {}) # Include HF info if available hf_info = {} if model_info.get("hf_direct"): hf_info = { "repo_id": model_info.get("name"), "revision": model_info.get("revision"), } if model_type == "llama.cpp": hf_info["quant"] = model_info.get("quant") return { "model_id": model_id, "model_type": model_type, "loaded": model_info.get("loaded", False), "stats": model_stats, "hf_info": hf_info if hf_info else None } # Admin API endpoints class APIKeyRequest(BaseModel): user_id: str rate_limit: int = 60 quota: int = 1000000 class APIKey(BaseModel): key: str user_id: str rate_limit: int quota: int created_at: str @api_app.post("/admin/api-keys", response_model=APIKey) async def create_api_key(request: APIKeyRequest, auth_info: dict = Depends(verify_api_key)): """Create a new API key for a user (admin only)""" # Check if this is an admin request if auth_info["user_id"] != "default": raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail="Only admin users can create API keys" ) # Generate a new API key new_key = f"sk-modal-{uuid.uuid4()}" user_id = request.user_id # Store the key api_keys_dict[user_id] = { "key": new_key, "rate_limit": request.rate_limit, "quota": request.quota, "created_at": datetime.now().isoformat(), "owner": user_id } # Initialize user usage if not user_usage_dict.contains(user_id): user_usage_dict[user_id] = { "requests": [], "tokens": { "input": 0, "output": 0, "last_reset": datetime.now().isoformat() } } return APIKey( key=new_key, user_id=user_id, rate_limit=request.rate_limit, quota=request.quota, created_at=datetime.now().isoformat() ) @api_app.get("/admin/api-keys") async def list_api_keys(auth_info: dict = Depends(verify_api_key)): """List all API keys (admin only)""" # Check if this is an admin request if auth_info["user_id"] != "default": raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail="Only admin users can list API keys" ) # Return all keys (except the actual key values for security) keys = [] for user_id, key_info in api_keys_dict.items(): keys.append({ "user_id": user_id, "rate_limit": key_info.get("rate_limit", 60), "quota": key_info.get("quota", 1000000), "created_at": key_info.get("created_at", datetime.now().isoformat()), # Mask the actual key "key": key_info.get("key", "")[:8] + "..." if key_info.get("key") else "None" }) return {"keys": keys} @api_app.get("/admin/stats") async def get_stats(auth_info: dict = Depends(verify_api_key)): """Get usage statistics (admin only)""" # Check if this is an admin request if auth_info["user_id"] != "default": raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail="Only admin users can view stats" ) # Get model stats model_stats = {} for model_id in list(VLLM_MODELS.keys()) + list(LLAMA_CPP_MODELS.keys()): if model_id in model_stats_dict: model_stats[model_id] = model_stats_dict[model_id] # Get user stats user_stats = {} for user_id in user_usage_dict.keys(): usage = user_usage_dict[user_id] # Don't include request timestamps for brevity if "requests" in usage: usage = usage.copy() usage["request_count"] = len(usage["requests"]) del usage["requests"] user_stats[user_id] = usage # Get queue info queue_info = { "pending_requests": request_queue.len(), "active_workers": model_stats_dict.get("workers_running", 0) } return { "models": model_stats, "users": user_stats, "queue": queue_info, "timestamp": datetime.now().isoformat() } @api_app.delete("/admin/api-keys/{user_id}") async def delete_api_key(user_id: str, auth_info: dict = Depends(verify_api_key)): """Delete an API key (admin only)""" # Check if this is an admin request if auth_info["user_id"] != "default": raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail="Only admin users can delete API keys" ) # Check if the key exists if not api_keys_dict.contains(user_id): raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail=f"No API key found for user {user_id}" ) # Can't delete the default key if user_id == "default": raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="Cannot delete the default API key" ) # Delete the key api_keys_dict.pop(user_id) return {"status": "success", "message": f"API key deleted for user {user_id}"} @api_app.post("/v1/chat/completions") async def chat_completions(request: Request, background_tasks: BackgroundTasks, auth_info: dict = Depends(verify_api_key)): """OpenAI-compatible chat completions endpoint with request queueing, streaming and response caching""" try: json_data = await request.json() # Extract model or use default model_id = json_data.get("model", DEFAULT_MODEL) messages = json_data.get("messages", []) temperature = json_data.get("temperature", 0.7) max_tokens = json_data.get("max_tokens", 1024) stream = json_data.get("stream", False) user = json_data.get("user", auth_info["user_id"]) # Calculate a cache key based on the request parameters cache_key = calculate_cache_key(model_id, messages, temperature, max_tokens) # Check if we have a cached response in memory cache first (faster) cached_response = memory_cache.get(cache_key) if cached_response and not stream: # Don't use cache for streaming requests # Update stats update_stats(model_id, "cache_hit") return cached_response # Check if we have a cached response in Modal's persistent cache if not cached_response and cache_key in response_dict and not stream: cached_response = response_dict[cache_key] cache_age = time.time() - cached_response.get("timestamp", 0) # Use cached response if it's fresh enough if cache_age < MAX_CACHE_AGE: # Update stats update_stats(model_id, "cache_hit") response_data = cached_response["response"] # Also cache in memory for faster access next time memory_cache.set(cache_key, response_data) return response_data # Select best model if "auto" is specified if model_id == "auto" and len(messages) > 0: # Get the last user message last_message = None for msg in reversed(messages): if msg.get("role") == "user": last_message = msg.get("content", "") break if last_message: prompt = last_message # Select best model based on prompt and parameters model_id = select_best_model(prompt, max_tokens, temperature) logging.info(f"Auto-selected model: {model_id} for prompt") # Check if model exists if model_id not in VLLM_MODELS and model_id not in LLAMA_CPP_MODELS: # Default to the default model if specified model not found logging.warning(f"Model {model_id} not found, using default: {DEFAULT_MODEL}") model_id = DEFAULT_MODEL # Create a unique request ID request_id = str(uuid.uuid4()) # Create request object gen_request = GenerationRequest( request_id=request_id, model_id=model_id, messages=messages, temperature=temperature, max_tokens=max_tokens, top_p=json_data.get("top_p", 1.0), frequency_penalty=json_data.get("frequency_penalty", 0.0), presence_penalty=json_data.get("presence_penalty", 0.0), user=user, stream=stream, api_key=auth_info["key"] ) # For streaming requests, set up streaming response if stream: # Create a new stream stream_manager.create_stream(request_id) # Put the request in the queue await request_queue.put.aio(gen_request.model_dump()) # Update stats update_stats(model_id, "request_count") update_stats(model_id, "stream_count") # Start a background worker to process the request if needed background_tasks.add_task(ensure_worker_running) # Return a streaming response using FastAPI's StreamingResponse from fastapi.responses import StreamingResponse as FastAPIStreamingResponse return FastAPIStreamingResponse( content=stream_response(request_id, model_id, auth_info["user_id"]), media_type="text/event-stream" ) # For non-streaming, enqueue the request and wait for result # Put the request in the queue await request_queue.put.aio(gen_request.model_dump()) # Update stats update_stats(model_id, "request_count") # Start a background worker to process the request if needed background_tasks.add_task(ensure_worker_running) # Wait for the response with timeout start_time = time.time() timeout = 120 # 2-minute timeout for non-streaming requests while time.time() - start_time < timeout: # Check memory cache first (faster) response_data = memory_cache.get(request_id) if response_data: # Update stats update_stats(model_id, "success_count") estimate_tokens(messages, response_data, auth_info["user_id"], model_id) # Save to persistent cache response_dict[cache_key] = { "response": response_data, "timestamp": time.time() } # Clean up request-specific cache memory_cache.set(request_id, None) return response_data # Check persistent cache if response_dict.contains(request_id): response_data = response_dict[request_id] # Remove from response dict to save memory try: response_dict.pop(request_id) except Exception: pass # Save to cache response_dict[cache_key] = { "response": response_data, "timestamp": time.time() } # Also cache in memory memory_cache.set(cache_key, response_data) # Update stats update_stats(model_id, "success_count") estimate_tokens(messages, response_data, auth_info["user_id"], model_id) return response_data # Wait a bit before checking again await asyncio.sleep(0.1) # If we get here, we timed out update_stats(model_id, "timeout_count") raise HTTPException( status_code=status.HTTP_504_GATEWAY_TIMEOUT, detail="Request timed out. The model may be busy. Please try again later." ) except Exception as e: logging.error(f"Error in chat completions: {str(e)}") # Update error stats if "model_id" in locals(): update_stats(model_id, "error_count") raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Error generating response: {str(e)}" ) async def stream_response(request_id: str, model_id: str, user_id: str) -> AsyncIterator[str]: """Stream response chunks to the client""" try: # Stream header yield "data: " + json.dumps({"object": "chat.completion.chunk"}) + "\n\n" # Stream chunks async for chunk in stream_manager.get_chunks(request_id): if chunk: yield f"data: {json.dumps(chunk)}\n\n" # Stream done yield "data: [DONE]\n\n" except Exception as e: logging.error(f"Error streaming response: {str(e)}") # Update error stats update_stats(model_id, "stream_error_count") # Send error as SSE error_json = json.dumps({"error": str(e)}) yield f"data: {error_json}\n\n" yield "data: [DONE]\n\n" async def ensure_worker_running(): """Ensure that a worker is running to process the queue""" # Check if workers are already running via a sentinel in shared dict workers_running_key = "workers_running" if not model_stats_dict.contains(workers_running_key): model_stats_dict[workers_running_key] = 0 current_workers = model_stats_dict[workers_running_key] # If no workers or too few workers, start more if current_workers < 3: # Keep up to 3 workers running # Increment worker count model_stats_dict[workers_running_key] = current_workers + 1 # Start a worker await process_queue_worker.spawn.aio() def calculate_cache_key(model_id: str, messages: List[dict], temperature: float, max_tokens: int) -> str: """Calculate a deterministic cache key for a request using SHA-256""" # Create a simplified version of the request for cache key cache_dict = { "model": model_id, "messages": messages, "temperature": round(temperature, 2), # Round to reduce variations "max_tokens": max_tokens } # Convert to a stable string representation and hash it with SHA-256 cache_str = json.dumps(cache_dict, sort_keys=True) hash_obj = hashlib.sha256(cache_str.encode()) return f"cache:{hash_obj.hexdigest()[:16]}" def update_stats(model_id: str, stat_type: str): """Update usage statistics for a model""" if not model_stats_dict.contains(model_id): model_stats_dict[model_id] = { "request_count": 0, "success_count": 0, "error_count": 0, "timeout_count": 0, "cache_hit": 0, "token_count": 0, "avg_latency": 0 } stats = model_stats_dict[model_id] stats[stat_type] = stats.get(stat_type, 0) + 1 model_stats_dict[model_id] = stats def estimate_tokens(messages: List[dict], response: dict, user_id: str, model_id: str): """Estimate token usage and update user quotas""" # Very simple token estimation based on whitespace-split words * 1.3 input_tokens = 0 for msg in messages: input_tokens += len(msg.get("content", "").split()) * 1.3 output_tokens = 0 if response and "choices" in response: for choice in response["choices"]: if "message" in choice and "content" in choice["message"]: output_tokens += len(choice["message"]["content"].split()) * 1.3 # Update model stats if model_stats_dict.contains(model_id): stats = model_stats_dict[model_id] stats["token_count"] = stats.get("token_count", 0) + input_tokens + output_tokens model_stats_dict[model_id] = stats # Update user usage if user_id in user_usage_dict: usage = user_usage_dict[user_id] # Check if we need to reset daily counters last_reset = datetime.fromisoformat(usage["tokens"]["last_reset"]) now = datetime.now() if now.date() > last_reset.date(): # Reset daily counters usage["tokens"]["input"] = 0 usage["tokens"]["output"] = 0 usage["tokens"]["last_reset"] = now.isoformat() # Update token counts usage["tokens"]["input"] += int(input_tokens) usage["tokens"]["output"] += int(output_tokens) user_usage_dict[user_id] = usage def select_best_model(prompt: str, n_predict: int, temperature: float) -> str: """ Intelligently selects the best model based on input parameters. Args: prompt (str): The input prompt for the model. n_predict (int): The number of tokens to predict. temperature (float): The sampling temperature. Returns: str: The identifier of the best model to use. """ # Check for code generation patterns code_indicators = ["```", "def ", "class ", "function", "import ", "from ", "<script", "<style", "SELECT ", "CREATE TABLE", "const ", "let ", "var ", "function(", "=>"] is_likely_code = any(indicator in prompt for indicator in code_indicators) # Check for creative writing patterns creative_indicators = ["story", "poem", "creative", "imagine", "fiction", "narrative", "write a", "compose", "create a"] is_creative_task = any(indicator in prompt.lower() for indicator in creative_indicators) # Check for analytical/reasoning tasks analytical_indicators = ["explain", "analyze", "compare", "contrast", "reason", "evaluate", "assess", "why", "how does"] is_analytical_task = any(indicator in prompt.lower() for indicator in analytical_indicators) # Decision logic if is_likely_code: # For code generation, prefer phi-4 for all code tasks return "phi-4" # Excellent for code generation elif is_creative_task: # For creative tasks, use models with higher creativity if temperature > 0.8: return "deepseek-r1" # More creative at high temperatures else: return "phi-4" # Good balance of creativity and coherence elif is_analytical_task: # For analytical tasks, use models with strong reasoning return "phi-4" # Strong reasoning capabilities # Length-based decisions if len(prompt) > 2000: # For very long prompts, use models with good context handling return "llama3-8b" elif len(prompt) < 1000: # For shorter prompts, prefer phi-4 return "phi-4" # Temperature-based decisions if temperature < 0.5: # For deterministic outputs return "phi-4" elif temperature > 0.9: # For very creative outputs return "deepseek-r1" # Default to phi-4 instead of the standard model return "phi-4" # vLLM serving function @app.function( image=vllm_image, gpu="H100:1", allow_concurrent_inputs=100, volumes={ f"{CACHE_DIR}/huggingface": hf_cache_vol, f"{CACHE_DIR}/vllm": vllm_cache_vol, }, timeout=30 * MINUTES, ) @modal.web_server(port=SERVER_PORT) def serve_vllm_model(model_id: str = DEFAULT_MODEL): """ Serves a model using vLLM with an OpenAI-compatible API. Args: model_id (str): The identifier of the model to serve. Defaults to DEFAULT_MODEL. Raises: ValueError: If the specified model_id is not found in VLLM_MODELS. """ import subprocess if model_id not in VLLM_MODELS: available_models = list(VLLM_MODELS.keys()) logging.error(f"Error: Unknown model: {model_id}. Available models: {available_models}") raise ValueError(f"Unknown model: {model_id}. Available models: {available_models}") model_info = VLLM_MODELS[model_id] model_name = model_info["name"] revision = model_info["revision"] logging.basicConfig(level=logging.INFO) logging.info(f"Starting vLLM server with model: {model_name}") cmd = [ "vllm", "serve", "--uvicorn-log-level=info", model_name, "--revision", revision, "--host", "0.0.0.0", "--port", str(SERVER_PORT), "--api-key", DEFAULT_API_KEY, ] # Use subprocess.run instead of Popen to ensure the server is fully started # before returning, and don't use shell=True for better process management process = subprocess.Popen(cmd) # Log that we've started the server logging.info(f"Started vLLM server with PID {process.pid}") # Define the worker that will process the queue @app.function( image=vllm_image, gpu=None, # Worker will spawn GPU functions as needed allow_concurrent_inputs=10, volumes={ f"{CACHE_DIR}/huggingface": hf_cache_vol, }, timeout=30 * MINUTES, ) async def process_queue_worker(): """Worker function that processes requests from the queue""" import asyncio import time try: # Signal that we're starting a worker worker_id = str(uuid.uuid4())[:8] logging.info(f"Starting queue processing worker {worker_id}") # Process requests until timeout or empty queue empty_count = 0 max_empty_count = 10 # Stop after 10 consecutive empty polls while empty_count < max_empty_count: # Try to get a request from the queue try: request_dict = await request_queue.get.aio(timeout_ms=5000) empty_count = 0 # Reset empty counter # Process the request try: # Create request object request_id = request_dict.get("request_id") model_id = request_dict.get("model_id") messages = request_dict.get("messages", []) temperature = request_dict.get("temperature", 0.7) max_tokens = request_dict.get("max_tokens", 1024) api_key = request_dict.get("api_key", DEFAULT_API_KEY) stream_mode = request_dict.get("stream", False) logging.info(f"Worker {worker_id} processing request {request_id} for model {model_id}") # Start time for latency calculation start_time = time.time() if stream_mode: # Generate streaming response await generate_streaming_response( request_id=request_id, model_id=model_id, messages=messages, temperature=temperature, max_tokens=max_tokens, api_key=api_key ) else: # Generate non-streaming response response = await generate_response( model_id=model_id, messages=messages, temperature=temperature, max_tokens=max_tokens, api_key=api_key ) # Calculate latency latency = time.time() - start_time # Update latency stats if model_stats_dict.contains(model_id): stats = model_stats_dict[model_id] old_avg = stats.get("avg_latency", 0) old_count = stats.get("success_count", 0) # Calculate new average (moving average) if old_count > 0: new_avg = (old_avg * old_count + latency) / (old_count + 1) else: new_avg = latency stats["avg_latency"] = new_avg model_stats_dict[model_id] = stats # Store the response in both caches memory_cache.set(request_id, response) response_dict[request_id] = response logging.info(f"Worker {worker_id} completed request {request_id} in {latency:.2f}s") except Exception as e: # Log error and move on logging.error(f"Worker {worker_id} error processing request {request_id}: {str(e)}") # Create error response error_response = { "error": { "message": str(e), "type": "internal_error", "code": 500 } } # Store the error as a response memory_cache.set(request_id, error_response) response_dict[request_id] = error_response # If streaming, send error and finish stream if "stream_mode" in locals() and stream_mode: stream_manager.add_chunk(request_id, { "id": f"chatcmpl-{int(time.time())}", "object": "chat.completion.chunk", "created": int(time.time()), "model": model_id, "choices": [{ "index": 0, "delta": {"content": f"Error: {str(e)}"}, "finish_reason": "error" }] }) stream_manager.finish_stream(request_id) except asyncio.TimeoutError: # No requests in queue empty_count += 1 logging.info(f"Worker {worker_id}: No requests in queue. Empty count: {empty_count}") # Clean up expired cache entries and old streams if empty_count % 5 == 0: # Every 5 empty polls memory_cache.clear_expired() stream_manager.clean_old_streams() await asyncio.sleep(1) # Wait a bit before checking again # If we get here, we've had too many consecutive empty polls logging.info(f"Worker {worker_id} shutting down due to empty queue") finally: # Signal that this worker is done workers_running_key = "workers_running" if model_stats_dict.contains(workers_running_key): current_workers = model_stats_dict[workers_running_key] model_stats_dict[workers_running_key] = max(0, current_workers - 1) logging.info(f"Worker {worker_id} shutdown. Workers remaining: {max(0, current_workers - 1)}") async def generate_streaming_response( request_id: str, model_id: str, messages: List[dict], temperature: float, max_tokens: int, api_key: str ): """ Generate a streaming response and send chunks to the stream manager. Args: request_id: The unique ID for this request model_id: The ID of the model to use messages: The chat messages temperature: The sampling temperature max_tokens: The maximum tokens to generate api_key: The API key for authentication """ import httpx import time import json import asyncio try: # Create response ID response_id = f"chatcmpl-{int(time.time())}" if model_id in VLLM_MODELS: # Start vLLM server for this model server_url = await serve_vllm_model.remote(model_id=model_id) # Need to wait for server startup await wait_for_server(serve_vllm_model.web_url, timeout=120) # Forward request to vLLM with streaming enabled async with httpx.AsyncClient(timeout=120.0) as client: headers = { "Authorization": f"Bearer {api_key}", "Content-Type": "application/json", "Accept": "text/event-stream" } # Format request for vLLM OpenAI-compatible endpoint vllm_request = { "model": VLLM_MODELS[model_id]["name"], "messages": messages, "temperature": temperature, "max_tokens": max_tokens, "stream": True } # Make streaming request async with client.stream( "POST", f"{serve_vllm_model.web_url}/v1/chat/completions", json=vllm_request, headers=headers ) as response: # Process streaming response buffer = "" async for chunk in response.aiter_text(): buffer += chunk # Process complete SSE messages while "\n\n" in buffer: message, buffer = buffer.split("\n\n", 1) if message.startswith("data: "): data = message[6:] # Remove "data: " prefix if data == "[DONE]": # End of stream stream_manager.finish_stream(request_id) return try: # Parse JSON data chunk_data = json.loads(data) # Forward to client stream_manager.add_chunk(request_id, chunk_data) except json.JSONDecodeError: logging.error(f"Invalid JSON in stream: {data}") # Ensure stream is finished stream_manager.finish_stream(request_id) elif model_id in LLAMA_CPP_MODELS: # For llama.cpp models, we need to simulate streaming # First convert the chat format to a prompt prompt = format_messages_to_prompt(messages) # Run llama.cpp with the prompt output = await run_llama_cpp_stream.remote( model_id=model_id, prompt=prompt, n_predict=max_tokens, temperature=temperature, request_id=request_id ) # Streaming is handled by the run_llama_cpp_stream function # which directly adds chunks to the stream manager # Wait for completion signal while True: if request_id in stream_queues and stream_queues[request_id] == "DONE": # Clean up stream_queues.pop(request_id) break await asyncio.sleep(0.1) else: raise ValueError(f"Unknown model: {model_id}") except Exception as e: logging.error(f"Error in streaming generation: {str(e)}") # Send error chunk stream_manager.add_chunk(request_id, { "id": response_id, "object": "chat.completion.chunk", "created": int(time.time()), "model": model_id, "choices": [{ "index": 0, "delta": {"content": f"Error: {str(e)}"}, "finish_reason": "error" }] }) # Finish stream stream_manager.finish_stream(request_id) async def generate_response(model_id: str, messages: List[dict], temperature: float, max_tokens: int, api_key: str): """ Generate a response using the appropriate model based on model_id. Args: model_id: The ID of the model to use messages: The chat messages temperature: The sampling temperature max_tokens: The maximum tokens to generate api_key: The API key for authentication Returns: A response in OpenAI-compatible format """ import httpx import time import json import asyncio if model_id in VLLM_MODELS: # Start vLLM server for this model server_url = await serve_vllm_model.remote(model_id=model_id) # Need to wait for server startup await wait_for_server(serve_vllm_model.web_url, timeout=120) # Forward request to vLLM async with httpx.AsyncClient(timeout=60.0) as client: headers = { "Authorization": f"Bearer {api_key}", "Content-Type": "application/json" } # Format request for vLLM OpenAI-compatible endpoint vllm_request = { "model": VLLM_MODELS[model_id]["name"], "messages": messages, "temperature": temperature, "max_tokens": max_tokens } response = await client.post( f"{serve_vllm_model.web_url}/v1/chat/completions", json=vllm_request, headers=headers ) return response.json() elif model_id in LLAMA_CPP_MODELS: # For llama.cpp models, use the run_llama_cpp function # First convert the chat format to a prompt prompt = format_messages_to_prompt(messages) # Run llama.cpp with the prompt output = await run_llama_cpp.remote( model_id=model_id, prompt=prompt, n_predict=max_tokens, temperature=temperature ) # Format the response in the OpenAI format completion_text = output.strip() finish_reason = "stop" if len(completion_text) < max_tokens else "length" return { "id": f"chatcmpl-{int(time.time())}", "object": "chat.completion", "created": int(time.time()), "model": model_id, "choices": [ { "index": 0, "message": { "role": "assistant", "content": completion_text }, "finish_reason": finish_reason } ], "usage": { "prompt_tokens": len(prompt) // 4, # Rough estimation "completion_tokens": len(completion_text) // 4, # Rough estimation "total_tokens": (len(prompt) + len(completion_text)) // 4 # Rough estimation } } else: raise ValueError(f"Unknown model: {model_id}") def format_messages_to_prompt(messages: List[Dict[str, str]]) -> str: """ Convert chat messages to a text prompt format for llama.cpp. Args: messages: List of message dictionaries with role and content Returns: Formatted prompt string """ formatted_prompt = "" for message in messages: role = message.get("role", "").lower() content = message.get("content", "") if role == "system": formatted_prompt += f"<|system|>\n{content}\n" elif role == "user": formatted_prompt += f"<|user|>\n{content}\n" elif role == "assistant": formatted_prompt += f"<|assistant|>\n{content}\n" else: # For unknown roles, treat as user formatted_prompt += f"<|user|>\n{content}\n" # Add final assistant marker to prompt the model to respond formatted_prompt += "<|assistant|>\n" return formatted_prompt async def wait_for_server(url: str, timeout: int = 120, check_interval: int = 2): """ Wait for a server to be ready by checking its health endpoint. Args: url: The base URL of the server timeout: Maximum time to wait in seconds check_interval: Interval between checks in seconds Returns: True if server is ready, False otherwise """ import httpx import asyncio import time start_time = time.time() health_url = f"{url}/health" logging.info(f"Waiting for server at {url} to be ready...") while time.time() - start_time < timeout: try: async with httpx.AsyncClient(timeout=5.0) as client: response = await client.get(health_url) if response.status_code == 200: logging.info(f"Server at {url} is ready!") return True except Exception as e: elapsed = time.time() - start_time logging.info(f"Server not ready yet after {elapsed:.1f}s: {str(e)}") await asyncio.sleep(check_interval) logging.error(f"Timed out waiting for server at {url} after {timeout} seconds") return False @app.function( image=llama_cpp_image, gpu=None, # Will be set dynamically based on model volumes={ f"{CACHE_DIR}/huggingface": hf_cache_vol, f"{CACHE_DIR}/llama_cpp": llama_cpp_cache_vol, RESULTS_DIR: results_vol, }, timeout=30 * MINUTES, ) async def run_llama_cpp_stream( model_id: str, prompt: str, n_predict: int = 1024, temperature: float = 0.7, request_id: str = None, ): """ Run streaming inference with llama.cpp for models like DeepSeek-R1 and Phi-4 """ import subprocess import os import json import time import threading from uuid import uuid4 from pathlib import Path from huggingface_hub import snapshot_download if model_id not in LLAMA_CPP_MODELS: available_models = list(LLAMA_CPP_MODELS.keys()) error_msg = f"Unknown model: {model_id}. Available models: {available_models}" logging.error(error_msg) if request_id: # Send error to stream stream_manager.add_chunk(request_id, { "id": f"chatcmpl-{int(time.time())}", "object": "chat.completion.chunk", "created": int(time.time()), "model": model_id, "choices": [{ "index": 0, "delta": {"content": f"Error: {error_msg}"}, "finish_reason": "error" }] }) stream_manager.finish_stream(request_id) # Signal completion stream_queues[request_id] = "DONE" raise ValueError(error_msg) model_info = LLAMA_CPP_MODELS[model_id] repo_id = model_info["name"] pattern = model_info["pattern"] revision = model_info["revision"] quant = model_info["quant"] # Download model if not already cached logging.info(f"Downloading model {repo_id} if not present") try: model_path = snapshot_download( repo_id=repo_id, revision=revision, local_dir=f"{CACHE_DIR}/llama_cpp", allow_patterns=[pattern], ) except ValueError as e: if "hf_transfer" in str(e): # Fallback to standard download if hf_transfer fails logging.warning("hf_transfer failed, falling back to standard download") # Temporarily disable hf_transfer import os old_env = os.environ.get("HF_HUB_ENABLE_HF_TRANSFER", "1") os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "0" try: model_path = snapshot_download( repo_id=repo_id, revision=revision, local_dir=f"{CACHE_DIR}/llama_cpp", allow_patterns=[pattern], ) finally: # Restore original setting os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = old_env else: raise # Find the model file model_files = list(Path(model_path).glob(pattern)) if not model_files: error_msg = f"No model files found matching pattern {pattern}" logging.error(error_msg) if request_id: # Send error to stream stream_manager.add_chunk(request_id, { "id": f"chatcmpl-{int(time.time())}", "object": "chat.completion.chunk", "created": int(time.time()), "model": model_id, "choices": [{ "index": 0, "delta": {"content": f"Error: {error_msg}"}, "finish_reason": "error" }] }) stream_manager.finish_stream(request_id) # Signal completion stream_queues[request_id] = "DONE" raise FileNotFoundError(error_msg) model_file = str(model_files[0]) logging.info(f"Using model file: {model_file}") # Set up command cmd = [ "llama-cli", "--model", model_file, "--prompt", prompt, "--n-predict", str(n_predict), "--temp", str(temperature), "--ctx-size", "8192", ] # Add GPU layers if needed if model_info["gpu"] is not None: cmd.extend(["--n-gpu-layers", "9999"]) # Use all layers on GPU # Run inference result_id = str(uuid4()) logging.info(f"Running streaming inference with ID: {result_id}") # Create response ID for streaming response_id = f"chatcmpl-{int(time.time())}" # Function to process output in real-time and send to stream def process_output(process, request_id): content_buffer = "" last_send_time = time.time() # Send initial chunk with role if request_id: stream_manager.add_chunk(request_id, { "id": response_id, "object": "chat.completion.chunk", "created": int(time.time()), "model": model_id, "choices": [{ "index": 0, "delta": {"role": "assistant"}, }] }) for line in iter(process.stdout.readline, b''): try: line_str = line.decode('utf-8', errors='replace') # Skip llama.cpp info lines if line_str.startswith("llama_"): continue # Add to buffer content_buffer += line_str # Send chunks at reasonable intervals or when buffer gets large now = time.time() if (now - last_send_time > 0.1 or len(content_buffer) > 20) and request_id: # Send chunk stream_manager.add_chunk(request_id, { "id": response_id, "object": "chat.completion.chunk", "created": int(time.time()), "model": model_id, "choices": [{ "index": 0, "delta": {"content": content_buffer}, }] }) # Reset buffer and time content_buffer = "" last_send_time = now except Exception as e: logging.error(f"Error processing output: {str(e)}") # Send any remaining content if content_buffer and request_id: stream_manager.add_chunk(request_id, { "id": response_id, "object": "chat.completion.chunk", "created": int(time.time()), "model": model_id, "choices": [{ "index": 0, "delta": {"content": content_buffer}, }] }) # Send final chunk with finish reason if request_id: stream_manager.add_chunk(request_id, { "id": response_id, "object": "chat.completion.chunk", "created": int(time.time()), "model": model_id, "choices": [{ "index": 0, "delta": {}, "finish_reason": "stop" }] }) # Finish stream stream_manager.finish_stream(request_id) # Signal completion stream_queues[request_id] = "DONE" # Start process process = subprocess.Popen( cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=False, bufsize=1 # Line buffered ) # Start output processing thread if streaming if request_id: thread = threading.Thread(target=process_output, args=(process, request_id)) thread.daemon = True thread.start() # Return immediately for streaming return "Streaming in progress" else: # For non-streaming, collect all output stdout, stderr = collect_output(process) # Save results result_dir = Path(RESULTS_DIR) / result_id result_dir.mkdir(parents=True, exist_ok=True) (result_dir / "output.txt").write_text(stdout) (result_dir / "stderr.txt").write_text(stderr) (result_dir / "prompt.txt").write_text(prompt) logging.info(f"Results saved to {result_dir}") return stdout @app.function( image=llama_cpp_image, gpu=None, # Will be set dynamically based on model volumes={ f"{CACHE_DIR}/huggingface": hf_cache_vol, f"{CACHE_DIR}/llama_cpp": llama_cpp_cache_vol, RESULTS_DIR: results_vol, }, timeout=30 * MINUTES, ) async def run_llama_cpp( model_id: str, prompt: str = "Tell me about Modal and how it helps with ML deployments.", n_predict: int = 1024, temperature: float = 0.7, ): """ Run inference with llama.cpp for models like DeepSeek-R1 and Phi-4 """ import subprocess import os from uuid import uuid4 from pathlib import Path from huggingface_hub import snapshot_download if model_id not in LLAMA_CPP_MODELS: available_models = list(LLAMA_CPP_MODELS.keys()) print(f"Error: Unknown model: {model_id}. Available models: {available_models}") raise ValueError(f"Unknown model: {model_id}. Available models: {available_models}") model_info = LLAMA_CPP_MODELS[model_id] repo_id = model_info["name"] pattern = model_info["pattern"] revision = model_info["revision"] quant = model_info["quant"] # Download model if not already cached logging.info(f"Downloading model {repo_id} if not present") try: model_path = snapshot_download( repo_id=repo_id, revision=revision, local_dir=f"{CACHE_DIR}/llama_cpp", allow_patterns=[pattern], ) except ValueError as e: if "hf_transfer" in str(e): # Fallback to standard download if hf_transfer fails logging.warning("hf_transfer failed, falling back to standard download") # Temporarily disable hf_transfer import os old_env = os.environ.get("HF_HUB_ENABLE_HF_TRANSFER", "1") os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "0" try: model_path = snapshot_download( repo_id=repo_id, revision=revision, local_dir=f"{CACHE_DIR}/llama_cpp", allow_patterns=[pattern], ) finally: # Restore original setting os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = old_env else: raise # Find the model file model_files = list(Path(model_path).glob(pattern)) if not model_files: logging.error(f"No model files found matching pattern {pattern}") raise FileNotFoundError(f"No model files found matching pattern {pattern}") model_file = str(model_files[0]) print(f"Using model file: {model_file}") # Set up command cmd = [ "llama-cli", "--model", model_file, "--prompt", prompt, "--n-predict", str(n_predict), "--temp", str(temperature), "--ctx-size", "8192", ] # Add GPU layers if needed if model_info["gpu"] is not None: cmd.extend(["--n-gpu-layers", "9999"]) # Use all layers on GPU # Run inference result_id = str(uuid4()) print(f"Running inference with ID: {result_id}") process = subprocess.Popen( cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=False ) stdout, stderr = collect_output(process) # Save results result_dir = Path(RESULTS_DIR) / result_id result_dir.mkdir(parents=True, exist_ok=True) (result_dir / "output.txt").write_text(stdout) (result_dir / "stderr.txt").write_text(stderr) (result_dir / "prompt.txt").write_text(prompt) print(f"Results saved to {result_dir}") return stdout @app.function( image=vllm_image, volumes={ f"{CACHE_DIR}/huggingface": hf_cache_vol, }, ) def list_available_models(): """ Lists available models that can be used with this server. Returns: dict: A dictionary containing lists of available vLLM and llama.cpp models. """ print("Available vLLM models:") for model_id, model_info in VLLM_MODELS.items(): print(f"- {model_id}: {model_info['name']}") print("\nAvailable llama.cpp models:") for model_id, model_info in LLAMA_CPP_MODELS.items(): gpu_info = f"(GPU: {model_info['gpu']})" if model_info['gpu'] else "(CPU)" print(f"- {model_id}: {model_info['name']} {gpu_info}") return { "vllm": list(VLLM_MODELS.keys()), "llama_cpp": list(LLAMA_CPP_MODELS.keys()) } def collect_output(process): """ Collect output from a process while streaming it. Args: process: The process from which to collect output. Returns: tuple: A tuple containing the collected stdout and stderr as strings. """ import sys from queue import Queue from threading import Thread def stream_output(stream, queue, write_stream): for line in iter(stream.readline, b""): line_str = line.decode("utf-8", errors="replace") write_stream.write(line_str) write_stream.flush() queue.put(line_str) stream.close() stdout_queue = Queue() stderr_queue = Queue() stdout_thread = Thread(target=stream_output, args=(process.stdout, stdout_queue, sys.stdout)) stderr_thread = Thread(target=stream_output, args=(process.stderr, stderr_queue, sys.stderr)) stdout_thread.start() stderr_thread.start() stdout_thread.join() stderr_thread.join() process.wait() stdout_collected = "".join(list(stdout_queue.queue)) stderr_collected = "".join(list(stderr_queue.queue)) return stdout_collected, stderr_collected # Main ASGI app for Modal @app.function( image=vllm_image, gpu=None, # No GPU for the API frontend allow_concurrent_inputs=100, volumes={ f"{CACHE_DIR}/huggingface": hf_cache_vol, }, ) @modal.asgi_app() def inference_api(): """The main ASGI app that serves the FastAPI application""" return api_app @app.local_entrypoint() def main( prompt: str = "What can you tell me about Modal?", n_predict: int = 1024, temperature: float = 0.7, create_admin_key: bool = False, stream: bool = False, model: str = "auto", load_model: str = None, load_hf_model: str = None, hf_model_type: str = "vllm", ): """ Main entrypoint for testing the API """ import json import time import urllib.request # Initialize the API print(f"Starting API at {inference_api.web_url}") # Wait for API to be ready print("Checking if API is ready...") up, start, delay = False, time.time(), 10 while not up: try: with urllib.request.urlopen(inference_api.web_url + "/health") as response: if response.getcode() == 200: up = True except Exception: if time.time() - start > 5 * MINUTES: break time.sleep(delay) assert up, f"Failed health check for API at {inference_api.web_url}" print(f"API is up and running at {inference_api.web_url}") # Create a test API key if requested if create_admin_key: print("Creating a test API key...") key_request = { "user_id": "test_user", "rate_limit": 120, "quota": 2000000 } headers = { "Authorization": f"Bearer {DEFAULT_API_KEY}", # Admin key "Content-Type": "application/json", } req = urllib.request.Request( inference_api.web_url + "/admin/api-keys", data=json.dumps(key_request).encode("utf-8"), headers=headers, method="POST", ) try: with urllib.request.urlopen(req) as response: result = json.loads(response.read().decode()) print("Created API key:") print(json.dumps(result, indent=2)) # Use this key for the test message test_key = result["key"] except Exception as e: print(f"Error creating API key: {str(e)}") test_key = DEFAULT_API_KEY else: test_key = DEFAULT_API_KEY # List available models print("\nAvailable models:") try: headers = { "Authorization": f"Bearer {test_key}", "Content-Type": "application/json", } req = urllib.request.Request( inference_api.web_url + "/v1/models", headers=headers, method="GET", ) with urllib.request.urlopen(req) as response: models = json.loads(response.read().decode()) print(json.dumps(models, indent=2)) except Exception as e: print(f"Error listing models: {str(e)}") # Select best model for the prompt model = select_best_model(prompt, n_predict, temperature) # Send a test message print(f"\nSending a sample message to {inference_api.web_url}") messages = [{"role": "user", "content": prompt}] headers = { "Authorization": f"Bearer {test_key}", "Content-Type": "application/json", } payload = json.dumps({ "messages": messages, "model": model, "temperature": temperature, "max_tokens": n_predict, "stream": stream }) req = urllib.request.Request( inference_api.web_url + "/v1/chat/completions", data=payload.encode("utf-8"), headers=headers, method="POST", ) try: if stream: print("Streaming response:") with urllib.request.urlopen(req) as response: for line in response: line = line.decode('utf-8') if line.startswith('data: '): data = line[6:].strip() if data == '[DONE]': print("\n[DONE]") else: try: chunk = json.loads(data) if 'choices' in chunk and len(chunk['choices']) > 0: if 'delta' in chunk['choices'][0] and 'content' in chunk['choices'][0]['delta']: content = chunk['choices'][0]['delta']['content'] print(content, end='', flush=True) except json.JSONDecodeError: print(f"Error parsing: {data}") else: with urllib.request.urlopen(req) as response: result = json.loads(response.read().decode()) print("Response:") print(json.dumps(result, indent=2)) except Exception as e: print(f"Error: {str(e)}") # Check API stats print("\nChecking API stats...") headers = { "Authorization": f"Bearer {DEFAULT_API_KEY}", # Admin key "Content-Type": "application/json", } req = urllib.request.Request( inference_api.web_url + "/admin/stats", headers=headers, method="GET", ) try: with urllib.request.urlopen(req) as response: stats = json.loads(response.read().decode()) print("API Stats:") print(json.dumps(stats, indent=2)) except Exception as e: print(f"Error getting stats: {str(e)}") # Start a worker if none running try: current_workers = stats.get("queue", {}).get("active_workers", 0) if current_workers < 1: print("\nStarting a queue worker...") process_queue_worker.spawn() except Exception as e: print(f"Error starting worker: {str(e)}") print(f"\nAPI is available at {inference_api.web_url}") print(f"Documentation is at {inference_api.web_url}/docs") print(f"Default Bearer token: {DEFAULT_API_KEY}") if create_admin_key: print(f"Test Bearer token: {test_key}") # If a model was specified to load, load it if load_model: print(f"\nLoading model: {load_model}") load_url = f"{inference_api.web_url}/admin/models/load" headers = { "Authorization": f"Bearer {test_key}", "Content-Type": "application/json", } payload = json.dumps({ "model_id": load_model, "force_reload": True }) req = urllib.request.Request( load_url, data=payload.encode("utf-8"), headers=headers, method="POST", ) try: with urllib.request.urlopen(req) as response: result = json.loads(response.read().decode()) print("Load response:") print(json.dumps(result, indent=2)) # If it's a small model, wait a bit for it to load if load_model in ["tiny-llama-1.1b", "phi-2"]: print(f"Waiting for {load_model} to load...") time.sleep(10) # Check status status_url = f"{inference_api.web_url}/admin/models/status/{load_model}" status_req = urllib.request.Request( status_url, headers={"Authorization": f"Bearer {test_key}"}, method="GET", ) with urllib.request.urlopen(status_req) as status_response: status_result = json.loads(status_response.read().decode()) print("Model status:") print(json.dumps(status_result, indent=2)) # Use this model for the test model = load_model except Exception as e: print(f"Error loading model: {str(e)}") # If a HF model was specified to load directly if load_hf_model: print(f"\nLoading HF model: {load_hf_model} with type {hf_model_type}") load_url = f"{inference_api.web_url}/admin/models/load-from-hf" headers = { "Authorization": f"Bearer {test_key}", "Content-Type": "application/json", } payload = json.dumps({ "repo_id": load_hf_model, "model_type": hf_model_type, "max_tokens": n_predict }) req = urllib.request.Request( load_url, data=payload.encode("utf-8"), headers=headers, method="POST", ) try: with urllib.request.urlopen(req) as response: result = json.loads(response.read().decode()) print("HF Load response:") print(json.dumps(result, indent=2)) # Get the model_id from the response hf_model_id = result.get("model_id") # Wait a bit for it to start loading print(f"Waiting for {load_hf_model} to start loading...") time.sleep(5) # Check status if hf_model_id: status_url = f"{inference_api.web_url}/admin/models/status/{hf_model_id}" status_req = urllib.request.Request( status_url, headers={"Authorization": f"Bearer {test_key}"}, method="GET", ) with urllib.request.urlopen(status_req) as status_response: status_result = json.loads(status_response.read().decode()) print("Model status:") print(json.dumps(status_result, indent=2)) # Use this model for the test if hf_model_id: model = hf_model_id except Exception as e: print(f"Error loading HF model: {str(e)}") # Show curl examples print("\nExample curl commands:") # Regular completion print(f"""# Regular completion: curl -X POST {inference_api.web_url}/v1/chat/completions \\ -H "Content-Type: application/json" \\ -H "Authorization: Bearer {test_key}" \\ -d '{{ "model": "{model}", "messages": [ {{ "role": "user", "content": "Hello, how can you help me today?" }} ], "temperature": 0.7, "max_tokens": 500 }}'""") # Streaming completion print(f"""\n# Streaming completion: curl -X POST {inference_api.web_url}/v1/chat/completions \\ -H "Content-Type: application/json" \\ -H "Authorization: Bearer {test_key}" \\ -d '{{ "model": "{model}", "messages": [ {{ "role": "user", "content": "Write a short story about AI" }} ], "temperature": 0.8, "max_tokens": 1000, "stream": true }}' --no-buffer""") # List models print(f"""\n# List available models: curl -X GET {inference_api.web_url}/v1/models \\ -H "Authorization: Bearer {test_key}" """) # Model management commands print(f"""\n# Load a model: curl -X POST {inference_api.web_url}/admin/models/load \\ -H "Content-Type: application/json" \\ -H "Authorization: Bearer {test_key}" \\ -d '{{ "model_id": "phi-2", "force_reload": false }}'""") print(f"""\n# Load a model directly from Hugging Face: curl -X POST {inference_api.web_url}/admin/models/load-from-hf \\ -H "Content-Type: application/json" \\ -H "Authorization: Bearer {test_key}" \\ -d '{{ "repo_id": "microsoft/phi-2", "model_type": "vllm", "max_tokens": 4096 }}'""") print(f"""\n# Get model status: curl -X GET {inference_api.web_url}/admin/models/status/phi-2 \\ -H "Authorization: Bearer {test_key}" """) print(f"""\n# Unload a model: curl -X POST {inference_api.web_url}/admin/models/unload \\ -H "Content-Type: application/json" \\ -H "Authorization: Bearer {test_key}" \\ -d '{{ "model_id": "phi-2" }}'""") async def preload_llama_cpp_model(model_id: str): """Preload a llama.cpp model to make inference faster on first request""" if model_id not in LLAMA_CPP_MODELS: logging.error(f"Unknown model: {model_id}") return try: # Run a simple inference to load the model logging.info(f"Preloading llama.cpp model: {model_id}") await run_llama_cpp.remote( model_id=model_id, prompt="Hello, this is a test to preload the model.", n_predict=10, temperature=0.7 ) logging.info(f"Successfully preloaded llama.cpp model: {model_id}") except Exception as e: logging.error(f"Error preloading llama.cpp model {model_id}: {str(e)}") # Mark as not loaded LLAMA_CPP_MODELS[model_id]["loaded"] = False