Skip to main content
Glama
AIDC-AI

pixelle-mcp-Image-generation

by AIDC-AI
chat_handler.py20.2 kB
# Copyright (C) 2025 AIDC-AI # This project is licensed under the MIT License (SPDX-License-identifier: MIT). from datetime import timedelta import json import os import time import chainlit as cl from typing import Any, Dict, List from mcp import ClientSession import re from pixelle.web.utils.llm_util import ModelInfo, ModelType from litellm import acompletion from pixelle.web.chat.starters import build_save_action from pixelle.web.utils.time_util import format_duration from pixelle.logger import logger from pixelle.settings import settings save_starter_enabled = settings.chainlit_save_starter_enabled def format_llm_error_message(model_name: str, error_str: str) -> str: """Unified LLM error message formatting function""" # Handle common error types and provide friendly English error messages if "RateLimitError" in error_str or "429" in error_str: if "quota" in error_str.lower() or "exceed" in error_str.lower(): return f"⚠️ {model_name} API quota exceeded. Please check your plan and billing details." else: return f"⚠️ {model_name} API rate limit hit. Please try again later." elif "401" in error_str or "authentication" in error_str.lower(): return f"🔑 {model_name} API key is invalid. Please check your configuration." elif "403" in error_str or "permission" in error_str.lower(): return f"🚫 {model_name} API access denied. Please check permissions." elif "timeout" in error_str.lower(): return f"⏰ {model_name} API call timed out. Please retry." else: return f"❌ {model_name} model call failed: {error_str}" def get_all_tools() -> List[Dict[str, Any]]: """Get all available MCP tools""" mcp_tools = cl.user_session.get("mcp_tools", {}) all_tools = [] for connection_tools in mcp_tools.values(): all_tools.extend(connection_tools) return all_tools def find_tool_connection(tool_name: str) -> str: """Find the MCP connection that owns the tool""" mcp_tools = cl.user_session.get("mcp_tools", {}) for connection_name, tools in mcp_tools.items(): if any(tool["function"]["name"] == tool_name for tool in tools): return connection_name return None def _extract_content(content_list): """Extract text content from CallToolResult's content list""" if not content_list: return "" text_parts = [] for content_item in content_list: # Handle different types of content if hasattr(content_item, 'text'): # TextContent text_parts.append(content_item.text) elif hasattr(content_item, 'data'): # ImageContent or other binary content text_parts.append(f"[Binary content: {getattr(content_item, 'mimeType', 'unknown')}]") elif hasattr(content_item, 'uri'): # EmbeddedResource text_parts.append(f"[Resource: {content_item.uri}]") else: # Other unknown types, try to convert to string text_parts.append(str(content_item)) return text_parts[0] if len(text_parts) == 1 else text_parts @cl.step(type="tool") async def execute_tool(tool_name: str, tool_input: Dict[str, Any]) -> str: """Execute MCP tool call""" # Record start time start_time = time.time() def _format_result_with_duration(content: str) -> str: """Unified formatting of results with duration""" duration = time.time() - start_time return f"[Took {format_duration(duration)}] {content}" current_step = cl.context.current_step current_step.input = tool_input current_step.name = tool_name await current_step.update() def record_step(): # Get existing message history current_steps = cl.user_session.get("current_steps", []) current_steps.append(current_step) # Update message history cl.user_session.set("current_steps", current_steps) # Find the connection that owns the tool mcp_name = find_tool_connection(tool_name) if not mcp_name: error_msg = json.dumps({"error": f"Tool {tool_name} not found in any MCP connection"}) result_with_duration = _format_result_with_duration(error_msg) current_step.output = result_with_duration record_step() return result_with_duration # Get MCP session mcp_session, _ = cl.context.session.mcp_sessions.get(mcp_name) if not mcp_session: error_msg = json.dumps({"error": f"MCP {mcp_name} not found in any MCP connection"}) result_with_duration = _format_result_with_duration(error_msg) current_step.output = result_with_duration record_step() return result_with_duration try: # Call MCP tool, returns CallToolResult object logger.info(f"Calling MCP tool: {tool_name} with input: {tool_input}") result = await mcp_session.call_tool(tool_name, tool_input, read_timeout_seconds=timedelta(hours=1)) # Check if there's an error if result.isError: error_content = _extract_content(result.content) logger.error(f"Tool execution failed: {error_content}") error_msg = json.dumps({"error": f"Tool execution failed: {error_content}"}) result_with_duration = _format_result_with_duration(error_msg) current_step.output = result_with_duration return result_with_duration # Extract content text content = _extract_content(result.content) logger.info(f"Tool execution succeeded: {content}") result_with_duration = _format_result_with_duration(str(content)) current_step.output = result_with_duration record_step() return result_with_duration except Exception as e: logger.error(f"Tool execution failed: {e}") error_msg = json.dumps({"error": str(e)}) result_with_duration = _format_result_with_duration(error_msg) current_step.output = result_with_duration record_step() return result_with_duration def _extract_and_clean_media_markers(text: str) -> tuple[Dict[str, List[str]], str]: """Extract media markers and clean text Returns: tuple: (media files dict, cleaned text) media files dict format: {"images": [...], "audios": [...], "videos": [...]} """ # Match different types of media markers patterns = { "images": r'\[SHOW_IMAGE:([^\]]+)\]', "audios": r'\[SHOW_AUDIO:([^\]]+)\]', "videos": r'\[SHOW_VIDEO:([^\]]+)\]' } media_files = {"images": [], "audios": [], "videos": []} cleaned_text = text # Extract various types of media files and clean text for media_type, pattern in patterns.items(): media_files[media_type] = re.findall(pattern, cleaned_text) cleaned_text = re.sub(pattern, '', cleaned_text) # Remove extra whitespace cleaned_text = cleaned_text.rstrip() return media_files, cleaned_text def _is_url(source: str) -> bool: """Check if media source is a URL""" return source.startswith(('http://', 'https://')) async def _process_media_markers(msg: cl.Message): """Process media markers in messages""" if not msg.content: return # Extract media markers and clean text media_files, cleaned_content = _extract_and_clean_media_markers(msg.content) # If content becomes empty after removing media markers but we have media files, # set a default message to ensure the message with media elements is displayed has_media = any(media_files.values()) if not cleaned_content.strip() and has_media: cleaned_content = "📎" # Simple media indicator # Update message content (remove markers) msg.content = cleaned_content # Process images for i, img_source in enumerate(media_files["images"]): img_source = img_source.strip() img_params = { "name": f"Generated_Image_{i+1}", "display": "inline", "size": "small", } if _is_url(img_source): img_params["url"] = img_source else: img_params["path"] = img_source img_element = cl.Image(**img_params) msg.elements.append(img_element) logger.info(f"Added image element: {img_source}") # Process audio for i, audio_source in enumerate(media_files["audios"]): audio_source = audio_source.strip() audio_params = { "name": f"Generated_Audio_{i+1}", "display": "inline", "size": "small", } if _is_url(audio_source): audio_params["url"] = audio_source else: audio_params["path"] = audio_source audio_element = cl.Audio(**audio_params) msg.elements.append(audio_element) logger.info(f"Added audio element: {audio_source}") # Process video for i, video_source in enumerate(media_files["videos"]): video_source = video_source.strip() video_params = { "name": f"Generated_Video_{i+1}", "display": "inline", "size": "small", } if _is_url(video_source): video_params["url"] = video_source else: video_params["path"] = video_source video_element = cl.Video(**video_params) msg.elements.append(video_element) logger.info(f"Added video element: {video_source}") async def _process_tool_call_delta( tool_calls_delta: List[Any], current_tool_calls: Dict[int, Dict], current_args: Dict[int, str] ): """Process incremental data for tool calls""" for tool_call_delta in tool_calls_delta: index = tool_call_delta.index # Initialize tool call data structure if index not in current_tool_calls: current_tool_calls[index] = { "id": "", "type": "function", "function": {"name": "", "arguments": ""} } current_args[index] = "" # Accumulate tool call information if tool_call_delta.id: current_tool_calls[index]["id"] = tool_call_delta.id if tool_call_delta.function and tool_call_delta.function.name: current_tool_calls[index]["function"]["name"] = tool_call_delta.function.name if tool_call_delta.function and tool_call_delta.function.arguments: current_args[index] += tool_call_delta.function.arguments current_tool_calls[index]["function"]["arguments"] = current_args[index] async def _execute_tool_calls( current_tool_calls: Dict[int, Dict], messages: List[Dict[str, Any]] ) -> List[Dict[str, Any]]: """Execute all tool calls and update message history""" # Build assistant message with tool calls tool_calls_list = list(current_tool_calls.values()) messages.append({ "role": "assistant", "content": None, "tool_calls": tool_calls_list }) # Execute all tool calls for tool_call in tool_calls_list: tool_name = tool_call["function"]["name"] tool_args_str = tool_call["function"]["arguments"] try: # Parse tool arguments tool_args = json.loads(tool_args_str) # Execute tool call tool_response = await execute_tool(tool_name, tool_args) # Add tool response to message history messages.append({ "role": "tool", "tool_call_id": tool_call["id"], "content": tool_response }) except Exception as e: error_message = f"Tool call error: {str(e)}" logger.error(error_message) # Add error response to message history messages.append({ "role": "tool", "tool_call_id": tool_call["id"], "content": error_message }) return messages async def _handle_stream_chunk(chunk, msg, current_tool_calls, current_args): """Process a single chunk of streaming response""" choice = chunk.choices[0] delta = choice.delta has_tool_call = False # Handle tool calls if delta.tool_calls: has_tool_call = True await _process_tool_call_delta( delta.tool_calls, current_tool_calls, current_args ) # Handle regular text response elif delta.content: await msg.stream_token(delta.content) return has_tool_call, choice.finish_reason async def _handle_response(model_info, api_params, enhanced_messages, messages): """Handle streaming response""" # Create independent message object for this round of response msg = cl.Message(content="") current_tool_calls = {} current_args = {} has_tool_call = False try: # Prepare LiteLLM parameters - directly pass all necessary parameters litellm_params = { "model": f"{model_info.provider}/{model_info.model}", "stream": True, "num_retries": 0, "timeout": 30, "api_key": model_info.api_key, "base_url": model_info.base_url or None, **api_params, } logger.info(f"Call LLM: {model_info.provider}/{model_info.model}") response = await acompletion(**litellm_params) try: async for chunk in response: chunk_has_tool_call, finish_reason = await _handle_stream_chunk( chunk, msg, current_tool_calls, current_args ) if chunk_has_tool_call: has_tool_call = True # Check completion status if finish_reason == 'tool_calls': try: # First send the current round's message (if there's content) if msg.content and msg.content.strip(): await msg.send() # Execute tool calls enhanced_messages = await _execute_tool_calls( current_tool_calls, enhanced_messages ) return enhanced_messages, True # Continue to next round except Exception as e: error_message = f"Error when processing tool calls: {str(e)}" logger.error(error_message) await msg.stream_token(f"\n{error_message}\n") await msg.send() return messages, False # End processing elif finish_reason: # Other completion reasons, end streaming processing break # Process media markers and send message if not has_tool_call: await _process_media_markers(msg) if msg.content and msg.content.strip(): if save_starter_enabled: # Add save action to AI reply message msg.actions = [ build_save_action() ] await msg.send() # Add assistant message to history if msg.content and msg.content.strip(): enhanced_messages.append({ "role": "assistant", "content": msg.content }) else: # If there are tool calls, send message directly (tool call related messages are handled elsewhere) await msg.send() return enhanced_messages, False # End processing except Exception as e: error_str = str(e) error_message = format_llm_error_message(model_info.name, error_str) logger.error(f"Stream processing error: {error_str}") await msg.stream_token(f"\n{error_message}\n") # Process media markers even if there's an error await _process_media_markers(msg) await msg.send() return messages, False except Exception as e: error_str = str(e) error_message = format_llm_error_message(model_info.name, error_str) logger.error(f"LiteLLM call failed: {error_str}") await msg.stream_token(f"\n{error_message}\n") # Process media markers even if there's an error await _process_media_markers(msg) await msg.send() return messages, False async def process_streaming_response( messages: List[Dict[str, Any]], model_info: ModelInfo, ) -> List[Dict[str, Any]]: """ Process streaming response and tool calls Args: messages: Message history model_info: Model information to use Returns: Updated message history """ # Clean up steps, handle scenarios where users re-edit sent messages chat_messages = cl.chat_context.get() if len(chat_messages) >= 2: # Need at least 2 messages to process # Get timestamp of second-to-last message second_last_msg = chat_messages[-2] second_last_time = second_last_msg.created_at if second_last_time: # Get current steps current_steps = cl.user_session.get("current_steps", []) # Only keep steps with timestamps before the second-to-last message filtered_steps = [ step for step in current_steps if str(step.created_at) <= second_last_time ] # Update steps cl.user_session.set("current_steps", filtered_steps) tools = get_all_tools() # Inject media display system instructions enhanced_messages = messages.copy() while True: # Loop to handle tool calls # Prepare API parameters api_params = { "messages": enhanced_messages, } # If there are tools, add tool parameters if tools: api_params["tools"] = tools api_params["tool_choice"] = "auto" # All parameters are passed through LiteLLM function parameters, not using environment variables try: enhanced_messages, should_continue = await _handle_response( model_info, api_params, enhanced_messages, messages ) if not should_continue: return enhanced_messages except Exception as e: error_str = str(e) error_message = format_llm_error_message(model_info.name, error_str) logger.error(f"LiteLLM main loop error: {error_str}") # Send error message to user error_msg = cl.Message(content=error_message) await error_msg.send() return enhanced_messages # Return directly, don't continue loop # MCP connection management convenience functions async def handle_mcp_connect(connection, session: ClientSession, tools_converter_func): """Handle common logic for MCP connections""" tools_result = await session.list_tools() openai_tools = tools_converter_func(tools_result.tools) mcp_tools = cl.user_session.get("mcp_tools", {}) mcp_tools[connection.name] = openai_tools cl.user_session.set("mcp_tools", mcp_tools) async def handle_mcp_disconnect(name: str): """Handle common logic for MCP disconnections""" mcp_tools = cl.user_session.get("mcp_tools", {}) if name in mcp_tools: del mcp_tools[name] cl.user_session.set("mcp_tools", mcp_tools)

Latest Blog Posts

MCP directory API

We provide all the information about MCP servers via our MCP API.

curl -X GET 'https://glama.ai/api/mcp/v1/servers/AIDC-AI/Pixelle-MCP'

If you have feedback or need assistance with the MCP directory API, please join our Discord server