import torch
import logging
import numpy as np
from typing import List, Dict
from youtube_mcp_server.config import Config
logger = logging.getLogger("youtube-mcp-server")
class VADService:
def __init__(self):
self.model = None
self.utils = None
self.get_speech_timestamps = None
self._load_model()
def _load_model(self):
logger.info("Loading Silero VAD model...")
try:
# Note: We rely on torch.hub which caches models.
self.model, self.utils = torch.hub.load(repo_or_dir=Config.SILERO_REPO,
model=Config.SILERO_MODEL,
force_reload=False,
trust_repo=True)
(self.get_speech_timestamps, _, _, _, _) = self.utils
logger.info("Silero VAD loaded successfully.")
except Exception as e:
logger.error(f"Failed to load Silero VAD: {e}")
self.model = None
def get_segments(self, audio_np: np.ndarray) -> List[Dict[str, int]]:
"""
Detects speech segments in 16k mono audio.
Returns list of {'start': int_sample, 'end': int_sample}.
"""
if self.model is None:
# Fallback: return entire audio as one segment
logger.warning("VAD model not available, falling back to full segment.")
return [{'start': 0, 'end': len(audio_np)}]
try:
wav_tensor = torch.from_numpy(audio_np)
speech_timestamps = self.get_speech_timestamps(wav_tensor, self.model, sampling_rate=Config.SAMPLING_RATE)
if not speech_timestamps:
logger.warning("No speech detected by VAD. Returning full segment.")
return [{'start': 0, 'end': len(audio_np)}]
logger.info(f"VAD detected {len(speech_timestamps)} speech segments.")
return speech_timestamps
except Exception as e:
logger.error(f"Error during VAD inference: {e}")
return [{'start': 0, 'end': len(audio_np)}]