factory.py•1.54 kB
from typing import Any, ClassVar
from memos.configs.embedder import EmbedderConfigFactory
from memos.embedders.base import BaseEmbedder
from memos.embedders.ollama import OllamaEmbedder
from memos.embedders.siliconflow import SiliconFlowEmbedder
# 延迟导入sentence_transformer以避免torch依赖
def _get_sentran_embedder():
try:
from memos.embedders.sentence_transformer import SenTranEmbedder
return SenTranEmbedder
except ImportError as e:
raise ImportError(f"sentence_transformers not available: {e}. Please install with: pip install sentence-transformers")
class EmbedderFactory(BaseEmbedder):
"""Factory class for creating embedder instances."""
backend_to_class: ClassVar[dict[str, Any]] = {
"ollama": OllamaEmbedder,
"sentence_transformer": _get_sentran_embedder,
"siliconflow": SiliconFlowEmbedder,
}
@classmethod
def from_config(cls, config_factory: EmbedderConfigFactory) -> BaseEmbedder:
backend = config_factory.backend
if backend not in cls.backend_to_class:
raise ValueError(f"Invalid backend: {backend}")
embedder_class_or_func = cls.backend_to_class[backend]
# 如果是函数(延迟导入),则调用它获取类
if callable(embedder_class_or_func) and not hasattr(embedder_class_or_func, '__init__'):
embedder_class = embedder_class_or_func()
else:
embedder_class = embedder_class_or_func
return embedder_class(config_factory.config)