"""
Tests for CryptoAPI class.
"""
import pytest
from src.crypto_mcp_server.crypto_api import CryptoAPI # type: ignore
class TestCryptoAPI:
"""Test suite for CryptoAPI."""
@pytest.fixture
def api(self):
"""Create CryptoAPI instance for testing."""
return CryptoAPI(exchange_id='binance')
def test_initialization(self, api):
"""Test API initialization."""
assert api.exchange_id == 'binance'
assert api.exchange is not None
def test_initialization_invalid_exchange(self):
"""Test initialization with invalid exchange."""
with pytest.raises(ValueError, match="not supported"):
CryptoAPI(exchange_id='invalid_exchange_name')
def test_get_current_price(self, api):
"""Test fetching current price."""
result = api.get_current_price('BTC/USDT')
assert 'symbol' in result
assert result['symbol'] == 'BTC/USDT'
assert 'price' in result
assert 'exchange' in result
assert result['exchange'] == 'binance'
assert isinstance(result['price'], (int, float))
assert result['price'] > 0
# Check all required fields
required_fields = [
'symbol', 'exchange', 'price', 'bid', 'ask',
'high_24h', 'low_24h', 'volume_24h', 'timestamp', 'datetime'
]
for field in required_fields:
assert field in result
def test_get_current_price_invalid_symbol(self, api):
"""Test fetching price with invalid symbol."""
with pytest.raises(Exception):
api.get_current_price('INVALID/PAIR')
def test_get_multiple_prices(self, api):
"""Test fetching multiple prices."""
symbols = ['BTC/USDT', 'ETH/USDT']
results = api.get_multiple_prices(symbols)
assert len(results) == 2
assert all('symbol' in r for r in results)
# Check that we got data for both symbols
returned_symbols = [r['symbol'] for r in results if 'error' not in r]
assert 'BTC/USDT' in returned_symbols
assert 'ETH/USDT' in returned_symbols
def test_get_multiple_prices_with_invalid(self, api):
"""Test fetching multiple prices with one invalid symbol."""
symbols = ['BTC/USDT', 'INVALID/PAIR']
results = api.get_multiple_prices(symbols)
assert len(results) == 2
# One should succeed, one should have error
valid_results = [r for r in results if 'error' not in r]
error_results = [r for r in results if 'error' in r]
assert len(valid_results) >= 1
assert len(error_results) >= 1
def test_get_historical_ohlcv(self, api):
"""Test fetching historical OHLCV data."""
result = api.get_historical_ohlcv('BTC/USDT', timeframe='1d', limit=10)
assert isinstance(result, list)
assert len(result) > 0
assert len(result) <= 10
# Check structure of first candle
candle = result[0]
required_fields = ['timestamp', 'datetime', 'open', 'high', 'low', 'close', 'volume']
for field in required_fields:
assert field in candle
# Verify OHLC relationship
assert candle['high'] >= candle['low']
assert candle['high'] >= candle['open']
assert candle['high'] >= candle['close']
assert candle['low'] <= candle['open']
assert candle['low'] <= candle['close']
def test_get_historical_ohlcv_different_timeframes(self, api):
"""Test fetching historical data with different timeframes."""
timeframes = ['1h', '1d']
for tf in timeframes:
result = api.get_historical_ohlcv('BTC/USDT', timeframe=tf, limit=5)
assert len(result) > 0
assert len(result) <= 5
def test_get_orderbook(self, api):
"""Test fetching order book."""
result = api.get_orderbook('BTC/USDT', limit=10)
assert 'symbol' in result
assert 'bids' in result
assert 'asks' in result
assert isinstance(result['bids'], list)
assert isinstance(result['asks'], list)
# Check orderbook structure
if len(result['bids']) > 0:
bid = result['bids'][0]
assert len(bid) == 2 # [price, amount]
assert isinstance(bid[0], (int, float)) # price
assert isinstance(bid[1], (int, float)) # amount
if len(result['asks']) > 0:
ask = result['asks'][0]
assert len(ask) == 2
def test_get_market_summary(self, api):
"""Test fetching market summary."""
result = api.get_market_summary('BTC/USDT')
required_fields = [
'symbol', 'exchange', 'last_price', 'bid', 'ask',
'high_24h', 'low_24h', 'volume_24h', 'price_change_24h'
]
for field in required_fields:
assert field in result
# Verify spread calculation
if result['spread'] is not None:
assert result['spread'] >= 0
assert result['spread'] == result['ask'] - result['bid']
def test_get_available_symbols(self, api):
"""Test fetching available symbols."""
symbols = api.get_available_symbols()
assert isinstance(symbols, list)
assert len(symbols) > 0
# Check that common pairs exist
assert any('BTC' in s for s in symbols)
assert any('ETH' in s for s in symbols)
def test_search_symbols(self, api):
"""Test searching symbols."""
results = api.search_symbols('BTC')
assert isinstance(results, list)
assert len(results) > 0
# All results should contain 'BTC'
assert all('BTC' in s for s in results)
def test_search_symbols_case_insensitive(self, api):
"""Test that symbol search is case-insensitive."""
results_upper = api.search_symbols('BTC')
results_lower = api.search_symbols('btc')
assert len(results_upper) == len(results_lower)
assert set(results_upper) == set(results_lower)
def test_get_supported_exchanges(self):
"""Test getting supported exchanges."""
exchanges = CryptoAPI.get_supported_exchanges()
assert isinstance(exchanges, list)
assert len(exchanges) > 0
# Check that major exchanges are included
assert 'binance' in exchanges
assert 'coinbase' in exchanges or 'coinbasepro' in exchanges
@pytest.mark.asyncio
class TestCryptoAPIAsync:
"""Async tests for CryptoAPI (for future async implementation)."""
@pytest.fixture
def api(self):
"""Create CryptoAPI instance."""
return CryptoAPI(exchange_id='binance')
async def test_concurrent_requests(self, api):
"""Test that API can handle concurrent requests."""
import asyncio
# Create multiple requests
symbols = ['BTC/USDT', 'ETH/USDT', 'BNB/USDT']
# Run them concurrently (simulated)
results = []
for symbol in symbols:
result = api.get_current_price(symbol)
results.append(result)
assert len(results) == 3
assert all('price' in r for r in results)
if __name__ == '__main__':
pytest.main([__file__, '-v'])