"""
即梦AI图像生成服务
提供图像生成的核心功能
"""
import json
import traceback
import uuid
import time
import os
import logging
import requests
from typing import Dict, Any, Optional, List
from pathlib import Path
from header_util import HeaderGenerator
# 配置日志
logger = logging.getLogger("jimeng-service")
class JimengImageService:
"""即梦AI图像生成服务类"""
# API端点
API_BASE_URL = "https://jimeng.jianying.com/mweb/v1/aigc_draft/generate"
API_HISTORY_URL = "https://jimeng.jianying.com/mweb/v1/get_history_by_ids"
# 默认配置
DEFAULT_CONFIG = {
"model": "high_aes_general_v40",
"sample_strength": 0.5,
"negative_prompt": "",
"resolution_type": "2k"
}
# 轮询配置
POLL_INTERVAL = 3 # 轮询间隔(���)
MAX_POLL_ATTEMPTS = 60 # 最大轮询次数(总计3分钟)
def __init__(self, token: str):
"""
初始化服务
Args:
token: 即梦平台的认证token (sessionid)
"""
self.token = token
self.header_generator = HeaderGenerator()
def _generate_submit_id(self) -> str:
"""生成唯一的提交ID"""
return str(uuid.uuid4())
def _generate_draft_id(self) -> str:
"""生成唯一的草稿ID"""
return str(uuid.uuid4())
def _generate_component_id(self) -> str:
"""生成组件ID"""
# 生成类似 912136cc-3bc3-1e19-8e70-372a7b639b10 的UUID
return str(uuid.uuid4())
def _generate_metadata_id(self) -> str:
"""生成元数据ID"""
return str(uuid.uuid4())
def _generate_seed(self) -> int:
"""生成随机种子"""
import random
return random.randint(1000000000, 9999999999)
def _calculate_dimensions(self, width: int, height: int) -> Dict[str, Any]:
"""
根据宽高计算图像比例和分辨率信息
Args:
width: 图像宽度
height: 图像高度
Returns:
包含比例和分辨率信息的字典
"""
# 计算宽高比
ratio = width / height
# 确定image_ratio (根据curl.sh中的示例,2代表竖图)
if ratio < 1:
image_ratio = 2 # 竖图
elif ratio > 1:
image_ratio = 1 # 横图
else:
image_ratio = 0 # 正方形
return {
"image_ratio": image_ratio,
"large_image_info": {
"type": "",
"id": str(uuid.uuid4()),
"height": height,
"width": width,
"resolution_type": self.DEFAULT_CONFIG["resolution_type"]
}
}
def _build_draft_content(
self,
prompt: str,
width: int,
height: int,
model: str = None,
negative_prompt: str = None,
sample_strength: float = None
) -> str:
"""
构建draft_content JSON字符串
Args:
prompt: 图像生成提示词
width: 图像宽度
height: 图像高度
model: 使用的模型
negative_prompt: 负面提示词
sample_strength: 采样强度
Returns:
JSON格式的draft_content字符串
"""
import time
# 使用默认值
model = model or self.DEFAULT_CONFIG["model"]
negative_prompt = negative_prompt or self.DEFAULT_CONFIG["negative_prompt"]
sample_strength = sample_strength if sample_strength is not None else self.DEFAULT_CONFIG["sample_strength"]
# 生成各种ID
draft_id = self._generate_draft_id()
component_id = self._generate_component_id()
metadata_id = self._generate_metadata_id()
seed = self._generate_seed()
# 计算尺寸信息
dimension_info = self._calculate_dimensions(width, height)
# 获取当前时间戳(毫秒)
created_time = int(time.time() * 1000)
# 构建draft_content结构
draft_content = {
"type": "draft",
"id": draft_id,
"min_version": "3.0.2",
"min_features": [],
"is_from_tsn": True,
"version": "3.3.3",
"main_component_id": component_id,
"component_list": [
{
"type": "image_base_component",
"id": component_id,
"min_version": "3.0.2",
"aigc_mode": "workbench",
"metadata": {
"type": "",
"id": metadata_id,
"created_platform": 3,
"created_platform_version": "",
"created_time_in_ms": str(created_time),
"created_did": ""
},
"generate_type": "generate",
"abilities": {
"type": "",
"id": str(uuid.uuid4()),
"generate": {
"type": "",
"id": str(uuid.uuid4()),
"core_param": {
"type": "",
"id": str(uuid.uuid4()),
"model": model,
"prompt": prompt,
"negative_prompt": negative_prompt,
"seed": seed,
"sample_strength": sample_strength,
"image_ratio": dimension_info["image_ratio"],
"large_image_info": dimension_info["large_image_info"],
"intelligent_ratio": False
}
},
"gen_option": {
"type": "",
"id": str(uuid.uuid4()),
"generate_all": False
}
}
}
]
}
# 转换为JSON字符串
return json.dumps(draft_content, ensure_ascii=False, separators=(',', ':'))
def _build_request_payload(
self,
prompt: str,
width: int,
height: int,
submit_id: str,
model: str = None,
negative_prompt: str = None,
sample_strength: float = None
) -> Dict[str, Any]:
"""
构建完整的请求负载
Args:
prompt: 图像生成提示词
width: 图像宽度
height: 图像高度
submit_id: 提交ID
model: 使用的模型
negative_prompt: 负面提示词
sample_strength: 采样强度
Returns:
请求负载字典
"""
model = model or self.DEFAULT_CONFIG["model"]
draft_content = self._build_draft_content(
prompt, width, height, model, negative_prompt, sample_strength
)
# 构建metrics_extra
metrics_extra = {
"promptSource": "custom",
"generateCount": 1,
"enterFrom": "reprompt",
"templateId": "0",
"generateId": submit_id,
"isRegenerate": False
}
payload = {
"extend": {
"root_model": model
},
"submit_id": submit_id,
"metrics_extra": json.dumps(metrics_extra, ensure_ascii=False),
"draft_content": draft_content,
"http_common_info": {
"aid": int(self.header_generator.assistant_id)
}
}
return payload
def _poll_generation_result(self, submit_id: str) -> Dict[str, Any]:
"""
轮询图像生成结果
Args:
submit_id: 提交ID
Returns:
生成结果字典
"""
logger.info(f"开始轮询生成结果,submit_id: {submit_id}")
logger.info(f"最大轮询次数: {self.MAX_POLL_ATTEMPTS}, 间隔: {self.POLL_INTERVAL}秒")
uri = "/mweb/v1/get_history_by_ids"
params = self.header_generator.generate_params({
"aigc_features": "app_lip_sync",
"da_version": "3.3.3",
"web_version": "7.5.0"
})
headers = self.header_generator.generate_headers(
token=self.token,
uri=uri,
additional_headers={
"Content-Type": "application/json",
"Lan": "zh-Hans",
"Loc": "cn"
}
)
payload = {
"submit_ids": [submit_id]
}
for attempt in range(self.MAX_POLL_ATTEMPTS):
logger.debug(f"轮询尝试 {attempt + 1}/{self.MAX_POLL_ATTEMPTS}")
try:
response = requests.post(
self.API_HISTORY_URL,
params=params,
headers=headers,
json=payload,
timeout=30
)
response.raise_for_status()
# requests 会自动处理 gzip/br/deflate 解压
result = response.json()
# 检查返回状态
if result.get("ret") != "0":
if attempt == self.MAX_POLL_ATTEMPTS - 1:
return {
"success": False,
"error": f"API返回错误: {result.get('errmsg', 'Unknown error')}"
}
time.sleep(self.POLL_INTERVAL)
continue
# 检查是否有数据返回
# 数据结构: {"data": {submit_id: {task_data}}}
data_dict = result.get("data")
if not data_dict or not isinstance(data_dict, dict):
time.sleep(self.POLL_INTERVAL)
continue
# 获取当前 submit_id 的数据
item = data_dict.get(submit_id)
if not item:
time.sleep(self.POLL_INTERVAL)
continue
status = item.get("status")
logger.debug(f"当前状态: {status}")
# 状态: 50=成功, 其他状态需要继续等待或失败
if status == 50:
logger.info(f"✅ 图像生成完成!耗时约 {(attempt + 1) * self.POLL_INTERVAL} 秒")
return {
"success": True,
"data": item
}
elif status in [40, 41]: # 失败状态
fail_msg = item.get('fail_msg', 'Unknown error')
logger.error(f"图像生成失败: {fail_msg}")
return {
"success": False,
"error": f"图像生成失败: {fail_msg}"
}
# 如果还在生成中,等待后继续轮询
if (attempt + 1) % 5 == 0: # 每5次轮询记录一次
logger.info(f"仍在生成中... 已等待 {(attempt + 1) * self.POLL_INTERVAL} 秒")
time.sleep(self.POLL_INTERVAL)
except Exception as e:
logger.warning(f"轮询出错 (尝试 {attempt + 1}): {str(e)}")
if attempt == self.MAX_POLL_ATTEMPTS - 1:
logger.error("轮询失败,已达最大尝试次数")
return {
"success": False,
"error": f"轮询失败: {str(e)}"
}
time.sleep(self.POLL_INTERVAL)
logger.error(f"轮询超时,已等待 {self.MAX_POLL_ATTEMPTS * self.POLL_INTERVAL} 秒")
return {
"success": False,
"error": "轮询超时,图像生成耗时过长"
}
def _download_images(self, image_urls: List[str], save_folder: str) -> Dict[str, Any]:
"""
下载图片到指定文件夹
Args:
image_urls: 图片URL列表
save_folder: 保存文件夹路径
Returns:
下载结果字典
"""
logger.info(f"开始下载图片,共 {len(image_urls)} 张")
logger.info(f"保存路径: {save_folder}")
# 创建文件夹
folder_path = Path(save_folder)
folder_path.mkdir(parents=True, exist_ok=True)
downloaded_files = []
failed_downloads = []
for idx, url in enumerate(image_urls, start=1):
try:
logger.debug(f"下载第 {idx} 张图片...")
# 下载图片
response = requests.get(url, timeout=60)
response.raise_for_status()
# 确定文件扩展名
content_type = response.headers.get('content-type', '')
if 'jpeg' in content_type or 'jpg' in content_type:
ext = 'jpg'
elif 'png' in content_type:
ext = 'png'
elif 'webp' in content_type:
ext = 'webp'
else:
ext = 'jpg' # 默认使用jpg
# 保存文件
filename = f"{idx:02d}.{ext}"
file_path = folder_path / filename
with open(file_path, 'wb') as f:
f.write(response.content)
file_size = len(response.content) / 1024 # KB
logger.info(f"✓ 第 {idx} 张下载完成: {filename} ({file_size:.1f} KB)")
downloaded_files.append(str(file_path))
except Exception as e:
logger.error(f"✗ 第 {idx} 张下载失败: {str(e)}")
failed_downloads.append({
"index": idx,
"url": url,
"error": str(e)
})
logger.info(f"下载完成: 成功 {len(downloaded_files)}/{len(image_urls)} 张")
if failed_downloads:
logger.warning(f"失败 {len(failed_downloads)} 张")
return {
"success": len(failed_downloads) == 0,
"downloaded": downloaded_files,
"failed": failed_downloads,
"total": len(image_urls),
"succeeded": len(downloaded_files)
}
def generate_image(
self,
prompt: str,
save_folder: str,
width: int = 1728,
height: int = 2304,
model: Optional[str] = None,
negative_prompt: Optional[str] = None,
sample_strength: Optional[float] = None
) -> Dict[str, Any]:
"""
生成图像并保存到指定文件夹
Args:
prompt: 图像生成提示词
save_folder: 图片保存文件夹路径
width: 图像宽度,默认1728
height: 图像高度,默认2304
model: 使用的模型,默认使用high_aes_general_v40
negative_prompt: 负面提示词
sample_strength: 采样强度,默认0.5
Returns:
包含生成和下载结果的字典
Raises:
Exception: 当API请求失败时
"""
# 生成提交ID
submit_id = self._generate_submit_id()
# 构建请求负载
payload = self._build_request_payload(
prompt=prompt,
width=width,
height=height,
submit_id=submit_id,
model=model,
negative_prompt=negative_prompt,
sample_strength=sample_strength
)
# 构建请求URL和参数
uri = "/mweb/v1/aigc_draft/generate"
params = self.header_generator.generate_params({
"aigc_features": "app_lip_sync",
"da_version": "3.3.3",
"web_component_open_flag": "1",
"web_version": "7.5.0"
})
# 生成请求头
headers = self.header_generator.generate_headers(
token=self.token,
uri=uri,
additional_headers={
"Content-Type": "application/json",
"Lan": "zh-Hans",
"Loc": "cn"
}
)
try:
# 1. 发送生成请求
logger.info("步骤1: 提交图像生成请求...")
response = requests.post(
self.API_BASE_URL,
params=params,
headers=headers,
json=payload,
timeout=60
)
# 检查响应状态
response.raise_for_status()
# 解析响应
logger.info(f"请求已提交,submit_id: {submit_id}")
# 2. 轮询获取生成结果
logger.info("步骤2: 轮询获取生成结果...")
poll_result = self._poll_generation_result(submit_id)
if not poll_result["success"]:
logger.error(f"轮询失败: {poll_result.get('error')}")
return {
"success": False,
"submit_id": submit_id,
"error": poll_result.get("error", "获取生成结果失败")
}
# 3. 提取图片URL
logger.info("步骤3: 提取图片URL...")
generation_data = poll_result["data"]
image_urls = []
# 从返回数据中提取图片URL
# 数据结构: generation_data["item_list"][i]["image"]["large_images"][j]["url"]
item_list = generation_data.get("item_list", [])
logger.debug(f"找到 {len(item_list)} 个item")
for item in item_list:
if not isinstance(item, dict):
continue
image_data = item.get("image", {})
if not isinstance(image_data, dict):
continue
large_images = image_data.get("large_images", [])
if not isinstance(large_images, list):
continue
# 通常取第一张大图(最高质量)
if len(large_images) > 0:
img = large_images[0]
if isinstance(img, dict) and "image_url" in img:
image_urls.append(img["image_url"])
logger.debug(f"提取URL: {img['image_url'][:50]}...")
logger.info(f"共提取到 {len(image_urls)} 个图片URL")
if not image_urls:
logger.error("未找到图片URL")
return {
"success": False,
"submit_id": submit_id,
"error": "未找到生成的图片URL",
"debug_data": generation_data # 用于调试
}
# 4. 下载图片
logger.info("步骤4: 下载图片...")
download_result = self._download_images(image_urls, save_folder)
return {
"success": download_result["success"],
"submit_id": submit_id,
"image_urls": image_urls,
"saved_files": download_result["downloaded"],
"failed_downloads": download_result["failed"],
"total_images": download_result["total"],
"save_folder": save_folder
}
except requests.exceptions.RequestException as e:
logger.info(traceback.print_exc())
return {
"success": False,
"error": str(e),
"submit_id": submit_id
}
except Exception as e:
return {
"success": False,
"error": f"Unexpected error: {str(e)}",
"submit_id": submit_id
}
def create_service(token: str) -> JimengImageService:
"""
创建服务实例的工厂函数
Args:
token: 即梦平台的认证token
Returns:
JimengImageService实例
"""
return JimengImageService(token)