//! ML server client implementation.
use super::protocol::*;
use anyhow::Result;
use std::io::{BufRead, BufReader, Write};
use std::net::TcpStream;
use std::sync::atomic::{AtomicU64, Ordering};
/// Client for communicating with the Python ML server.
pub struct MlClient {
stream: TcpStream,
reader: BufReader<TcpStream>,
request_id: AtomicU64,
}
impl MlClient {
/// Connect to the ML server.
pub fn connect(address: &str) -> Result<Self> {
let stream = TcpStream::connect(address)?;
stream.set_nodelay(true)?;
let reader = BufReader::new(stream.try_clone()?);
Ok(Self {
stream,
reader,
request_id: AtomicU64::new(1),
})
}
/// Send a request and wait for response.
fn call(&mut self, method: &str, params: serde_json::Value) -> Result<serde_json::Value> {
let id = self.request_id.fetch_add(1, Ordering::SeqCst);
let request = Request::new(id, method, params);
// Send request
let mut request_str = serde_json::to_string(&request)?;
request_str.push('\n');
self.stream.write_all(request_str.as_bytes())?;
self.stream.flush()?;
// Read response
let mut response_str = String::new();
self.reader.read_line(&mut response_str)?;
let response: Response = serde_json::from_str(&response_str)?;
if let Some(error) = response.error {
return Err(anyhow::anyhow!("RPC error {}: {}", error.code, error.message));
}
response
.result
.ok_or_else(|| anyhow::anyhow!("No result in response"))
}
/// Transcribe audio.
pub fn transcribe(&mut self, audio: &[f32], sample_rate: u32) -> Result<TranscribeResult> {
// Convert to bytes and base64 encode
let bytes: Vec<u8> = audio
.iter()
.flat_map(|&s| s.to_le_bytes())
.collect();
let audio_b64 = base64::Engine::encode(&base64::engine::general_purpose::STANDARD, &bytes);
let params = TranscribeParams {
audio_b64,
sample_rate,
};
let result = self.call("transcribe", serde_json::to_value(params)?)?;
Ok(serde_json::from_value(result)?)
}
/// Generate speech.
pub fn speak(&mut self, text: &str, voice: Option<&str>, emotion: Option<&str>) -> Result<Vec<f32>> {
let params = SpeakParams {
text: text.to_string(),
voice: voice.map(String::from),
emotion: emotion.map(String::from),
};
let result = self.call("speak", serde_json::to_value(params)?)?;
// Decode audio
let audio_b64: String = serde_json::from_value(result["audio_b64"].clone())?;
let bytes = base64::Engine::decode(&base64::engine::general_purpose::STANDARD, &audio_b64)?;
// Convert bytes to f32
let audio: Vec<f32> = bytes
.chunks_exact(4)
.map(|chunk| f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]))
.collect();
Ok(audio)
}
/// Send chat message.
pub fn chat(&mut self, message: &str, user_emotion: Option<&str>) -> Result<ChatResult> {
let params = ChatParams {
message: message.to_string(),
user_emotion: user_emotion.map(String::from),
};
let result = self.call("chat", serde_json::to_value(params)?)?;
Ok(serde_json::from_value(result)?)
}
/// Get server status.
pub fn status(&mut self) -> Result<ServerStatus> {
let result = self.call("status", serde_json::json!({}))?;
Ok(serde_json::from_value(result)?)
}
/// Load a skill.
pub fn load_skill(&mut self, skill_id: &str) -> Result<()> {
self.call("load_skill", serde_json::json!({"skill_id": skill_id}))?;
Ok(())
}
/// List available skills.
pub fn list_skills(&mut self) -> Result<Vec<SkillInfo>> {
let result = self.call("list_skills", serde_json::json!({}))?;
let skills: SkillListResult = serde_json::from_value(result)?;
Ok(skills.skills)
}
/// Set TTS model.
pub fn set_tts_model(&mut self, model: &str) -> Result<()> {
self.call("set_tts_model", serde_json::json!({"model": model}))?;
Ok(())
}
/// Check if audio contains speech using VAD.
pub fn vad(&mut self, audio: &[f32], sample_rate: u32, threshold: f32) -> Result<VadResult> {
// Convert to bytes and base64 encode
let bytes: Vec<u8> = audio
.iter()
.flat_map(|&s| s.to_le_bytes())
.collect();
let audio_b64 = base64::Engine::encode(&base64::engine::general_purpose::STANDARD, &bytes);
let result = self.call("vad", serde_json::json!({
"audio_b64": audio_b64,
"sample_rate": sample_rate,
"threshold": threshold
}))?;
Ok(serde_json::from_value(result)?)
}
}