enhanced_simple_memos.py•20.8 kB
#!/usr/bin/env python3
"""
增强版SimpleMemOS - 使用Qwen嵌入和重排模型
基于原有SimpleMemOS,但使用真正的嵌入向量而不是哈希向量
"""
import os
import json
import time
from pathlib import Path
from typing import Dict, Any, List, Optional
from datetime import datetime
from openai import OpenAI
from qdrant_client import QdrantClient
from qdrant_client.models import Distance, VectorParams, PointStruct
# 导入性能优化模块
try:
from qdrant_config import get_optimized_search_params, qdrant_perf_config
from cache_utils import cached_search, cached_embedding, search_cache, embedding_cache
PERFORMANCE_MODULES_AVAILABLE = True
except ImportError:
PERFORMANCE_MODULES_AVAILABLE = False
print("⚠️ 性能优化模块未找到,将使用默认配置")
# 直接使用SiliconFlow API,避免复杂的memos框架依赖
def load_env_file():
"""加载环境变量"""
env_file = Path(".env")
if env_file.exists():
with open(env_file) as f:
for line in f:
if line.strip() and not line.startswith('#'):
key, value = line.strip().split('=', 1)
os.environ[key] = value
class EnhancedSimpleMemOS:
"""增强版SimpleMemOS,使用Qwen嵌入和重排模型"""
def __init__(self, data_dir="./memos_data"):
self.data_dir = Path(data_dir)
self.data_dir.mkdir(exist_ok=True)
# 读取统一配置文件
config_path = os.environ.get('MEMOS_CONFIG', str(self.data_dir / "concurrent_config.json"))
qdrant_config = self._load_qdrant_config(config_path)
# 初始化向量数据库 - 使用1024维向量
qdrant_url = f"http://{qdrant_config['host']}:{qdrant_config['port']}"
print(f"[ENHANCED] 连接Qdrant: {qdrant_url}")
self.vector_db = QdrantClient(url=qdrant_url)
self.collection_name = "enhanced_memories"
self._init_collection()
# 初始化SiliconFlow客户端
print("🔄 初始化SiliconFlow客户端...")
load_env_file()
self.llm_client = OpenAI(
api_key=os.getenv("SILICONFLOW_API_KEY", "sk-ygqlrgrxrypykiiskuspuahkwihhbhhjhazqokntwdzfwqdv"),
base_url=os.getenv("SILICONFLOW_BASE_URL", "https://api.siliconflow.cn/v1")
)
print("✅ SiliconFlow客户端初始化成功")
# 初始化LLM客户端
self.llm_client = OpenAI(
api_key=os.getenv("SILICONFLOW_API_KEY"),
base_url=os.getenv("SILICONFLOW_BASE_URL")
)
# 记忆计数器
self.memory_counter = self._load_counter()
def _load_qdrant_config(self, config_path: str) -> Dict[str, Any]:
"""读取Qdrant配置"""
try:
if Path(config_path).exists():
with open(config_path, 'r', encoding='utf-8') as f:
config = json.load(f)
if config.get("qdrant_mode") == "server":
return {
"host": "localhost",
"port": config.get("qdrant_port", 6335)
}
else:
print(f"[WARNING] 配置文件中qdrant_mode不是server: {config.get('qdrant_mode')}")
else:
print(f"[WARNING] 配置文件不存在: {config_path}")
except Exception as e:
print(f"[ERROR] 读取配置文件失败: {e}")
# 回退到环境变量或默认值
qdrant_url = os.environ.get('QDRANT_URL', 'http://localhost:6335')
if qdrant_url.startswith('http://'):
parts = qdrant_url.replace('http://', '').split(':')
return {"host": parts[0], "port": int(parts[1]) if len(parts) > 1 else 6335}
else:
return {"host": "localhost", "port": 6335}
def _init_collection(self):
"""初始化向量集合 - 使用1024维向量"""
# 添加连通性检测
try:
health_check = self.vector_db.get_collections()
print(f"✅ Qdrant连通性检测成功,发现 {len(health_check.collections)} 个集合")
except Exception as e:
raise RuntimeError(f"❌ Qdrant连接失败: {e}")
max_retries = 3
for attempt in range(max_retries):
try:
# 检查集合是否存在
collections = self.vector_db.get_collections()
collection_exists = any(c.name == self.collection_name for c in collections.collections)
if not collection_exists:
print(f"🔧 创建向量集合: {self.collection_name} (1024维) (尝试 {attempt+1}/{max_retries})")
self.vector_db.create_collection(
collection_name=self.collection_name,
vectors_config=VectorParams(size=1024, distance=Distance.COSINE) # 使用1024维
)
print(f"✅ 向量集合创建成功: {self.collection_name}")
else:
print(f"✅ 向量集合已存在: {self.collection_name}")
# 验证集合确实可用
info = self.vector_db.get_collection(self.collection_name)
print(f"📊 集合信息: {info.vectors_count} 个向量")
return True
except Exception as e:
print(f"❌ 集合初始化失败 (尝试 {attempt+1}/{max_retries}): {e}")
if attempt == max_retries - 1:
raise
time.sleep(1)
def _ensure_collection_exists(self) -> bool:
"""确保集合存在"""
try:
self.vector_db.get_collection(self.collection_name)
return True
except Exception:
print("⚠️ 集合不存在,尝试重新创建...")
return self._init_collection()
def _load_counter(self) -> int:
"""加载记忆计数器"""
counter_file = self.data_dir / "counter.txt"
if counter_file.exists():
try:
with open(counter_file, 'r') as f:
return int(f.read().strip())
except:
pass
return 0
def _save_counter(self):
"""保存记忆计数器"""
counter_file = self.data_dir / "counter.txt"
with open(counter_file, 'w') as f:
f.write(str(self.memory_counter))
def _get_embedding(self, text: str) -> List[float]:
"""使用SiliconFlow API获取Qwen嵌入向量 - 支持缓存"""
# 检查缓存
if PERFORMANCE_MODULES_AVAILABLE:
cached_embedding = embedding_cache.get(text)
if cached_embedding is not None:
return cached_embedding
try:
response = self.llm_client.embeddings.create(
model="Qwen/Qwen3-Embedding-0.6B",
input=text
)
embedding = response.data[0].embedding
# 存储到缓存
if PERFORMANCE_MODULES_AVAILABLE:
embedding_cache.put(text, embedding)
return embedding
except Exception as e:
print(f"❌ 嵌入生成失败: {e}")
# 降级处理:返回零向量
return [0.0] * 1024
def _get_memory_count(self) -> int:
"""获取当前记忆数量"""
try:
collection_info = self.vector_db.get_collection(self.collection_name)
return collection_info.points_count or 0
except Exception:
return 0
def _rerank_with_siliconflow(self, query: str, documents: List[str], top_k: int = 5) -> List[tuple]:
"""使用SiliconFlow API进行重排"""
try:
# 构造重排请求
pairs = [[query, doc] for doc in documents]
# 这里我们使用一个简化的重排逻辑
# 实际上SiliconFlow可能有专门的重排API,但这里我们用嵌入相似度来模拟
query_embedding = self._get_embedding(query)
doc_scores = []
for doc in documents:
doc_embedding = self._get_embedding(doc)
# 计算余弦相似度
similarity = self._cosine_similarity(query_embedding, doc_embedding)
doc_scores.append((doc, similarity))
# 按相似度排序并返回top_k
doc_scores.sort(key=lambda x: x[1], reverse=True)
return doc_scores[:top_k]
except Exception as e:
print(f"❌ 重排失败: {e}")
# 降级:返回原始顺序
return [(doc, 0.5) for doc in documents[:top_k]]
def _cosine_similarity(self, vec1: List[float], vec2: List[float]) -> float:
"""计算余弦相似度"""
try:
import math
dot_product = sum(a * b for a, b in zip(vec1, vec2))
magnitude1 = math.sqrt(sum(a * a for a in vec1))
magnitude2 = math.sqrt(sum(a * a for a in vec2))
if magnitude1 == 0 or magnitude2 == 0:
return 0.0
return dot_product / (magnitude1 * magnitude2)
except:
return 0.0
def add_memory(self, content, tags=None, metadata=None):
"""添加记忆 - 使用真实嵌入向量"""
if not self._ensure_collection_exists():
print("❌ 集合不可用,无法添加记忆")
return None
self.memory_counter += 1
memory_id = self.memory_counter
# 获取真实的嵌入向量
print(f"🔄 为记忆生成嵌入向量...")
vector = self._get_embedding(content)
# 准备元数据
payload = {
"content": content,
"tags": tags or [],
"metadata": metadata or {},
"timestamp": datetime.now().isoformat()
}
# 存储到向量数据库
try:
point = PointStruct(
id=memory_id,
vector=vector,
payload=payload
)
print(f"🔍 [DEBUG] 存储到集合: {self.collection_name}")
self.vector_db.upsert(
collection_name=self.collection_name,
points=[point]
)
print(f"🔍 [DEBUG] 存储成功,向量维度: {len(vector)}")
self._save_counter()
print(f"✅ 添加记忆 #{memory_id}: {content[:50]}...")
return memory_id
except Exception as e:
print(f"❌ 添加记忆失败: {e}")
return None
def search_memories(self, query, limit=5, use_reranker=True):
"""搜索记忆 - 支持重排器优化和性能缓存"""
if not self._ensure_collection_exists():
print("❌ 集合不可用,无法搜索记忆")
return []
try:
start_time = time.time()
# 检查缓存
if PERFORMANCE_MODULES_AVAILABLE:
cached_result = search_cache.get(query, limit, use_reranker, False, False)
if cached_result is not None:
print(f"🎯 缓存命中,直接返回结果 (耗时: {time.time() - start_time:.3f}s)")
return cached_result
# 获取查询的嵌入向量
print(f"🔄 为查询生成嵌入向量...")
query_vector = self._get_embedding(query)
# 获取当前记忆数量用于性能优化
memory_count = self._get_memory_count()
# 向量搜索,获取更多候选结果用于重排
search_limit = limit * 3 if use_reranker else limit
# 使用优化的搜索参数
search_params = None
if PERFORMANCE_MODULES_AVAILABLE:
search_params = get_optimized_search_params(memory_count)
qdrant_perf_config.print_config_summary(memory_count)
print(f"🔍 [DEBUG] 搜索集合: {self.collection_name}, 查询向量维度: {len(query_vector)}")
search_result = self.vector_db.search(
collection_name=self.collection_name,
query_vector=query_vector, # ✅ 修复:使用正确的参数名
limit=search_limit,
search_params=search_params
)
print(f"🔍 [DEBUG] 搜索结果: {len(search_result)} 条")
memories = []
documents = []
for point in search_result:
memory = {
"id": point.id,
"content": point.payload["content"],
"tags": point.payload.get("tags", []),
"score": point.score,
"metadata": point.payload.get("metadata", {})
}
memories.append(memory)
documents.append(point.payload["content"])
# 如果启用重排器且有足够的结果
if use_reranker and len(documents) > 1:
try:
print(f"🔄 使用SiliconFlow重排器优化结果...")
reranked_results = self._rerank_with_siliconflow(query, documents, top_k=limit)
# 根据重排结果重新排序记忆
reranked_memories = []
for reranked_doc, rerank_score in reranked_results:
# 找到对应的记忆
for memory in memories:
if memory["content"] == reranked_doc:
memory["rerank_score"] = rerank_score
reranked_memories.append(memory)
break
print(f"✅ 重排完成,返回 {len(reranked_memories)} 条结果")
final_results = reranked_memories
except Exception as e:
print(f"⚠️ 重排失败,使用原始搜索结果: {e}")
final_results = memories[:limit]
else:
final_results = memories[:limit]
# 存储到缓存
if PERFORMANCE_MODULES_AVAILABLE and final_results:
search_cache.put(query, final_results, limit, use_reranker, False, False)
# 性能统计
duration = time.time() - start_time
print(f"🔍 搜索完成,返回 {len(final_results)} 条结果 (耗时: {duration:.3f}s)")
return final_results
except Exception as e:
print(f"❌ 搜索记忆失败: {e}")
return []
def get_all_memories(self):
"""获取所有记忆"""
if not self._ensure_collection_exists():
return []
try:
# 获取集合中的所有点
result = self.vector_db.scroll(
collection_name=self.collection_name,
limit=10000 # 假设不会超过10000条记忆
)
memories = []
for point in result[0]: # scroll返回(points, next_page_offset)
memories.append({
"id": point.id,
"content": point.payload["content"],
"tags": point.payload.get("tags", []),
"metadata": point.payload.get("metadata", {})
})
return memories
except Exception as e:
print(f"❌ 获取所有记忆失败: {e}")
return []
def update_memory_metadata(self, memory_id: int, new_metadata: Dict[str, Any]) -> bool:
"""更新记忆的metadata"""
try:
# 使用Qdrant的set_payload方法更新metadata
self.vector_db.set_payload(
collection_name=self.collection_name,
payload={"metadata": new_metadata},
points=[memory_id]
)
print(f"✅ 记忆#{memory_id}的metadata已更新")
return True
except Exception as e:
print(f"❌ 更新记忆#{memory_id}的metadata失败: {e}")
return False
def get_memory_by_id(self, memory_id: int) -> Optional[Dict[str, Any]]:
"""根据ID获取单个记忆"""
try:
# 使用Qdrant的retrieve方法获取特定记忆
result = self.vector_db.retrieve(
collection_name=self.collection_name,
ids=[memory_id],
with_payload=True,
with_vectors=False
)
if result and len(result) > 0:
point = result[0]
payload = point.payload
memory = {
'id': point.id,
'content': payload.get('content', ''),
'metadata': payload.get('metadata', {}),
'tags': payload.get('tags', [])
}
return memory
else:
return None
except Exception as e:
print(f"❌ 获取记忆#{memory_id}失败: {e}")
return None
def print_performance_stats(self):
"""打印性能统计信息"""
if PERFORMANCE_MODULES_AVAILABLE:
from cache_utils import print_all_cache_stats
print_all_cache_stats()
else:
print("⚠️ 性能模块未加载,无法显示统计信息")
def clear_caches(self):
"""清空所有缓存"""
if PERFORMANCE_MODULES_AVAILABLE:
from cache_utils import clear_all_caches
clear_all_caches()
else:
print("⚠️ 性能模块未加载,无法清空缓存")
def quick_test():
"""快速测试修复效果"""
print("🔧 快速测试专家修复方案")
print("=" * 50)
try:
# 初始化增强版MemOS
memos = EnhancedSimpleMemOS("./quick_test_data")
print(f"✅ 初始化成功,集合名称: {memos.collection_name}")
# 测试添加记忆
test_content = "这是一条快速测试记忆"
memory_id = memos.add_memory(test_content, tags=["快速测试"])
if memory_id:
print(f"✅ 记忆添加成功,ID: {memory_id}")
# 测试搜索记忆
results = memos.search_memories("快速测试", limit=3)
if results:
print(f"✅ 搜索成功,找到 {len(results)} 条记忆")
for result in results:
print(f" - ID: {result['id']}, 分数: {result['score']:.3f}")
return True
else:
print("❌ 搜索失败,未找到记忆")
return False
else:
print("❌ 记忆添加失败")
return False
except Exception as e:
print(f"❌ 测试失败: {e}")
return False
if __name__ == "__main__":
quick_test()
def test_enhanced_memos():
"""测试增强版MemOS"""
print("🧪 测试增强版MemOS")
print("=" * 50)
# 加载环境变量
load_env_file()
# 创建增强版MemOS实例
memos = EnhancedSimpleMemOS("./test_enhanced_memos_data")
# 测试添加记忆
print("\n📝 测试添加记忆...")
test_memories = [
("MemOS是一个智能记忆管理系统,支持向量搜索和语义检索", ["MemOS", "技术"]),
("Qwen模型提供高质量的嵌入向量和重排功能", ["Qwen", "AI模型"]),
("数据分析平台需要考虑数据采集、存储、计算、可视化四个层次", ["数据分析", "架构"])
]
for content, tags in test_memories:
memory_id = memos.add_memory(content, tags=tags)
if memory_id:
print(f"✅ 记忆 #{memory_id} 添加成功")
# 测试搜索记忆
print("\n🔍 测试搜索记忆...")
query = "智能记忆系统"
results = memos.search_memories(query, limit=3, use_reranker=True)
print(f"查询: {query}")
print(f"找到 {len(results)} 条相关记忆:")
for i, result in enumerate(results, 1):
print(f"{i}. 向量分数: {result['score']:.4f}")
if 'rerank_score' in result:
print(f" 重排分数: {result['rerank_score']:.4f}")
print(f" 内容: {result['content']}")
print(f" 标签: {result['tags']}")
print()
if __name__ == "__main__":
test_enhanced_memos()