orchestration-agent.ipynb•13.3 kB
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"<center>\n",
" <p style=\"text-align:center\">\n",
" <img alt=\"phoenix logo\" src=\"https://storage.googleapis.com/arize-phoenix-assets/assets/phoenix-logo-light.svg\" width=\"200\"/>\n",
" <br>\n",
" <a href=\"https://arize.com/docs/phoenix/\">Docs</a>\n",
" |\n",
" <a href=\"https://github.com/Arize-ai/phoenix\">GitHub</a>\n",
" |\n",
" <a href=\"https://arize-ai.slack.com/join/shared_invite/zt-2w57bhem8-hq24MB6u7yE_ZF_ilOYSBw#/shared-invite/email\">Community</a>\n",
" </p>\n",
"</center>\n",
"\n",
"# Google GenAI SDK - Building an Orchestrator Agent"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Install Dependencies"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"!pip install -q google-genai arize-phoenix-otel openinference-instrumentation-google-genai"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Connect to Arize Phoenix"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"from getpass import getpass\n",
"\n",
"from google import genai\n",
"from google.genai import types\n",
"\n",
"from phoenix.otel import register\n",
"\n",
"if \"PHOENIX_API_KEY\" not in os.environ:\n",
" os.environ[\"PHOENIX_API_KEY\"] = getpass(\"🔑 Enter your Phoenix API key: \")\n",
"\n",
"if \"PHOENIX_COLLECTOR_ENDPOINT\" not in os.environ:\n",
" os.environ[\"PHOENIX_COLLECTOR_ENDPOINT\"] = getpass(\"🔑 Enter your Phoenix Collector Endpoint\")\n",
"\n",
"tracer_provider = register(auto_instrument=True, project_name=\"google-genai-orchestrator-agent\")\n",
"tracer = tracer_provider.get_tracer(__name__)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Authenticate with Google Vertex AI"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"!gcloud auth login"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Create a client using the Vertex AI API, you could also use the Google GenAI API instead here\n",
"client = genai.Client(vertexai=True, project=\"<ADD YOUR GCP PROJECT ID>\", location=\"us-central1\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Orchestration Agent"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"First, define the sub agents, or in this case tools, that the orchestrator can choose between."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Define models for different specialized agents\n",
"FLASH_MODEL = \"gemini-2.0-flash-001\"\n",
"\n",
"\n",
"@tracer.chain()\n",
"def call_user_proxy_agent(query, context=\"\"):\n",
" \"\"\"User proxy agent that acts as the user and gives feedback.\"\"\"\n",
" prompt = f\"\"\"You are a user proxy assistant. Provide feedback as if you were the user on:\n",
" Context: {context}\n",
" Query: {query}\n",
" Give honest, constructive feedback from a user's perspective.\"\"\"\n",
"\n",
" response = client.models.generate_content(\n",
" model=FLASH_MODEL,\n",
" contents=prompt,\n",
" )\n",
" return response.text.strip()\n",
"\n",
"\n",
"@tracer.chain()\n",
"def call_flight_planning_agent(query, context=\"\"):\n",
" \"\"\"Flight planning agent that helps find and recommend flights.\"\"\"\n",
" prompt = f\"\"\"You are a flight planning assistant. Help plan flights for:\n",
" Context: {context}\n",
" Query: {query}\n",
" Provide detailed flight options with considerations for price, timing, and convenience.\"\"\"\n",
"\n",
" response = client.models.generate_content(\n",
" model=FLASH_MODEL,\n",
" contents=prompt,\n",
" )\n",
" return response.text.strip()\n",
"\n",
"\n",
"@tracer.chain()\n",
"def call_hotel_recommendation_agent(query, context=\"\"):\n",
" \"\"\"Hotel recommendation agent that suggests accommodations.\"\"\"\n",
" prompt = f\"\"\"You are a hotel recommendation assistant. Suggest accommodations for:\n",
" Context: {context}\n",
" Query: {query}\n",
" Provide suitable hotel options with details on amenities, location, and price ranges.\"\"\"\n",
"\n",
" response = client.models.generate_content(\n",
" model=FLASH_MODEL,\n",
" contents=prompt,\n",
" )\n",
" return response.text.strip()\n",
"\n",
"\n",
"@tracer.chain()\n",
"def call_travel_attraction_agent(query, context=\"\"):\n",
" \"\"\"Travel attraction recommendation agent that suggests places to visit.\"\"\"\n",
" prompt = f\"\"\"You are a travel attraction recommendation assistant. Suggest attractions for:\n",
" Context: {context}\n",
" Query: {query}\n",
" Provide interesting places to visit with descriptions, highlights, and practical information.\"\"\"\n",
"\n",
" response = client.models.generate_content(\n",
" model=FLASH_MODEL,\n",
" contents=prompt,\n",
" )\n",
" return response.text.strip()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"@tracer.chain()\n",
"def determine_next_step(user_query, context, cycle, max_cycles):\n",
" \"\"\"\n",
" Determines the next agent to call based on the current context and user query.\n",
" Args:\n",
" user_query: The initial user query\n",
" context: Current accumulated context\n",
" cycle: Current cycle number\n",
" max_cycles: Maximum number of agent calls\n",
" Returns:\n",
" The function name to call next\n",
" \"\"\"\n",
" orchestration_prompt = f\"\"\"You are an orchestration agent. Decide the next step to take:\n",
" User query: {user_query}\n",
" Current context: {context}\n",
" Current cycle: {cycle}/{max_cycles}\n",
" Choose one of the available tools to help address the user query, or decide to return a final answer.\n",
" \"\"\"\n",
"\n",
" # Define orchestrator tools\n",
" orchestrator_tools = {\n",
" \"function_declarations\": [\n",
" {\n",
" \"name\": \"call_planning_agent\",\n",
" \"description\": \"Call planning agent to create a structured plan with next steps\",\n",
" },\n",
" {\n",
" \"name\": \"call_flight_planning_agent\",\n",
" \"description\": \"Call flight planning agent to help find and recommend flights\",\n",
" },\n",
" {\n",
" \"name\": \"call_hotel_recommendation_agent\",\n",
" \"description\": \"Call hotel recommendation agent to suggest accommodations\",\n",
" },\n",
" {\n",
" \"name\": \"call_travel_attraction_agent\",\n",
" \"description\": \"Call travel attraction agent to suggest interesting places to visit with descriptions\",\n",
" },\n",
" {\n",
" \"name\": \"call_user_proxy_agent\",\n",
" \"description\": \"Call user proxy agent that acts as the user and gives feedback\",\n",
" },\n",
" {\n",
" \"name\": \"return_final_answer\",\n",
" \"description\": \"Return to user with final answer when sufficient information has been gathered\",\n",
" },\n",
" ]\n",
" }\n",
"\n",
" orchestration_response = client.models.generate_content(\n",
" model=FLASH_MODEL,\n",
" contents=orchestration_prompt,\n",
" config=types.GenerateContentConfig(tools=[orchestrator_tools]),\n",
" )\n",
"\n",
" if orchestration_response.candidates[0].content.parts[0].function_call:\n",
" function_call = orchestration_response.candidates[0].content.parts[0].function_call\n",
" return function_call.name\n",
" else:\n",
" return \"return_final_answer\" # Default to returning final answer if no tool called\n",
"\n",
"\n",
"@tracer.chain()\n",
"def execute_agent_call(function_name, user_query, context):\n",
" \"\"\"\n",
" Executes the specified agent call and returns the response and agent type.\n",
" Args:\n",
" function_name: The name of the function to call\n",
" user_query: The initial user query\n",
" context: Current accumulated context\n",
" Returns:\n",
" Tuple of (agent_response, agent_type)\n",
" \"\"\"\n",
" if function_name == \"call_flight_planning_agent\":\n",
" agent_response = call_flight_planning_agent(user_query, context)\n",
" agent_type = \"Flight Planning\"\n",
" elif function_name == \"call_hotel_recommendation_agent\":\n",
" agent_response = call_hotel_recommendation_agent(user_query, context)\n",
" agent_type = \"Hotel Recommendation\"\n",
" elif function_name == \"call_travel_attraction_agent\":\n",
" agent_response = call_travel_attraction_agent(user_query, context)\n",
" agent_type = \"Travel Attraction\"\n",
" elif function_name == \"call_user_proxy_agent\":\n",
" agent_response = call_user_proxy_agent(user_query, context)\n",
" agent_type = \"User Proxy\"\n",
" else:\n",
" agent_response = \"\"\n",
" agent_type = \"Unknown\"\n",
"\n",
" return agent_response, agent_type\n",
"\n",
"\n",
"@tracer.chain()\n",
"def generate_final_answer(user_query, context, max_cycles_reached=False):\n",
" \"\"\"\n",
" Generates a final answer based on the accumulated context.\n",
" Args:\n",
" user_query: The initial user query\n",
" context: Current accumulated context\n",
" max_cycles_reached: Whether the maximum cycles were reached\n",
" Returns:\n",
" Final response to the user\n",
" \"\"\"\n",
" final_prompt = f\"\"\"Create a final response to the user query: {user_query}\n",
" Based on this context: {context}\n",
" \"\"\"\n",
"\n",
" if max_cycles_reached:\n",
" final_prompt += \"\\n\\nProvide a comprehensive and helpful answer, noting that we've reached our maximum processing cycles.\"\n",
" else:\n",
" final_prompt += \"\\n\\nProvide a comprehensive and helpful answer.\"\n",
"\n",
" final_response = client.models.generate_content(\n",
" model=FLASH_MODEL,\n",
" contents=final_prompt,\n",
" )\n",
"\n",
" return final_response.text.strip()\n",
"\n",
"\n",
"@tracer.agent()\n",
"def orchestrator(user_query, max_cycles=3):\n",
" \"\"\"\n",
" Orchestrator that decides which agent to call at each step of the process.\n",
" Args:\n",
" user_query: The initial user query\n",
" max_cycles: Maximum number of agent calls before returning to user\n",
" Returns:\n",
" Final response to the user\n",
" \"\"\"\n",
" context = \"\"\n",
" cycle = 0\n",
"\n",
" while cycle < max_cycles:\n",
" # Determine next step\n",
" function_name = determine_next_step(user_query, context, cycle, max_cycles)\n",
"\n",
" if function_name == \"return_final_answer\":\n",
" return generate_final_answer(user_query, context)\n",
"\n",
" # Execute the agent call\n",
" agent_response, agent_type = execute_agent_call(function_name, user_query, context)\n",
"\n",
" # Update context with agent response\n",
" context += f\"\\n\\n{agent_type} Agent Output:\\n{agent_response}\"\n",
" cycle += 1\n",
"\n",
" # If max cycles reached, return what we have\n",
" return generate_final_answer(user_query, context, max_cycles_reached=True)\n",
"\n",
"\n",
"# Example usage\n",
"user_query = \"\"\"I want to plan a 5-day trip to Paris, France, sometime in October. I'm interested\n",
"in museums and good food. Find flight options from SFO, suggest mid-range hotels near the city center,\n",
"and recommend some relevant activities.\"\"\"\n",
"\n",
"response = orchestrator(user_query)\n",
"print(response)"
]
}
],
"metadata": {
"language_info": {
"name": "python"
}
},
"nbformat": 4,
"nbformat_minor": 2
}