import argparse
import json
import time
import random
from pathlib import Path
import sys
sys.path.append(str(Path(__file__).parent / "src"))
from main.server import split_text, concat_wav_files, load_config
def readFile(path: Path) -> str:
return path.read_text(encoding="utf-8")
def createTtsModel(language: str, device: str):
from melo.api import TTS
return TTS(language=language, device=device)
def resolveSpeakerId(model, language: str, speakerTag: str | None):
try:
spk2id = model.hps.data.spk2id
except AttributeError:
spk2id = getattr(getattr(model, "hps", None), "data", None)
if spk2id:
spk2id = getattr(spk2id, "spk2id", {})
else:
spk2id = {}
targetKey = speakerTag if speakerTag else language
speakerId = spk2id.get(targetKey)
if speakerId is None and not speakerTag:
if language in spk2id:
speakerId = spk2id[language]
elif spk2id:
firstKey = list(spk2id.keys())[0]
speakerId = spk2id[firstKey]
if speakerId is None:
available = list(spk2id.keys())
raise ValueError(f"无效的说话人: {targetKey}. 可用: {available}")
return speakerId
def httpGenerateChunk(apiBaseUrl: str, language: str, text: str, speed: float, speakerTag: str | None, fnIndex: int, sessionHash: str | None, saveTo: Path):
import requests
base = apiBaseUrl.rstrip("/")
sess = sessionHash or "".join(random.choice("abcdefghijklmnopqrstuvwxyz0123456789") for _ in range(12))
payload = {
"data": [language, text, speed, speakerTag or language],
"event_data": None,
"fn_index": fnIndex,
"trigger_id": random.randint(1, 1000000),
"session_hash": sess,
}
r = requests.post(f"{base}/queue/join?", json=payload, timeout=30)
if r.status_code != 200:
raise RuntimeError(f"join 失败: {r.status_code} {r.text}")
dataUrl = f"{base}/queue/data?session_hash={sess}"
with requests.get(dataUrl, headers={"Accept": "text/event-stream"}, stream=True, timeout=300) as s:
s.raise_for_status()
audioUrl = None
for line in s.iter_lines(decode_unicode=True):
if not line:
continue
if "process_completed" in line:
try:
parts = line.split("\t", 1)
body = parts[1] if len(parts) > 1 else line
obj = json.loads(body)
except Exception:
if line.startswith("data: "):
obj = json.loads(line[6:])
else:
obj = {}
outputs = obj.get("output", {}).get("data", [])
if outputs and isinstance(outputs[0], dict):
audioUrl = outputs[0].get("url")
break
if not audioUrl:
raise RuntimeError("未获取到生成的音频URL")
ar = requests.get(audioUrl, timeout=120)
ar.raise_for_status()
saveTo.write_bytes(ar.content)
def generateChunksLocal(texts: list[str], language: str, device: str, speed: float, speakerTag: str | None, outDir: Path) -> list[Path]:
model = createTtsModel(language, device)
speakerId = resolveSpeakerId(model, language, speakerTag)
paths: list[Path] = []
width = max(3, len(str(len(texts))))
for idx, seg in enumerate(texts, start=1):
p = outDir / f"chunk_{str(idx).zfill(width)}.wav"
model.tts_to_file(seg, speakerId, str(p), speed=speed)
if not p.exists() or p.stat().st_size == 0:
raise RuntimeError(f"分段音频生成失败: {p}")
paths.append(p)
return paths
def generateChunksHttp(texts: list[str], apiBaseUrl: str, language: str, speed: float, speakerTag: str | None, fnIndex: int, sessionHash: str | None, outDir: Path) -> list[Path]:
paths: list[Path] = []
width = max(3, len(str(len(texts))))
for idx, seg in enumerate(texts, start=1):
p = outDir / f"chunk_{str(idx).zfill(width)}.wav"
httpGenerateChunk(apiBaseUrl, language, seg, speed, speakerTag, fnIndex, sessionHash, p)
if not p.exists() or p.stat().st_size == 0:
raise RuntimeError(f"分段音频生成失败: {p}")
paths.append(p)
return paths
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--input", default="test-text.txt")
parser.add_argument("--output-dir", default="output")
parser.add_argument("--target-filename")
parser.add_argument("--language", default=None)
parser.add_argument("--speaker")
parser.add_argument("--speed", type=float, default=None)
parser.add_argument("--device", default=None)
parser.add_argument("--use-http-api", action="store_true")
parser.add_argument("--api-base-url", default="http://localhost:9900")
parser.add_argument("--fn-index", type=int, default=1)
parser.add_argument("--session-hash")
args = parser.parse_args()
cfg = load_config()
language = args.language or cfg.get("default_language", "ZH")
speed = args.speed if args.speed is not None else cfg.get("default_speed", 1.0)
device = args.device or cfg.get("default_device", "cpu")
limit = int(cfg.get("chunk_size_limit", 100))
inputPath = Path(args.input)
outDir = Path(args.output_dir)
outDir.mkdir(parents=True, exist_ok=True)
ts = int(time.time())
rnd = random.randint(10000, 99999)
finalName = args.target_filename if args.target_filename else f"{ts}_{rnd}.wav"
if not finalName.lower().endswith(".wav"):
finalName += ".wav"
finalWav = outDir / finalName
text = readFile(inputPath).strip()
segments = split_text(text, limit)
if args.use_http_api:
chunkPaths = generateChunksHttp(segments, args.api_base_url, language, speed, args.speaker, args.fn_index, args.session_hash, outDir)
else:
chunkPaths = generateChunksLocal(segments, language, device, speed, args.speaker, outDir)
if len(chunkPaths) == 1:
chunkPaths[0].rename(finalWav)
else:
concat_wav_files(chunkPaths, finalWav)
for p in chunkPaths:
try:
p.unlink(missing_ok=True)
except Exception:
pass
print(str(finalWav))
if __name__ == "__main__":
main()