memory_search.py•3.71 kB
#!/usr/bin/env python3
import os, sys, json, math, re
from qdrant_client import QdrantClient
from qdrant_client.http import models as qm
from sentence_transformers import SentenceTransformer
from neo4j import GraphDatabase
ROOT = os.path.dirname(os.path.dirname(__file__))
CFG_PATH = os.path.join(ROOT, "config", "memory.config.json")
# Load config with env overrides
sys.path.append(os.path.join(ROOT, "scripts"))
from config_loader import load_config # type: ignore
CFG = load_config(CFG_PATH)
q = QdrantClient(url=CFG["qdrant"]["url"])
collection = CFG["qdrant"]["collection"]
dense_model = SentenceTransformer(CFG["qdrant"]["dense_model"])
kg = GraphDatabase.driver(CFG["kg"]["uri"], auth=(CFG["kg"]["user"], CFG["kg"]["pass"]))
SPARSE_CFG = CFG.get("sparse", {"enabled": False, "hash_dim": 32768, "model": "bge-m3"})
SPARSE_ENABLED = bool(SPARSE_CFG.get("enabled", False))
HASH_DIM = int(SPARSE_CFG.get("hash_dim", 32768))
def embed_dense(text):
return dense_model.encode(text, normalize_embeddings=True).tolist()
def _tokenize(text):
return re.findall(r"[a-z0-9]+", text.lower())
def _hash_token(tok):
import hashlib
return int(hashlib.md5(tok.encode("utf-8")).hexdigest(), 16) % HASH_DIM
def sparse_hash_trick(text):
toks = _tokenize(text)
if not toks: return [], []
from collections import Counter
counts = Counter(toks)
idxs, vals = [], []
for tok, c in counts.items():
idxs.append(_hash_token(tok))
vals.append(float(c))
norm = math.sqrt(sum(v*v for v in vals)) or 1.0
vals = [v / norm for v in vals]
return idxs, vals
def try_bge_m3_sparse(text):
try:
from FlagEmbedding import BGEM3FlagModel
m = BGEM3FlagModel('BAAI/bge-m3', use_fp16=True)
out = m.encode([text], return_dense=False, return_sparse=True)
if isinstance(out, dict) and "sparse" in out:
sp = out["sparse"][0]
return sp["indices"], sp["values"]
elif isinstance(out, list) and len(out) and "indices" in out[0]:
return out[0]["indices"], out[0]["values"]
except Exception:
pass
return [], []
def embed_sparse(text):
if not SPARSE_ENABLED:
return [], []
if SPARSE_CFG.get("model","").lower().startswith("bge-m3"):
idxs, vals = try_bge_m3_sparse(text)
if idxs and vals:
return idxs, vals
return sparse_hash_trick(text)
def search(query, limit):
dvec = embed_dense(query)
sidx, svals = embed_sparse(query)
kwargs = dict(collection_name=collection, with_payload=True, limit=limit)
if sidx and svals:
kwargs["query_sparse_vector"] = qm.SparseVector(indices=sidx, values=svals)
res = q.search(query_vector=dvec, **kwargs)
return [{"id": r.id, "score": r.score, "text": r.payload.get("text",""), "meta": {k:v for k,v in r.payload.items() if k!="text"}} for r in res]
def kg_symmetry_for_text(tx, text):
q = """
UNWIND $ents AS e
MATCH (n:Entity {name: e})-[r:RELATION]-(m:Entity)
RETURN e as seed, type(r) as rtype, n.name as n, m.name as m
"""
return list(tx.run(q, ents=[text]).data())
def main():
if len(sys.argv) < 2:
print("Usage: memory-search '<query>' [k]")
sys.exit(1)
query = sys.argv[1]
k = int(sys.argv[2]) if len(sys.argv) > 2 else CFG["qdrant"]["top_k"]
dense = search(query, k)
kg_edges = []
try:
with kg.session() as s:
for d in dense[:3]:
kg_edges.extend(s.execute_read(kg_symmetry_for_text, d["text"]))
except Exception:
pass
print(json.dumps({"dense": dense, "kg_symmetry": kg_edges}, ensure_ascii=False, indent=2))
if __name__ == "__main__":
main()