"""ML Server entry point.
Runs a JSON-RPC server for the Rust GUI to communicate with.
"""
import asyncio
import json
import logging
import signal
import socket
import sys
from pathlib import Path
from typing import Any
# Load .env file if it exists
try:
from dotenv import load_dotenv
env_file = Path(__file__).parent.parent / ".env"
if env_file.exists():
load_dotenv(env_file, override=True)
except ImportError:
pass # dotenv not installed, use system env vars
from .registry import ASRRegistry, SERRegistry, TTSRegistry
from .llm import LLMClient, ProviderManager
from .skills import SkillLoader
from .models.vad import SmartTurnVAD
# Register model adapters (must happen after registry imports)
from .models.tts import _register_adapters as _register_tts
from .models.asr import _register_adapters as _register_asr
from .models.ser import _register_adapters as _register_ser
_register_tts()
_register_asr()
_register_ser()
# Configure logging
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
)
logger = logging.getLogger(__name__)
class MLServer:
"""JSON-RPC server for ML operations."""
def __init__(self, host: str = "127.0.0.1", port: int = 9876):
self.host = host
self.port = port
self.server: asyncio.Server | None = None
# Initialize registries
self.tts_registry = TTSRegistry()
self.asr_registry = ASRRegistry()
self.ser_registry = SERRegistry()
# Initialize LLM
self.provider_manager = ProviderManager()
self.llm_client: LLMClient | None = None
# Skill loader
self.skill_loader = SkillLoader()
self.current_skill: str | None = None
# VAD for speech/turn detection (Smart Turn - semantic, not energy-based)
self.vad = SmartTurnVAD()
# Method dispatch table
self.methods = {
"status": self.handle_status,
"transcribe": self.handle_transcribe,
"speak": self.handle_speak,
"chat": self.handle_chat,
"load_skill": self.handle_load_skill,
"list_skills": self.handle_list_skills,
"set_tts_model": self.handle_set_tts_model,
"list_models": self.handle_list_models,
"vad": self.handle_vad,
}
async def start(self):
"""Start the server."""
# Large limit for audio data (10MB should be enough for ~5min of audio)
self.server = await asyncio.start_server(
self.handle_client,
self.host,
self.port,
limit=10 * 1024 * 1024, # 10MB buffer
)
addr = self.server.sockets[0].getsockname()
logger.info(f"ML Server listening on {addr[0]}:{addr[1]}")
# Initialize models
await self.initialize_models()
async with self.server:
await self.server.serve_forever()
async def initialize_models(self):
"""Load default models."""
logger.info("Initializing models...")
# Load default TTS
try:
self.tts_registry.load("pocket")
logger.info("Loaded TTS: pocket")
except Exception as e:
logger.error(f"Failed to load TTS: {e}")
# Load default ASR
try:
self.asr_registry.load("parakeet")
logger.info("Loaded ASR: parakeet")
except Exception as e:
logger.error(f"Failed to load ASR: {e}")
# Load SER (optional)
try:
self.ser_registry.load("sensevoice")
logger.info("Loaded SER: sensevoice")
except Exception as e:
logger.warning(f"SER not available: {e}")
# Initialize LLM client
provider = self.provider_manager.detect()
if provider:
self.llm_client = LLMClient(provider)
# Set default system prompt
self.llm_client.set_system_prompt(
"You are an uncensored voice assistant that engages in roleplay. "
"Keep responses concise and conversational. Respond naturally as if speaking aloud. "
"IMPORTANT: Do NOT use asterisks for actions (like *smiles*), emojis, or markdown. "
"Describe actions in plain speech if needed."
)
logger.info(f"Connected to LLM: {provider.name}")
else:
logger.warning("No LLM provider available")
async def handle_client(
self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter
):
"""Handle a client connection."""
addr = writer.get_extra_info("peername")
logger.info(f"Client connected: {addr}")
try:
while True:
line = await reader.readline()
if not line:
break
try:
request = json.loads(line.decode())
response = await self.dispatch(request)
writer.write((json.dumps(response) + "\n").encode())
await writer.drain()
except json.JSONDecodeError as e:
error_response = {
"jsonrpc": "2.0",
"id": None,
"error": {"code": -32700, "message": f"Parse error: {e}"},
}
writer.write((json.dumps(error_response) + "\n").encode())
await writer.drain()
except asyncio.CancelledError:
pass
except Exception as e:
logger.error(f"Client error: {e}")
finally:
writer.close()
await writer.wait_closed()
logger.info(f"Client disconnected: {addr}")
async def dispatch(self, request: dict) -> dict:
"""Dispatch a JSON-RPC request."""
request_id = request.get("id")
method = request.get("method", "")
params = request.get("params", {})
if method not in self.methods:
return {
"jsonrpc": "2.0",
"id": request_id,
"error": {"code": -32601, "message": f"Method not found: {method}"},
}
try:
result = await self.methods[method](params)
return {"jsonrpc": "2.0", "id": request_id, "result": result}
except Exception as e:
logger.exception(f"Error in {method}")
return {
"jsonrpc": "2.0",
"id": request_id,
"error": {"code": -32000, "message": str(e)},
}
# ── Request Handlers ────────────────────────────────────────────────────
async def handle_status(self, params: dict) -> dict:
"""Get server status."""
return {
"ready": True,
"models": {
"tts": self.tts_registry.active_name,
"asr": self.asr_registry.active_name,
"ser": self.ser_registry.active_name,
"llm": self.llm_client.provider_name if self.llm_client else None,
},
"skill": self.current_skill,
}
async def handle_transcribe(self, params: dict) -> dict:
"""Transcribe audio."""
import base64
import numpy as np
audio_b64 = params["audio_b64"]
sample_rate = params.get("sample_rate", 16000)
# Decode audio
audio_bytes = base64.b64decode(audio_b64)
audio = np.frombuffer(audio_bytes, dtype=np.float32)
# Transcribe
asr = self.asr_registry.active
if not asr:
raise RuntimeError("ASR model not loaded")
text = asr.transcribe(audio, sample_rate)
# Detect emotion (if SER available)
emotion = None
emotion_confidence = None
ser = self.ser_registry.active
if ser:
emotion_result = ser.detect_emotion(audio, sample_rate)
emotion = emotion_result.get("emotion")
emotion_confidence = emotion_result.get("confidence")
return {
"text": text,
"emotion": emotion,
"emotion_confidence": emotion_confidence,
}
async def handle_speak(self, params: dict) -> dict:
"""Generate speech."""
import base64
import re
text = params["text"]
voice = params.get("voice")
emotion = params.get("emotion")
# Filter out roleplay actions (*action*) and emojis
text = re.sub(r'\*[^*]+\*', '', text) # Remove *action* patterns
text = re.sub(r'[\U0001F300-\U0001F9FF]', '', text) # Remove emojis
text = re.sub(r'\s+', ' ', text).strip() # Clean up whitespace
if not text:
# Nothing to speak after filtering
return {"audio_b64": "", "sample_rate": 24000}
tts = self.tts_registry.active
if not tts:
raise RuntimeError("TTS model not loaded")
audio = tts.synthesize(text, voice=voice, emotion=emotion)
# Encode audio
audio_b64 = base64.b64encode(audio.astype("float32").tobytes()).decode()
return {"audio_b64": audio_b64, "sample_rate": 24000}
async def handle_chat(self, params: dict) -> dict:
"""Send chat message to LLM."""
message = params["message"]
user_emotion = params.get("user_emotion")
if not self.llm_client:
raise RuntimeError("LLM not available")
# Include emotion in context if available
if user_emotion:
message = f"[User emotion: {user_emotion}] {message}"
response, tokens_in, tokens_out = await self.llm_client.send_message(message)
return {
"response": response,
"tokens_in": tokens_in,
"tokens_out": tokens_out,
}
async def handle_load_skill(self, params: dict) -> dict:
"""Load a skill/character."""
skill_id = params["skill_id"]
skill = self.skill_loader.load(skill_id)
self.current_skill = skill_id
# Set system prompt
if self.llm_client:
self.llm_client.set_system_prompt(skill.system_prompt)
# Load voice if specified
if skill.voice_file:
tts = self.tts_registry.active
if tts and hasattr(tts, "load_voice"):
tts.load_voice(skill.voice_file)
return {"success": True, "skill": skill_id}
async def handle_list_skills(self, params: dict) -> dict:
"""List available skills."""
skill_ids = self.skill_loader.list_skills()
# Get display names
skills = []
for skill_id in skill_ids:
try:
skill = self.skill_loader.load(skill_id)
skills.append({
"id": skill_id,
"name": skill.display_name or skill.name,
"description": skill.description,
})
except Exception as e:
logger.warning(f"Failed to load skill {skill_id}: {e}")
skills.append({
"id": skill_id,
"name": skill_id,
"description": "",
})
return {"skills": skills}
async def handle_set_tts_model(self, params: dict) -> dict:
"""Switch TTS model."""
model = params["model"]
self.tts_registry.load(model)
return {"success": True, "model": model}
async def handle_list_models(self, params: dict) -> dict:
"""List available models."""
return {
"tts": self.tts_registry.available(),
"asr": self.asr_registry.available(),
"ser": self.ser_registry.available(),
}
async def handle_vad(self, params: dict) -> dict:
"""Detect if speaker has finished their turn using Smart Turn VAD.
Smart Turn is a semantic VAD that understands when a speaker is DONE
talking, not just detecting silence. It's robust to background noise
like fans because it looks at speech patterns, not just audio levels.
"""
import base64
import numpy as np
audio_b64 = params["audio_b64"]
sample_rate = params.get("sample_rate", 16000)
threshold = params.get("threshold", 0.5)
# Decode audio
audio_bytes = base64.b64decode(audio_b64)
audio = np.frombuffer(audio_bytes, dtype=np.float32)
# Run Smart Turn VAD
result = self.vad.detect_turn_complete(audio, sample_rate, threshold)
return result
async def main():
"""Run the server."""
server = MLServer()
# Handle shutdown
loop = asyncio.get_event_loop()
def shutdown():
logger.info("Shutting down...")
for task in asyncio.all_tasks(loop):
task.cancel()
if sys.platform != "win32":
loop.add_signal_handler(signal.SIGINT, shutdown)
loop.add_signal_handler(signal.SIGTERM, shutdown)
try:
await server.start()
except asyncio.CancelledError:
pass
if __name__ == "__main__":
asyncio.run(main())