"""Silero VAD for speech detection."""
import time
from pathlib import Path
import numpy as np
class SileroVAD:
"""Minimal Silero VAD ONNX wrapper for speech detection at 16kHz.
Uses a lightweight ONNX model (~1.6MB) to detect speech probability
in 512-sample chunks. Runs on CPU with ~1ms inference time.
"""
ONNX_MODEL_URL = "https://github.com/snakers4/silero-vad/raw/master/src/silero_vad/data/silero_vad.onnx"
ONNX_MODEL_PATH = Path(__file__).parent.parent.parent.parent / "models" / "silero_vad.onnx"
CHUNK_SIZE = 512 # Silero expects 512 samples at 16kHz
CONTEXT_SIZE = 64
RESET_INTERVAL = 5.0 # Reset state every N seconds
def __init__(self):
self._session = None
self._state = None
self._context = None
self._last_reset_time = time.time()
def _ensure_loaded(self) -> None:
"""Lazy load the Silero VAD model."""
if self._session is not None:
return
import onnxruntime as ort
# Download model if needed
model_path = self.ONNX_MODEL_PATH
model_path.parent.mkdir(parents=True, exist_ok=True)
if not model_path.exists():
print("Downloading Silero VAD model...")
import urllib.request
urllib.request.urlretrieve(self.ONNX_MODEL_URL, str(model_path))
print("Silero VAD model downloaded.")
# Create ONNX session
opts = ort.SessionOptions()
opts.inter_op_num_threads = 1
opts.intra_op_num_threads = 1
self._session = ort.InferenceSession(
str(model_path), providers=["CPUExecutionProvider"], sess_options=opts
)
self._init_states()
def _init_states(self) -> None:
"""Initialize/reset VAD internal state."""
self._state = np.zeros((2, 1, 128), dtype=np.float32)
self._context = np.zeros((1, self.CONTEXT_SIZE), dtype=np.float32)
self._last_reset_time = time.time()
def _maybe_reset(self) -> None:
"""Reset state periodically to avoid drift."""
if (time.time() - self._last_reset_time) >= self.RESET_INTERVAL:
self._init_states()
def predict(self, chunk: np.ndarray) -> float:
"""Compute speech probability for a 512-sample chunk.
Args:
chunk: Float32 audio samples (512 samples at 16kHz)
Returns:
Speech probability (0.0 to 1.0)
"""
self._ensure_loaded()
# Ensure correct shape
x = chunk.reshape(1, -1).astype(np.float32)
if x.shape[1] != self.CHUNK_SIZE:
raise ValueError(f"Expected {self.CHUNK_SIZE} samples, got {x.shape[1]}")
# Concatenate context
x = np.concatenate((self._context, x), axis=1)
# Run inference
ort_inputs = {
"input": x,
"state": self._state,
"sr": np.array(16000, dtype=np.int64),
}
out, self._state = self._session.run(None, ort_inputs)
# Update context
self._context = x[:, -self.CONTEXT_SIZE :]
self._maybe_reset()
return float(out[0][0])
def detect_speech(self, audio: np.ndarray, sample_rate: int = 16000, threshold: float = 0.5) -> dict:
"""Detect if audio contains speech.
Args:
audio: Float32 audio samples
sample_rate: Audio sample rate (will resample to 16kHz if different)
threshold: Speech probability threshold
Returns:
Dictionary with speech detection results
"""
self._ensure_loaded()
# Resample if needed
if sample_rate != 16000:
from scipy import signal
num_samples = int(len(audio) * 16000 / sample_rate)
audio = signal.resample(audio, num_samples).astype(np.float32)
# Process in chunks
speech_probs = []
for i in range(0, len(audio) - self.CHUNK_SIZE + 1, self.CHUNK_SIZE):
chunk = audio[i:i + self.CHUNK_SIZE]
prob = self.predict(chunk)
speech_probs.append(prob)
if not speech_probs:
return {"is_speech": False, "probability": 0.0, "max_probability": 0.0}
avg_prob = sum(speech_probs) / len(speech_probs)
max_prob = max(speech_probs)
return {
"is_speech": max_prob > threshold,
"probability": avg_prob,
"max_probability": max_prob,
}