db.py•4.83 kB
import sqlite3
from typing import Optional
import numpy as np
DB_PATH = 'storage/database.db'
def get_db_connection():
"""Create and return a database connection."""
import os
if not os.path.exists('storage'):
os.makedirs('storage')
conn = sqlite3.connect(DB_PATH)
conn.row_factory = sqlite3.Row
return conn
def init_db():
conn = get_db_connection()
with conn:
conn.execute('''
CREATE TABLE IF NOT EXISTS projects (
id INTEGER PRIMARY KEY AUTOINCREMENT,
name TEXT NOT NULL UNIQUE,
name_embedding BLOB
)
''')
conn.execute('''
CREATE TABLE IF NOT EXISTS cards (
id INTEGER PRIMARY KEY AUTOINCREMENT,
project_id INTEGER NOT NULL,
question TEXT NOT NULL,
hint TEXT,
answer TEXT NOT NULL,
description TEXT,
embedding BLOB,
FOREIGN KEY(project_id) REFERENCES projects(id)
)
''')
conn.close()
def add_project(name: str, name_embedding: bytes) -> dict:
conn = get_db_connection()
with conn:
cur = conn.execute('INSERT INTO projects (name, name_embedding) VALUES (?, ?)', (name, name_embedding))
pid = cur.lastrowid
row = conn.execute('SELECT id, name FROM projects WHERE id = ?', (pid,)).fetchone()
project = dict(row)
project['type'] = 'project'
return project
def get_all_projects():
conn = get_db_connection()
with conn:
return conn.execute('SELECT id, name FROM projects').fetchall()
def find_project_id_by_name_embedding(query_embedding: np.ndarray) -> Optional[int]:
conn = get_db_connection()
with conn:
rows = conn.execute('SELECT id, name_embedding FROM projects').fetchall()
best_id, best_sim = None, -1
for row in rows:
if row['name_embedding']:
emb = np.frombuffer(row['name_embedding'], dtype=np.float32)
sim = float(np.dot(query_embedding, emb) / (np.linalg.norm(query_embedding) * np.linalg.norm(emb)))
if sim > best_sim:
best_sim, best_id = sim, row['id']
return best_id
def add_card(project_id: int, question: str, answer: str, hint: Optional[str], description: Optional[str], embedding: bytes) -> dict:
conn = get_db_connection()
with conn:
cur = conn.execute(
'INSERT INTO cards (project_id, question, hint, answer, description, embedding) VALUES (?, ?, ?, ?, ?, ?)',
(project_id, question, hint, answer, description, embedding)
)
card_id = cur.lastrowid
row = conn.execute('SELECT * FROM cards WHERE id = ?', (card_id,)).fetchone()
card = dict(row)
if 'embedding' in card:
del card['embedding']
card['type'] = 'card'
return card
def get_all_cards_by_project(project_id: int):
conn = get_db_connection()
with conn:
return conn.execute('SELECT * FROM cards WHERE project_id = ?', (project_id,)).fetchall()
def get_random_card_by_project(project_id: int):
conn = get_db_connection()
with conn:
return conn.execute('SELECT * FROM cards WHERE project_id = ? ORDER BY RANDOM() LIMIT 1', (project_id,)).fetchone()
def get_card_by_id(card_id: int):
conn = get_db_connection()
with conn:
return conn.execute('SELECT * FROM cards WHERE id = ?', (card_id,)).fetchone()
def search_cards_by_embedding(project_id: int, query_embedding: np.ndarray):
conn = get_db_connection()
with conn:
rows = conn.execute('SELECT * FROM cards WHERE project_id = ?', (project_id,)).fetchall()
results = []
for row in rows:
if row['embedding']:
emb = np.frombuffer(row['embedding'], dtype=np.float32)
sim = float(np.dot(query_embedding, emb) / (np.linalg.norm(query_embedding) * np.linalg.norm(emb)))
results.append((sim, dict(row)))
results.sort(reverse=True, key=lambda x: x[0])
return [r[1] for r in results[:10]]
def global_search_cards_by_embedding(query_embedding: np.ndarray):
conn = get_db_connection()
with conn:
rows = conn.execute('SELECT * FROM cards').fetchall()
results = []
for row in rows:
if row['embedding']:
emb = np.frombuffer(row['embedding'], dtype=np.float32)
sim = float(np.dot(query_embedding, emb) / (np.linalg.norm(query_embedding) * np.linalg.norm(emb)))
results.append((sim, dict(row)))
results.sort(reverse=True, key=lambda x: x[0])
return [r[1] for r in results[:10]]
if __name__ == "__main__":
init_db()
print("Database initialized.")