"""Smart Turn v3 - Semantic turn detection using Whisper Tiny encoder.
Uses the Smart Turn model from pipecat-ai to detect when a speaker has
finished their turn, not just detecting silence. This is more robust than
simple VAD because it understands speech patterns, pauses, and intent.
Model: https://huggingface.co/onnx-community/smart-turn-v3-ONNX
"""
import logging
from pathlib import Path
import numpy as np
logger = logging.getLogger(__name__)
class SmartTurnVAD:
"""Smart Turn v3 ONNX wrapper for semantic turn detection.
Takes 16kHz mono audio and predicts whether the speaker has finished
their turn. Uses Whisper's feature extractor for mel spectrograms.
Key differences from Silero VAD:
- Semantic: understands when speaker is DONE, not just silent
- Handles "um", pauses, thinking silences naturally
- Supports 23 languages
- ~10-65ms inference time on CPU
"""
MODEL_URL = "https://huggingface.co/onnx-community/smart-turn-v3-ONNX/resolve/main/onnx/model_quantized.onnx"
MODEL_PATH = Path(__file__).parent.parent.parent.parent / "models" / "smart_turn_v3.onnx"
# Smart Turn expects up to 8 seconds of audio
MAX_AUDIO_SECONDS = 8
SAMPLE_RATE = 16000
def __init__(self):
self._session = None
self._feature_extractor = None
def _ensure_loaded(self) -> None:
"""Lazy load the Smart Turn model and feature extractor."""
if self._session is not None:
return
import onnxruntime as ort
# Download model if needed
model_path = self.MODEL_PATH
model_path.parent.mkdir(parents=True, exist_ok=True)
if not model_path.exists():
logger.info("Downloading Smart Turn v3 model...")
import urllib.request
urllib.request.urlretrieve(self.MODEL_URL, str(model_path))
logger.info("Smart Turn v3 model downloaded.")
# Create ONNX session
opts = ort.SessionOptions()
opts.execution_mode = ort.ExecutionMode.ORT_SEQUENTIAL
opts.inter_op_num_threads = 1
opts.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
self._session = ort.InferenceSession(
str(model_path),
providers=["CPUExecutionProvider"],
sess_options=opts,
)
# Initialize Whisper feature extractor
try:
from transformers import WhisperFeatureExtractor
self._feature_extractor = WhisperFeatureExtractor(chunk_length=8)
logger.info("Smart Turn v3 loaded successfully")
except ImportError:
raise ImportError(
"transformers package required for Smart Turn. "
"Install with: pip install transformers"
)
def _truncate_or_pad(self, audio: np.ndarray, n_seconds: int = 8) -> np.ndarray:
"""Truncate to last n seconds or pad with zeros at beginning.
Smart Turn works best with the END of the utterance, so we keep
the last n_seconds and pad at the beginning if needed.
"""
target_len = n_seconds * self.SAMPLE_RATE
if len(audio) > target_len:
# Keep the last n_seconds (end of utterance is most important)
return audio[-target_len:]
elif len(audio) < target_len:
# Pad with zeros at the beginning
padding = np.zeros(target_len - len(audio), dtype=np.float32)
return np.concatenate([padding, audio])
else:
return audio
def predict(self, audio: np.ndarray, sample_rate: int = 16000) -> dict:
"""Predict if the speaker has finished their turn.
Args:
audio: Float32 audio samples
sample_rate: Audio sample rate (will resample to 16kHz if different)
Returns:
Dictionary with:
- is_complete: True if speaker appears done (turn complete)
- probability: Probability of turn completion (0.0 to 1.0)
"""
self._ensure_loaded()
# Resample if needed
if sample_rate != self.SAMPLE_RATE:
from scipy import signal
num_samples = int(len(audio) * self.SAMPLE_RATE / sample_rate)
audio = signal.resample(audio, num_samples).astype(np.float32)
# Truncate/pad to 8 seconds
audio = self._truncate_or_pad(audio, n_seconds=8)
# Extract Whisper features
inputs = self._feature_extractor(
audio,
sampling_rate=self.SAMPLE_RATE,
return_tensors="np",
padding="max_length",
max_length=8 * self.SAMPLE_RATE,
truncation=True,
do_normalize=True,
)
# Get mel features and ensure correct shape
input_features = inputs.input_features.squeeze(0).astype(np.float32)
input_features = np.expand_dims(input_features, axis=0) # Add batch dim
# Run ONNX inference
outputs = self._session.run(None, {"input_features": input_features})
# Extract probability (model returns sigmoid probability)
probability = float(outputs[0][0].item())
return {
"is_complete": probability > 0.5,
"probability": probability,
}
def detect_turn_complete(
self,
audio: np.ndarray,
sample_rate: int = 16000,
threshold: float = 0.5,
) -> dict:
"""Detect if speaker has finished their turn (API compatible with SileroVAD).
Args:
audio: Float32 audio samples
sample_rate: Audio sample rate
threshold: Completion probability threshold (default 0.5)
Returns:
Dictionary with:
- is_speech: True if speaker appears NOT done (still speaking)
- is_complete: True if speaker appears done (turn complete)
- probability: Turn completion probability
"""
result = self.predict(audio, sample_rate)
# Return in format compatible with existing VAD interface
# Note: is_speech is inverted - if turn is complete, they're NOT speaking
is_complete = result["probability"] > threshold
return {
"is_speech": not is_complete, # Compatible with SileroVAD interface
"is_complete": is_complete,
"probability": result["probability"],
"max_probability": result["probability"], # For compatibility
}