test_async_appbuilder_client_toolcall.py•4.59 kB
# Copyright (c) 2024 Baidu, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import appbuilder
import asyncio
import os
from appbuilder.core.console.appbuilder_client.async_event_handler import (
AsyncAppBuilderEventHandler,
)
class MyEventHandler(AsyncAppBuilderEventHandler):
def get_current_weather(self, location=None, unit="摄氏度"):
return "{} 的温度是 {} {}".format(location, 20, unit)
async def interrupt(self, run_context, run_response):
thought = run_context.current_thought
# 绿色打印
print("\033[1;32m", "-> Agent 中间思考: ", thought, "\033[0m")
tool_output = []
for tool_call in run_context.current_tool_calls:
tool_call_id = tool_call.id
tool_res = self.get_current_weather(**tool_call.function.arguments)
# 蓝色打印
print("\033[1;34m", "-> 本地ToolCallId: ", tool_call_id, "\033[0m")
print("\033[1;34m", "-> ToolCall结果: ", tool_res, "\033[0m\n")
tool_output.append(
{"tool_call_id": tool_call_id, "output": tool_res})
return tool_output
async def success(self, run_context, run_response):
print("\n\033[1;31m", "-> Agent 非流式回答: ",
run_response.answer, "\033[0m")
@unittest.skipUnless(os.getenv("TEST_CASE", "UNKNOWN") == "CPU_SERIAL", "")
class TestAgentRuntime(unittest.TestCase):
def setUp(self):
"""
设置环境变量。
Args:
无参数,默认值为空。
Returns:
无返回值,方法中执行了环境变量的赋值操作。
"""
self.app_id = "b2a972c5-e082-46e5-b313-acbf51792422"
def test_appbuilder_client_tool_call(self):
# 如果app_id为空,则跳过单测执行, 避免单测因配置无效而失败
"""
如果app_id为空,则跳过单测执行, 避免单测因配置无效而失败
Args:
self (unittest.TestCase): unittest的TestCase对象
Raises:
None: 如果app_id不为空,则不会引发任何异常
unittest.SkipTest (optional): 如果app_id为空,则跳过单测执行
"""
tools = [
{
"type": "function",
"function": {
"name": "get_current_weather",
"description": "仅支持中国城市的天气查询,参数location为中国城市名称,其他国家城市不支持天气查询",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "城市名,举例:北京",
},
"unit": {"type": "string", "enum": ["摄氏度", "华氏度"]},
},
"required": ["location", "unit"],
},
},
}
]
appbuilder.logger.setLoglevel("DEBUG")
async def agent_run(client, query):
conversation_id = await client.create_conversation()
with await client.run_with_handler(
conversation_id=conversation_id,
query=query,
tools=tools,
event_handler=MyEventHandler(),
) as run:
await run.until_done()
async def agent_handle():
client = appbuilder.AsyncAppBuilderClient(self.app_id)
task1 = asyncio.create_task(
agent_run(client, "北京的天气怎么样"))
task2 = asyncio.create_task(
agent_run(client, "上海的天气怎么样"))
await asyncio.gather(task1, task2)
await client.http_client.session.close()
loop = asyncio.get_event_loop()
loop.run_until_complete(agent_handle())
if __name__ == "__main__":
unittest.main()