openai-tool2mcp

import asyncio import json import time from typing import Any import openai from ..models.openai import ToolOutput, ToolRequest, ToolResponse from ..utils.logging import logger class ThreadCreationError(RuntimeError): """Exception raised when thread creation fails""" def __init__(self): super().__init__("Thread creation failed") class AssistantCreationError(RuntimeError): """Exception raised when assistant creation fails""" def __init__(self): super().__init__("Assistant creation failed") class RunCreationError(RuntimeError): """Exception raised when run creation fails""" def __init__(self): super().__init__("Run creation failed") class RunTimeoutError(TimeoutError): """Exception raised when run times out""" def __init__(self): super().__init__("Run timed out") class NoChoicesError(ValueError): """Exception raised when no choices are found in response""" def __init__(self): super().__init__("No choices") class OpenAIClient: """Client for interacting with the OpenAI API""" def __init__(self, api_key: str | None, request_timeout: int = 30, max_retries: int = 3): """ Initialize the OpenAI client. Args: api_key (str): OpenAI API key request_timeout (int): Request timeout in seconds max_retries (int): Maximum number of retries """ self.client = openai.Client(api_key=api_key, timeout=request_timeout) self.max_retries = max_retries self.retry_delay = 1 # Assuming a default retry_delay async def invoke_tool(self, request: ToolRequest) -> ToolResponse: """ Invoke an OpenAI tool. Args: request (ToolRequest): Tool request Returns: ToolResponse: Tool response """ # Special handling for web search if request.tool_type == "web_search": return await self._handle_web_search(request) # Create or get thread thread_id = request.thread_id if not thread_id: thread = await self._create_thread() if thread is None: raise ThreadCreationError() thread_id = thread.id # Create message with tool call await self._create_message( thread_id=thread_id, content=f"Please use the {request.tool_type} tool with these parameters: {request.parameters}", ) # Create assistant with the appropriate tool assistant = await self._create_assistant( tools=[{"type": request.tool_type}], instructions=request.instructions or "Execute the requested tool function.", ) if assistant is None: raise AssistantCreationError() # Run the assistant run = await self._create_run(thread_id=thread_id, assistant_id=assistant.id) if run is None: raise RunCreationError() # Wait for completion run = await self._wait_for_run(thread_id, run.id) # Get tool outputs tool_outputs = [] if hasattr(run, "required_action") and hasattr(run.required_action, "submit_tool_outputs"): for tool_call in run.required_action.submit_tool_outputs.tool_calls: # Process each tool call tool_outputs.append(ToolOutput(output=tool_call.function.arguments, error=None)) # Create response response = ToolResponse(thread_id=thread_id, tool_outputs=tool_outputs) return response async def _create_thread(self) -> Any | None: """Create a new thread""" for attempt in range(self.max_retries): try: thread = await asyncio.to_thread(self.client.beta.threads.create) except Exception as e: logger.error(f"Error creating thread (attempt {attempt + 1}/{self.max_retries}): {e!s}") if attempt == self.max_retries - 1: raise await asyncio.sleep(self.retry_delay) else: return thread return None async def _create_message(self, thread_id: str, content: str) -> Any | None: """Create a new message in a thread""" for attempt in range(self.max_retries): try: message = await asyncio.to_thread( self.client.beta.threads.messages.create, thread_id=thread_id, role="user", content=content ) except Exception as e: logger.error(f"Error creating message (attempt {attempt + 1}/{self.max_retries}): {e!s}") if attempt == self.max_retries - 1: raise await asyncio.sleep(self.retry_delay) else: return message return None async def _create_assistant( self, tools: list[dict], instructions: str = "", model: str = "gpt-4o-mini" ) -> Any | None: """Create a new assistant with the specified tools""" for attempt in range(self.max_retries): try: # Prepare assistant parameters based on tool type assistant_params = self._prepare_assistant_params(tools, instructions, model) # Create the assistant with prepared parameters assistant = await asyncio.to_thread(self.client.beta.assistants.create, **assistant_params) logger.info(f"Assistant created: {assistant.id}") except Exception as e: logger.error(f"Error creating assistant (attempt {attempt + 1}/{self.max_retries}): {e!s}") if attempt == self.max_retries - 1: raise await asyncio.sleep(self.retry_delay) else: return assistant return None def _prepare_assistant_params(self, tools: list[dict], instructions: str, model: str) -> dict: """Prepare parameters for assistant creation based on tool types""" # Check for web search tool if tools and tools[0].get("type") == "web_search": return self._prepare_web_search_assistant_params(instructions) else: return self._prepare_standard_assistant_params(tools, instructions, model) def _prepare_web_search_assistant_params(self, instructions: str) -> dict: """Prepare parameters specifically for web search assistant""" model = "gpt-4o-mini-search-preview" logger.info(f"Creating assistant with {model} for web search") return {"name": "Tool Assistant", "model": model, "instructions": instructions} def _prepare_standard_assistant_params(self, tools: list[dict], instructions: str, model: str) -> dict: """Prepare parameters for standard assistant with various tools""" # Convert tools to format expected by OpenAI formatted_tools = self._format_tools(tools) assistant_params = {"name": "Tool Assistant", "model": model, "instructions": instructions} # Add tools only if present if formatted_tools: assistant_params["tools"] = formatted_tools logger.info(f"Creating assistant with model {model} and tools: {formatted_tools}") return assistant_params def _format_tools(self, tools: list[dict]) -> list[dict]: """Format tools list for OpenAI API compatibility""" formatted_tools = [] for tool in tools: tool_type = tool.get("type", "") if tool_type == "code_interpreter": formatted_tools.append({"type": "code_interpreter"}) elif tool_type == "file_search": formatted_tools.append({"type": "retrieval"}) elif tool_type == "web_browser": # This type is not directly supported, so we substitute with retrieval formatted_tools.append({"type": "retrieval"}) return formatted_tools async def _create_run(self, thread_id: str, assistant_id: str) -> Any | None: """Create a new run""" for attempt in range(self.max_retries): try: run = await asyncio.to_thread( self.client.beta.threads.runs.create, thread_id=thread_id, assistant_id=assistant_id ) except Exception as e: logger.error(f"Error creating run (attempt {attempt + 1}/{self.max_retries}): {e!s}") if attempt == self.max_retries - 1: raise await asyncio.sleep(self.retry_delay) else: return run return None async def _wait_for_run(self, thread_id: str, run_id: str) -> Any: """Wait for a run to complete""" max_wait_time = 60 # Maximum wait time in seconds start_time = time.time() while True: if time.time() - start_time > max_wait_time: raise RunTimeoutError() for attempt in range(self.max_retries): try: run = await asyncio.to_thread( self.client.beta.threads.runs.retrieve, thread_id=thread_id, run_id=run_id ) break except Exception as e: logger.error(f"Error retrieving run (attempt {attempt + 1}/{self.max_retries}): {e!s}") if attempt == self.max_retries - 1: raise await asyncio.sleep(self.retry_delay) if run.status in ["completed", "failed", "requires_action"]: return run await asyncio.sleep(1) async def _handle_web_search(self, request: ToolRequest) -> ToolResponse: """ Handle web search request using the chat completions API. Args: request (ToolRequest): Tool request Returns: ToolResponse: Tool response """ logger.info(f"[WEB SEARCH] Handling request: {request.parameters}") # Validate and extract query query = request.parameters.get("query", "") if not query: logger.warning("[WEB SEARCH] Empty query received") return self._create_error_response(request.thread_id, "No query provided") logger.info(f"[WEB SEARCH] Executing with query: '{query}'") # Try multiple attempts based on max_retries for attempt in range(self.max_retries): try: # Create and send request to OpenAI response = await self._execute_web_search_request(query, attempt) # Process the response output_json = await self._process_web_search_response(response) # Return successful response logger.info("[WEB SEARCH] Successfully completed request") return ToolResponse( thread_id=request.thread_id or f"web_search_{int(time.time())}", tool_outputs=[ToolOutput(output=output_json, error=None)], ) except Exception as e: if not await self._handle_web_search_error(e, attempt): # Return error response on final attempt error_msg = str(e) return self._create_error_response(request.thread_id, error_msg) # Unexpected situation - all attempts failed without exception logger.error("[WEB SEARCH] Unexpected end of method without result or error") return self._create_error_response(request.thread_id, "Unknown error in web search process") def _create_error_response(self, thread_id: str | None, error_msg: str) -> ToolResponse: """Create an error response for web search""" return ToolResponse( thread_id=thread_id or "error", tool_outputs=[ToolOutput(output=json.dumps({"error": error_msg}), error=error_msg)], ) async def _execute_web_search_request(self, query: str, attempt: int) -> Any: """Execute the web search request to OpenAI API""" logger.debug(f"[WEB SEARCH] Request attempt {attempt + 1}/{self.max_retries}") logger.debug("[WEB SEARCH] Creating request with model: 'gpt-4o-mini-search-preview'") # Prepare request data request_data = { "model": "gpt-4o-mini-search-preview", "messages": [{"role": "user", "content": query}], "web_search_options": {}, # Empty object to enable web search "store": True, # Store the response (optional) } logger.debug(f"[WEB SEARCH] Full request data: {json.dumps(request_data)}") # Execute API call return await asyncio.to_thread(self.client.chat.completions.create, **request_data) async def _handle_web_search_error(self, error: Exception, attempt: int) -> bool: """Handle errors in web search. Returns True if should retry, False if should stop.""" logger.error(f"[WEB SEARCH] Error: {error!s}") if attempt == self.max_retries - 1: logger.error(f"[WEB SEARCH] All attempts failed. Last error: {error}") return False # Delay before retrying logger.info(f"[WEB SEARCH] Retrying in {self.retry_delay}s...") await asyncio.sleep(self.retry_delay) return True async def _process_web_search_response(self, response: Any) -> str: """Process the web search response and return the serialized JSON output""" # Log response information for debugging logger.debug(f"[WEB SEARCH] Response type: {type(response)}") logger.debug(f"[WEB SEARCH] Response attributes: {dir(response)}") # Check if response has choices if not response.choices: logger.warning("[WEB SEARCH] No choices found in response") raise NoChoicesError() # Extract message content message = response.choices[0].message logger.debug(f"[WEB SEARCH] Message role: {message.role}") logger.debug(f"[WEB SEARCH] Message attributes: {dir(message)}") # Get content content = message.content or "" logger.info(f"[WEB SEARCH] Response content length: {len(content)}") logger.debug(f"[WEB SEARCH] Content preview: {content[:100]}...") # Check for annotations annotations = self._extract_annotations(message) # Prepare response data response_data = {"content": content, "annotations": annotations} logger.debug(f"[WEB SEARCH] Constructed response data with keys: {list(response_data.keys())}") # Serialize to JSON return self._serialize_response_data(response_data, content) def _extract_annotations(self, message: Any) -> list: """Extract annotations from the message if available""" annotations = [] has_annotations = hasattr(message, "annotations") logger.debug(f"[WEB SEARCH] Has annotations attribute: {has_annotations}") if has_annotations and message.annotations is not None: annotations = message.annotations logger.info(f"[WEB SEARCH] Found {len(annotations)} annotations") # Log annotation details for i, annotation in enumerate(annotations): logger.debug(f"[WEB SEARCH] Annotation {i + 1}: {type(annotation)}") if hasattr(annotation, "type"): logger.debug(f"[WEB SEARCH] Annotation type: {annotation.type}") return annotations def _serialize_response_data(self, response_data: dict, content: str) -> str: """Serialize the response data to JSON""" try: # Helper function for serializing complex objects def serialize_openai_objects(obj): if hasattr(obj, "model_dump_json"): # Use model_dump_json for Pydantic objects return json.loads(obj.model_dump_json()) elif hasattr(obj, "to_dict"): # Use to_dict method if available return obj.to_dict() elif hasattr(obj, "__dict__"): # Use __dict__ attribute if available return obj.__dict__ else: # Convert to string for other objects return str(obj) # Convert complex OpenAI objects to standard Python objects output_json = json.dumps(response_data, default=serialize_openai_objects) logger.debug(f"[WEB SEARCH] Successfully serialized to JSON (length: {len(output_json)})") except (TypeError, Exception) as e: logger.error(f"[WEB SEARCH] JSON serialization error: {e}") # Fallback response fallback_data = {"content": content, "error": f"Failed to serialize annotations: {e!s}"} output_json = json.dumps(fallback_data) logger.debug(f"[WEB SEARCH] Using fallback JSON: {output_json[:100]}...") return output_json