"""Tests for DuckDB storage module."""
import pytest
from schwab_mcp.storage import DataStorage, get_storage
class TestDataStorage:
"""Tests for DataStorage class."""
def setup_method(self):
"""Create fresh storage instance for each test."""
self.storage = DataStorage()
def test_init_creates_tables(self):
"""Test that initialization creates required tables."""
result = self.storage.query(
"SELECT table_name FROM information_schema.tables WHERE table_schema = 'main'"
)
assert result["success"]
table_names = [row["table_name"] for row in result["data"]]
assert "price_history" in table_names
assert "options" in table_names
assert "data_metadata" in table_names
def test_store_price_history(self):
"""Test storing price history data."""
candles = [
{"datetime": "2025-01-01T10:00:00", "open": 100, "high": 105, "low": 99, "close": 104, "volume": 1000000},
{"datetime": "2025-01-02T10:00:00", "open": 104, "high": 110, "low": 103, "close": 108, "volume": 1200000},
{"datetime": "2025-01-03T10:00:00", "open": 108, "high": 112, "low": 106, "close": 110, "volume": 900000},
]
count = self.storage.store_price_history("TSLA", candles, {"period": "3 days"})
assert count == 3
# Verify data was stored
result = self.storage.query("SELECT * FROM price_history WHERE symbol = 'TSLA' ORDER BY datetime")
assert result["success"]
assert result["row_count"] == 3
assert result["data"][0]["close"] == 104
assert result["data"][2]["volume"] == 900000
def test_store_price_history_replaces_existing(self):
"""Test that storing replaces existing data for symbol."""
candles1 = [{"datetime": "2025-01-01T10:00:00", "open": 100, "high": 105, "low": 99, "close": 104, "volume": 1000000}]
candles2 = [{"datetime": "2025-01-02T10:00:00", "open": 200, "high": 205, "low": 199, "close": 204, "volume": 2000000}]
self.storage.store_price_history("TSLA", candles1)
self.storage.store_price_history("TSLA", candles2)
result = self.storage.query("SELECT * FROM price_history WHERE symbol = 'TSLA'")
assert result["row_count"] == 1
assert result["data"][0]["open"] == 200
def test_store_options(self):
"""Test storing option chain data."""
calls = [
{"symbol": "TSLA250117C00400000", "strike": 400, "expiration": "2025-01-17", "days_to_expiration": 10,
"bid": 5.0, "ask": 5.5, "last": 5.25, "mark": 5.25, "volume": 1000, "open_interest": 5000,
"implied_volatility": 0.45, "delta": 0.35, "gamma": 0.02, "theta": -0.05, "vega": 0.10, "rho": 0.01,
"in_the_money": False, "intrinsic_value": 0, "extrinsic_value": 5.25, "time_value": 5.25},
]
puts = [
{"symbol": "TSLA250117P00350000", "strike": 350, "expiration": "2025-01-17", "days_to_expiration": 10,
"bid": 3.0, "ask": 3.5, "last": 3.25, "mark": 3.25, "volume": 800, "open_interest": 3000,
"implied_volatility": 0.42, "delta": -0.25, "gamma": 0.015, "theta": -0.04, "vega": 0.08, "rho": -0.01,
"in_the_money": False, "intrinsic_value": 0, "extrinsic_value": 3.25, "time_value": 3.25},
]
count = self.storage.store_options("TSLA", 380.0, calls, puts)
assert count == 2
# Verify calls
result = self.storage.query("SELECT * FROM options WHERE option_type = 'CALL'")
assert result["row_count"] == 1
assert result["data"][0]["strike"] == 400
assert result["data"][0]["delta"] == 0.35
# Verify puts
result = self.storage.query("SELECT * FROM options WHERE option_type = 'PUT'")
assert result["row_count"] == 1
assert result["data"][0]["strike"] == 350
def test_query_select_only(self):
"""Test that only SELECT queries are allowed."""
with pytest.raises(ValueError):
self.storage.query("DELETE FROM price_history")
with pytest.raises(ValueError):
self.storage.query("DROP TABLE price_history")
with pytest.raises(ValueError):
self.storage.query("INSERT INTO price_history VALUES ('X', NOW(), 1,1,1,1,1)")
with pytest.raises(ValueError):
self.storage.query("UPDATE price_history SET close = 100")
with pytest.raises(ValueError):
self.storage.query("TRUNCATE TABLE price_history")
def test_query_cte_allowed(self):
"""Test that CTEs (WITH clauses) are allowed."""
candles = [
{"datetime": "2025-01-01", "open": 100, "high": 105, "low": 99, "close": 104, "volume": 1000000},
{"datetime": "2025-01-02", "open": 104, "high": 110, "low": 103, "close": 108, "volume": 1200000},
]
self.storage.store_price_history("TEST", candles)
# CTE query should work
result = self.storage.query("""
WITH daily_data AS (
SELECT symbol, close, volume
FROM price_history
WHERE symbol = 'TEST'
)
SELECT symbol, SUM(volume) as total_vol
FROM daily_data
GROUP BY symbol
""")
assert result["success"]
assert result["row_count"] == 1
assert result["data"][0]["total_vol"] == 2200000
def test_query_volume_profile(self):
"""Test volume profile calculation via SQL."""
candles = [
{"datetime": "2025-01-01", "open": 100, "high": 105, "low": 99, "close": 104, "volume": 1000000},
{"datetime": "2025-01-02", "open": 104, "high": 108, "low": 102, "close": 106, "volume": 1500000},
{"datetime": "2025-01-03", "open": 106, "high": 115, "low": 105, "close": 114, "volume": 2000000},
{"datetime": "2025-01-04", "open": 114, "high": 118, "low": 110, "close": 112, "volume": 1200000},
]
self.storage.store_price_history("TEST", candles)
result = self.storage.query("""
SELECT
FLOOR(close / 5) * 5 as price_level,
SUM(volume) as total_volume,
COUNT(*) as candle_count
FROM price_history
WHERE symbol = 'TEST'
GROUP BY 1
ORDER BY 2 DESC
""")
assert result["success"]
assert result["row_count"] > 0
# Verify aggregation worked
total = sum(row["total_volume"] for row in result["data"])
assert total == 5700000 # Sum of all volumes
def test_get_available_data(self):
"""Test getting metadata about loaded data."""
candles = [{"datetime": "2025-01-01", "open": 100, "high": 105, "low": 99, "close": 104, "volume": 1000000}]
self.storage.store_price_history("AAPL", candles)
available = self.storage.get_available_data()
assert "datasets" in available
assert len(available["datasets"]) == 1
assert available["datasets"][0]["symbol"] == "AAPL"
assert available["datasets"][0]["data_type"] == "price_history"
assert available["datasets"][0]["record_count"] == 1
def test_get_schema_info(self):
"""Test getting schema information."""
schema = self.storage.get_schema_info()
assert "tables" in schema
assert "price_history" in schema["tables"]
assert "options" in schema["tables"]
assert "example_queries" in schema
# Check price_history columns
ph_columns = [col["name"] for col in schema["tables"]["price_history"]["columns"]]
assert "symbol" in ph_columns
assert "datetime" in ph_columns
assert "volume" in ph_columns
# Check options columns
opt_columns = [col["name"] for col in schema["tables"]["options"]["columns"]]
assert "delta" in opt_columns
assert "implied_volatility" in opt_columns
def test_multiple_symbols(self):
"""Test storing data for multiple symbols."""
self.storage.store_price_history("AAPL", [
{"datetime": "2025-01-01", "open": 180, "high": 185, "low": 178, "close": 183, "volume": 50000000}
])
self.storage.store_price_history("TSLA", [
{"datetime": "2025-01-01", "open": 380, "high": 390, "low": 375, "close": 385, "volume": 80000000}
])
result = self.storage.query("SELECT symbol, close FROM price_history ORDER BY symbol")
assert result["row_count"] == 2
assert result["data"][0]["symbol"] == "AAPL"
assert result["data"][1]["symbol"] == "TSLA"
class TestCacheTTL:
"""Tests for cache TTL functionality."""
def setup_method(self):
"""Create fresh storage instance for each test."""
self.storage = DataStorage()
def test_cache_valid_after_store(self):
"""Test that cache is valid immediately after storing."""
candles = [{"datetime": "2025-01-01", "open": 100, "high": 105, "low": 99, "close": 104, "volume": 1000000}]
self.storage.store_price_history("AAPL", candles)
cache_status = self.storage.is_cache_valid("price_history", "AAPL")
assert cache_status["valid"] is True
assert cache_status["record_count"] == 1
assert cache_status["ttl_remaining_seconds"] > 3500 # Should be close to 3600
def test_cache_invalid_when_not_loaded(self):
"""Test that cache is invalid for symbols not loaded."""
cache_status = self.storage.is_cache_valid("price_history", "NOTLOADED")
assert cache_status["valid"] is False
assert cache_status["reason"] == "not_loaded"
def test_cache_case_insensitive(self):
"""Test that cache lookup is case insensitive."""
candles = [{"datetime": "2025-01-01", "open": 100, "high": 105, "low": 99, "close": 104, "volume": 1000000}]
self.storage.store_price_history("AAPL", candles)
# Should find cache regardless of case
assert self.storage.is_cache_valid("price_history", "aapl")["valid"] is True
assert self.storage.is_cache_valid("price_history", "Aapl")["valid"] is True
assert self.storage.is_cache_valid("price_history", "AAPL")["valid"] is True
def test_options_cache(self):
"""Test cache for options data."""
calls = [{"symbol": "AAPL250117C00200000", "strike": 200, "expiration": "2025-01-17",
"days_to_expiration": 10, "bid": 5, "ask": 6, "last": 5.5, "mark": 5.5,
"volume": 100, "open_interest": 500, "implied_volatility": 0.3,
"delta": 0.5, "gamma": 0.02, "theta": -0.05, "vega": 0.1, "rho": 0.01,
"in_the_money": False, "intrinsic_value": 0, "extrinsic_value": 5.5, "time_value": 5.5}]
self.storage.store_options("AAPL", 195.0, calls, [])
cache_status = self.storage.is_cache_valid("options", "AAPL")
assert cache_status["valid"] is True
assert cache_status["record_count"] == 1
class TestGetStorage:
"""Tests for get_storage singleton."""
def test_returns_same_instance(self):
"""Test that get_storage returns singleton."""
# Note: This test may interfere with other tests due to global state
# In production, you might want to reset the global between tests
storage1 = get_storage()
storage2 = get_storage()
assert storage1 is storage2