"""SQLite database client for structured data storage."""
import aiosqlite
import os
from datetime import datetime
from pathlib import Path
from typing import Optional
from .schemas import (
Statement,
StatementRecord,
HoldingRecord,
TransactionRecord,
BenchmarkRecord,
RiskMetricsRecord,
)
# Calculate default path relative to this file's location
# investing-mcp/src/database/sqlite_client.py -> investments/db
_DEFAULT_DB_PATH = str(Path(__file__).parent.parent.parent.parent.resolve() / "db" / "statements.db")
class SQLiteClient:
"""Async SQLite client for investment statement data."""
def __init__(self, db_path: str = _DEFAULT_DB_PATH):
"""Initialize SQLite client.
Args:
db_path: Path to SQLite database file
"""
self.db_path = db_path
os.makedirs(os.path.dirname(db_path), exist_ok=True)
async def initialize(self):
"""Initialize database schema."""
async with aiosqlite.connect(self.db_path) as db:
# Statements table
await db.execute("""
CREATE TABLE IF NOT EXISTS statements (
statement_id TEXT PRIMARY KEY,
institution TEXT NOT NULL,
account_number TEXT NOT NULL,
statement_date TEXT NOT NULL,
indexed_at TEXT NOT NULL,
file_path TEXT NOT NULL,
pdf_path TEXT,
json_path TEXT,
current_balance_cad REAL,
return_mtd REAL,
return_qtd REAL,
return_ytd REAL,
return_1y REAL,
return_3y REAL,
return_5y REAL,
return_since_inception REAL
)
""")
# Add performance columns to existing statements table (for migration)
performance_columns = [
("current_balance_cad", "REAL"),
("return_mtd", "REAL"),
("return_qtd", "REAL"),
("return_ytd", "REAL"),
("return_1y", "REAL"),
("return_3y", "REAL"),
("return_5y", "REAL"),
("return_since_inception", "REAL"),
]
for col_name, col_type in performance_columns:
try:
await db.execute(f"ALTER TABLE statements ADD COLUMN {col_name} {col_type}")
except Exception:
pass # Column already exists
# Holdings table
await db.execute("""
CREATE TABLE IF NOT EXISTS holdings (
id INTEGER PRIMARY KEY AUTOINCREMENT,
statement_id TEXT NOT NULL,
symbol TEXT NOT NULL,
description TEXT,
quantity REAL NOT NULL,
currency TEXT NOT NULL,
cost_per_share REAL,
total_cost REAL,
market_price REAL,
market_value_cad REAL,
profit_loss REAL,
percent_return REAL,
security_type TEXT DEFAULT 'ETF',
annualized_return REAL,
FOREIGN KEY (statement_id) REFERENCES statements (statement_id)
)
""")
# Add new columns to existing holdings table (for migration)
holdings_columns = [
("security_type", "TEXT DEFAULT 'ETF'"),
("annualized_return", "REAL"),
]
for col_name, col_type in holdings_columns:
try:
await db.execute(f"ALTER TABLE holdings ADD COLUMN {col_name} {col_type}")
except Exception:
pass # Column already exists
# Transactions table
await db.execute("""
CREATE TABLE IF NOT EXISTS transactions (
id INTEGER PRIMARY KEY AUTOINCREMENT,
statement_id TEXT NOT NULL,
transaction_date TEXT NOT NULL,
settle_date TEXT NOT NULL,
activity_type TEXT NOT NULL,
symbol TEXT,
description TEXT,
quantity REAL,
price REAL,
gross_amount REAL,
commission REAL,
net_amount REAL,
currency TEXT NOT NULL,
FOREIGN KEY (statement_id) REFERENCES statements (statement_id)
)
""")
# Benchmarks table
await db.execute("""
CREATE TABLE IF NOT EXISTS benchmarks (
id INTEGER PRIMARY KEY AUTOINCREMENT,
statement_id TEXT NOT NULL,
name TEXT NOT NULL,
symbol TEXT,
return_mtd REAL,
return_qtd REAL,
return_ytd REAL,
return_1y REAL,
return_3y REAL,
return_5y REAL,
alpha_ytd REAL,
alpha_1y REAL,
FOREIGN KEY (statement_id) REFERENCES statements (statement_id)
)
""")
# Risk metrics table (calculated from historical data)
await db.execute("""
CREATE TABLE IF NOT EXISTS risk_metrics (
id INTEGER PRIMARY KEY AUTOINCREMENT,
account_number TEXT NOT NULL,
calculation_date TEXT NOT NULL,
period_months INTEGER DEFAULT 12,
volatility REAL,
beta REAL,
sharpe_ratio REAL,
max_drawdown REAL,
var_95 REAL
)
""")
# Create indexes
await db.execute("CREATE INDEX IF NOT EXISTS idx_holdings_symbol ON holdings(symbol)")
await db.execute("CREATE INDEX IF NOT EXISTS idx_holdings_statement ON holdings(statement_id)")
await db.execute("CREATE INDEX IF NOT EXISTS idx_transactions_date ON transactions(transaction_date)")
await db.execute("CREATE INDEX IF NOT EXISTS idx_transactions_statement ON transactions(statement_id)")
await db.execute("CREATE INDEX IF NOT EXISTS idx_transactions_symbol ON transactions(symbol)")
await db.execute("CREATE INDEX IF NOT EXISTS idx_statements_date ON statements(statement_date)")
await db.execute("CREATE INDEX IF NOT EXISTS idx_statements_account ON statements(account_number)")
await db.execute("CREATE INDEX IF NOT EXISTS idx_benchmarks_statement ON benchmarks(statement_id)")
await db.execute("CREATE INDEX IF NOT EXISTS idx_risk_metrics_account ON risk_metrics(account_number)")
await db.commit()
async def insert_statement(self, statement: Statement):
"""Insert a statement with all its data.
Args:
statement: Statement object to insert
"""
async with aiosqlite.connect(self.db_path) as db:
# Delete existing data for this statement to prevent duplicates
await db.execute("DELETE FROM holdings WHERE statement_id = ?", (statement.statement_id,))
await db.execute("DELETE FROM benchmarks WHERE statement_id = ?", (statement.statement_id,))
# Note: Transactions have their own deduplication logic, so we don't delete them
# Insert statement record with performance data
await db.execute("""
INSERT OR REPLACE INTO statements
(statement_id, institution, account_number, statement_date, indexed_at,
file_path, pdf_path, json_path, current_balance_cad,
return_mtd, return_qtd, return_ytd, return_1y, return_3y, return_5y,
return_since_inception)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
""", (
statement.statement_id,
statement.institution,
statement.summary.account_number,
statement.summary.statement_date.isoformat(),
statement.indexed_at.isoformat(),
statement.file_path,
statement.pdf_path,
statement.json_path,
statement.summary.current_balance_cad,
statement.summary.return_mtd,
statement.summary.return_qtd,
statement.summary.return_ytd,
statement.summary.return_1y,
statement.summary.return_3y,
statement.summary.return_5y,
statement.summary.return_since_inception,
))
# Insert holdings
for holding in statement.holdings:
await db.execute("""
INSERT INTO holdings
(statement_id, symbol, description, quantity, currency, cost_per_share,
total_cost, market_price, market_value_cad, profit_loss, percent_return,
security_type, annualized_return)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
""", (
statement.statement_id,
holding.symbol,
holding.description,
holding.quantity,
holding.currency,
holding.cost_per_share,
holding.total_cost,
holding.market_price,
holding.market_value_cad,
holding.profit_loss,
holding.percent_return,
holding.security_type,
holding.annualized_return,
))
# Insert transactions with deduplication
for txn in statement.transactions:
# Check for duplicate transaction
# Key: date + symbol + description + quantity + price + activity_type + currency + net_amount
async with db.execute("""
SELECT id FROM transactions
WHERE transaction_date = ?
AND activity_type = ?
AND COALESCE(symbol, '') = ?
AND COALESCE(description, '') = ?
AND COALESCE(quantity, 0) = ?
AND COALESCE(price, 0) = ?
AND currency = ?
AND net_amount = ?
""", (
txn.transaction_date.isoformat(),
txn.activity_type,
txn.symbol or '',
txn.description or '',
txn.quantity or 0,
txn.price or 0,
txn.currency,
txn.net_amount,
)) as cursor:
existing = await cursor.fetchone()
if not existing:
await db.execute("""
INSERT INTO transactions
(statement_id, transaction_date, settle_date, activity_type, symbol,
description, quantity, price, gross_amount, commission, net_amount, currency)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
""", (
statement.statement_id,
txn.transaction_date.isoformat(),
txn.settle_date.isoformat(),
txn.activity_type,
txn.symbol,
txn.description,
txn.quantity,
txn.price,
txn.gross_amount,
txn.commission,
txn.net_amount,
txn.currency,
))
# Insert benchmarks
for benchmark in statement.benchmarks:
await db.execute("""
INSERT INTO benchmarks
(statement_id, name, symbol, return_mtd, return_qtd, return_ytd,
return_1y, return_3y, return_5y, alpha_ytd, alpha_1y)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
""", (
statement.statement_id,
benchmark.name,
benchmark.symbol,
benchmark.return_mtd,
benchmark.return_qtd,
benchmark.return_ytd,
benchmark.return_1y,
benchmark.return_3y,
benchmark.return_5y,
benchmark.alpha_ytd,
benchmark.alpha_1y,
))
await db.commit()
async def get_statement(self, statement_id: str) -> Optional[dict]:
"""Get statement by ID.
Args:
statement_id: Statement ID
Returns:
Statement data or None
"""
async with aiosqlite.connect(self.db_path) as db:
db.row_factory = aiosqlite.Row
async with db.execute(
"SELECT * FROM statements WHERE statement_id = ?",
(statement_id,)
) as cursor:
row = await cursor.fetchone()
return dict(row) if row else None
async def list_statements(self, limit: int = 100) -> list[dict]:
"""List all statements.
Args:
limit: Maximum number of statements to return
Returns:
List of statement records
"""
async with aiosqlite.connect(self.db_path) as db:
db.row_factory = aiosqlite.Row
async with db.execute(
"SELECT * FROM statements ORDER BY statement_date DESC LIMIT ?",
(limit,)
) as cursor:
rows = await cursor.fetchall()
return [dict(row) for row in rows]
async def get_holdings_by_symbol(self, symbol: str) -> list[dict]:
"""Get all holdings for a symbol across all statements.
Args:
symbol: Stock/ETF symbol
Returns:
List of holding records
"""
async with aiosqlite.connect(self.db_path) as db:
db.row_factory = aiosqlite.Row
async with db.execute(
"""
SELECT h.*, s.statement_date, s.account_number
FROM holdings h
JOIN statements s ON h.statement_id = s.statement_id
WHERE h.symbol = ?
ORDER BY s.statement_date DESC
""",
(symbol,)
) as cursor:
rows = await cursor.fetchall()
return [dict(row) for row in rows]
async def get_transactions_by_date_range(
self, start_date: str, end_date: str, account_number: Optional[str] = None
) -> list[dict]:
"""Get transactions within a date range.
Args:
start_date: Start date (ISO format)
end_date: End date (ISO format)
account_number: Optional account number filter
Returns:
List of transaction records
"""
async with aiosqlite.connect(self.db_path) as db:
db.row_factory = aiosqlite.Row
if account_number:
query = """
SELECT t.*, s.account_number, s.statement_date
FROM transactions t
JOIN statements s ON t.statement_id = s.statement_id
WHERE t.transaction_date >= ? AND t.transaction_date <= ?
AND s.account_number = ?
ORDER BY t.transaction_date DESC
"""
params = (start_date, end_date, account_number)
else:
query = """
SELECT t.*, s.account_number, s.statement_date
FROM transactions t
JOIN statements s ON t.statement_id = s.statement_id
WHERE t.transaction_date >= ? AND t.transaction_date <= ?
ORDER BY t.transaction_date DESC
"""
params = (start_date, end_date)
async with db.execute(query, params) as cursor:
rows = await cursor.fetchall()
return [dict(row) for row in rows]
async def get_account_balance(self, account_number: str, date: str) -> Optional[dict]:
"""Get account balance at a specific date.
Args:
account_number: Account number
date: Date (ISO format)
Returns:
Statement record closest to the date
"""
async with aiosqlite.connect(self.db_path) as db:
db.row_factory = aiosqlite.Row
async with db.execute(
"""
SELECT * FROM statements
WHERE account_number = ? AND statement_date <= ?
ORDER BY statement_date DESC
LIMIT 1
""",
(account_number, date)
) as cursor:
row = await cursor.fetchone()
return dict(row) if row else None
async def get_stats(self) -> dict:
"""Get database statistics.
Returns:
Statistics dictionary
"""
async with aiosqlite.connect(self.db_path) as db:
stats = {}
# Total statements
async with db.execute("SELECT COUNT(*) FROM statements") as cursor:
stats["total_statements"] = (await cursor.fetchone())[0]
# Total holdings
async with db.execute("SELECT COUNT(*) FROM holdings") as cursor:
stats["total_holdings"] = (await cursor.fetchone())[0]
# Total transactions
async with db.execute("SELECT COUNT(*) FROM transactions") as cursor:
stats["total_transactions"] = (await cursor.fetchone())[0]
# Unique symbols
async with db.execute("SELECT COUNT(DISTINCT symbol) FROM holdings") as cursor:
stats["unique_symbols"] = (await cursor.fetchone())[0]
# Unique accounts
async with db.execute("SELECT COUNT(DISTINCT account_number) FROM statements") as cursor:
stats["unique_accounts"] = (await cursor.fetchone())[0]
return stats
async def clear_all(self):
"""Clear all data from the database.
Deletes all records from statements, holdings, and transactions tables.
"""
async with aiosqlite.connect(self.db_path) as db:
await db.execute("DELETE FROM benchmarks")
await db.execute("DELETE FROM risk_metrics")
await db.execute("DELETE FROM transactions")
await db.execute("DELETE FROM holdings")
await db.execute("DELETE FROM statements")
await db.commit()
async def get_performance_history(
self, account_number: str, limit: int = 12
) -> list[dict]:
"""Get performance history for an account over time.
Args:
account_number: Account number
limit: Maximum number of statements to return
Returns:
List of statement records with performance data
"""
async with aiosqlite.connect(self.db_path) as db:
db.row_factory = aiosqlite.Row
async with db.execute(
"""
SELECT statement_id, statement_date, current_balance_cad,
return_mtd, return_qtd, return_ytd, return_1y,
return_3y, return_5y, return_since_inception
FROM statements
WHERE account_number = ?
ORDER BY statement_date DESC
LIMIT ?
""",
(account_number, limit)
) as cursor:
rows = await cursor.fetchall()
return [dict(row) for row in rows]
async def get_benchmarks_for_statement(self, statement_id: str) -> list[dict]:
"""Get benchmark comparisons for a statement.
Args:
statement_id: Statement ID
Returns:
List of benchmark records
"""
async with aiosqlite.connect(self.db_path) as db:
db.row_factory = aiosqlite.Row
async with db.execute(
"SELECT * FROM benchmarks WHERE statement_id = ?",
(statement_id,)
) as cursor:
rows = await cursor.fetchall()
return [dict(row) for row in rows]
async def insert_risk_metrics(self, metrics: dict):
"""Insert calculated risk metrics.
Args:
metrics: Risk metrics dictionary
"""
async with aiosqlite.connect(self.db_path) as db:
await db.execute("""
INSERT INTO risk_metrics
(account_number, calculation_date, period_months, volatility,
beta, sharpe_ratio, max_drawdown, var_95)
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
""", (
metrics["account_number"],
metrics["calculation_date"],
metrics.get("period_months", 12),
metrics.get("volatility"),
metrics.get("beta"),
metrics.get("sharpe_ratio"),
metrics.get("max_drawdown"),
metrics.get("var_95"),
))
await db.commit()
async def get_risk_metrics(
self, account_number: str, limit: int = 1
) -> list[dict]:
"""Get risk metrics for an account.
Args:
account_number: Account number
limit: Number of most recent calculations to return
Returns:
List of risk metrics records
"""
async with aiosqlite.connect(self.db_path) as db:
db.row_factory = aiosqlite.Row
async with db.execute(
"""
SELECT * FROM risk_metrics
WHERE account_number = ?
ORDER BY calculation_date DESC
LIMIT ?
""",
(account_number, limit)
) as cursor:
rows = await cursor.fetchall()
return [dict(row) for row in rows]
async def get_accounts(self) -> list[dict]:
"""Get list of all accounts with their latest balance.
Returns:
List of account summaries
"""
async with aiosqlite.connect(self.db_path) as db:
db.row_factory = aiosqlite.Row
async with db.execute(
"""
SELECT account_number, institution,
MAX(statement_date) as latest_date,
current_balance_cad,
return_ytd
FROM statements
GROUP BY account_number
ORDER BY account_number
"""
) as cursor:
rows = await cursor.fetchall()
return [dict(row) for row in rows]