"""
Tests for caching functionality.
"""
import pytest
import time
from unittest.mock import Mock, MagicMock
from src.crypto_mcp_server.cache import Cache, CachedCryptoAPI # type: ignore
class TestCache:
"""Test suite for Cache class."""
@pytest.fixture
def cache(self):
"""Create Cache instance with short TTL for testing."""
return Cache(default_ttl=2)
def test_initialization(self, cache):
"""Test cache initialization."""
assert cache.default_ttl == 2
assert len(cache._cache) == 0
def test_set_and_get(self, cache):
"""Test basic set and get operations."""
cache.set('key1', 'value1')
result = cache.get('key1')
assert result == 'value1'
def test_get_nonexistent_key(self, cache):
"""Test getting a key that doesn't exist."""
result = cache.get('nonexistent')
assert result is None
def test_set_with_custom_ttl(self, cache):
"""Test setting value with custom TTL."""
cache.set('key1', 'value1', ttl=10)
result = cache.get('key1')
assert result == 'value1'
def test_expiration(self, cache):
"""Test that entries expire after TTL."""
cache.set('key1', 'value1', ttl=1)
# Should be available immediately
assert cache.get('key1') == 'value1'
# Wait for expiration
time.sleep(1.5)
# Should be expired now
assert cache.get('key1') is None
def test_delete(self, cache):
"""Test deleting a key."""
cache.set('key1', 'value1')
# Delete should return True
assert cache.delete('key1') is True
# Key should be gone
assert cache.get('key1') is None
# Deleting again should return False
assert cache.delete('key1') is False
def test_clear(self, cache):
"""Test clearing all entries."""
cache.set('key1', 'value1')
cache.set('key2', 'value2')
cache.set('key3', 'value3')
assert len(cache._cache) == 3
cache.clear()
assert len(cache._cache) == 0
assert cache.get('key1') is None
def test_cleanup_expired(self, cache):
"""Test cleaning up expired entries."""
# Add some entries with different TTLs
cache.set('key1', 'value1', ttl=1)
cache.set('key2', 'value2', ttl=10)
cache.set('key3', 'value3', ttl=1)
# Wait for some to expire
time.sleep(1.5)
# Cleanup expired entries
removed = cache.cleanup_expired()
assert removed == 2 # key1 and key3 should be removed
assert cache.get('key2') == 'value2' # key2 should still exist
def test_get_stats(self, cache):
"""Test getting cache statistics."""
cache.set('key1', 'value1', ttl=10)
cache.set('key2', 'value2', ttl=1)
stats = cache.get_stats()
assert stats['total_entries'] == 2
assert stats['valid_entries'] == 2
assert stats['expired_entries'] == 0
# Wait for one to expire
time.sleep(1.5)
stats = cache.get_stats()
assert stats['total_entries'] == 2
assert stats['valid_entries'] == 1
assert stats['expired_entries'] == 1
def test_generate_key(self, cache):
"""Test cache key generation."""
key1 = cache._generate_key('arg1', 'arg2', kwarg1='value1')
key2 = cache._generate_key('arg1', 'arg2', kwarg1='value1')
key3 = cache._generate_key('arg1', 'arg3', kwarg1='value1')
# Same arguments should generate same key
assert key1 == key2
# Different arguments should generate different key
assert key1 != key3
def test_overwrite_existing_key(self, cache):
"""Test overwriting an existing key."""
cache.set('key1', 'value1')
cache.set('key1', 'value2')
assert cache.get('key1') == 'value2'
class TestCachedCryptoAPI:
"""Test suite for CachedCryptoAPI."""
@pytest.fixture
def mock_api(self):
"""Create mock CryptoAPI."""
mock = Mock()
# Mock get_current_price
mock.get_current_price.return_value = {
'symbol': 'BTC/USDT',
'price': 50000,
'bid': 49999,
'ask': 50001
}
# Mock get_historical_ohlcv
mock.get_historical_ohlcv.return_value = [
{'timestamp': 1, 'open': 50000, 'high': 51000, 'low': 49000, 'close': 50500, 'volume': 100}
]
# Mock get_market_summary
mock.get_market_summary.return_value = {
'symbol': 'BTC/USDT',
'last_price': 50000,
'volume_24h': 1000
}
return mock
@pytest.fixture
def cached_api(self, mock_api):
"""Create CachedCryptoAPI with mock."""
return CachedCryptoAPI(mock_api, cache_ttl=60)
def test_get_current_price_cached(self, cached_api, mock_api):
"""Test that get_current_price uses cache."""
# First call should hit API
result1 = cached_api.get_current_price('BTC/USDT')
assert result1['from_cache'] is False
assert mock_api.get_current_price.call_count == 1
# Second call should use cache
result2 = cached_api.get_current_price('BTC/USDT')
assert result2['from_cache'] is True
assert mock_api.get_current_price.call_count == 1 # Still 1, not called again
def test_get_current_price_no_cache(self, cached_api, mock_api):
"""Test that cache can be bypassed."""
# First call
result1 = cached_api.get_current_price('BTC/USDT', use_cache=False)
assert result1['from_cache'] is False
# Second call without cache should hit API again
result2 = cached_api.get_current_price('BTC/USDT', use_cache=False)
assert result2['from_cache'] is False
assert mock_api.get_current_price.call_count == 2
def test_get_historical_ohlcv_cached(self, cached_api, mock_api):
"""Test that historical data is cached."""
# First call
result1 = cached_api.get_historical_ohlcv('BTC/USDT', '1d', 100)
assert mock_api.get_historical_ohlcv.call_count == 1
# Second call should use cache
result2 = cached_api.get_historical_ohlcv('BTC/USDT', '1d', 100)
assert mock_api.get_historical_ohlcv.call_count == 1
assert result1 == result2
def test_get_market_summary_cached(self, cached_api, mock_api):
"""Test that market summary is cached."""
# First call
result1 = cached_api.get_market_summary('BTC/USDT')
assert result1['from_cache'] is False
assert mock_api.get_market_summary.call_count == 1
# Second call should use cache
result2 = cached_api.get_market_summary('BTC/USDT')
assert result2['from_cache'] is True
assert mock_api.get_market_summary.call_count == 1
def test_clear_cache(self, cached_api, mock_api):
"""Test clearing the cache."""
# Make a cached call
cached_api.get_current_price('BTC/USDT')
assert mock_api.get_current_price.call_count == 1
# Clear cache
cached_api.clear_cache()
# Next call should hit API again
cached_api.get_current_price('BTC/USDT')
assert mock_api.get_current_price.call_count == 2
def test_get_cache_stats(self, cached_api):
"""Test getting cache statistics."""
stats = cached_api.get_cache_stats()
assert 'total_entries' in stats
assert 'valid_entries' in stats
assert 'expired_entries' in stats
def test_different_symbols_different_cache(self, cached_api, mock_api):
"""Test that different symbols have different cache entries."""
# Configure mock to return different values
def get_price_side_effect(symbol):
if symbol == 'BTC/USDT':
return {'symbol': 'BTC/USDT', 'price': 50000}
elif symbol == 'ETH/USDT':
return {'symbol': 'ETH/USDT', 'price': 3000}
mock_api.get_current_price.side_effect = get_price_side_effect
# Get prices for different symbols
btc = cached_api.get_current_price('BTC/USDT')
eth = cached_api.get_current_price('ETH/USDT')
assert btc['price'] == 50000
assert eth['price'] == 3000
assert mock_api.get_current_price.call_count == 2
if __name__ == '__main__':
pytest.main([__file__, '-v'])