Skip to main content
Glama

baidu-ai-search

Official
by baidubce
component.py10.9 kB
# Copyright (c) 2023 Baidu, Inc. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. r"""ASR component. """ import os import uuid import json import requests import tempfile from urllib.parse import urlparse from appbuilder.core.component import Component from appbuilder.core.message import Message from appbuilder.core._exception import AppBuilderServerException, InvalidRequestArgumentError from appbuilder.core._client import HTTPClient from appbuilder.core.components.asr.model import ShortSpeechRecognitionRequest, ShortSpeechRecognitionResponse, \ ASRInMsg, ASROutMsg from appbuilder.utils.trace.tracer_wrapper import components_run_trace, components_run_stream_trace DEFAULT_AUDIO_MAX_DURATION = 55 * 1000 # 55s # 参考短语音极速版API(https://ai.baidu.com/ai-doc/SPEECH/Jlbxdezuf) DEFAULT_FRAME_RATE = 16000 class ASR(Component): r""" ASR组件,即对于输入的语音文件,输出语音识别结果 Examples: .. code-block:: python import appbuilder asr = appbuilder.ASR() os.environ["APPBUILDER_TOKEN"] = '...' with open("xxxx.pcm", "rb") as f: audio_data = f.read() content_data = {"audio_format": "pcm", "raw_audio": audio_data, "rate": 16000} msg = appbuilder.Message(content_data) out = asr.run(msg) print(out.content) # eg: {"result": ["北京科技馆。"]} """ name = "asr" version = "v1" manifests = [ { "name": "asr", "description": "对于输入的语音文件进行识别,输出语音识别结果。", "parameters": { "type": "object", "properties": { "file_url": { "type": "string", "description": "输入语音文件的url,根据url获取到语音文件" }, "file_name": { "type": "string", "description": "待识别语音文件名,用于生成获取语音的url" }, "file_type": { "type": "string", "description": "语音文件类型,支持pcm/wav/amr/m4a", "enum": ["pcm", "wav", "amr", "m4a"] } }, "anyOf": [ { "required": [ "file_url" ] }, { "required": [ "file_name" ] } ] } } ] @HTTPClient.check_param @components_run_trace def run(self, message: Message, audio_format: str = "pcm", rate: int = 16000, timeout: float = None, retry: int = 0, **kwargs) -> Message: r""" 执行语音识别操作,并返回识别结果。 Args: message (Message): 输入消息对象,包含待识别的音频数据。该参数为必需项,格式如:Message(content={"raw_audio": b"..."})。 audio_format (str, optional): 音频文件格式,支持pcm/wav/amr/m4a,不区分大小写,推荐使用pcm格式。默认为"pcm"。 rate (int, optional): 音频采样率,固定为16000。默认为16000。 timeout (float, optional): HTTP请求超时时间。默认为None。 retry (int, optional): HTTP请求重试次数。默认为0。 Returns: Message: 语音识别结果,格式如:Message(content={"result": ["识别结果"]})。 """ inp = ASRInMsg(**message.content) request = ShortSpeechRecognitionRequest() request.format = audio_format request.rate = rate request.cuid = str(uuid.uuid4()) request.dev_pid = "80001" request.speech = inp.raw_audio traceid = kwargs.get("traceid", "") response = self._recognize(request, timeout, retry, request_id=traceid) out = ASROutMsg(result=list(response.result)) return Message(content=out.model_dump()) def _recognize( self, request: ShortSpeechRecognitionRequest, timeout: float = None, retry: int = 0, request_id: str = None, ) -> ShortSpeechRecognitionResponse: """ 使用给定的输入并返回语音识别的结果。 参数: request (obj:`ShortSpeechRecognitionRequest`): 输入请求,这是一个必需的参数。 timeout (float, 可选): 请求的超时时间。 retry (int, 可选): 请求的重试次数。 返回: obj:`ShortSpeechRecognitionResponse`: 接口返回的输出消息。 """ ContentType = "audio/" + request.format + ";rate=" + str(request.rate) headers = self.http_client.auth_header(request_id) headers['content-type'] = ContentType params = { 'dev_pid': request.dev_pid, 'cuid': request.cuid } if retry != self.http_client.retry.total: self.http_client.retry.total = retry response = self.http_client.session.post(self.http_client.service_url("/v1/bce/aip_speech/asrpro"), params=params, headers=headers, data=request.speech, timeout=timeout) self.http_client.check_response_header(response) data = response.json() self.http_client.check_response_json(data) request_id = self.http_client.response_request_id(response) self.__class__._check_service_error(request_id, data) response = ShortSpeechRecognitionResponse.from_json(payload=json.dumps(data)) response.request_id = request_id return response @staticmethod def _check_service_error(request_id: str, data: dict): r"""个性化服务response参数检查 参数: request (dict) : 短语音识别body返回 返回: 无 """ if "err_no" in data and "err_msg" in data: if data["err_no"] != 0: raise AppBuilderServerException( request_id=request_id, service_err_code=data["err_no"], service_err_message=data["err_msg"] ) @components_run_stream_trace def tool_eval(self, name: str, streaming: bool, **kwargs): """ 评估给定文件名或文件URL的语音识别结果。 Args: name (str): 函数调用名称。 streaming (bool): 是否以流的方式返回结果。 **kwargs: 关键字参数,用于指定文件名、文件URL等参数。 Returns: 如果streaming为True,则通过生成器逐个返回包含识别结果的消息对象; 如果streaming为False,则返回包含识别结果的JSON字符串。 Raises: InvalidRequestArgumentError: 如果未设置文件名或文件URL不存在,则抛出此异常。 """ file_url = kwargs.get("file_url", None) if not file_url: file_urls = kwargs.get("file_urls", {}) file_path = kwargs.get("file_name", None) if not file_path: raise InvalidRequestArgumentError("request format error, file name is not set") file_name = os.path.basename(file_path) file_url = file_urls.get(file_name, None) if not file_url: raise InvalidRequestArgumentError( f"request format error, file {file_url} url does not exist" ) _, file_type = os.path.splitext(os.path.basename(urlparse(file_url).path)) file_type = file_type.strip('.') audio_file = tempfile.NamedTemporaryFile("wb", suffix=file_type) audio_file.write(requests.get(file_url).content) raw_audios = _convert(audio_file.name, file_type) text = "" for raw_audio in raw_audios: content_data = {"audio_format": file_type, "raw_audio": raw_audio, "rate": DEFAULT_FRAME_RATE} msg = Message(content_data) out = self.run(msg) text += "".join(out.content["result"]) results = {"识别结果": text} audio_file.close() res = json.dumps(results, ensure_ascii=False, indent=4) if streaming: yield { "type": "text", "text": res, "visible_scope": 'llm', } yield { "type": "text", "text": "", "visible_scope": 'user', } else: return res def _convert(path, file_type): from pydub import AudioSegment if file_type.lower() == "mp3": audio = AudioSegment.from_mp3(path) elif file_type.lower() == "wav": audio = AudioSegment.from_wav(path) # 如果是pcm格式,则直接读取并返回 elif file_type.lower() == "pcm": with open(path, "rb") as f: return [f.read()] else: # pydub自动检测音频类型 audio = AudioSegment.from_wav(path) # 如果取样率为16000且时长小于60s,则直接读取音频并返回 if (audio.frame_rate == DEFAULT_FRAME_RATE and audio.frame_count() * 1000 / audio.frame_rate < DEFAULT_AUDIO_MAX_DURATION): with open(path, "rb") as f: return [f.read()] audio = audio.set_frame_rate(DEFAULT_FRAME_RATE) total_milliseconds = int(audio.frame_count() * 1000 / audio.frame_rate) start = 0 raw_audios = [] while start < total_milliseconds: end = start + DEFAULT_AUDIO_MAX_DURATION if start + DEFAULT_AUDIO_MAX_DURATION > total_milliseconds: end = total_milliseconds audio_seg = audio[start:end] audio_seg_file = tempfile.NamedTemporaryFile("wb", suffix="wav") try: audio_seg.export(audio_seg_file.name, format="wav") with open(audio_seg_file.name, "rb") as f: raw_audios.append(f.read()) finally: audio_seg_file.close() start = end return raw_audios

Latest Blog Posts

MCP directory API

We provide all the information about MCP servers via our MCP API.

curl -X GET 'https://glama.ai/api/mcp/v1/servers/baidubce/app-builder'

If you have feedback or need assistance with the MCP directory API, please join our Discord server