Skip to main content
Glama
conftest.py17.2 kB
""" Shared test fixtures for the datalake-mcp-server test suite. Provides reusable mocks for external dependencies: - Spark/PySpark: Mock SparkSession, DataFrames, and Spark operations - Redis: Mock redis.Redis client and connection pool - KBase Auth: Mock aiohttp.ClientSession for auth API calls - HTTP: Mock httpx.Client for governance API calls """ import asyncio import concurrent.futures from typing import Any, Dict, List from unittest.mock import AsyncMock, MagicMock, patch import pytest from fastapi import FastAPI from fastapi.testclient import TestClient from pydantic import AnyHttpUrl, AnyUrl from src.service.app_state import RequestState from src.service.dependencies import auth, get_spark_session from src.service.exceptions import InvalidTokenError from src.service.kb_auth import AdminPermission, KBaseUser from src.settings import BERDLSettings # ============================================================================= # Settings Fixtures # ============================================================================= @pytest.fixture def mock_settings() -> BERDLSettings: """ Create mock BERDLSettings for testing. Returns: BERDLSettings with test-appropriate defaults. """ return BERDLSettings( KBASE_AUTH_TOKEN="test_token_12345", USER="testuser", MINIO_ENDPOINT_URL="localhost:9002", MINIO_ACCESS_KEY="test_access_key", MINIO_SECRET_KEY="test_secret_key", MINIO_SECURE=False, BERDL_REDIS_HOST="localhost", BERDL_REDIS_PORT=6379, SPARK_HOME="/usr/local/spark", SPARK_MASTER_URL=AnyUrl("spark://localhost:7077"), SPARK_CONNECT_URL=AnyUrl("sc://localhost:15002"), BERDL_HIVE_METASTORE_URI=AnyUrl("thrift://localhost:9083"), SPARK_WORKER_COUNT=1, SPARK_WORKER_CORES=1, SPARK_WORKER_MEMORY="2GiB", SPARK_MASTER_CORES=1, SPARK_MASTER_MEMORY="1GiB", GOVERNANCE_API_URL=AnyHttpUrl("http://localhost:8000"), BERDL_POD_IP="127.0.0.1", ) @pytest.fixture def mock_settings_patch(mock_settings): """ Patch get_settings to return mock settings. Usage: def test_something(mock_settings_patch): # get_settings() now returns mock_settings """ with patch("src.settings.get_settings", return_value=mock_settings): yield mock_settings # ============================================================================= # Spark Session Fixtures # ============================================================================= @pytest.fixture def mock_spark_row(): """Factory for creating mock Spark Row objects.""" def _create_row(data: Dict[str, Any]) -> MagicMock: row = MagicMock() row.asDict.return_value = data for key, value in data.items(): setattr(row, key, value) return row return _create_row @pytest.fixture def mock_spark_dataframe(mock_spark_row): """ Create a configurable mock Spark DataFrame. Usage: def test_query(mock_spark_dataframe): df = mock_spark_dataframe([{"id": 1, "name": "test"}]) df.collect() # Returns list of mock rows """ def _create_df(data: List[Dict[str, Any]] = None, count: int = None): if data is None: data = [] df = MagicMock() # Mock collect() to return rows rows = [mock_spark_row(row) for row in data] df.collect.return_value = rows # Mock count() df.count.return_value = count if count is not None else len(data) # Mock chained operations df.select.return_value = df df.filter.return_value = df df.where.return_value = df df.limit.return_value = df df.orderBy.return_value = df df.groupBy.return_value = df return df return _create_df @pytest.fixture def mock_spark_session(mock_spark_dataframe): """ Create a mock SparkSession for testing. The mock supports: - spark.sql(query) -> mock DataFrame - spark.table(name) -> mock DataFrame - spark.catalog.listDatabases() -> list of mock Database objects - spark.catalog.listTables(db) -> list of mock Table objects - spark.stop() -> no-op Usage: def test_query(mock_spark_session): spark = mock_spark_session() spark.sql("SELECT * FROM test").collect() """ def _create_session( sql_results: List[Dict[str, Any]] = None, table_results: List[Dict[str, Any]] = None, databases: List[str] = None, tables: Dict[str, List[str]] = None, ): if sql_results is None: sql_results = [] if table_results is None: table_results = sql_results if databases is None: databases = ["default", "testdb"] if tables is None: tables = {"default": ["table1"], "testdb": ["users", "orders"]} spark = MagicMock() # Mock sql() method sql_df = mock_spark_dataframe(sql_results) spark.sql.return_value = sql_df # Mock table() method table_df = mock_spark_dataframe(table_results) spark.table.return_value = table_df # Mock catalog catalog = MagicMock() # Mock listDatabases db_objects = [] for db_name in databases: db_obj = MagicMock() db_obj.name = db_name db_objects.append(db_obj) catalog.listDatabases.return_value = db_objects # Mock listTables def list_tables_side_effect(db_name): table_list = tables.get(db_name, []) table_objects = [] for table_name in table_list: table_obj = MagicMock() table_obj.name = table_name table_objects.append(table_obj) return table_objects catalog.listTables.side_effect = list_tables_side_effect spark.catalog = catalog # Mock stop spark.stop.return_value = None return spark return _create_session # ============================================================================= # Redis Fixtures # ============================================================================= @pytest.fixture def mock_redis_client(): """ Create a mock Redis client for testing. The mock supports: - client.get(key) -> bytes or None - client.set(name, value, ex=ttl) -> True - client.delete(key) -> int Usage: def test_cache(mock_redis_client): client = mock_redis_client({"key": b'{"data": "value"}'}) """ def _create_client(cache_data: Dict[str, bytes] = None): if cache_data is None: cache_data = {} client = MagicMock() stored_data = dict(cache_data) def get_side_effect(key): return stored_data.get(key) def set_side_effect(name, value, ex=None): stored_data[name] = value.encode() if isinstance(value, str) else value return True def delete_side_effect(*keys): count = 0 for key in keys: if key in stored_data: del stored_data[key] count += 1 return count client.get.side_effect = get_side_effect client.set.side_effect = set_side_effect client.delete.side_effect = delete_side_effect # Store reference to internal data for assertions client._stored_data = stored_data return client return _create_client @pytest.fixture def mock_redis_client_patch(mock_redis_client): """ Patch the Redis client factory to return a mock client. Usage: def test_cache(mock_redis_client_patch): client, patch = mock_redis_client_patch() # All redis operations now use the mock """ def _create_patched(cache_data: Dict[str, bytes] = None): client = mock_redis_client(cache_data) patcher = patch("src.cache.redis_cache._get_redis_client", return_value=client) patcher.start() return client, patcher return _create_patched # ============================================================================= # KBase Auth Fixtures # ============================================================================= @pytest.fixture def mock_kbase_user(): """Create a mock KBaseUser for testing.""" def _create_user( username: str = "testuser", admin_perm: AdminPermission = AdminPermission.NONE, ): return KBaseUser(user=username, admin_perm=admin_perm) return _create_user @pytest.fixture def mock_kbase_auth(mock_kbase_user): """ Create a mock KBaseAuth client for testing. Usage: def test_auth(mock_kbase_auth): auth = mock_kbase_auth() user = await auth.get_user("valid_token") """ def _create_auth( valid_tokens: Dict[str, str] = None, admin_tokens: List[str] = None, ): if valid_tokens is None: valid_tokens = {"valid_token": "testuser"} if admin_tokens is None: admin_tokens = [] auth = MagicMock() async def get_user_side_effect(token): if token in valid_tokens: username = valid_tokens[token] admin_perm = ( AdminPermission.FULL if token in admin_tokens else AdminPermission.NONE ) return mock_kbase_user(username, admin_perm) raise InvalidTokenError("KBase auth server reported token is invalid.") auth.get_user = AsyncMock(side_effect=get_user_side_effect) auth.close = AsyncMock() return auth return _create_auth @pytest.fixture def mock_aiohttp_session(): """ Create a mock aiohttp ClientSession for testing. Useful for testing KBaseAuth initialization and HTTP calls. """ def _create_session(responses: Dict[str, Any] = None): if responses is None: responses = {} session = MagicMock() async def mock_get(url, headers=None): response = MagicMock() response.status = 200 async def mock_json(): if url in responses: return responses[url] # Default auth service response return {"servicename": "Authentication Service"} async def mock_text(): return str(responses.get(url, "")) response.json = mock_json response.text = mock_text return response # Create async context manager context_manager = MagicMock() context_manager.__aenter__ = AsyncMock( side_effect=lambda: mock_get(context_manager._url, context_manager._headers) ) context_manager.__aexit__ = AsyncMock(return_value=None) def get_side_effect(url, headers=None): context_manager._url = url context_manager._headers = headers return context_manager session.get.side_effect = get_side_effect session.close = AsyncMock() session.closed = False return session return _create_session # ============================================================================= # HTTP Client Fixtures (for governance API) # ============================================================================= @pytest.fixture def mock_httpx_client(): """ Create a mock httpx Client for testing governance API calls. Usage: def test_governance(mock_httpx_client): client = mock_httpx_client({ "http://localhost/api": {"data": "value"} }) """ def _create_client(responses: Dict[str, Any] = None): if responses is None: responses = {} client = MagicMock() def get_side_effect(url, headers=None, params=None): response = MagicMock() response.status_code = 200 # Build full URL with params for lookup lookup_url = url if params: param_str = "&".join(f"{k}={v}" for k, v in params.items()) lookup_url = f"{url}?{param_str}" def json_func(): # Try exact match first, then base URL if lookup_url in responses: return responses[lookup_url] if url in responses: return responses[url] return {} response.json = json_func response.raise_for_status = MagicMock() return response client.get.side_effect = get_side_effect client.close = MagicMock() return client return _create_client # ============================================================================= # FastAPI Test Client Fixtures # ============================================================================= @pytest.fixture def mock_app_dependencies(mock_spark_session, mock_kbase_user, mock_settings): """ Create a fully mocked FastAPI app for testing routes. This fixture patches both auth and spark session dependencies. """ from src.main import create_application app = create_application() # Create mock spark session spark = mock_spark_session() # Create mock user user = mock_kbase_user() # Override dependencies def mock_get_spark(): yield spark def mock_auth(): return user app.dependency_overrides[get_spark_session] = mock_get_spark app.dependency_overrides[auth] = mock_auth return app, spark, user @pytest.fixture def test_client(mock_app_dependencies): """ Create a TestClient with mocked dependencies. Usage: def test_endpoint(test_client): app, spark, user = mock_app_dependencies client = TestClient(app) response = client.get("/health") """ app, spark, user = mock_app_dependencies return TestClient(app) @pytest.fixture def client(): """ Basic test client without mocked dependencies. Note: This requires actual services to be available. Use test_client fixture for unit tests with mocks. """ from src.main import create_application app = create_application() return TestClient(app) # ============================================================================= # Request/Response Fixtures # ============================================================================= @pytest.fixture def mock_request(mock_kbase_user): """ Create a mock FastAPI Request object for testing. Usage: def test_handler(mock_request): request = mock_request(user="testuser") """ def _create_request( user: str = "testuser", headers: Dict[str, str] = None, app: FastAPI = None, ): if headers is None: headers = {"Authorization": "Bearer valid_token"} request = MagicMock() request.headers = headers request.app = app or MagicMock() # Set up request state with user request.state = MagicMock() request.state._request_state = RequestState( user=mock_kbase_user(user) if user else None ) return request return _create_request # ============================================================================= # Concurrency Testing Fixtures # ============================================================================= @pytest.fixture def concurrent_executor(): """ Fixture for testing concurrent operations. Usage: def test_concurrent(concurrent_executor): results = concurrent_executor(my_func, args_list, max_workers=10) """ def _execute(func, args_list, max_workers=10): with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor: futures = [executor.submit(func, *args) for args in args_list] results = [] exceptions = [] for future in concurrent.futures.as_completed(futures): try: results.append(future.result()) except Exception as e: exceptions.append(e) return results, exceptions return _execute @pytest.fixture def async_concurrent_executor(): """ Fixture for testing concurrent async operations. Usage: async def test_concurrent(async_concurrent_executor): results = await async_concurrent_executor(my_async_func, args_list) """ async def _execute(func, args_list, max_concurrent=10): semaphore = asyncio.Semaphore(max_concurrent) async def limited_func(*args): async with semaphore: return await func(*args) tasks = [limited_func(*args) for args in args_list] results = await asyncio.gather(*tasks, return_exceptions=True) return results return _execute

Latest Blog Posts

MCP directory API

We provide all the information about MCP servers via our MCP API.

curl -X GET 'https://glama.ai/api/mcp/v1/servers/BERDataLakehouse/datalake-mcp-server'

If you have feedback or need assistance with the MCP directory API, please join our Discord server