"""Text-to-Speech (TTS) module.
Provides TTSEngine using Pocket TTS for CPU-based speech synthesis
with voice cloning support.
"""
from __future__ import annotations
import io
import contextlib
from pathlib import Path
from typing import Any, Dict, Optional
import numpy as np
import sounddevice as sd
# Get script directory for voice references (go from src/localvoicemode/speech/tts.py to project root)
SCRIPT_DIR = Path(__file__).parent.parent.parent.parent.absolute()
def _ensure_package(package: str, import_name: str = None) -> None:
"""Ensure a package is installed, install if missing."""
import sys
import subprocess
import_name = import_name or package
try:
with (
contextlib.redirect_stdout(io.StringIO()),
contextlib.redirect_stderr(io.StringIO()),
):
__import__(import_name)
except ImportError:
print(f"Installing {package}...")
subprocess.check_call([sys.executable, "-m", "pip", "install", package, "-q"])
class TTSEngine:
"""Text-to-speech engine using Pocket TTS.
Pocket TTS is a lightweight CPU-based TTS model with:
- ~100M parameters
- ~200ms latency
- Voice cloning from reference audio
Features:
- Lazy loading (model not loaded until first use)
- Voice caching (voices loaded once and reused)
- Built-in content filtering integration
"""
def __init__(self):
self._model = None
self._voice_states: Dict[str, Any] = {}
self._current_voice: Optional[str] = None
def _ensure_loaded(self) -> None:
"""Lazy load the TTS model."""
if self._model is not None:
return
# We don't silence this part so the user can see download progress on first run
_ensure_package("pocket-tts", "pocket_tts")
from pocket_tts import TTSModel
self._model = TTSModel.load_model()
def load_voice(self, voice_path: Optional[Path] = None, voice_name: str = "default") -> None:
"""Load a voice from file or use built-in.
Args:
voice_path: Path to voice reference WAV file
voice_name: Name to cache the voice under
"""
self._ensure_loaded()
# Special handling for "default" voice - always prefer local default.wav
if voice_name == "default":
default_voice = SCRIPT_DIR / "voice_references" / "default.wav"
if default_voice.exists():
# Force reload default.wav even if already cached
if "default" in self._voice_states:
del self._voice_states["default"]
try:
print(f"TTSEngine: Loading default voice from {default_voice}")
state = self._model.get_state_for_audio_prompt(str(default_voice))
self._voice_states[voice_name] = state
self._current_voice = voice_name
return
except Exception as e:
print(f"TTSEngine: Failed to load default voice: {e}")
pass
if voice_name in self._voice_states:
self._current_voice = voice_name
return
# Try to load voice from various sources
voice_sources = []
if voice_path and voice_path.exists():
voice_sources.append(str(voice_path))
# Built-in voices (female first)
voice_sources.extend(
[
"hf://kyutai/tts-voices/voice-donations/Selfie.wav", # Female
"hf://kyutai/tts-voices/voice-donations/Mona.wav", # Female
"hf://kyutai/tts-voices/alba-mackenna/casual.wav", # Male
]
)
for source in voice_sources:
try:
state = self._model.get_state_for_audio_prompt(source)
self._voice_states[voice_name] = state
self._current_voice = voice_name
return
except Exception:
continue
raise RuntimeError(f"Failed to load any voice for {voice_name}")
def synthesize(self, text: str) -> tuple[np.ndarray, int]:
"""Synthesize text to audio.
Args:
text: Text to synthesize
Returns:
Tuple of (audio array, sample rate)
"""
self._ensure_loaded()
if self._current_voice is None:
self.load_voice()
state = self._voice_states[self._current_voice]
audio = self._model.generate_audio(state, text)
return audio.numpy(), self._model.sample_rate
def speak(self, text: str, level_callback=None) -> None:
"""Synthesize and play text with optional level tracking.
Args:
text: Text to speak
level_callback: Optional callback for audio level visualization
"""
# Clean text for TTS using filter
from .filter import get_tts_filter
tts_filter = get_tts_filter()
cleaned = tts_filter.filter_for_speech(text, max_length=500)
if not cleaned or not tts_filter.should_speak(text):
return
try:
audio, sr = self.synthesize(cleaned)
if level_callback:
# Play with level tracking
self._play_with_levels(audio, sr, level_callback)
else:
sd.play(audio, sr)
sd.wait()
except Exception as e:
print(f"TTS error: {e}")
# Try with shorter text
try:
short = cleaned[:200] if len(cleaned) > 200 else "Sorry, I had trouble speaking."
audio, sr = self.synthesize(short)
sd.play(audio, sr)
sd.wait()
except Exception:
print("TTS failed completely, skipping audio")
def _play_with_levels(self, audio: np.ndarray, sr: int, level_callback) -> None:
"""Play audio while tracking levels for visualization."""
chunk_size = int(sr * 0.05) # 50ms chunks
position = 0
finished = False
def callback(outdata, frames, time_info, status):
nonlocal position, finished
end = position + frames
if end >= len(audio):
# Pad with zeros if we've reached the end
chunk = np.zeros(frames, dtype=np.float32)
remaining = len(audio) - position
if remaining > 0:
chunk[:remaining] = audio[position : len(audio)]
outdata[:, 0] = chunk
position = len(audio)
finished = True
raise sd.CallbackStop()
else:
outdata[:, 0] = audio[position:end]
# Calculate RMS for visualization
rms = np.sqrt(np.mean(audio[position:end] ** 2))
level_callback(min(1.0, rms / 0.3))
position = end
with sd.OutputStream(
samplerate=sr,
channels=1,
dtype=np.float32,
callback=callback,
blocksize=chunk_size,
):
# Wait for playback to complete
while not finished:
sd.sleep(50)