Skip to main content
Glama

baidu-ai-search

Official
by baidubce
component.py5.1 kB
# Copyright (c) 2023 Baidu, Inc. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import List, Union import numpy as np from appbuilder.core.message import Message from appbuilder.core.components.embeddings import EmbeddingBaseComponent from appbuilder.utils.trace.tracer_wrapper import components_run_trace, components_run_stream_trace from .base import MatchingBaseComponent, MatchingArgs class Matching(MatchingBaseComponent): """ Matching 基于Embedding类型的文本表示模型,输入query和文本列表,对其进行排序或者相似度计算 Examples: .. code-block:: python import appbuilder os.environ["APPBUILDER_TOKEN"] = '...' # 初始化所需要的组件 embedding = appbuilder.Embedding() matching = appbuilder.Matching(embedding) # 定义输入query和文本列表 query = appbuilder.Message("你好") contexts = appbuilder.Message(["世界", "你好"]) # 根据query,对文本列表做相似度排序 contexts_matched = matching(query, contexts) print(contexts_matched.content) # ['你好', '世界'] """ name: str = "Matching" version: str = "v1" meta: MatchingArgs = MatchingArgs def __init__( self, embedding_component: EmbeddingBaseComponent, **kwargs ): """ EmbeddingBaseComponent: 用于计算文本的embedding """ self.embedding_component = embedding_component super().__init__(self.meta) @components_run_trace def run( self, query: Union[Message[str], str], contexts: Union[Message[List[str]], List[str]], return_score: bool=False, ) -> Message[List[str]]: """ 根据给定的查询和上下文,返回匹配的上下文列表。 Args: query (Union[Message[str], str]): 查询字符串或Message对象,包含查询字符串。 contexts (Union[Message[List[str]], List[str]]): 上下文字符串列表或Message对象,包含上下文字符串列表。 return_score (bool, optional): 是否返回匹配得分。默认为False。 Returns: Message[List[str]]: 匹配的上下文列表。如果return_score为True,则返回包含得分和上下文的元组列表;否则仅返回上下文列表。 """ query_embedding = self.embedding_component(query) contexts_embedding = self.embedding_component.batch(contexts) sematic = self.semantics(query_embedding, contexts_embedding) combined = list(zip(sematic.content, contexts.content)) sorted_combined = sorted(combined, reverse=True) if return_score: return Message([(item[0], item[1]) for item in sorted_combined]) else: return Message([item[1] for item in sorted_combined]) def _cosine_similarity(self, X, Y): """ Args: X: 长度为 1 x n 的矩阵 Y: 长度为 m x n 的矩阵 Returns: 长度为 m x 1 的矩阵,每个元素表示 X 与 Y的对应行m 的余弦相似度 """ X_norm = X / np.linalg.norm(X) Y_norm = Y / np.linalg.norm(Y, axis=1, keepdims=True) similarity = np.dot(Y_norm, X_norm.T) return similarity def semantics( self, query_embedding: Union[Message[List[float]], List[float]], context_embeddings: Union[Message[List[List[float]]], List[List[float]]], ) -> Message[List[float]]: """ 计算query和context的相似度 Args: query_embedding (Union[Message[List[float]], List[float]]): query的embedding,长度为n的数组 context_embeddings (Union[Message[List[List[float]]], List[List[float]]]): context的embedding,长度为m x n的矩阵,其中m表示候选context的数量 Returns: Message[List[float]]: query和所有候选context的相似度列表 """ _query_embedding = query_embedding.content if isinstance(query_embedding, Message) else query_embedding _context_embeddings = context_embeddings.content if isinstance(context_embeddings, Message) else context_embeddings similarity_matrix = self._cosine_similarity([_query_embedding], _context_embeddings) similarity_matrix = similarity_matrix.flatten().tolist() return Message(similarity_matrix)

Latest Blog Posts

MCP directory API

We provide all the information about MCP servers via our MCP API.

curl -X GET 'https://glama.ai/api/mcp/v1/servers/baidubce/app-builder'

If you have feedback or need assistance with the MCP directory API, please join our Discord server