"""
Tests for MCP Server functionality.
"""
import pytest
import json
from unittest.mock import Mock, AsyncMock, patch
from src.crypto_mcp_server.server import CryptoMCPServer # type: ignore
class TestCryptoMCPServer:
"""Test suite for CryptoMCPServer."""
@pytest.fixture
def server(self):
"""Create server instance for testing."""
with patch('src.crypto_mcp_server.server.CryptoAPI'):
server = CryptoMCPServer(exchange_id='binance', cache_ttl=60)
return server
def test_initialization(self, server):
"""Test server initialization."""
assert server.server is not None
assert server.crypto_api is not None
assert server.cached_api is not None
@pytest.mark.asyncio
async def test_list_tools(self, server):
"""Test that all tools are registered."""
# Get the list_tools handler
tools = await server.server._tool_manager.list_tools()
expected_tools = [
'get_crypto_price',
'get_multiple_prices',
'get_historical_data',
'get_market_summary',
'get_orderbook',
'search_symbols',
'get_supported_exchanges',
'clear_cache',
'get_cache_stats'
]
tool_names = [tool.name for tool in tools]
for expected_tool in expected_tools:
assert expected_tool in tool_names
@pytest.mark.asyncio
async def test_tool_schemas(self, server):
"""Test that all tools have proper schemas."""
tools = await server.server._tool_manager.list_tools()
for tool in tools:
# Each tool should have inputSchema
assert hasattr(tool, 'inputSchema')
assert 'type' in tool.inputSchema
assert tool.inputSchema['type'] == 'object'
assert 'properties' in tool.inputSchema
class TestToolCalls:
"""Test individual tool calls."""
@pytest.fixture
def mock_server(self):
"""Create server with mocked APIs."""
with patch('src.crypto_mcp_server.server.CryptoAPI') as mock_api_class:
# Setup mock
mock_api = Mock()
mock_api_class.return_value = mock_api
# Mock methods
mock_api.get_current_price.return_value = {
'symbol': 'BTC/USDT',
'price': 50000,
'exchange': 'binance'
}
mock_api.get_multiple_prices.return_value = [
{'symbol': 'BTC/USDT', 'price': 50000},
{'symbol': 'ETH/USDT', 'price': 3000}
]
mock_api.get_historical_ohlcv.return_value = [
{'timestamp': 1, 'open': 50000, 'close': 50500}
]
mock_api.get_market_summary.return_value = {
'symbol': 'BTC/USDT',
'last_price': 50000
}
mock_api.get_orderbook.return_value = {
'symbol': 'BTC/USDT',
'bids': [[50000, 1.0]],
'asks': [[50001, 1.0]]
}
mock_api.search_symbols.return_value = ['BTC/USDT', 'BTC/USD']
mock_api.get_supported_exchanges.return_value = ['binance', 'coinbase']
server = CryptoMCPServer(exchange_id='binance')
server.crypto_api = mock_api
# Mock cached_api
server.cached_api = Mock()
server.cached_api.get_current_price = mock_api.get_current_price
server.cached_api.get_historical_ohlcv = mock_api.get_historical_ohlcv
server.cached_api.get_market_summary = mock_api.get_market_summary
server.cached_api.clear_cache = Mock()
server.cached_api.get_cache_stats = Mock(return_value={'total_entries': 5})
return server
@pytest.mark.asyncio
async def test_get_crypto_price_tool(self, mock_server):
"""Test get_crypto_price tool."""
result = await mock_server.server._tool_manager.call_tool(
'get_crypto_price',
{'symbol': 'BTC/USDT'}
)
assert len(result) == 1
text_content = result[0].text
data = json.loads(text_content)
assert 'symbol' in data
assert data['symbol'] == 'BTC/USDT'
assert 'price' in data
@pytest.mark.asyncio
async def test_get_multiple_prices_tool(self, mock_server):
"""Test get_multiple_prices tool."""
result = await mock_server.server._tool_manager.call_tool(
'get_multiple_prices',
{'symbols': ['BTC/USDT', 'ETH/USDT']}
)
assert len(result) == 1
data = json.loads(result[0].text)
assert isinstance(data, list)
assert len(data) == 2
@pytest.mark.asyncio
async def test_get_historical_data_tool(self, mock_server):
"""Test get_historical_data tool."""
result = await mock_server.server._tool_manager.call_tool(
'get_historical_data',
{'symbol': 'BTC/USDT', 'timeframe': '1d', 'limit': 10}
)
assert len(result) == 1
data = json.loads(result[0].text)
assert isinstance(data, list)
assert len(data) > 0
@pytest.mark.asyncio
async def test_get_market_summary_tool(self, mock_server):
"""Test get_market_summary tool."""
result = await mock_server.server._tool_manager.call_tool(
'get_market_summary',
{'symbol': 'BTC/USDT'}
)
assert len(result) == 1
data = json.loads(result[0].text)
assert 'symbol' in data
assert 'last_price' in data
@pytest.mark.asyncio
async def test_get_orderbook_tool(self, mock_server):
"""Test get_orderbook tool."""
result = await mock_server.server._tool_manager.call_tool(
'get_orderbook',
{'symbol': 'BTC/USDT', 'limit': 10}
)
assert len(result) == 1
data = json.loads(result[0].text)
assert 'bids' in data
assert 'asks' in data
@pytest.mark.asyncio
async def test_search_symbols_tool(self, mock_server):
"""Test search_symbols tool."""
result = await mock_server.server._tool_manager.call_tool(
'search_symbols',
{'query': 'BTC'}
)
assert len(result) == 1
data = json.loads(result[0].text)
assert isinstance(data, list)
assert all('BTC' in s for s in data)
@pytest.mark.asyncio
async def test_get_supported_exchanges_tool(self, mock_server):
"""Test get_supported_exchanges tool."""
result = await mock_server.server._tool_manager.call_tool(
'get_supported_exchanges',
{}
)
assert len(result) == 1
data = json.loads(result[0].text)
assert isinstance(data, list)
assert 'binance' in data
@pytest.mark.asyncio
async def test_clear_cache_tool(self, mock_server):
"""Test clear_cache tool."""
result = await mock_server.server._tool_manager.call_tool(
'clear_cache',
{}
)
assert len(result) == 1
data = json.loads(result[0].text)
assert 'status' in data
mock_server.cached_api.clear_cache.assert_called_once()
@pytest.mark.asyncio
async def test_get_cache_stats_tool(self, mock_server):
"""Test get_cache_stats tool."""
result = await mock_server.server._tool_manager.call_tool(
'get_cache_stats',
{}
)
assert len(result) == 1
data = json.loads(result[0].text)
assert 'total_entries' in data
@pytest.mark.asyncio
async def test_error_handling(self, mock_server):
"""Test that errors are handled gracefully."""
# Mock an error
mock_server.crypto_api.get_current_price.side_effect = Exception("API Error")
result = await mock_server.server._tool_manager.call_tool(
'get_crypto_price',
{'symbol': 'INVALID/PAIR'}
)
assert len(result) == 1
data = json.loads(result[0].text)
assert 'error' in data
@pytest.mark.asyncio
async def test_unknown_tool(self, mock_server):
"""Test calling an unknown tool."""
result = await mock_server.server._tool_manager.call_tool(
'unknown_tool',
{}
)
assert len(result) == 1
data = json.loads(result[0].text)
assert 'error' in data
if __name__ == '__main__':
pytest.main([__file__, '-v'])