factory.py•1.71 kB
from typing import Any, ClassVar
from importlib import import_module
from memos.configs.memory import MemoryConfigFactory
from memos.memories.activation.base import BaseActMemory
from memos.memories.base import BaseMemory
from memos.memories.parametric.base import BaseParaMemory
from memos.memories.textual.base import BaseTextMemory
class MemoryFactory(BaseMemory):
"""Factory class for creating memory instances with delayed import."""
# 使用字符串路径进行延迟导入,避免循环依赖
backend_to_class: ClassVar[dict[str, str]] = {
"naive_text": "memos.memories.textual.naive:NaiveTextMemory",
"general_text": "memos.memories.textual.general:GeneralTextMemory",
"tree_text": "memos.memories.textual.tree:TreeTextMemory",
"kv_cache": "memos.memories.activation.kv:KVCacheMemory",
"lora": "memos.memories.parametric.lora:LoRAMemory",
}
@classmethod
def from_config(
cls, config_factory: MemoryConfigFactory
) -> BaseTextMemory | BaseActMemory | BaseParaMemory:
backend = config_factory.backend
if backend not in cls.backend_to_class:
raise ValueError(f"Invalid backend: {backend}")
# 延迟导入:只在实际需要时才导入具体的Memory类
module_path = cls.backend_to_class[backend]
module_name, class_name = module_path.rsplit(":", 1)
try:
module = import_module(module_name)
memory_class = getattr(module, class_name)
return memory_class(config_factory.config)
except (ImportError, AttributeError) as e:
raise ValueError(f"Failed to import backend '{backend}' from '{module_path}': {e}")