db_wrapper.py•2.9 kB
"""
Thread-safe SQLite database wrapper for SmartCodeSearch
"""
import sqlite3
import threading
from contextlib import contextmanager
from pathlib import Path
from typing import Any, List, Tuple, Optional
class ThreadSafeDB:
"""Thread-safe SQLite database wrapper using thread-local storage"""
def __init__(self, db_path: Path):
self.db_path = str(db_path)
self.thread_local = threading.local()
def get_connection(self) -> sqlite3.Connection:
"""Get or create a connection for the current thread"""
if not hasattr(self.thread_local, "connection"):
# Create a new connection for this thread
self.thread_local.connection = sqlite3.connect(
self.db_path,
check_same_thread=False # Allow the connection to be used in different threads
)
# Enable WAL mode for better concurrency
self.thread_local.connection.execute("PRAGMA journal_mode=WAL")
self.thread_local.connection.execute("PRAGMA synchronous=NORMAL")
self.thread_local.connection.execute("PRAGMA cache_size=10000")
self.thread_local.connection.execute("PRAGMA temp_store=MEMORY")
return self.thread_local.connection
@contextmanager
def get_cursor(self):
"""Context manager for database operations"""
conn = self.get_connection()
cursor = conn.cursor()
try:
yield cursor
conn.commit()
except Exception as e:
conn.rollback()
raise e
finally:
cursor.close()
def execute(self, query: str, params: Tuple = ()) -> None:
"""Execute a query"""
with self.get_cursor() as cursor:
cursor.execute(query, params)
def executemany(self, query: str, params: List[Tuple]) -> None:
"""Execute many queries"""
with self.get_cursor() as cursor:
cursor.executemany(query, params)
def fetchone(self, query: str, params: Tuple = ()) -> Optional[Tuple]:
"""Execute query and fetch one result"""
with self.get_cursor() as cursor:
cursor.execute(query, params)
return cursor.fetchone()
def fetchall(self, query: str, params: Tuple = ()) -> List[Tuple]:
"""Execute query and fetch all results"""
with self.get_cursor() as cursor:
cursor.execute(query, params)
return cursor.fetchall()
def close_thread_connection(self):
"""Close the connection for the current thread"""
if hasattr(self.thread_local, "connection"):
self.thread_local.connection.close()
del self.thread_local.connection
def close_all(self):
"""Close all connections (call from main thread when shutting down)"""
self.close_thread_connection()