Skip to main content
Glama
AIDC-AI

pixelle-mcp-Image-generation

by AIDC-AI
websocket_executor.py18.6 kB
# Copyright (C) 2025 AIDC-AI # This project is licensed under the MIT License (SPDX-License-identifier: MIT). import os import json import time import uuid import asyncio from typing import Optional, Dict, Any from urllib.parse import urlparse, urlunparse import websockets from pixelle.comfyui.base_executor import ComfyUIExecutor, COMFYUI_API_KEY, logger from pixelle.comfyui.models import ExecuteResult class WebSocketExecutor(ComfyUIExecutor): """WebSocket executor for ComfyUI""" def __init__(self, base_url: str = None): super().__init__(base_url) self._parse_ws_url() logger.info(f"HTTP Base URL: {self.http_base_url}") logger.info(f"WebSocket Base URL: {self.ws_base_url}") def _parse_ws_url(self): """Parse API URL and build WebSocket URL""" # Use standard URL parser to parse base_url parsed = urlparse(self.base_url) # Determine WebSocket scheme based on original scheme ws_scheme = 'wss' if parsed.scheme == 'https' else 'ws' http_scheme = 'https' if parsed.scheme == 'https' else 'http' # Original path base_path = parsed.path.rstrip('/') # Build WebSocket URL, add /ws path ws_netloc = parsed.netloc ws_path = f"{base_path}/ws" self.ws_base_url = urlunparse((ws_scheme, ws_netloc, ws_path, '', '', '')) # Build HTTP URL, keep original structure self.http_base_url = urlunparse((http_scheme, ws_netloc, base_path, '', '', '')) async def _queue_prompt(self, workflow: Dict[str, Any], client_id: str, prompt_ext_params: Optional[Dict[str, Any]] = None) -> str: """Submit workflow to queue""" prompt_data = { "prompt": workflow, "client_id": client_id } # Update all parameters of prompt_data and prompt_ext_params if prompt_ext_params: prompt_data.update(prompt_ext_params) json_data = json.dumps(prompt_data) # Use aiohttp to send request prompt_url = f"{self.base_url}/prompt" async with self.get_comfyui_session() as session: async with session.post( prompt_url, data=json_data, headers={"Content-Type": "application/json"} ) as response: if response.status != 200: response_text = await response.text() raise Exception(f"Submit workflow failed: [{response.status}] {response_text}") result = await response.json() prompt_id = result.get("prompt_id") if not prompt_id: raise Exception(f"Get prompt_id failed: {result}") logger.info(f"Task submitted: {prompt_id}") return prompt_id def _parse_ws_message(self, message: dict, prompt_id: str) -> tuple[bool, dict]: """ Parse websocket message Args: message: websocket message prompt_id: prompt_id to listen Returns: (invoke completed, message content) """ invoke_completed = False if message.get('type') == 'executing': data = message.get('data', {}) if data.get('node') is None and data.get('prompt_id') == prompt_id: invoke_completed = True return invoke_completed, message def _build_result_from_collected_outputs(self, collected_outputs: Dict[str, Any], prompt_id: str, output_id_2_var: Optional[Dict[str, str]] = None) -> ExecuteResult: """ Build execution result from collected WebSocket outputs """ try: logger.info(f"Build execution result from collected WebSocket outputs (prompt_id: {prompt_id})") # Collect all images, videos, audios and texts outputs by file extension output_id_2_images = {} output_id_2_videos = {} output_id_2_audios = {} output_id_2_texts = {} for node_id, output in collected_outputs.items(): images, videos, audios = self._split_media_by_suffix(output, self.http_base_url) if images: output_id_2_images[node_id] = images if videos: output_id_2_videos[node_id] = videos if audios: output_id_2_audios[node_id] = audios # Collect text outputs if "text" in output: texts = output["text"] if isinstance(texts, str): texts = [texts] elif not isinstance(texts, list): texts = [str(texts)] output_id_2_texts[node_id] = texts result = ExecuteResult( status="completed", prompt_id=prompt_id, outputs=collected_outputs ) # If there is a mapping, map by variable name if output_id_2_images: result.images_by_var = self._map_outputs_by_var(output_id_2_var or {}, output_id_2_images) result.images = self._extend_flat_list_from_dict(result.images_by_var) if output_id_2_videos: result.videos_by_var = self._map_outputs_by_var(output_id_2_var or {}, output_id_2_videos) result.videos = self._extend_flat_list_from_dict(result.videos_by_var) if output_id_2_audios: result.audios_by_var = self._map_outputs_by_var(output_id_2_var or {}, output_id_2_audios) result.audios = self._extend_flat_list_from_dict(result.audios_by_var) # Process texts/texts_by_var if output_id_2_texts: result.texts_by_var = self._map_outputs_by_var(output_id_2_var or {}, output_id_2_texts) result.texts = self._extend_flat_list_from_dict(result.texts_by_var) if not (result.images or result.videos or result.audios or result.texts): logger.warning("No outputs found") result.status = "error" result.msg = "No outputs found" logger.info(f"Build result successfully, output count: images={len(result.images)}, videos={len(result.videos)}, audios={len(result.audios)}, texts={len(result.texts)}") return result except Exception as e: logger.error(f"Build execution result from collected WebSocket outputs failed: {str(e)}") return ExecuteResult( status="error", prompt_id=prompt_id, msg=f"Build execution result from collected WebSocket outputs failed: {str(e)}" ) async def execute_workflow(self, workflow_file: str, params: Dict[str, Any] = None) -> ExecuteResult: """Execute workflow (WebSocket way)""" try: start_time = time.time() if not os.path.exists(workflow_file): logger.error(f"Workflow file does not exist: {workflow_file}") return ExecuteResult(status="error", msg=f"Workflow file does not exist: {workflow_file}") # Get workflow metadata metadata = self.get_workflow_metadata(workflow_file) if not metadata: return ExecuteResult(status="error", msg="Cannot parse workflow metadata") # Load workflow JSON with open(workflow_file, 'r', encoding='utf-8') as f: workflow_data = json.load(f) if not workflow_data: return ExecuteResult(status="error", msg="Workflow data is missing") # Use new parameter mapping logic if params: workflow_data = await self._apply_params_to_workflow(workflow_data, metadata, params) else: # Even if no parameters are passed, default values need to be applied workflow_data = await self._apply_params_to_workflow(workflow_data, metadata, {}) # Replace any seed == 0 with a random 63-bit seed before submission workflow_data, _ = self._randomize_seed_in_workflow(workflow_data) # Extract output node information from metadata output_id_2_var = self._extract_output_nodes(metadata) # Generate client ID client_id = str(uuid.uuid4()) # Prepare extra parameters prompt_ext_params = {} if COMFYUI_API_KEY: prompt_ext_params = { "extra_data": { "api_key_comfy_org": COMFYUI_API_KEY } } else: logger.warning("COMFYUI_API_KEY is not set") # First establish WebSocket connection, then submit task timeout = 30 * 60 # Default 30 minutes timeout # Build WebSocket URL ws_url = f"{self.ws_base_url}?clientId={client_id}" logger.info(f"WebSocket URL: {ws_url}") logger.info(f"Prepare to connect websocket service") # For collecting nodes with outputs collected_outputs = {} prompt_id = None try: # Prepare extra headers for WebSocket connection, include cookies additional_headers = {} cookies = await self._parse_comfyui_cookies() if cookies: try: if isinstance(cookies, dict): cookie_string = "; ".join([f"{k}={v}" for k, v in cookies.items()]) else: cookie_string = str(cookies) additional_headers["Cookie"] = cookie_string logger.debug(f"WebSocket connection will use cookies: {cookie_string[:50]}...") except Exception as e: logger.warning(f"Parse WebSocket cookies failed: {e}") # Establish WebSocket connection async with websockets.connect(ws_url, additional_headers=additional_headers) as websocket: logger.info('WebSocket connection established, now submit workflow') # After connection established, immediately submit workflow try: prompt_id = await self._queue_prompt(workflow_data, client_id, prompt_ext_params) except Exception as e: error_message = f"Submit workflow failed: [{type(e)}] {str(e)}" logger.error(error_message) return ExecuteResult(status="error", msg=error_message) logger.info(f"Workflow submitted, prompt_id: {prompt_id}, now wait for result") while True: # Check timeout elapsed = time.time() - start_time if elapsed > timeout: logger.warning(f"WebSocket timeout ({timeout} seconds)") result = ExecuteResult( status="timeout", prompt_id=prompt_id, msg=f"WebSocket timeout ({timeout} seconds)", duration=elapsed ) return result try: # Wait for message, set shorter timeout to check total timeout message_str = await asyncio.wait_for(websocket.recv(), timeout=3.0) if not isinstance(message_str, str): continue message = json.loads(message_str) # Print full message for target prompt_id for debugging if message.get('data', {}).get('prompt_id') == prompt_id: logger.debug(f'Received target WebSocket message (prompt_id: {prompt_id}): {json.dumps(message, ensure_ascii=False)}') # Process different types of messages msg_type = message.get('type') data = message.get('data', {}) if msg_type == 'execution_cached': # Process cached execution message cached_nodes = data.get('nodes', []) logger.debug(f"Detected cached execution, skip nodes: {cached_nodes}") elif msg_type == 'executed': # Collect nodes with outputs node_id = data.get('node') output = data.get('output') if output and node_id: # Check if there are outputs we are interested in has_media = output.get('images') \ or output.get('gifs') \ or output.get('audio') \ or output.get('text') if has_media: logger.info(f"Collected outputs from node {node_id}") collected_outputs[node_id] = output elif msg_type == 'execution_error': # Process execution error error_message = data.get('exception_message', 'Unknown error') logger.error(f"Execution error: {error_message}") return ExecuteResult( status="error", prompt_id=prompt_id, msg=error_message, duration=time.time() - start_time ) else: # For status message, record queue status if message.get('type') == 'status': queue_remaining = message.get('data', {}).get('status', {}).get('exec_info', {}).get('queue_remaining', 'unknown') logger.debug(f'Queue status updated: remaining tasks {queue_remaining} ') else: logger.debug(f'Received other WebSocket message: {message}') # Parse message invoke_completed, parsed_message = self._parse_ws_message(message, prompt_id) if invoke_completed: logger.info('WebSocket detected execution completed') # Set execution duration duration = time.time() - start_time # If there are collected outputs, use them to build result if collected_outputs: result = self._build_result_from_collected_outputs(collected_outputs, prompt_id, output_id_2_var) result.duration = duration # Transfer result files result = await self.transfer_result_files(result) return result else: # WebSocket way did not collect any outputs, return error logger.warning("WebSocket did not collect any outputs") result = ExecuteResult( status="error", prompt_id=prompt_id, msg="WebSocket did not collect any outputs", duration=duration ) return result except asyncio.TimeoutError: # Wait for message timeout, continue loop to check total timeout continue except websockets.exceptions.ConnectionClosed: logger.warning("WebSocket connection closed") result = ExecuteResult( status="error", prompt_id=prompt_id, msg="WebSocket connection closed", duration=time.time() - start_time ) return result except Exception as e: logger.error(f"WebSocket connection or execution exception: {str(e)}") result = ExecuteResult( status="error", prompt_id=prompt_id, msg=f"WebSocket connection or execution exception: {str(e)}", duration=time.time() - start_time ) return result except Exception as e: logger.error(f"Execute workflow failed: {str(e)}", exc_info=True) return ExecuteResult(status="error", msg=str(e))

Latest Blog Posts

MCP directory API

We provide all the information about MCP servers via our MCP API.

curl -X GET 'https://glama.ai/api/mcp/v1/servers/AIDC-AI/Pixelle-MCP'

If you have feedback or need assistance with the MCP directory API, please join our Discord server