import shutil
import tempfile
import logging
import concurrent.futures
from typing import Dict, Any, List
from youtube_mcp_server.config import Config
from youtube_mcp_server.services.download_service import DownloadService
from youtube_mcp_server.services.vad_service import VADService
from youtube_mcp_server.services.whisper_service import WhisperService
from youtube_mcp_server.services.cache_service import CacheService
logger = logging.getLogger("youtube-mcp-server")
class TranscriptionService:
def __init__(self):
self.download_service = DownloadService()
self.cache_service = CacheService()
self.vad_service = VADService()
self.whisper_service = WhisperService()
@staticmethod
def _format_time(milliseconds: int) -> str:
seconds = milliseconds // 1000
m, s = divmod(seconds, 60)
h, m = divmod(m, 60)
return f"{h:02d}:{m:02d}:{s:02d}"
def _process_task(self, args):
"""Worker function for threads."""
audio_chunk, start_ms, end_ms, language, task = args
try:
logger.info(f"Transcribing segment {start_ms}ms-{end_ms}ms (lang={language}, task={task})")
text = self.whisper_service.transcribe_segment(audio_chunk, language=language, task=task)
from_str = self._format_time(start_ms)
to_str = self._format_time(end_ms)
logger.info(f"Segment {from_str}-{to_str}: {text}")
return {
"from": from_str,
"to": to_str,
"transcription": text
}
except Exception as e:
logger.error(f"Error in segment task {start_ms}-{end_ms}: {e}")
return {
"from": self._format_time(start_ms),
"to": self._format_time(end_ms),
"transcription": "[Error]"
}
def transcribe(self, url: str, language: str = "auto", progress_callback=None) -> Dict[str, Any]:
"""
Orchestrates the full transcription pipeline.
1. Check Cache
2. Download
3. Load Memory
4. VAD
5. Transcribe Segments
6. Save Cache
"""
# Determine task
if language == "en":
whisper_task = "translate"
whisper_lang = "auto" # Whisper determines source, output is English
elif language == "auto":
whisper_task = "transcribe"
whisper_lang = None # Whisper detects language and transcribes in it
else:
whisper_task = "transcribe"
whisper_lang = language
# 1. Quick Cache Check
video_id_preview = self.download_service.get_video_id(url)
if video_id_preview:
cached = self.cache_service.load_transcription(video_id_preview, language)
if cached:
return cached
if progress_callback:
progress_callback(0, 100, "Downloading audio...")
# 2. Download
temp_dir = tempfile.mkdtemp()
try:
audio_path, info = self.download_service.download_audio(url, temp_dir)
video_id = info['id']
# Double check cache with definitive ID
cached = self.cache_service.load_transcription(video_id, language)
if cached:
return cached
# 3. Load Audio
if progress_callback:
progress_callback(10, 100, "Loading audio...")
logger.info("Loading audio into memory...")
audio_np = self.whisper_service.load_audio(audio_path)
# 4. VAD Segmentation
if progress_callback:
progress_callback(20, 100, "Analyzing speech (VAD)...")
logger.info("Running VAD...")
segments_raw = self.vad_service.get_segments(audio_np)
# 5. Transcribe Segments (Batch)
tasks = []
pad_samples = int(Config.SEGMENT_PADDING_MS * Config.SAMPLING_RATE / 1000)
for ts in segments_raw:
start_sample = int(ts['start'])
end_sample = int(ts['end'])
# Apply padding
start_sample_padded = max(0, start_sample - pad_samples)
end_sample_padded = min(len(audio_np), end_sample + pad_samples)
chunk_np = audio_np[start_sample_padded:end_sample_padded]
# Calculate display times
start_ms = int(start_sample_padded / Config.SAMPLING_RATE * 1000)
end_ms = int(end_sample_padded / Config.SAMPLING_RATE * 1000)
tasks.append((chunk_np, start_ms, end_ms, whisper_lang, whisper_task))
# Execute Parallel
logger.info(f"Starting batched transcription of {len(tasks)} segments...")
results = []
total_segments = len(tasks)
with concurrent.futures.ThreadPoolExecutor(max_workers=Config.MAX_WORKERS) as executor:
future_to_segment = {executor.submit(self._process_task, task): task for task in tasks}
for i, future in enumerate(concurrent.futures.as_completed(future_to_segment)):
results.append(future.result())
# Calculate progress from 30% to 90%
if progress_callback:
pct = 30 + int((i + 1) / total_segments * 60)
progress_callback(pct, 100, f"Transcribing segment {i+1}/{total_segments}")
if (i+1) % 5 == 0:
logger.info(f"Progress: {i+1}/{len(tasks)}")
results.sort(key=lambda x: x["from"])
# 6. Final Result & Cache
if progress_callback:
progress_callback(95, 100, "Saving results...")
final_output = {
"id": info.get("id"),
"title": info.get("title"),
"description": info.get("description"),
"url": info.get("webpage_url", url),
"uploader": info.get("uploader"),
"duration": info.get("duration"),
"transcription": results
}
self.cache_service.save_transcription(video_id, final_output, language)
logger.info("Pipeline finished successfully.")
return final_output
finally:
shutil.rmtree(temp_dir)
logger.info("Cleanup complete.")
_service = None
def get_transcription_service():
global _service
if _service is None:
_service = TranscriptionService()
return _service