"""Audio recording with VAD (Voice Activity Detection) support.
Provides AudioRecorder class for capturing audio input with various
VAD backends including Smart Turn semantic turn detection.
"""
from __future__ import annotations
import time
from collections import deque
from typing import TYPE_CHECKING, Callable, Optional
import numpy as np
import sounddevice as sd
if TYPE_CHECKING:
from localvoicemode.speech.vad import SmartTurnVAD
class AudioRecorder:
"""Records audio with optional VAD (Voice Activity Detection).
Supports two recording modes:
- VAD mode: Automatically detects speech start/end using Smart Turn VAD
- PTT mode: Push-to-talk recording with manual start/stop
Args:
sample_rate: Audio sample rate in Hz (default: 16000)
vad_factory: Optional factory function that returns a SmartTurnVAD instance.
Used for dependency injection in tests.
"""
def __init__(
self,
sample_rate: int = 16000,
vad_factory: Optional[Callable[[], "SmartTurnVAD"]] = None,
):
self.sample_rate = sample_rate
self._vad_factory = vad_factory
self.recording = False
self.audio_buffer = []
self._stream = None
self.current_level = 0.0 # Current audio level (0.0 to 1.0)
self.level_history = [0.0] * 50 # History for waveform visualization
self._level_callback = None
def set_level_callback(self, callback: Callable[[float, list], None]) -> None:
"""Set callback for real-time audio level updates.
Args:
callback: Function called with (level, history) on each update
"""
self._level_callback = callback
def _update_level(self, rms: float) -> None:
"""Update audio level and history.
Args:
rms: Root mean square of audio chunk
"""
# Normalize RMS to 0-1 range (assuming max RMS around 0.3)
level = min(1.0, rms / 0.3)
self.current_level = level
# Update history (shift left, add new value)
self.level_history.pop(0)
self.level_history.append(level)
# Call callback if set
if self._level_callback:
self._level_callback(level, self.level_history)
def record_vad(self, max_duration: float = 30.0) -> np.ndarray:
"""Record with Smart Turn semantic voice activity detection.
Uses Silero VAD for speech detection, then runs Smart Turn model
to determine if speaker has finished their turn.
Args:
max_duration: Maximum recording duration in seconds
Returns:
Recorded audio as float32 numpy array
"""
return self._record_vad_smart_turn(max_duration)
def _record_vad_smart_turn(self, max_duration: float = 30.0) -> np.ndarray:
"""Record with Smart Turn semantic voice activity detection.
Uses Silero VAD for speech detection, then runs Smart Turn model
to determine if speaker has finished their turn.
Args:
max_duration: Maximum recording duration in seconds
Returns:
Recorded audio as float32 numpy array
"""
# Get VAD instance: use factory if provided, otherwise lazy import
if self._vad_factory is not None:
smart_turn = self._vad_factory()
else:
# Lazy import to avoid circular dependency
from localvoicemode.speech.vad import get_smart_turn_vad
smart_turn = get_smart_turn_vad()
silero = smart_turn.silero
# Silero VAD expects exactly 512 samples per chunk
chunk_size = 512 # SileroVAD.CHUNK_SIZE
chunk_duration = chunk_size / self.sample_rate
# Settings
vad_threshold = 0.5 # Silero speech probability threshold
silence_check_duration = 0.8 # Check Smart Turn after 800ms silence
max_silence_chunks = int(silence_check_duration / chunk_duration)
pre_speech_duration = 0.2 # Keep 200ms before speech trigger
pre_speech_chunks = int(pre_speech_duration / chunk_duration)
# Buffers
pre_buffer: deque = deque(maxlen=pre_speech_chunks)
audio_chunks = []
started_speaking = False
silence_chunks = 0
with sd.InputStream(
samplerate=self.sample_rate, channels=1, dtype=np.float32
) as stream:
start_time = time.time()
while time.time() - start_time < max_duration:
chunk, _ = stream.read(chunk_size)
chunk = chunk.flatten()
# Calculate RMS for visualization
rms = np.sqrt(np.mean(chunk**2))
self._update_level(rms)
# Run Silero VAD for speech detection
speech_prob = silero.predict(chunk)
is_speech = speech_prob > vad_threshold
if not started_speaking:
# Buffer audio until speech starts
pre_buffer.append(chunk)
if is_speech:
# Speech started - include pre-buffer
audio_chunks = list(pre_buffer)
audio_chunks.append(chunk)
started_speaking = True
silence_chunks = 0
else:
# Already speaking - accumulate audio
audio_chunks.append(chunk)
if is_speech:
silence_chunks = 0
else:
silence_chunks += 1
# After silence threshold, check Smart Turn
if silence_chunks >= max_silence_chunks:
# Pause stream during Smart Turn inference
stream.stop()
# Concatenate audio for Smart Turn analysis
current_audio = np.concatenate(audio_chunks)
result = smart_turn.predict_endpoint(current_audio)
if result["prediction"] == 1:
# Turn complete - return audio
self.current_level = 0.0
return current_audio
else:
# Turn incomplete - continue recording
# Reset silence counter and resume
silence_chunks = 0
stream.start()
# Max duration reached
self.current_level = 0.0
if audio_chunks:
return np.concatenate(audio_chunks)
return np.array([], dtype=np.float32)
def start_recording(self) -> None:
"""Start continuous recording (for PTT mode)."""
self.recording = True
self.audio_buffer = []
def callback(indata, frames, time_info, status):
if self.recording:
self.audio_buffer.append(indata.copy())
self._stream = sd.InputStream(
samplerate=self.sample_rate,
channels=1,
dtype=np.float32,
callback=callback,
)
self._stream.start()
def stop_recording(self) -> np.ndarray:
"""Stop recording and return audio.
Returns:
Recorded audio as float32 numpy array
"""
self.recording = False
if self._stream:
self._stream.stop()
self._stream.close()
self._stream = None
if self.audio_buffer:
return np.concatenate(self.audio_buffer).flatten()
return np.array([], dtype=np.float32)