background_trainer.py•13.5 kB
"""
Production-safe background training system.
Trains models without affecting live performance.
"""
import asyncio
import time
import logging
from typing import Dict, Any, List, Optional
from datetime import datetime, timedelta
from pathlib import Path
import threading
from ..models import TrainingExample, LearningMetrics
logger = logging.getLogger(__name__)
class ProductionSafeTrainer:
"""
Background trainer that safely retrains models without affecting production.
Uses separate thread/process isolation and atomic model swaps.
"""
def __init__(self, setfit_classifier, config):
"""
Initialize background trainer.
Args:
setfit_classifier: SetFit classifier to retrain
config: Training configuration
"""
self.setfit_classifier = setfit_classifier
self.config = config
# Training queue and worker
self.training_queue = asyncio.Queue()
self.training_worker = None
self.is_training = False
self.last_training_time = None
# Training metrics
self.training_sessions = 0
self.successful_trainings = 0
self.failed_trainings = 0
# Background worker will be started on first use
self._worker_started = False
def _start_background_worker(self):
"""Start background training worker"""
if self.training_worker is None:
self.training_worker = asyncio.create_task(self._training_worker_loop())
logger.info("Background training worker started")
async def schedule_training(self, training_examples: List[TrainingExample]):
"""
Schedule model training without blocking production.
Args:
training_examples: New training examples to learn from
"""
if not self.config.auto_retrain_enabled:
logger.debug("Auto-retraining disabled, skipping training")
return
if self.is_training:
logger.debug("Training already in progress, skipping")
return
# Check if enough time has passed since last training
if self.last_training_time:
time_since_last = datetime.utcnow() - self.last_training_time
min_interval = timedelta(hours=self.config.retraining_schedule_hours)
if time_since_last < min_interval:
logger.debug(f"Too soon to retrain, waiting {min_interval - time_since_last}")
return
# Start worker if not started
if not self._worker_started:
self._start_background_worker()
self._worker_started = True
# Add to training queue
try:
await self.training_queue.put(training_examples)
logger.info(f"Scheduled training with {len(training_examples)} new examples")
except Exception as e:
logger.error(f"Failed to schedule training: {e}")
async def _training_worker_loop(self):
"""Background worker loop for safe training"""
while True:
try:
# Wait for training request
training_examples = await self.training_queue.get()
# Perform safe training
await self._train_safely(training_examples)
# Mark task done
self.training_queue.task_done()
# Add delay to prevent rapid retraining
await asyncio.sleep(60) # 1 minute cooldown
except asyncio.CancelledError:
logger.info("Training worker cancelled")
break
except Exception as e:
logger.error(f"Training worker error: {e}")
await asyncio.sleep(300) # 5 minute delay on error
async def _train_safely(self, training_examples: List[TrainingExample]):
"""
Train model safely without affecting production performance.
Args:
training_examples: New training examples
"""
if self.is_training:
return
self.is_training = True
self.training_sessions += 1
try:
logger.info(f"Starting safe background training session #{self.training_sessions}")
# Step 1: Prepare comprehensive training data
all_training_data = await self._prepare_comprehensive_training_data(training_examples)
# Step 2: Train in isolated environment
success = await self._train_in_isolation(all_training_data)
if success:
self.successful_trainings += 1
self.last_training_time = datetime.utcnow()
logger.info("Background training completed successfully")
else:
self.failed_trainings += 1
logger.warning("Background training failed")
except Exception as e:
self.failed_trainings += 1
logger.error(f"Safe training failed: {e}")
finally:
self.is_training = False
async def _prepare_comprehensive_training_data(self, new_examples: List[TrainingExample]) -> Dict[str, List[str]]:
"""
Prepare comprehensive training data combining existing and new examples.
Args:
new_examples: New training examples from LLM learning
Returns:
Dict[str, List[str]]: Complete training dataset
"""
# Start with base training data
training_data = self.setfit_classifier._get_initial_training_data()
# Add new examples from learning buffer
for example in new_examples:
intent = example.intent
query = example.query.strip().lower()
if intent not in training_data:
training_data[intent] = []
# Avoid duplicates
if query not in training_data[intent]:
training_data[intent].append(query)
# Ensure balanced dataset (minimum examples per intent)
min_examples = 5
for intent, examples in training_data.items():
if len(examples) < min_examples:
logger.warning(f"Intent '{intent}' has only {len(examples)} examples (minimum: {min_examples})")
logger.info(f"Prepared training data: {sum(len(examples) for examples in training_data.values())} total examples")
return training_data
async def _train_in_isolation(self, training_data: Dict[str, List[str]]) -> bool:
"""
Train model in isolated process to avoid affecting production.
Args:
training_data: Complete training dataset
Returns:
bool: True if training succeeded
"""
try:
# Use threading to isolate training from main process
result = await asyncio.get_event_loop().run_in_executor(
None, # Default thread pool
self._blocking_train_operation,
training_data
)
return result
except Exception as e:
logger.error(f"Isolated training failed: {e}")
return False
def _blocking_train_operation(self, training_data: Dict[str, List[str]]) -> bool:
"""
Blocking training operation (runs in separate thread).
Args:
training_data: Training dataset
Returns:
bool: Training success
"""
try:
# Prepare data
texts, labels = self.setfit_classifier._prepare_training_data(training_data)
# Create label mappings
unique_labels = sorted(set(labels))
new_label_to_intent = {i: label for i, label in enumerate(unique_labels)}
new_intent_to_label = {label: i for i, label in enumerate(unique_labels)}
# Convert labels
numeric_labels = [new_intent_to_label[label] for label in labels]
# Create dataset
train_dataset = Dataset.from_dict({
"text": texts,
"label": numeric_labels
})
# Create new model instance (don't affect current model)
new_model = SetFitModel.from_pretrained(
"sentence-transformers/paraphrase-mpnet-base-v2"
)
# Train new model
trainer = SetFitTrainer(
model=new_model,
train_dataset=train_dataset,
num_iterations=15,
num_epochs=1,
batch_size=16
)
trainer.train()
# Validate new model
if self._validate_model_quality(new_model, training_data):
# Atomic update of production model
self._atomic_model_update(new_model, new_label_to_intent, new_intent_to_label)
return True
else:
logger.warning("New model validation failed - keeping current model")
return False
except Exception as e:
logger.error(f"Blocking train operation failed: {e}")
return False
def _validate_model_quality(self, new_model, training_data: Dict[str, List[str]]) -> bool:
"""
Validate new model quality before deployment.
Args:
new_model: Newly trained model
training_data: Training dataset used
Returns:
bool: True if model quality is acceptable
"""
try:
# Test model on sample queries
test_queries = [
("show products", "inventory_inquiry"),
("sales data", "sales_inquiry"),
("customer info", "customer_inquiry"),
("order status", "order_inquiry")
]
correct_predictions = 0
total_predictions = len(test_queries)
for query, expected_intent in test_queries:
try:
predictions = new_model.predict_proba([query])[0]
predicted_label = int(np.argmax(predictions))
confidence = float(predictions[predicted_label])
# This is simplified - would need proper label mapping for full validation
if confidence > 0.5: # Basic confidence check
correct_predictions += 1
except Exception as e:
logger.warning(f"Validation test failed for '{query}': {e}")
accuracy = correct_predictions / total_predictions
logger.info(f"New model validation accuracy: {accuracy:.2f}")
# Require minimum accuracy
return accuracy >= 0.7 # 70% minimum accuracy
except Exception as e:
logger.error(f"Model validation failed: {e}")
return False
def _atomic_model_update(self, new_model, new_label_to_intent, new_intent_to_label):
"""
Atomically update production model (minimal downtime).
Args:
new_model: New trained model
new_label_to_intent: New label mappings
new_intent_to_label: New intent mappings
"""
try:
# Atomic swap of model and mappings
old_model = self.setfit_classifier.model
old_label_to_intent = self.setfit_classifier.label_to_intent
old_intent_to_label = self.setfit_classifier.intent_to_label
# Update all at once (atomic operation)
self.setfit_classifier.model = new_model
self.setfit_classifier.label_to_intent = new_label_to_intent
self.setfit_classifier.intent_to_label = new_intent_to_label
logger.info("Model updated atomically - zero downtime achieved")
except Exception as e:
# Rollback on failure
self.setfit_classifier.model = old_model
self.setfit_classifier.label_to_intent = old_label_to_intent
self.setfit_classifier.intent_to_label = old_intent_to_label
logger.error(f"Atomic model update failed, rolled back: {e}")
raise
def get_training_status(self) -> Dict[str, Any]:
"""Get current training status and metrics"""
return {
"is_training": self.is_training,
"training_sessions": self.training_sessions,
"successful_trainings": self.successful_trainings,
"failed_trainings": self.failed_trainings,
"last_training_time": self.last_training_time.isoformat() if self.last_training_time else None,
"queue_size": self.training_queue.qsize(),
"auto_retrain_enabled": self.config.auto_retrain_enabled,
"next_training_eligible": self._get_next_training_time()
}
def _get_next_training_time(self) -> Optional[str]:
"""Get next eligible training time"""
if not self.last_training_time:
return "immediately"
next_time = self.last_training_time + timedelta(hours=self.config.retraining_schedule_hours)
if datetime.utcnow() >= next_time:
return "immediately"
else:
return next_time.isoformat()
async def shutdown(self):
"""Safely shutdown background trainer"""
if self.training_worker:
self.training_worker.cancel()
try:
await self.training_worker
except asyncio.CancelledError:
pass
logger.info("Background trainer shutdown completed")