"""Voice Activity Detection (VAD) modules.
Provides Silero VAD for basic speech detection and Smart Turn VAD
for semantic turn completion detection.
"""
from __future__ import annotations
import os
import sys
import time
import warnings
from pathlib import Path
from typing import Optional
import numpy as np
# Get script directory for model storage
SCRIPT_DIR = Path(__file__).parent.parent.parent.parent.parent.absolute()
class SileroVAD:
"""Minimal Silero VAD ONNX wrapper for speech detection at 16kHz.
Uses a lightweight ONNX model (~1.6MB) to detect speech probability
in 512-sample chunks. Runs on CPU with ~1ms inference time.
Attributes:
CHUNK_SIZE: Expected input chunk size (512 samples at 16kHz)
"""
ONNX_MODEL_URL = "https://github.com/snakers4/silero-vad/raw/master/src/silero_vad/data/silero_vad.onnx"
ONNX_MODEL_PATH = SCRIPT_DIR / "models" / "silero_vad.onnx"
CHUNK_SIZE = 512 # Silero expects 512 samples at 16kHz
CONTEXT_SIZE = 64
RESET_INTERVAL = 5.0 # Reset state every N seconds
def __init__(self):
self._session = None
self._state = None
self._context = None
self._last_reset_time = time.time()
def _ensure_loaded(self) -> None:
"""Lazy load the Silero VAD model."""
if self._session is not None:
return
import onnxruntime as ort
# Download model if needed
model_path = self.ONNX_MODEL_PATH
model_path.parent.mkdir(parents=True, exist_ok=True)
if not model_path.exists():
print("Downloading Silero VAD model...")
import urllib.request
urllib.request.urlretrieve(self.ONNX_MODEL_URL, str(model_path))
print("Silero VAD model downloaded.")
# Create ONNX session
opts = ort.SessionOptions()
opts.inter_op_num_threads = 1
opts.intra_op_num_threads = 1
self._session = ort.InferenceSession(
str(model_path), providers=["CPUExecutionProvider"], sess_options=opts
)
self._init_states()
def _init_states(self) -> None:
"""Initialize/reset VAD internal state."""
self._state = np.zeros((2, 1, 128), dtype=np.float32)
self._context = np.zeros((1, self.CONTEXT_SIZE), dtype=np.float32)
self._last_reset_time = time.time()
def _maybe_reset(self) -> None:
"""Reset state periodically to avoid drift."""
if (time.time() - self._last_reset_time) >= self.RESET_INTERVAL:
self._init_states()
def predict(self, chunk: np.ndarray) -> float:
"""Compute speech probability for a 512-sample chunk.
Args:
chunk: Float32 audio samples (512 samples at 16kHz)
Returns:
Speech probability (0.0 to 1.0)
"""
self._ensure_loaded()
# Ensure correct shape
x = chunk.reshape(1, -1).astype(np.float32)
if x.shape[1] != self.CHUNK_SIZE:
raise ValueError(f"Expected {self.CHUNK_SIZE} samples, got {x.shape[1]}")
# Concatenate context
x = np.concatenate((self._context, x), axis=1)
# Run inference
ort_inputs = {
"input": x,
"state": self._state,
"sr": np.array(16000, dtype=np.int64),
}
out, self._state = self._session.run(None, ort_inputs)
# Update context
self._context = x[:, -self.CONTEXT_SIZE :]
self._maybe_reset()
return float(out[0][0])
class SmartTurnVAD:
"""Smart Turn v3 - Semantic Voice Activity Detection.
Uses Whisper Tiny encoder to understand when a speaker has finished
their turn, rather than just detecting silence. This enables:
- Faster response times (no need for long silence threshold)
- Natural conversation flow (handles pauses, "um", etc.)
- Multilingual support (23 languages)
The model runs on CPU in ~12ms and is only 8MB (quantized).
Attributes:
SAMPLE_RATE: Expected sample rate (16000 Hz)
MAX_DURATION: Maximum audio duration for analysis (8 seconds)
"""
MODEL_REPO = "pipecat-ai/smart-turn-v3"
MODEL_FILE = "smart-turn-v3.2-cpu.onnx"
MODEL_URL = f"https://huggingface.co/{MODEL_REPO}/resolve/main/{MODEL_FILE}"
MODEL_PATH = SCRIPT_DIR / "models" / "smart-turn-v3.2-cpu.onnx"
# Audio settings
SAMPLE_RATE = 16000
MAX_DURATION = 8 # seconds
MAX_SAMPLES = MAX_DURATION * SAMPLE_RATE
def __init__(self, smart_turn_threshold: float = 0.5):
"""Initialize SmartTurnVAD.
Args:
smart_turn_threshold: Probability threshold for turn completion (0.0-1.0)
"""
self._session = None
self._feature_extractor = None
self._silero = SileroVAD() # For initial speech detection
self._threshold = smart_turn_threshold
def _ensure_loaded(self) -> None:
"""Lazy load the Smart Turn model."""
if self._session is not None:
return
import onnxruntime as ort
# Download model if needed
model_path = self.MODEL_PATH
model_path.parent.mkdir(parents=True, exist_ok=True)
if not model_path.exists():
print("Downloading Smart Turn v3 model (8MB)...")
import urllib.request
urllib.request.urlretrieve(self.MODEL_URL, str(model_path))
print("Smart Turn v3 model downloaded.")
# Load WhisperFeatureExtractor
# Suppress transformers warnings during import
with warnings.catch_warnings():
warnings.simplefilter("ignore")
from transformers import WhisperFeatureExtractor
self._feature_extractor = WhisperFeatureExtractor(chunk_length=8)
# Create ONNX session
opts = ort.SessionOptions()
opts.execution_mode = ort.ExecutionMode.ORT_SEQUENTIAL
opts.inter_op_num_threads = 1
opts.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
self._session = ort.InferenceSession(
str(model_path), providers=["CPUExecutionProvider"], sess_options=opts
)
print("Smart Turn v3 loaded successfully")
def _prepare_audio(self, audio: np.ndarray) -> np.ndarray:
"""Truncate to last 8s or pad with zeros at beginning."""
if len(audio) > self.MAX_SAMPLES:
return audio[-self.MAX_SAMPLES :]
elif len(audio) < self.MAX_SAMPLES:
padding = self.MAX_SAMPLES - len(audio)
return np.pad(audio, (padding, 0), mode="constant", constant_values=0)
return audio
def predict_endpoint(self, audio: np.ndarray) -> dict:
"""Predict whether an audio segment represents a completed turn.
Args:
audio: Float32 audio at 16kHz
Returns:
Dictionary with:
- prediction: 1 for complete, 0 for incomplete
- probability: Confidence score (0.0 to 1.0)
"""
self._ensure_loaded()
# Prepare audio (truncate/pad to 8 seconds)
audio = self._prepare_audio(audio.astype(np.float32))
# Extract Whisper features
inputs = self._feature_extractor(
audio,
sampling_rate=self.SAMPLE_RATE,
return_tensors="np",
padding="max_length",
max_length=self.MAX_SAMPLES,
truncation=True,
do_normalize=True,
)
# Prepare input
input_features = inputs.input_features.squeeze(0).astype(np.float32)
input_features = np.expand_dims(input_features, axis=0)
# Run inference
outputs = self._session.run(None, {"input_features": input_features})
# Extract probability
probability = float(outputs[0][0].item())
prediction = 1 if probability > self._threshold else 0
return {"prediction": prediction, "probability": probability}
@property
def silero(self) -> SileroVAD:
"""Access the Silero VAD for speech detection."""
return self._silero
# Global Smart Turn instance (lazy loaded)
_smart_turn_vad: Optional[SmartTurnVAD] = None
def get_smart_turn_vad() -> SmartTurnVAD:
"""Get or create the global SmartTurnVAD instance."""
global _smart_turn_vad
if _smart_turn_vad is None:
# Get threshold from environment variable if available
threshold = float(os.environ.get("VOICE_SMART_TURN_THRESHOLD", "0.5"))
_smart_turn_vad = SmartTurnVAD(smart_turn_threshold=threshold)
return _smart_turn_vad