Skip to main content
Glama

baidu-ai-search

Official
by baidubce
base.py18.8 kB
# Copyright (c) 2023 Baidu, Inc. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import itertools import json import os import uuid from enum import Enum import logging import requests import copy import collections.abc from appbuilder.core.constants import GATEWAY_URL, GATEWAY_INNER_URL from pydantic import BaseModel, Field, ValidationError, HttpUrl, validator from pydantic.types import confloat from appbuilder.core.component import Component from appbuilder.core.message import Message, _T from appbuilder.utils.logger_util import logger from typing import Dict, List, Optional, Any from appbuilder.core.component import ComponentArguments from appbuilder.core.utils import ModelInfo, ttl_lru_cache from appbuilder.utils.sse_util import SSEClient from appbuilder.core._exception import AppBuilderServerException, ModelNotSupportedException class LLMMessage(Message): content: Optional[_T] = {} extra: Optional[Dict] = {} token_usage: Optional[Dict] = {} def __str__(self): return f"Message(name={self.name}, content={self.content}, " \ f"mtype={self.mtype}, extra={self.extra}, token_usage={self.token_usage})" def __deepcopy__(self, memo): new_instance = self.__class__() memo[id(self)] = new_instance for k, v in self.__dict__.items(): if k == "content" and isinstance(v, collections.abc.Iterator): pass else: setattr(new_instance, k, copy.deepcopy(v, memo)) return new_instance class CompletionRequest(object): r"""ShortSpeechRecognitionRequest.""" params = None response_mode = "blocking" def __init__(self, params: Dict[str, Any] = None, response_mode: str = None, **kwargs): r""" __init__ the client state. """ self.params = params self.response_mode = response_mode class ModelArgsConfig(BaseModel): stream: bool = Field(default=False, description="是否流式响应。默认为 False。") temperature: float = Field(default=1e-10,gt=0.0, le=1.0 ,description="模型的温度参数,范围从 0.0 到 1.0。") top_p: float= Field(default=1e-10, ge=0.0, le=1.0, description="模型的top_p参数,范围从 0.0 到 1.0。") max_output_tokens: int = Field(default=1024, ge=2, description="最大输出token数。") disable_search: bool = Field(default=True, description="是否禁用搜索。默认为 True") response_format: str = Field(default="text", description="响应格式,可选项有text、json_object") stop: list[str] = Field(default=[], description="停止词列表。", max_length=4) class CompletionResponse(object): r"""ShortSpeechRecognitionResponse.""" error_no = 0 error_msg = "" result = None log_id = "" extra = None token_usage = {} def __init__(self, response, stream: bool = False): """初始化客户端状态。""" self.error_no = 0 self.error_msg = "" self.log_id = response.headers.get("X-Appbuilder-Request-Id", None) self.extra = {} self.token_usage = {} if stream: # 流式数据处理 def stream_data(): sse_client = SSEClient(response) for event in sse_client.events(): if not event: continue answer = self.parse_stream_data(event) if answer is not None: yield answer self.result = stream_data() else: # 非流式数据的处理 if response.status_code != 200: self.error_no = response.status_code self.error_msg = "error" self.result = response.text raise AppBuilderServerException(self.log_id, self.error_no, self.result) else: data = response.json() if data.get("code") and "message" in data: raise AppBuilderServerException(self.log_id, data["code"], data["message"]) if "code" in data and "message" in data and "requestId" in data: raise AppBuilderServerException(self.log_id, data["code"], data["message"]) if "code" in data and "message" in data and "status" in data: raise AppBuilderServerException(self.log_id, data["code"], data["message"]) self.result = data.get("answer", None) trace_log_list = data.get("trace_log", None) if trace_log_list is not None: for trace_log in trace_log_list: key = trace_log["tool"] result_list = trace_log["result"] result_list = ResultProcessor.process(key, result_list) self.extra[key] = result_list self.token_usage = data.get("usage", {}) def parse_stream_data(self, event): """解析流式数据块并提取answer字段""" parsed_str = event.data raw_str = event.raw if parsed_str: try: data = json.loads(parsed_str) if data.get("code") and "message" in data: raise AppBuilderServerException(self.log_id, data["code"], data["message"]) if "code" in data and "message" in data and "requestId" in data: raise AppBuilderServerException(self.log_id, data["code"], data["message"]) if "code" in data and "message" in data and "status" in data: raise AppBuilderServerException(self.log_id, data["code"], data["message"]) return data except json.JSONDecodeError: # 处理可能的解析错误 logging.error("failed to parse: " + parsed_str) raise AppBuilderServerException("unknown", "unknown", parsed_str) else: try: data = json.loads(raw_str) if "code" in data and "message" in data: raise AppBuilderServerException(self.log_id, data["code"], data["message"]) return data except json.JSONDecodeError: # 处理解析错误 logging.error("failed to parse: " + raw_str) raise AppBuilderServerException("unknown", "unknown", raw_str) def get_stream_data(self): """获取处理过的流式数据的迭代器""" return self.result def to_message(self): """将响应结果转换为Message对象。 Returns: Message: Message对象。 """ message = LLMMessage() message.id = self.log_id message.content = self.result message.extra = self.extra message.token_usage = self.token_usage return self.message_iterable_wrapper(message) def message_iterable_wrapper(self, message): """ 对模型输出的 Message 对象进行包装。 当 Message 是流式数据时,数据被迭代完后,将重新更新 content 为 blocking 的字符串。 """ class IterableWrapper: def __init__(self, stream_content): self._content = stream_content self._concat = "" self._token_usage = {} def __iter__(self): return self def __next__(self): try: result_json = next(self._content) char = result_json.get("answer", "") result_list = result_json.get("result") key = result_json.get("tool") if result_list is not None: result_list = ResultProcessor.process(key, result_list) message.extra = {key: result_list} # Update the original extra else: message.extra = {} if "usage" in result_json: self._token_usage = result_json.get("usage") message.token_usage = self._token_usage self._concat += char return char except StopIteration: message.content = self._concat # Update the original content raise from collections.abc import Generator if isinstance(message.content, Generator): # Replace the original content with the custom iterable message.content = IterableWrapper(message.content) return message class ResultProcessor: @staticmethod def process(key, result_list): if key == 'search_baidu': rename_fields = { 'id': 'url', 'mock_id': 'ref_id', 'content': 'content', 'title': 'title', 'icon': 'icon', 'site_name': 'site_name', } renamed_list = [] for result in result_list: renamed_list.append({rename_fields[k]: v for k, v in result.items() if k in rename_fields}) return renamed_list elif key == 'search_db': return result_list else: raise TypeError(f"illegal argument key, expected key in {'search_baidu','search_db'}, got {key}") class CompletionBaseComponent(Component): name: str version: str base_url: str = "/rpc/2.0/cloud_hub/v1/ai_engine/copilot_engine" model_name: str = "" model_url: str = "" model_type: str = "chat" excluded_models: List[str] = ["Yi-34B-Chat", "ChatLaw"] model_info: ModelInfo = None model_config: Dict[str, Any] = { "model": { "provider": "baidu", "name": "ERNIE-Bot", "completion_params": { "temperature": 1e-10, "top_p": 0, } } } def __init__( self, meta: ComponentArguments, model: str = None, secret_key: Optional[str] = None, gateway: str = "", lazy_certification: bool = False, **kwargs ): """ Args: meta (ComponentArguments): 组件参数信息 model (str, optional): 模型名称. Defaults to None. secret_key (Optional[str], optional): 可选的密钥. Defaults to None. gateway (str, optional): 网关地址. Defaults to "". lazy_certification (bool, optional): 延迟认证,为True时在第一次运行时认证. Defaults to False. **kwargs: 其他关键字参数 """ super(CompletionBaseComponent, self).__init__( meta=meta, secret_key=secret_key, gateway=gateway, lazy_certification=lazy_certification) self.model_name = model self.version = self.version if not lazy_certification: self._check_model_and_get_model_url(self.model_name, self.model_type) @ttl_lru_cache(seconds_to_live=1 * 60 * 60) # 1h def set_secret_key_and_gateway(self, secret_key: Optional[str] = None, gateway: str = ""): super(CompletionBaseComponent, self).set_secret_key_and_gateway( secret_key=secret_key, gateway=gateway) # 不用重新获取列表 if os.environ.get("PRIVATE_AB", "OFF") == "OFF": self.__class__.model_info = ModelInfo(client=self.http_client) def set_model_info(self, model_name: str, model_url: str): """为llm component设置模型信息""" self.model_name = model_name self.model_url = model_url @ttl_lru_cache(seconds_to_live=1 * 60 * 60) # 1h def _check_model_and_get_model_url(self, model, model_type): if model and model in self.excluded_models: raise ModelNotSupportedException(f"unsupport model, epected model in {self.excluded_models}, got {model}") if not model: raise ValueError("illegal argument, model_name can't be empty") if self.__class__.model_info is None: self.set_secret_key_and_gateway() m_type = self.model_info.get_model_type(model) if m_type != model_type: raise ModelNotSupportedException( f"unsupport model_type for model {model}, expected model_type in {model_type}, got {m_type}") model_url = self.model_info.get_model_url(model) return model_url def gene_request(self, query, inputs, response_mode, message_id, model_config): """"send request""" data = { "query": query, "inputs": inputs, "response_mode": response_mode, "user": message_id, "model_config": model_config } request = CompletionRequest(data, response_mode) return request def gene_response(self, response, stream: bool = False): """generate response""" response = CompletionResponse(response, stream) return response def run(self, *args, **kwargs): """ Run the model with given input and return the result. Args: **kwargs: Keyword arguments for both StyleWritingComponent and common component inputs. Returns: obj:`Message`: Output message after running model. """ timeout = kwargs.get('timeout') retry = kwargs.get('retry', 0) request_id = kwargs.get('request_id') specific_params = {k: v for k, v in kwargs.items() if k in self.meta.model_fields} model_config_params = {k: v for k, v in kwargs.items() if k in ModelArgsConfig.model_fields} # 不在timeout、retry、request_id、specific_params、model_config_params中的参数 other_params = {k: v for k, v in kwargs.items() if k not in [ 'timeout', 'retry', 'request_id' ] + list(specific_params.keys()) + list(model_config_params.keys()) } try: specific_inputs = self.meta(**specific_params) model_config_inputs = ModelArgsConfig(**model_config_params) except ValidationError as e: raise ValueError(e) query, inputs, response_mode, user_id = self.get_compeliton_params(specific_inputs, model_config_inputs) model_config = self.get_model_config(model_config_inputs, other_params) request = self.gene_request(query, inputs, response_mode, user_id, model_config) response = self.completion( version=self.version, base_url=self.base_url, request=request, timeout=timeout, retry=retry, request_id=request_id ) if response.error_no != 0: raise AppBuilderServerException(service_err_code=response.error_no, service_err_message=response.error_msg) return response.to_message() def get_compeliton_params(self, specific_inputs, model_config_inputs): """获取模型请求参数""" inputs = specific_inputs.extract_values_to_dict() query = inputs["query"] user_id = str(uuid.uuid4()) if model_config_inputs.stream: response_mode = "streaming" else: response_mode = "blocking" return query, inputs, response_mode, user_id def get_model_config(self, model_config_inputs: ModelArgsConfig, other_params: dict = {}): """获取模型配置信息""" self.model_config["model"]["name"] = self.model_name # 不需要进行地址替换 if os.environ.get("PRIVATE_AB", "false") == "false": model_url = self._check_model_and_get_model_url(self.model_name, self.model_type) if model_url: self.model_config["model"]["url"] = model_url elif os.environ.get("PRIVATE_AB", "false") == "true": if self.model_url: self.model_config["model"]["url"] = self.model_url self.model_config["model"]["completion_params"]["temperature"] = model_config_inputs.temperature self.model_config["model"]["completion_params"]["top_p"] = model_config_inputs.top_p self.model_config["model"]["completion_params"]["max_output_tokens"] = model_config_inputs.max_output_tokens self.model_config["model"]["completion_params"]["disable_search"] = model_config_inputs.disable_search self.model_config["model"]["completion_params"]["response_format"] = model_config_inputs.response_format self.model_config["model"]["completion_params"]["stop"] = model_config_inputs.stop if len(other_params) > 0: logger.info("Some paramters are not expected by the model configuration, we assume they will be used in llm completion api") for k, v in other_params.items(): self.model_config["model"]["completion_params"][k] = v logger.info("Add parameter: {}, value: {} in completion_params.".format(k, v)) return self.model_config def completion( self, version, base_url, request: CompletionRequest, timeout: float = None, retry: int = 0, request_id: str = None, ) -> CompletionResponse: r"""Send a byte array of an audio file to obtain the result of speech recognition.""" headers = self.http_client.auth_header(request_id) headers["Content-Type"] = "application/json" completion_url = "/" + self.version + "/api/llm/" + self.name stream = True if request.response_mode == "streaming" else False url = self.http_client.service_url(completion_url, self.base_url) response = self.http_client.session.post(url, json=request.params, headers=headers, timeout=timeout, stream=stream) return self.gene_response(response, stream) @staticmethod def check_service_error(data: dict): r"""check service internal error. :param: data: dict, service return body data. :rtype: . """ if "err_no" in data and "err_msg" in data: if data["err_no"] != 0: raise AppBuilderServerException(service_err_code=data["err_no"], service_err_message=data["err_msg"])

Latest Blog Posts

MCP directory API

We provide all the information about MCP servers via our MCP API.

curl -X GET 'https://glama.ai/api/mcp/v1/servers/baidubce/app-builder'

If you have feedback or need assistance with the MCP directory API, please join our Discord server