"""
DuckDB-based data storage for efficient SQL queries on market data.
This module provides in-memory storage for price history and option chain data,
allowing the LLM to run SQL queries instead of processing large raw datasets.
"""
import logging
from datetime import datetime, timedelta
from typing import Any
import duckdb
import sqlglot
from sqlglot import exp
logger = logging.getLogger(__name__)
# Cache TTL in seconds (1 hour)
CACHE_TTL_SECONDS = 3600
class DataStorage:
"""In-memory DuckDB storage for market data."""
def __init__(self):
"""Initialize in-memory DuckDB connection."""
self._conn = duckdb.connect(":memory:")
self._init_schema()
def _init_schema(self):
"""Create tables for storing market data."""
# Price history table
self._conn.execute("""
CREATE TABLE IF NOT EXISTS price_history (
symbol VARCHAR,
datetime TIMESTAMP,
open DOUBLE,
high DOUBLE,
low DOUBLE,
close DOUBLE,
volume BIGINT,
PRIMARY KEY (symbol, datetime)
)
""")
# Option chain table
self._conn.execute("""
CREATE TABLE IF NOT EXISTS options (
symbol VARCHAR,
underlying_symbol VARCHAR,
underlying_price DOUBLE,
option_type VARCHAR,
strike DOUBLE,
expiration DATE,
days_to_expiration INTEGER,
bid DOUBLE,
ask DOUBLE,
last DOUBLE,
mark DOUBLE,
volume BIGINT,
open_interest BIGINT,
implied_volatility DOUBLE,
delta DOUBLE,
gamma DOUBLE,
theta DOUBLE,
vega DOUBLE,
rho DOUBLE,
in_the_money BOOLEAN,
intrinsic_value DOUBLE,
extrinsic_value DOUBLE,
time_value DOUBLE
)
""")
# Metadata table to track what's loaded
self._conn.execute("""
CREATE TABLE IF NOT EXISTS data_metadata (
data_type VARCHAR,
symbol VARCHAR,
loaded_at TIMESTAMP,
expires_at TIMESTAMP,
record_count INTEGER,
params VARCHAR,
PRIMARY KEY (data_type, symbol)
)
""")
def is_cache_valid(self, data_type: str, symbol: str) -> dict[str, Any]:
"""
Check if cached data exists and is still valid (not expired).
Args:
data_type: Type of data ('price_history' or 'options')
symbol: Stock ticker symbol
Returns:
Dict with 'valid' bool, and if valid: 'loaded_at', 'expires_at', 'record_count', 'ttl_remaining'
"""
symbol = symbol.upper()
now = datetime.now()
result = self._conn.execute(
"""
SELECT loaded_at, expires_at, record_count, params
FROM data_metadata
WHERE data_type = ? AND symbol = ?
""",
[data_type, symbol],
).fetchone()
if not result:
return {"valid": False, "reason": "not_loaded"}
loaded_at, expires_at, record_count, params = result
if expires_at is None or now > expires_at:
# Expired - clean up the stale data
self._evict_expired(data_type, symbol)
return {"valid": False, "reason": "expired"}
ttl_remaining = int((expires_at - now).total_seconds())
return {
"valid": True,
"loaded_at": str(loaded_at),
"expires_at": str(expires_at),
"record_count": record_count,
"ttl_remaining_seconds": ttl_remaining,
"params": params,
}
def _evict_expired(self, data_type: str, symbol: str):
"""Remove expired data from cache."""
symbol = symbol.upper()
if data_type == "price_history":
self._conn.execute("DELETE FROM price_history WHERE symbol = ?", [symbol])
elif data_type == "options":
self._conn.execute("DELETE FROM options WHERE underlying_symbol = ?", [symbol])
self._conn.execute(
"DELETE FROM data_metadata WHERE data_type = ? AND symbol = ?",
[data_type, symbol],
)
logger.info(f"Evicted expired {data_type} cache for {symbol}")
def evict_all_expired(self):
"""Remove all expired data from cache."""
now = datetime.now()
expired = self._conn.execute(
"SELECT data_type, symbol FROM data_metadata WHERE expires_at < ?",
[now],
).fetchall()
for data_type, symbol in expired:
self._evict_expired(data_type, symbol)
return len(expired)
def store_price_history(
self, symbol: str, candles: list[dict], params: dict[str, Any] | None = None
) -> int:
"""
Store price history data for a symbol.
Args:
symbol: Stock ticker symbol
candles: List of OHLCV candle dicts
params: Optional parameters used to fetch the data
Returns:
Number of records stored
"""
symbol = symbol.upper()
# Clear existing data for this symbol
self._conn.execute("DELETE FROM price_history WHERE symbol = ?", [symbol])
# Insert new data
for candle in candles:
self._conn.execute(
"""
INSERT INTO price_history (symbol, datetime, open, high, low, close, volume)
VALUES (?, ?, ?, ?, ?, ?, ?)
""",
[
symbol,
candle.get("datetime"),
candle.get("open"),
candle.get("high"),
candle.get("low"),
candle.get("close"),
candle.get("volume"),
],
)
# Update metadata with TTL
import json
now = datetime.now()
expires_at = now + timedelta(seconds=CACHE_TTL_SECONDS)
self._conn.execute(
"""
INSERT OR REPLACE INTO data_metadata (data_type, symbol, loaded_at, expires_at, record_count, params)
VALUES ('price_history', ?, ?, ?, ?, ?)
""",
[symbol, now, expires_at, len(candles), json.dumps(params) if params else None],
)
logger.info(f"Stored {len(candles)} price history records for {symbol} (expires: {expires_at})")
return len(candles)
def store_options(
self,
underlying_symbol: str,
underlying_price: float | None,
calls: list[dict],
puts: list[dict],
params: dict[str, Any] | None = None,
) -> int:
"""
Store option chain data for a symbol.
Args:
underlying_symbol: Underlying stock ticker
underlying_price: Current underlying price
calls: List of call option dicts
puts: List of put option dicts
params: Optional parameters used to fetch the data
Returns:
Number of records stored
"""
underlying_symbol = underlying_symbol.upper()
# Clear existing data for this underlying
self._conn.execute("DELETE FROM options WHERE underlying_symbol = ?", [underlying_symbol])
total_count = 0
# Insert calls
for opt in calls:
self._insert_option(underlying_symbol, underlying_price, "CALL", opt)
total_count += 1
# Insert puts
for opt in puts:
self._insert_option(underlying_symbol, underlying_price, "PUT", opt)
total_count += 1
# Update metadata with TTL
import json
now = datetime.now()
expires_at = now + timedelta(seconds=CACHE_TTL_SECONDS)
self._conn.execute(
"""
INSERT OR REPLACE INTO data_metadata (data_type, symbol, loaded_at, expires_at, record_count, params)
VALUES ('options', ?, ?, ?, ?, ?)
""",
[underlying_symbol, now, expires_at, total_count, json.dumps(params) if params else None],
)
logger.info(f"Stored {total_count} option contracts for {underlying_symbol} (expires: {expires_at})")
return total_count
def _insert_option(
self,
underlying_symbol: str,
underlying_price: float | None,
option_type: str,
opt: dict,
):
"""Insert a single option contract."""
self._conn.execute(
"""
INSERT INTO options (
symbol, underlying_symbol, underlying_price, option_type,
strike, expiration, days_to_expiration,
bid, ask, last, mark, volume, open_interest,
implied_volatility, delta, gamma, theta, vega, rho,
in_the_money, intrinsic_value, extrinsic_value, time_value
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
""",
[
opt.get("symbol"),
underlying_symbol,
underlying_price,
option_type,
opt.get("strike"),
opt.get("expiration"),
opt.get("days_to_expiration"),
opt.get("bid"),
opt.get("ask"),
opt.get("last"),
opt.get("mark"),
opt.get("volume"),
opt.get("open_interest"),
opt.get("implied_volatility"),
opt.get("delta"),
opt.get("gamma"),
opt.get("theta"),
opt.get("vega"),
opt.get("rho"),
opt.get("in_the_money"),
opt.get("intrinsic_value"),
opt.get("extrinsic_value"),
opt.get("time_value"),
],
)
def _validate_sql(self, sql: str) -> None:
"""
Validate that SQL is a read-only SELECT query using sqlglot parser.
Args:
sql: SQL query to validate
Raises:
ValueError: If query is not a valid read-only SELECT
"""
try:
# Parse the SQL using sqlglot (use duckdb dialect for compatibility)
parsed = sqlglot.parse(sql, dialect="duckdb")
except sqlglot.errors.ParseError as e:
raise ValueError(f"SQL parse error: {e}")
if not parsed:
raise ValueError("Empty or invalid SQL query")
# Check each statement in the parsed result
for statement in parsed:
if statement is None:
continue
# Only allow SELECT statements (which includes CTEs with WITH...SELECT)
if not isinstance(statement, exp.Select):
stmt_type = type(statement).__name__
raise ValueError(
f"Only SELECT queries are allowed. Got: {stmt_type}"
)
# Check for subqueries that might be DML
for node in statement.walk():
# Disallow any DML expressions
disallowed = (exp.Insert, exp.Update, exp.Delete, exp.Drop,
exp.Create, exp.Alter, exp.Merge)
if isinstance(node, disallowed):
raise ValueError(
f"Query contains disallowed operation: {type(node).__name__}"
)
def query(self, sql: str) -> dict[str, Any]:
"""
Execute a SQL query and return results.
Args:
sql: SQL query to execute (SELECT only for safety)
Returns:
Dict with columns and rows
Raises:
ValueError: If query is not a SELECT statement
"""
# Validate SQL using sqlglot parser
self._validate_sql(sql)
try:
result = self._conn.execute(sql)
columns = [desc[0] for desc in result.description]
rows = result.fetchall()
# Convert to list of dicts for easier JSON serialization
data = [dict(zip(columns, row)) for row in rows]
return {
"success": True,
"columns": columns,
"row_count": len(data),
"data": data,
}
except Exception as e:
logger.error(f"Query failed: {e}")
return {
"success": False,
"error": str(e),
"columns": [],
"row_count": 0,
"data": [],
}
def get_available_data(self) -> dict[str, Any]:
"""Get metadata about what data is currently loaded."""
result = self._conn.execute("""
SELECT data_type, symbol, loaded_at, record_count, params
FROM data_metadata
ORDER BY loaded_at DESC
""")
rows = result.fetchall()
data = []
for row in rows:
data.append(
{
"data_type": row[0],
"symbol": row[1],
"loaded_at": str(row[2]) if row[2] else None,
"record_count": row[3],
"params": row[4],
}
)
return {"datasets": data}
def get_schema_info(self) -> dict[str, Any]:
"""Get schema information for available tables."""
return {
"tables": {
"price_history": {
"description": "OHLCV price history data",
"columns": [
{"name": "symbol", "type": "VARCHAR", "description": "Stock ticker"},
{"name": "datetime", "type": "TIMESTAMP", "description": "Candle timestamp"},
{"name": "open", "type": "DOUBLE", "description": "Opening price"},
{"name": "high", "type": "DOUBLE", "description": "High price"},
{"name": "low", "type": "DOUBLE", "description": "Low price"},
{"name": "close", "type": "DOUBLE", "description": "Closing price"},
{"name": "volume", "type": "BIGINT", "description": "Trading volume"},
],
},
"options": {
"description": "Option chain data with Greeks",
"columns": [
{"name": "symbol", "type": "VARCHAR", "description": "Option contract symbol"},
{"name": "underlying_symbol", "type": "VARCHAR", "description": "Underlying stock ticker"},
{"name": "underlying_price", "type": "DOUBLE", "description": "Current underlying price"},
{"name": "option_type", "type": "VARCHAR", "description": "CALL or PUT"},
{"name": "strike", "type": "DOUBLE", "description": "Strike price"},
{"name": "expiration", "type": "DATE", "description": "Expiration date"},
{"name": "days_to_expiration", "type": "INTEGER", "description": "Days until expiration"},
{"name": "bid", "type": "DOUBLE", "description": "Bid price"},
{"name": "ask", "type": "DOUBLE", "description": "Ask price"},
{"name": "last", "type": "DOUBLE", "description": "Last trade price"},
{"name": "mark", "type": "DOUBLE", "description": "Mark price (mid)"},
{"name": "volume", "type": "BIGINT", "description": "Contract volume"},
{"name": "open_interest", "type": "BIGINT", "description": "Open interest"},
{"name": "implied_volatility", "type": "DOUBLE", "description": "Implied volatility"},
{"name": "delta", "type": "DOUBLE", "description": "Delta Greek"},
{"name": "gamma", "type": "DOUBLE", "description": "Gamma Greek"},
{"name": "theta", "type": "DOUBLE", "description": "Theta Greek"},
{"name": "vega", "type": "DOUBLE", "description": "Vega Greek"},
{"name": "rho", "type": "DOUBLE", "description": "Rho Greek"},
{"name": "in_the_money", "type": "BOOLEAN", "description": "Whether option is ITM"},
{"name": "intrinsic_value", "type": "DOUBLE", "description": "Intrinsic value"},
{"name": "extrinsic_value", "type": "DOUBLE", "description": "Extrinsic value"},
{"name": "time_value", "type": "DOUBLE", "description": "Time value"},
],
},
},
"example_queries": {
"volume_profile": """
SELECT
FLOOR(close / 10) * 10 as price_level,
SUM(volume) as total_volume,
COUNT(*) as candle_count
FROM price_history
WHERE symbol = 'TSLA'
GROUP BY 1
ORDER BY 2 DESC
LIMIT 10""",
"high_iv_options": """
SELECT symbol, strike, expiration, implied_volatility, delta, volume
FROM options
WHERE underlying_symbol = 'TSLA'
AND option_type = 'CALL'
AND implied_volatility > 0.5
ORDER BY implied_volatility DESC
LIMIT 20""",
"options_by_expiration": """
SELECT expiration, option_type, COUNT(*) as contracts, SUM(volume) as total_volume
FROM options
WHERE underlying_symbol = 'TSLA'
GROUP BY 1, 2
ORDER BY 1, 2""",
"atm_options": """
SELECT symbol, option_type, strike, expiration, bid, ask, delta, implied_volatility
FROM options
WHERE underlying_symbol = 'TSLA'
AND ABS(delta) BETWEEN 0.4 AND 0.6
ORDER BY expiration, strike""",
},
}
# Global storage instance
_storage: DataStorage | None = None
def get_storage() -> DataStorage:
"""Get or create the global DataStorage instance."""
global _storage
if _storage is None:
_storage = DataStorage()
return _storage