"""Automatic Speech Recognition (ASR) module.
Provides ASREngine using ONNX-ASR with Parakeet TDT 0.6B v3
for GPU-accelerated speech-to-text.
"""
from __future__ import annotations
import os
import sys
from pathlib import Path
from typing import Optional
import numpy as np
# Get script directory for DLL setup
SCRIPT_DIR = Path(__file__).parent.parent.parent.parent.parent.absolute()
class ASREngine:
"""Speech-to-text engine using onnx-asr with Parakeet TDT 0.6B v3.
Parakeet TDT 0.6B v3:
- 600M parameters, multilingual ASR (25 languages)
- GPU acceleration via CUDA/TensorRT
- Handles all preprocessing and decoding internally
Args:
device: Device to use for inference ("cuda" or "cpu")
sample_rate: Expected audio sample rate (default: 16000)
"""
MODEL_NAME = "nemo-parakeet-tdt-0.6b-v3"
def __init__(self, device: str = "cuda", sample_rate: int = 16000):
self._model = None
self._device = device
self._sample_rate = sample_rate
def _ensure_loaded(self) -> None:
"""Lazy load the ASR model."""
if self._model is not None:
return
# Setup NVIDIA DLL paths on Windows (must be done before importing onnx_asr)
if sys.platform == "win32" and hasattr(os, "add_dll_directory"):
venv_nvidia = SCRIPT_DIR / ".venv" / "Lib" / "site-packages" / "nvidia"
if venv_nvidia.exists():
for pkg_dir in venv_nvidia.iterdir():
if pkg_dir.is_dir():
bin_dir = pkg_dir / "bin"
if bin_dir.exists():
try:
os.add_dll_directory(str(bin_dir.absolute()))
except Exception:
pass
# Suppress ONNX Runtime warnings about Memcpy nodes (harmless performance warning)
os.environ.setdefault("ORT_LOGGING_LEVEL", "ERROR")
try:
import onnx_asr
except ImportError:
print("Installing onnx-asr...")
import subprocess
subprocess.check_call(
[sys.executable, "-m", "pip", "install", "onnx-asr[gpu,hub]"]
)
import onnx_asr
# Configure execution providers based on device setting
providers = self._get_providers()
print(f"Loading Parakeet TDT ({self.MODEL_NAME})...")
print(
f"Device: {self._device}, Providers: {[p[0] if isinstance(p, tuple) else p for p in providers]}"
)
try:
self._model = onnx_asr.load_model(self.MODEL_NAME, providers=providers)
provider_name = (
providers[0][0] if isinstance(providers[0], tuple) else providers[0]
)
print(f"Parakeet TDT loaded successfully (provider: {provider_name})")
except Exception as e:
print(f"Failed to load with GPU, falling back to CPU: {e}")
self._model = onnx_asr.load_model(self.MODEL_NAME)
print("Parakeet TDT loaded successfully (provider: CPU)")
def _get_providers(self) -> list:
"""Get ONNX Runtime execution providers based on device config."""
if self._device != "cuda":
return ["CPUExecutionProvider"]
# Try to detect available GPU providers
try:
import onnxruntime as ort
available = ort.get_available_providers()
except ImportError:
return ["CPUExecutionProvider"]
providers = []
# Use CUDA provider (more widely supported than TensorRT)
if "CUDAExecutionProvider" in available:
providers.append(
(
"CUDAExecutionProvider",
{
"device_id": 0,
"arena_extend_strategy": "kNextPowerOfTwo",
"cudnn_conv_algo_search": "EXHAUSTIVE",
},
)
)
# Always have CPU as final fallback
providers.append("CPUExecutionProvider")
if len(providers) == 1:
print("No GPU providers available in onnxruntime, using CPU")
return providers
def transcribe(self, audio: np.ndarray) -> str:
"""Transcribe audio to text.
Args:
audio: Audio samples as numpy array (float32, 16kHz mono)
Returns:
Transcribed text
"""
self._ensure_loaded()
if len(audio) < self._sample_rate * 0.1: # Less than 100ms
return ""
try:
# onnx-asr expects audio at 16kHz
# It handles all preprocessing internally
result = self._model.recognize(audio, sample_rate=self._sample_rate)
# Result can be a string or a dict with 'text' key
if isinstance(result, dict):
return result.get("text", "").strip()
return str(result).strip()
except Exception as e:
print(f"Transcription error: {e}")
import traceback
traceback.print_exc()
return ""