# cogview_server.py
import asyncio
import os
import base64
import json
from mcp.server import Server
from mcp.server.stdio import stdio_server
from mcp.types import Tool, TextContent
from zai import ZhipuAiClient
from datetime import datetime
# 从环境变量中获取 API Key,这是更安全的做法
API_KEY = os.getenv("ZHIPU_API_KEY")
if not API_KEY:
raise ValueError("环境变量 ZHIPU_API_KEY 未设置。请确保配置了智谱AI的API密钥。")
# 1. 创建一个 MCP Server 实例
server = Server("cogview4-server")
# 2. 定义工具列表
@server.list_tools()
async def handle_list_tools() -> list[Tool]:
"""
这个函数向 AI Agent 声明我们服务器提供了哪些工具。
Agent 会根据这里的描述来决定调用哪个工具。
"""
return [
Tool(
name="generate_cogview4_image",
description="使用智谱AI的 CogView4 模型根据文本提示词生成一张图片。生成成功后会返回图片的本地保存路径。",
inputSchema={
"type": "object",
"properties": {
"prompt": {
"type": "string",
"description": "用于生成图片的详细中文描述。例如:'一只戴着宇航员头盔的猫漂浮在太空中,背景是璀璨的星河'",
},
# 可以根据需要添加更多参数,例如图片尺寸、风格等
"size": {
"type": "string",
"description": "生成图片的尺寸,可选 '1024x1024', '768x768', '576x1024'。默认为 '1024x1024'。",
"default": "1024x1024",
"enum": ["1024x1024", "768x768", "576x1024"]
}
},
"required": ["prompt"],
},
)
]
# 3. 实现工具的具体逻辑
@server.call_tool()
async def handle_call_tool(name: str, arguments: dict) -> list[TextContent]:
"""
这个函数处理 Agent 对工具的实际调用。
"""
if name != "generate_cogview4_image":
raise ValueError(f"未知的工具: {name}")
prompt = arguments.get("prompt")
size = arguments.get("size", "1024x1024")
if not prompt:
return [TextContent(type="text", text="错误:缺少必需的 'prompt' 参数。")]
# --- 使用zai-sdk调用智谱AI CogView4 API ---
client = ZhipuAiClient(api_key=API_KEY)
print(f"正在调用CogView4 API,提示词: {prompt}")
try:
response = client.images.generations(
model="cogView-4-250304",
prompt=prompt
)
if not response.data or len(response.data) == 0:
return [TextContent(type="text", text="API响应中未找到图片数据。")]
image_url = response.data[0].url
# 下载图片
import requests
img_response = requests.get(image_url)
img_response.raise_for_status()
# 创建一个有意义的文件名
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
filename = f"cogview4_image_{timestamp}.png"
# 将图片保存在当前工作目录下的 `generated_images` 文件夹中
os.makedirs("generated_images", exist_ok=True)
filepath = os.path.join("generated_images", filename)
with open(filepath, "wb") as f:
f.write(img_response.content)
print(f"图片已成功保存至: {filepath}")
# 返回成功消息和文件路径
return [TextContent(type="text", text=f"图片已生成并保存至: {filepath}")]
except Exception as e:
return [TextContent(type="text", text=f"发生错误: {e}")]
# 4. 启动服务器
async def main():
# 使用 stdio 作为通信协议
async with stdio_server() as (read_stream, write_stream):
await server.run(
read_stream,
write_stream,
server.create_initialization_options()
)
if __name__ == "__main__":
asyncio.run(main())