"""Tests for SQL query tools."""
import pytest
from schwab_mcp.storage import DataStorage, get_storage
from schwab_mcp.tools import query
# Reset global storage between tests
@pytest.fixture(autouse=True)
def reset_storage():
"""Reset global storage before each test."""
import schwab_mcp.storage as storage_module
storage_module._storage = None
yield
storage_module._storage = None
class TestQueryMarketData:
"""Tests for query_market_data tool."""
@pytest.mark.asyncio
async def test_empty_query(self):
"""Test that empty query returns error."""
result = await query.query_market_data({})
assert not result["success"]
assert "No SQL query" in result["error"]
@pytest.mark.asyncio
async def test_invalid_query_returns_error(self):
"""Test that invalid (non-SELECT) queries return error."""
result = await query.query_market_data({"sql": "DELETE FROM price_history"})
assert not result["success"]
assert "SELECT" in result["error"]
@pytest.mark.asyncio
async def test_query_empty_table(self):
"""Test querying empty table."""
result = await query.query_market_data({"sql": "SELECT * FROM price_history"})
assert result["success"]
assert result["row_count"] == 0
assert "available_data" in result
@pytest.mark.asyncio
async def test_query_with_data(self):
"""Test querying after loading data."""
storage = get_storage()
storage.store_price_history("AAPL", [
{"datetime": "2025-01-01", "open": 180, "high": 185, "low": 178, "close": 183, "volume": 50000000}
])
result = await query.query_market_data({
"sql": "SELECT symbol, close, volume FROM price_history WHERE symbol = 'AAPL'"
})
assert result["success"]
assert result["row_count"] == 1
assert result["data"][0]["symbol"] == "AAPL"
assert result["data"][0]["close"] == 183
@pytest.mark.asyncio
async def test_aggregation_query(self):
"""Test aggregation queries work correctly."""
storage = get_storage()
storage.store_price_history("TSLA", [
{"datetime": "2025-01-01", "open": 380, "high": 390, "low": 375, "close": 385, "volume": 80000000},
{"datetime": "2025-01-02", "open": 385, "high": 395, "low": 380, "close": 390, "volume": 70000000},
{"datetime": "2025-01-03", "open": 390, "high": 400, "low": 385, "close": 395, "volume": 90000000},
])
result = await query.query_market_data({
"sql": """
SELECT
symbol,
COUNT(*) as candle_count,
SUM(volume) as total_volume,
AVG(close) as avg_close,
MAX(high) as max_high,
MIN(low) as min_low
FROM price_history
WHERE symbol = 'TSLA'
GROUP BY symbol
"""
})
assert result["success"]
assert result["row_count"] == 1
data = result["data"][0]
assert data["candle_count"] == 3
assert data["total_volume"] == 240000000
assert data["max_high"] == 400
assert data["min_low"] == 375
class TestGetDataSchema:
"""Tests for get_data_schema tool."""
@pytest.mark.asyncio
async def test_returns_schema(self):
"""Test that schema info is returned."""
result = await query.get_data_schema({})
assert "tables" in result
assert "price_history" in result["tables"]
assert "options" in result["tables"]
assert "example_queries" in result
assert "loaded_data" in result
@pytest.mark.asyncio
async def test_shows_loaded_data(self):
"""Test that loaded data is shown."""
storage = get_storage()
storage.store_price_history("AAPL", [
{"datetime": "2025-01-01", "open": 180, "high": 185, "low": 178, "close": 183, "volume": 50000000}
])
result = await query.get_data_schema({})
assert len(result["loaded_data"]["datasets"]) == 1
assert result["loaded_data"]["datasets"][0]["symbol"] == "AAPL"