llm_client.pyโข27.5 kB
"""
LLM Client - Unified interface for multiple LLM backends.
Supports local models via transformers, LM Studio, and external APIs.
"""
import asyncio
import logging
import socket
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Dict, Any, Optional, List
import aiohttp
from typing import Dict, List, Optional, Any, Union
from abc import ABC, abstractmethod
from dataclasses import dataclass
logger = logging.getLogger(__name__)
@dataclass
class LLMConfig:
"""Configuration for LLM backend."""
backend: str # "transformers", "lm_studio", "openai", "anthropic"
model_name: str
api_base: Optional[str] = None
api_key: Optional[str] = None
temperature: float = 0.7
max_tokens: int = 1024
timeout: int = 30
class LLMBackend(ABC):
"""Abstract base class for LLM backends."""
@abstractmethod
async def initialize(self, config: LLMConfig) -> bool:
"""Initialize the backend."""
pass
@abstractmethod
async def generate(self, prompt: str, **kwargs) -> str:
"""Generate text from prompt."""
pass
@abstractmethod
async def cleanup(self):
"""Cleanup resources."""
pass
async def get_model_info(self) -> Dict[str, Any]:
"""Get information about the model/backend. Override if needed."""
return {"backend": type(self).__name__, "status": "available"}
class TransformersBackend(LLMBackend):
"""Local transformers backend."""
def __init__(self):
self.tokenizer = None
self.model = None
self.device = None
self.torch = None
self._initialized = False
async def initialize(self, config: LLMConfig) -> bool:
"""Initialize transformers model."""
try:
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
self.torch = torch
except ImportError:
logger.error("transformers not available")
return False
try:
self.tokenizer = AutoTokenizer.from_pretrained(config.model_name)
self.model = AutoModelForCausalLM.from_pretrained(
config.model_name,
torch_dtype=torch.float16,
device_map="auto"
)
self.device = next(self.model.parameters()).device
self._initialized = True
logger.info(f"Transformers model {config.model_name} loaded")
return True
except Exception as e:
logger.error(f"Failed to load transformers model: {e}")
return False
async def generate(self, prompt: str, **kwargs) -> str:
"""Generate text using local model."""
if not self._initialized or not self.torch or not self.tokenizer or not self.model:
raise RuntimeError("Backend not initialized")
max_tokens = kwargs.get('max_tokens', 1024)
temperature = kwargs.get('temperature', 0.7)
inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True, max_length=2048)
inputs = {k: v.to(self.device) for k, v in inputs.items()}
with self.torch.no_grad():
outputs = self.model.generate(
**inputs,
max_new_tokens=max_tokens,
temperature=temperature,
do_sample=True,
pad_token_id=self.tokenizer.eos_token_id
)
return self.tokenizer.decode(outputs[0], skip_special_tokens=True)
async def cleanup(self):
"""Cleanup model resources."""
if self.model:
del self.model
if self.tokenizer:
del self.tokenizer
self._initialized = False
async def get_model_info(self) -> Dict[str, Any]:
"""Get information about the transformers model."""
if not self._initialized:
return {"backend": "TransformersBackend", "status": "not_initialized"}
return {
"backend": "TransformersBackend",
"status": "initialized",
"device": str(self.device) if self.device else "unknown",
"model_loaded": self.model is not None,
"tokenizer_loaded": self.tokenizer is not None
}
class LMStudioBackend(LLMBackend):
"""LM Studio HTTP API backend."""
def __init__(self):
self.session = None
self.api_base = None
self.available_models = []
self.selected_model = None
self._initialized = False
async def initialize(self, config: LLMConfig) -> bool:
"""Initialize LM Studio client."""
try:
# Try to discover LM Studio instances if not explicitly configured
if not config.api_base or config.api_base == "auto":
discovered_url = await self._discover_lm_studio()
if discovered_url:
self.api_base = discovered_url
logger.info(f"Auto-discovered LM Studio at: {self.api_base}")
else:
# Fallback to default localhost
self.api_base = "http://localhost:1234"
logger.info(f"Using default LM Studio URL: {self.api_base}")
else:
self.api_base = config.api_base.rstrip('/')
# Create session for health check and model discovery
async with aiohttp.ClientSession() as session:
# Test connection
try:
async with session.get(f"{self.api_base}/health", timeout=aiohttp.ClientTimeout(total=5)) as resp:
if resp.status != 200:
logger.warning(f"LM Studio health check returned {resp.status}, continuing anyway")
except Exception as e:
logger.warning(f"LM Studio health check failed: {e}, continuing anyway")
# Query available models
try:
async with session.get(f"{self.api_base}/v1/models", timeout=aiohttp.ClientTimeout(total=10)) as resp:
if resp.status == 200:
data = await resp.json()
self.available_models = data.get("data", [])
logger.info(f"Found {len(self.available_models)} models in LM Studio")
# Select best model for our use case
self.selected_model = self._select_best_model()
if self.selected_model:
logger.info(f"Selected model: {self.selected_model}")
else:
logger.warning("No suitable model found, using fallback")
else:
logger.warning(f"Failed to query models: {resp.status}")
except Exception as e:
logger.warning(f"Failed to query LM Studio models: {e}")
# Initialize session for generation
self.session = aiohttp.ClientSession()
self._initialized = True
logger.info("LM Studio backend initialized")
return True
except Exception as e:
logger.error(f"LM Studio initialization failed: {e}")
return False
async def _discover_lm_studio(self) -> Optional[str]:
"""Discover LM Studio instances on the local network."""
logger.info("๐ Discovering LM Studio instances...")
# Common LM Studio ports to check
common_ports = [1234, 1235, 1236, 8080, 8000, 3000]
# Check localhost first (most common)
for port in common_ports:
if await self._check_lm_studio_at_url(f"http://localhost:{port}"):
return f"http://localhost:{port}"
# TODO: Network discovery for other machines on local network
# This would involve:
# 1. Getting local network interfaces
# 2. Scanning the local subnet for open ports
# 3. Testing each candidate for LM Studio API endpoints
# 4. Respecting network privacy and security policies
logger.info("No LM Studio instances found via auto-discovery")
return None
async def _check_lm_studio_at_url(self, url: str) -> bool:
"""Check if LM Studio is running at the given URL."""
try:
async with aiohttp.ClientSession() as session:
# Try health endpoint first
try:
async with session.get(f"{url}/health", timeout=aiohttp.ClientTimeout(total=2)) as resp:
if resp.status == 200:
logger.debug(f"โ
LM Studio found at {url} (health check)")
return True
except:
pass
# Try models endpoint (more reliable)
try:
async with session.get(f"{url}/v1/models", timeout=aiohttp.ClientTimeout(total=2)) as resp:
if resp.status == 200:
data = await resp.json()
if "data" in data and isinstance(data["data"], list):
logger.debug(f"โ
LM Studio found at {url} (models endpoint)")
return True
except:
pass
return False
except Exception as e:
logger.debug(f"Checked {url}: {e}")
return False
def _select_best_model(self) -> Optional[str]:
"""Select the best model for code generation tasks."""
if not self.available_models:
return None
# Model preferences for code generation (in order of preference)
preferred_patterns = [
# Code-specific models
r'codex', r'code(llama)?', r'star.*coder', r'deepseek.*coder',
# Instruction-tuned models (good for following instructions)
r'instruct', r'chat', r'tuned', r'it',
# General capable models
r'mixtral', r'llama.*\b(70|34|13|8)\b', r'qwen.*\b(14|7|4)\b',
# Fallback to any model
r'.*'
]
import re
for pattern in preferred_patterns:
for model in self.available_models:
model_id = model.get("id", "")
if re.search(pattern, model_id, re.IGNORECASE):
# Prefer models with reasonable context size
if any(x in model_id.lower() for x in ['32k', '16k', '8k']):
return model_id
return model_id
# If no pattern matches, use the first available model
return self.available_models[0].get("id") if self.available_models else None
async def generate(self, prompt: str, **kwargs) -> str:
"""Generate text via LM Studio API."""
if not self._initialized or not self.session:
raise RuntimeError("Backend not initialized")
max_tokens = kwargs.get('max_tokens', 1024)
temperature = kwargs.get('temperature', 0.7)
# Use selected model or fallback
model = self.selected_model or "local-model"
payload = {
"model": model,
"messages": [{"role": "user", "content": prompt}],
"max_tokens": max_tokens,
"temperature": temperature
}
try:
async with self.session.post(
f"{self.api_base}/v1/chat/completions",
json=payload,
timeout=aiohttp.ClientTimeout(total=60) # Longer timeout for generation
) as resp:
if resp.status == 200:
data = await resp.json()
return data["choices"][0]["message"]["content"]
else:
error_text = await resp.text()
raise Exception(f"LM Studio API error {resp.status}: {error_text}")
except Exception as e:
logger.error(f"LM Studio generation failed: {e}")
raise
async def get_model_info(self) -> Dict[str, Any]:
"""Get information about available models."""
if not self._initialized:
return {"error": "Backend not initialized"}
return {
"available_models": self.available_models,
"selected_model": self.selected_model,
"model_count": len(self.available_models)
}
async def cleanup(self):
"""Cleanup HTTP session."""
if self.session:
await self.session.close()
self.session = None
self._initialized = False
class OpenAIBackend(LLMBackend):
"""OpenAI API backend."""
def __init__(self):
self.session = None
self.api_key = None
self._initialized = False
async def initialize(self, config: LLMConfig) -> bool:
"""Initialize OpenAI client."""
if not config.api_key:
logger.error("OpenAI requires api_key")
return False
self.api_key = config.api_key
api_base = config.api_base or "https://api.openai.com/v1"
headers = {"Authorization": f"Bearer {self.api_key}"}
try:
self.session = aiohttp.ClientSession(
base_url=api_base,
headers=headers,
timeout=aiohttp.ClientTimeout(total=config.timeout)
)
self._initialized = True
logger.info("OpenAI client initialized")
return True
except Exception as e:
logger.error(f"Failed to initialize OpenAI client: {e}")
return False
async def generate(self, prompt: str, **kwargs) -> str:
"""Generate text via OpenAI API."""
if not self._initialized or not self.session:
raise RuntimeError("Backend not initialized")
max_tokens = kwargs.get('max_tokens', 1024)
temperature = kwargs.get('temperature', 0.7)
model = kwargs.get('model', "gpt-3.5-turbo")
payload = {
"model": model,
"messages": [{"role": "user", "content": prompt}],
"max_tokens": max_tokens,
"temperature": temperature
}
try:
async with self.session.post("/chat/completions", json=payload) as resp:
if resp.status == 200:
data = await resp.json()
return data["choices"][0]["message"]["content"]
else:
error_text = await resp.text()
raise Exception(f"OpenAI API error {resp.status}: {error_text}")
except Exception as e:
logger.error(f"OpenAI generation failed: {e}")
raise
async def cleanup(self):
"""Cleanup HTTP session."""
if self.session:
await self.session.close()
self.session = None
self._initialized = False
class LLMClient:
"""Local-first LLM client with intelligent fallback support."""
def __init__(self):
self.local_backend: Optional[TransformersBackend] = None
self.primary_backend: Optional[LLMBackend] = None
self.fallback_backends: List[LLMBackend] = []
self.local_config: Optional[LLMConfig] = None
self.primary_config: Optional[LLMConfig] = None
self._health_status: Dict[str, bool] = {}
self._last_health_check: Dict[str, float] = {}
self._health_check_interval = 60 # seconds
async def initialize(self, configs: List[LLMConfig]) -> bool:
"""Initialize with multiple configs, local first."""
if not configs:
logger.error("No LLM configurations provided")
return False
# Find local transformers config (must be first)
local_configs = [c for c in configs if c.backend == "transformers"]
if not local_configs:
logger.error("Local transformers backend required as fallback")
return False
self.local_config = local_configs[0]
self.local_backend = TransformersBackend()
# Initialize local backend first
if not await self.local_backend.initialize(self.local_config):
logger.error("Failed to initialize local backend - critical failure")
return False
logger.info(f"Local backend initialized: {self.local_config.model_name}")
self._health_status["local"] = True
self._last_health_check["local"] = asyncio.get_event_loop().time()
# Initialize other backends if available
for config in configs[1:]:
try:
if config.backend == "lm_studio":
backend = LMStudioBackend()
elif config.backend == "openai":
backend = OpenAIBackend()
else:
logger.warning(f"Unknown backend: {config.backend}")
continue
if await backend.initialize(config):
if not self.primary_backend:
self.primary_backend = backend
self.primary_config = config
logger.info(f"Primary backend initialized: {config.backend}")
else:
self.fallback_backends.append(backend)
logger.info(f"Fallback backend initialized: {config.backend}")
self._health_status[config.backend] = True
self._last_health_check[config.backend] = asyncio.get_event_loop().time()
else:
logger.warning(f"Failed to initialize {config.backend} backend")
except Exception as e:
logger.warning(f"Error initializing {config.backend}: {e}")
return True
async def generate(self, prompt: str, **kwargs) -> str:
"""Generate text using intelligent backend selection."""
current_time = asyncio.get_event_loop().time()
# Check if we should use primary backend
if (self.primary_backend and self.primary_config and
await self._is_backend_healthy(self.primary_config.backend, current_time)):
try:
generation_kwargs = {
"max_tokens": kwargs.get('max_tokens', self.primary_config.max_tokens),
"temperature": kwargs.get('temperature', self.primary_config.temperature),
**kwargs
}
result = await self.primary_backend.generate(prompt, **generation_kwargs)
logger.debug(f"Generated using {self.primary_config.backend}")
return result
except Exception as e:
logger.warning(f"Primary backend failed: {e}")
self._health_status[self.primary_config.backend] = False
self._last_health_check[self.primary_config.backend] = current_time
# Try fallback backends
for i, backend in enumerate(self.fallback_backends):
backend_name = f"fallback_{i}"
if await self._is_backend_healthy(backend_name, current_time):
try:
# Use primary config settings for consistency
generation_kwargs = {
"max_tokens": kwargs.get('max_tokens', self.primary_config.max_tokens if self.primary_config else 1024),
"temperature": kwargs.get('temperature', self.primary_config.temperature if self.primary_config else 0.7),
**kwargs
}
result = await backend.generate(prompt, **generation_kwargs)
logger.debug(f"Generated using fallback backend {i}")
return result
except Exception as e:
logger.warning(f"Fallback backend {i} failed: {e}")
self._health_status[backend_name] = False
self._last_health_check[backend_name] = current_time
# Always fall back to local
if (self.local_backend and self.local_config and
await self._is_backend_healthy("local", current_time)):
try:
generation_kwargs = {
"max_tokens": kwargs.get('max_tokens', self.local_config.max_tokens),
"temperature": kwargs.get('temperature', self.local_config.temperature),
**kwargs
}
result = await self.local_backend.generate(prompt, **generation_kwargs)
logger.debug("Generated using local backend (fallback)")
return result
except Exception as e:
logger.error(f"Local backend failed: {e}")
self._health_status["local"] = False
self._last_health_check["local"] = current_time
raise RuntimeError("All LLM backends failed")
async def get_backend_info(self) -> Dict[str, Any]:
"""Get information about all backends and their models."""
info = {
"primary_backend": None,
"fallback_backends": [],
"local_backend": None,
"backend_details": {}
}
if self.primary_backend:
info["primary_backend"] = type(self.primary_backend).__name__
if hasattr(self.primary_backend, 'get_model_info'):
info["backend_details"]["primary"] = await self.primary_backend.get_model_info()
for i, backend in enumerate(self.fallback_backends):
backend_name = f"fallback_{i}"
info["fallback_backends"].append(type(backend).__name__)
if hasattr(backend, 'get_model_info'):
info["backend_details"][backend_name] = await backend.get_model_info()
if self.local_backend:
info["local_backend"] = type(self.local_backend).__name__
if hasattr(self.local_backend, 'get_model_info'):
info["backend_details"]["local"] = await self.local_backend.get_model_info()
return info
async def _is_backend_healthy(self, backend_name: str, current_time: float) -> bool:
"""Check if backend is healthy based on heuristics."""
# Check if we need to refresh health status
last_check = self._last_health_check.get(backend_name, 0)
if current_time - last_check > self._health_check_interval:
await self._check_backend_health(backend_name)
return self._health_status.get(backend_name, False)
async def _check_backend_health(self, backend_name: str):
"""Check backend health with heuristics."""
try:
if backend_name == "local":
# Local backend is always considered healthy unless explicitly failed
if not self._health_status.get("local", True):
# Try to reinitialize local backend
if self.local_backend and self.local_config:
self._health_status["local"] = await self.local_backend.initialize(self.local_config)
else:
# For remote backends, we could implement health checks
# For now, assume they're healthy if previously marked as such
pass
self._last_health_check[backend_name] = asyncio.get_event_loop().time()
except Exception as e:
logger.warning(f"Health check failed for {backend_name}: {e}")
self._health_status[backend_name] = False
self._last_health_check[backend_name] = asyncio.get_event_loop().time()
def get_backend_status(self) -> Dict[str, Any]:
"""Get status of all backends."""
return {
"local": {
"healthy": self._health_status.get("local", False),
"model": self.local_config.model_name if self.local_config else None
},
"primary": {
"backend": self.primary_config.backend if self.primary_config else None,
"healthy": self._health_status.get(self.primary_config.backend, False) if self.primary_config else False,
"model": self.primary_config.model_name if self.primary_config else None
},
"fallbacks": [
{
"backend": f"fallback_{i}",
"healthy": self._health_status.get(f"fallback_{i}", False)
}
for i in range(len(self.fallback_backends))
]
}
async def cleanup(self):
"""Cleanup all backend resources."""
if self.local_backend:
await self.local_backend.cleanup()
if self.primary_backend:
await self.primary_backend.cleanup()
for backend in self.fallback_backends:
await backend.cleanup()
@classmethod
def create_from_config(cls, config_dict: Dict[str, Any]) -> tuple["LLMClient", List[LLMConfig]]:
"""Create LLMClient and configs from configuration dictionary."""
client = cls()
# Build list of configs, local first
configs = []
# Always include local config
local_config = LLMConfig(
backend="transformers",
model_name=config_dict.get("local_model", "Qwen/Qwen2-1.5B-Instruct"),
temperature=config_dict.get("temperature", 0.7),
max_tokens=config_dict.get("max_tokens", 1024),
timeout=config_dict.get("timeout", 30)
)
configs.append(local_config)
# Add LM Studio if configured
if config_dict.get("lm_studio_url"):
lm_config = LLMConfig(
backend="lm_studio",
model_name="local-model",
api_base=config_dict["lm_studio_url"],
temperature=config_dict.get("temperature", 0.7),
max_tokens=config_dict.get("max_tokens", 1024),
timeout=config_dict.get("timeout", 30)
)
configs.append(lm_config)
# Add OpenAI if configured
if config_dict.get("openai_api_key"):
openai_config = LLMConfig(
backend="openai",
model_name=config_dict.get("openai_model", "gpt-3.5-turbo"),
api_key=config_dict["openai_api_key"],
api_base=config_dict.get("openai_api_base"),
temperature=config_dict.get("temperature", 0.7),
max_tokens=config_dict.get("max_tokens", 1024),
timeout=config_dict.get("timeout", 30)
)
configs.append(openai_config)
return client, configs