test_embedding.py•7.09 kB
#!/usr/bin/env python3
"""
测试嵌入模型和重排模型的工作状态
"""
import os
import json
from pathlib import Path
from openai import OpenAI
def load_env_file():
"""加载环境变量"""
env_file = Path(".env")
if env_file.exists():
with open(env_file, 'r', encoding='utf-8') as f:
for line in f:
line = line.strip()
if line and not line.startswith('#') and '=' in line:
key, value = line.split('=', 1)
os.environ[key] = value
def test_siliconflow_embedding():
"""测试SiliconFlow嵌入模型"""
print("🧪 测试SiliconFlow嵌入模型...")
try:
# 初始化客户端
client = OpenAI(
api_key=os.getenv("SILICONFLOW_API_KEY"),
base_url=os.getenv("SILICONFLOW_BASE_URL")
)
# 测试文本
test_texts = [
"今天学习了MemOS的嵌入模型",
"手动验证测试功能正常工作",
"人工智能和机器学习"
]
print(f"API Key: {os.getenv('SILICONFLOW_API_KEY')[:20]}...")
print(f"Base URL: {os.getenv('SILICONFLOW_BASE_URL')}")
for i, text in enumerate(test_texts, 1):
print(f"\n📝 测试文本 {i}: {text}")
# 尝试获取嵌入向量
response = client.embeddings.create(
model="BAAI/bge-large-zh-v1.5", # 中文嵌入模型
input=text
)
embedding = response.data[0].embedding
print(f"✅ 嵌入向量维度: {len(embedding)}")
print(f"✅ 向量前5个值: {embedding[:5]}")
print("\n🎉 SiliconFlow嵌入模型测试通过!")
return True
except Exception as e:
print(f"❌ SiliconFlow嵌入模型测试失败: {e}")
return False
def test_local_embedding():
"""测试本地嵌入模型(如果有的话)"""
print("\n🧪 测试本地嵌入模型...")
try:
# 尝试导入本地嵌入库
from sentence_transformers import SentenceTransformer
# 尝试加载模型
model_name = "BAAI/bge-large-zh-v1.5"
print(f"📦 尝试加载模型: {model_name}")
model = SentenceTransformer(model_name)
# 测试文本
test_texts = [
"今天学习了MemOS的嵌入模型",
"手动验证测试功能正常工作"
]
embeddings = model.encode(test_texts)
print(f"✅ 本地嵌入模型加载成功")
print(f"✅ 嵌入向量形状: {embeddings.shape}")
print(f"✅ 向量维度: {embeddings.shape[1]}")
return True
except ImportError:
print("⚠️ sentence-transformers未安装,跳过本地模型测试")
return False
except Exception as e:
print(f"❌ 本地嵌入模型测试失败: {e}")
return False
def test_reranker():
"""测试重排模型"""
print("\n🧪 测试重排模型...")
try:
# 尝试导入重排库
from sentence_transformers import CrossEncoder
# 尝试加载重排模型
model_name = "BAAI/bge-reranker-v2-m3"
print(f"📦 尝试加载重排模型: {model_name}")
reranker = CrossEncoder(model_name)
# 测试查询和候选文档
query = "MemOS记忆管理"
candidates = [
"今天学习了MemOS的嵌入模型功能",
"手动验证测试功能正常工作",
"人工智能和机器学习技术",
"MemOS是一个智能记忆管理系统"
]
# 计算相关性分数
pairs = [[query, candidate] for candidate in candidates]
scores = reranker.predict(pairs)
print(f"✅ 重排模型加载成功")
print(f"✅ 查询: {query}")
print("✅ 重排结果:")
# 按分数排序
ranked_results = sorted(zip(candidates, scores), key=lambda x: x[1], reverse=True)
for i, (candidate, score) in enumerate(ranked_results, 1):
print(f" {i}. [{score:.4f}] {candidate}")
return True
except ImportError:
print("⚠️ sentence-transformers未安装,跳过重排模型测试")
return False
except Exception as e:
print(f"❌ 重排模型测试失败: {e}")
return False
def test_current_system():
"""测试当前系统的嵌入实现"""
print("\n🧪 测试当前系统的嵌入实现...")
try:
from usage_examples import SimpleMemOS
# 初始化系统
memos = SimpleMemOS()
# 测试当前的_get_embedding方法
test_text = "测试当前系统的嵌入功能"
embedding = memos._get_embedding(test_text)
print(f"✅ 当前系统嵌入方法工作正常")
print(f"✅ 嵌入向量维度: {len(embedding)}")
print(f"✅ 向量类型: 哈希向量(简化版)")
print(f"✅ 向量前5个值: {embedding[:5]}")
# 测试搜索功能
print("\n🔍 测试搜索功能...")
results = memos.search_memories("测试", limit=3)
print(f"✅ 搜索功能正常,找到 {len(results)} 条结果")
return True
except Exception as e:
print(f"❌ 当前系统测试失败: {e}")
return False
def main():
"""主测试函数"""
print("🚀 开始测试嵌入和重排模型...")
print("=" * 60)
# 加载环境变量
load_env_file()
results = {}
# 测试SiliconFlow嵌入API
results['siliconflow_embedding'] = test_siliconflow_embedding()
# 测试本地嵌入模型
results['local_embedding'] = test_local_embedding()
# 测试重排模型
results['reranker'] = test_reranker()
# 测试当前系统
results['current_system'] = test_current_system()
# 总结结果
print("\n" + "=" * 60)
print("📊 测试结果总结:")
print("-" * 30)
for test_name, success in results.items():
status = "✅ 通过" if success else "❌ 失败"
print(f"{test_name}: {status}")
# 建议
print("\n💡 建议:")
if results['siliconflow_embedding']:
print("- SiliconFlow嵌入API可用,建议升级系统使用真正的嵌入模型")
else:
print("- SiliconFlow嵌入API不可用,当前使用哈希向量")
if results['local_embedding']:
print("- 本地嵌入模型可用,可以考虑离线部署")
else:
print("- 本地嵌入模型不可用,需要安装sentence-transformers")
if results['reranker']:
print("- 重排模型可用,可以提升搜索精度")
else:
print("- 重排模型不可用,搜索精度可能受限")
if __name__ == "__main__":
main()