"""Database operations for expense tracker using SQLite."""
import sqlite3
from datetime import datetime, timedelta
from pathlib import Path
from typing import Optional
from .models import Receipt, LineItem, ItemStats
# Default database path
DEFAULT_DB_PATH = Path(__file__).parent.parent / "data" / "expenses.db"
def get_connection(db_path: Path = DEFAULT_DB_PATH) -> sqlite3.Connection:
"""Get a database connection with proper configuration."""
# Ensure data directory exists
db_path.parent.mkdir(parents=True, exist_ok=True)
conn = sqlite3.Connection(str(db_path))
conn.row_factory = sqlite3.Row # Enable column access by name
conn.execute("PRAGMA foreign_keys = ON") # Enable foreign key constraints
return conn
def init_database(db_path: Path = DEFAULT_DB_PATH) -> None:
"""Initialize the database with schema."""
conn = get_connection(db_path)
try:
# Create receipts table
conn.execute("""
CREATE TABLE IF NOT EXISTS receipts (
id INTEGER PRIMARY KEY AUTOINCREMENT,
store_name TEXT NOT NULL,
purchase_date TEXT NOT NULL,
subtotal REAL,
tax REAL,
total REAL NOT NULL,
created_at TEXT NOT NULL DEFAULT CURRENT_TIMESTAMP
)
""")
# Create indexes for receipts
conn.execute(
"CREATE INDEX IF NOT EXISTS idx_receipts_date ON receipts(purchase_date)"
)
conn.execute(
"CREATE INDEX IF NOT EXISTS idx_receipts_store ON receipts(store_name)"
)
# Create items table
conn.execute("""
CREATE TABLE IF NOT EXISTS items (
id INTEGER PRIMARY KEY AUTOINCREMENT,
receipt_id INTEGER NOT NULL,
item_name_raw TEXT NOT NULL,
item_type TEXT NOT NULL,
quantity REAL DEFAULT 1.0,
unit_price REAL,
line_total REAL NOT NULL,
created_at TEXT NOT NULL DEFAULT CURRENT_TIMESTAMP,
FOREIGN KEY (receipt_id) REFERENCES receipts(id) ON DELETE CASCADE
)
""")
# Create indexes for items
conn.execute("CREATE INDEX IF NOT EXISTS idx_items_type ON items(item_type)")
conn.execute(
"CREATE INDEX IF NOT EXISTS idx_items_receipt ON items(receipt_id)"
)
conn.execute(
"CREATE INDEX IF NOT EXISTS idx_items_type_date ON items(item_type, receipt_id)"
)
conn.commit()
finally:
conn.close()
def insert_receipt(receipt: Receipt, db_path: Path = DEFAULT_DB_PATH) -> int:
"""Insert a receipt and return its ID."""
conn = get_connection(db_path)
try:
cursor = conn.execute(
"""
INSERT INTO receipts (store_name, purchase_date, subtotal, tax, total)
VALUES (?, ?, ?, ?, ?)
""",
(
receipt.store_name,
receipt.purchase_date,
receipt.subtotal,
receipt.tax,
receipt.total,
),
)
conn.commit()
return cursor.lastrowid
finally:
conn.close()
def insert_items(
receipt_id: int, items: list[LineItem], db_path: Path = DEFAULT_DB_PATH
) -> None:
"""Bulk insert items for a receipt."""
if not items:
return
conn = get_connection(db_path)
try:
conn.executemany(
"""
INSERT INTO items (receipt_id, item_name_raw, item_type, quantity, unit_price, line_total)
VALUES (?, ?, ?, ?, ?, ?)
""",
[
(
receipt_id,
item.item_name_raw,
item.item_type,
item.quantity,
item.unit_price,
item.line_total,
)
for item in items
],
)
conn.commit()
finally:
conn.close()
def query_item_history(
item_type: str,
time_range_days: int = 365,
db_path: Path = DEFAULT_DB_PATH,
) -> dict:
"""Query purchase history for a specific item type.
Returns a dictionary with:
- purchases: list of purchase records
- stats: ItemStats object with aggregated statistics
"""
cutoff_date = (datetime.now() - timedelta(days=time_range_days)).strftime(
"%Y-%m-%d"
)
conn = get_connection(db_path)
try:
# Query purchases
cursor = conn.execute(
"""
SELECT
r.purchase_date,
r.store_name,
i.item_name_raw,
i.quantity,
i.unit_price,
i.line_total
FROM items i
JOIN receipts r ON i.receipt_id = r.id
WHERE i.item_type = ? AND r.purchase_date >= ?
ORDER BY r.purchase_date DESC
""",
(item_type, cutoff_date),
)
purchases = []
for row in cursor.fetchall():
purchases.append(
{
"date": row["purchase_date"],
"store": row["store_name"],
"item_name": row["item_name_raw"],
"quantity": row["quantity"],
"unit_price": row["unit_price"],
"price": row["line_total"],
}
)
# Query stats
stats_cursor = conn.execute(
"""
SELECT
COUNT(DISTINCT r.id) as total_purchases,
MAX(r.purchase_date) as last_purchase_date,
MIN(r.purchase_date) as first_purchase_date,
SUM(i.line_total) as total_spent
FROM items i
JOIN receipts r ON i.receipt_id = r.id
WHERE i.item_type = ? AND r.purchase_date >= ?
""",
(item_type, cutoff_date),
)
stats_row = stats_cursor.fetchone()
# Calculate average days between purchases
avg_days = None
if stats_row["total_purchases"] and stats_row["total_purchases"] > 1:
first_date = datetime.fromisoformat(stats_row["first_purchase_date"])
last_date = datetime.fromisoformat(stats_row["last_purchase_date"])
total_days = (last_date - first_date).days
if total_days > 0:
avg_days = total_days / (stats_row["total_purchases"] - 1)
stats = {
"total_purchases": stats_row["total_purchases"] or 0,
"last_purchase_date": stats_row["last_purchase_date"],
"first_purchase_date": stats_row["first_purchase_date"],
"average_days_between": round(avg_days, 1) if avg_days else None,
"total_spent": round(stats_row["total_spent"] or 0, 2),
}
return {"purchases": purchases, "stats": stats}
finally:
conn.close()
def get_all_item_types(db_path: Path = DEFAULT_DB_PATH) -> list[dict]:
"""Get all distinct item types with their statistics."""
conn = get_connection(db_path)
try:
cursor = conn.execute("""
SELECT
i.item_type,
COUNT(DISTINCT r.id) as total_purchases,
MAX(r.purchase_date) as last_purchase_date,
SUM(i.line_total) as total_spent
FROM items i
JOIN receipts r ON i.receipt_id = r.id
GROUP BY i.item_type
ORDER BY last_purchase_date DESC
""")
item_types = []
for row in cursor.fetchall():
item_types.append(
{
"item_type": row["item_type"],
"total_purchases": row["total_purchases"],
"last_purchase_date": row["last_purchase_date"],
"total_spent": round(row["total_spent"], 2),
}
)
return item_types
finally:
conn.close()