memory_store.py•4.1 kB
#!/usr/bin/env python3
import os, sys, json, hashlib, subprocess, shlex, math, re
from collections import Counter
from qdrant_client import QdrantClient
from qdrant_client.http import models as qm
from sentence_transformers import SentenceTransformer
ROOT = os.path.dirname(os.path.dirname(__file__))
CFG_PATH = os.path.join(ROOT, "config", "memory.config.json")
# Load config with env overrides
sys.path.append(os.path.join(ROOT, "scripts"))
from config_loader import load_config # type: ignore
CFG = load_config(CFG_PATH)
q = QdrantClient(url=CFG["qdrant"]["url"])
collection = CFG["qdrant"]["collection"]
dense_model_name = CFG["qdrant"]["dense_model"]
top_k = CFG["qdrant"]["top_k"]
dense_model = SentenceTransformer(dense_model_name)
SPARSE_CFG = CFG.get("sparse", {"enabled": False, "hash_dim": 32768, "model": "bge-m3"})
SPARSE_ENABLED = bool(SPARSE_CFG.get("enabled", False))
HASH_DIM = int(SPARSE_CFG.get("hash_dim", 32768))
def ensure_collection():
from qdrant_client.http.models import Distance, VectorParams
try:
q.get_collection(collection_name=collection)
except Exception:
q.create_collection(
collection_name=collection,
vectors_config=VectorParams(size=dense_model.get_sentence_embedding_dimension(), distance=Distance.COSINE),
on_disk_payload=True
)
def embed_dense(text):
return dense_model.encode(text, normalize_embeddings=True).tolist()
def _tokenize(text):
return re.findall(r"[a-z0-9]+", text.lower())
def _hash_token(tok):
import hashlib
return int(hashlib.md5(tok.encode("utf-8")).hexdigest(), 16) % HASH_DIM
def sparse_hash_trick(text):
toks = _tokenize(text)
if not toks:
return [], []
counts = Counter(toks)
idxs = []
vals = []
for tok, c in counts.items():
idx = _hash_token(tok)
idxs.append(idx)
vals.append(float(c))
norm = math.sqrt(sum(v*v for v in vals)) or 1.0
vals = [v / norm for v in vals]
return idxs, vals
def try_bge_m3_sparse(text):
try:
from FlagEmbedding import BGEM3FlagModel
m = BGEM3FlagModel('BAAI/bge-m3', use_fp16=True)
out = m.encode([text], return_dense=False, return_sparse=True)
if isinstance(out, dict) and "sparse" in out:
sp = out["sparse"][0]
return sp["indices"], sp["values"]
elif isinstance(out, list) and len(out) and "indices" in out[0]:
return out[0]["indices"], out[0]["values"]
except Exception:
pass
return None, None
def embed_sparse(text):
if not SPARSE_ENABLED:
return [], []
if SPARSE_CFG.get("model","").lower().startswith("bge-m3"):
idxs, vals = try_bge_m3_sparse(text)
if idxs and vals:
return idxs, vals
return sparse_hash_trick(text)
def id_from(text):
return hashlib.sha1(text.encode("utf-8")).hexdigest()
def upsert(id, text, meta=None):
dvec = embed_dense(text)
sidx, svals = embed_sparse(text)
payload = {"text": text}
if meta: payload.update(meta)
sparse = None
if sidx and svals:
sparse = qm.SparseVector(indices=sidx, values=svals)
q.upsert(
collection_name=collection,
points=[qm.PointStruct(id=id, vector=dvec, payload=payload, sparse_vector=sparse)]
)
return id
def run_graphsync(text, pid):
hook = os.path.join(ROOT, "hooks", "graphsync_post_store.py")
cmd = f'{shlex.quote(sys.executable)} {shlex.quote(hook)} {shlex.quote(text)} {shlex.quote(pid)}'
subprocess.run(cmd, shell=True, check=False)
def main():
if sys.stdin.isatty() and len(sys.argv) < 2:
print('Usage: echo \'{\"text\":\"...\",\"meta\":{\"tags\":[\"x\"]}}\' | memory-store')
sys.exit(1)
data = json.loads(sys.stdin.read() or sys.argv[1])
text = data["text"]
meta = data.get("meta", {})
ensure_collection()
pid = data.get("id") or id_from(text)
upsert(pid, text, meta)
run_graphsync(text, pid)
print(json.dumps({"status": "ok", "id": pid}, ensure_ascii=False))
if __name__ == "__main__":
main()