#
# Copyright (C) 2017-2025 Dremio Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import io
import json
import re
import asyncio
import time
from pathlib import Path
from typing import Dict, Any, Optional, Union, TextIO, List, Coroutine, Callable
from unittest.mock import MagicMock
from aiohttp import ClientSession
from collections import OrderedDict
from starlette.applications import Starlette
from starlette.requests import Request
from starlette.responses import JSONResponse
from starlette.routing import Route
from starlette.middleware import Middleware
from starlette.middleware.base import BaseHTTPMiddleware
from pydantic import BaseModel, ConfigDict
import uvicorn
import threading
from dremioai import log
class MockResponse:
"""Mock ClientResponse that returns data from files"""
def __init__(self, data: str, status: int = 200, headers: Optional[Dict] = None):
self.data = data
self.status = status
self.headers = headers or {}
self.request_info = MagicMock()
self.request_info.method = "GET"
self.request_info.url = "http://mock.url"
async def text(self) -> str:
"""Return the mock data as text"""
return self.data
async def json(self) -> Dict[str, Any]:
"""Return the mock data as JSON"""
return json.loads(self.data)
def raise_for_status(self):
"""Mock raise_for_status - only raises if status >= 400"""
if self.status >= 400:
raise Exception(f"HTTP {self.status}")
@property
def content(self):
"""Mock content property for streaming reads"""
mock_content = MagicMock()
async def read(chunk_size=1024):
# Return data in chunks for download simulation
if hasattr(self, "_read_position"):
if self._read_position >= len(self.data):
return b""
chunk = self.data[
self._read_position : self._read_position + chunk_size
].encode()
self._read_position += chunk_size
return chunk
else:
self._read_position = 0
chunk = self.data[0:chunk_size].encode()
self._read_position = chunk_size
return chunk
mock_content.read = read
return mock_content
async def __aenter__(self):
"""Async context manager entry"""
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
"""Async context manager exit"""
pass
class HttpMockFramework:
"""Simple HTTP mock framework for testing transport.py"""
def __init__(self, resources_dir: str = "tests/resources"):
self.resources_dir = Path(resources_dir)
self.mock_responses = OrderedDict()
self.original_session = None
def load_mock_data(self, endpoint: str, filename: str) -> "HttpMockFramework":
"""
Load mock data from a file for a specific endpoint
Args:
endpoint: The API endpoint to mock (e.g., "/api/v3/catalog")
filename: The filename in tests/resources (e.g., "catalog/spaces.json")
"""
file_path = self.resources_dir / filename
if not file_path.exists():
raise FileNotFoundError(f"Mock data file not found: {file_path}")
with open(file_path, "r") as f:
self.mock_responses[endpoint] = f.read()
return self
def add_mock_response(
self, endpoint: str, response_data: Union[str, Dict]
) -> "HttpMockFramework":
"""
Add a mock response directly without loading from file
Args:
endpoint: The API endpoint to mock
response_data: The response data (string or dict that will be JSON serialized)
"""
if isinstance(response_data, dict):
response_data = json.dumps(response_data)
self.mock_responses[endpoint] = response_data
return self
def _get_mock_response(self, url: str, method: str = "GET") -> MockResponse:
"""Get mock response for a URL"""
for endpoint, data in self.mock_responses.items():
if re.search(endpoint, url):
return MockResponse(data)
# Default response if no mock found
return MockResponse('{"error": "No mock data found"}', status=404)
def _mock_get(self, url: str, **kwargs) -> MockResponse:
"""Mock ClientSession.get method"""
return self._get_mock_response(url, "GET")
def _mock_post(self, url: str, **kwargs) -> MockResponse:
"""Mock ClientSession.post method"""
return self._get_mock_response(url, "POST")
def __enter__(self):
"""Context manager entry - start mocking"""
# Store original methods
self.original_get = ClientSession.get
self.original_post = ClientSession.post
# Replace with mocks
ClientSession.get = self._mock_get
ClientSession.post = self._mock_post
return self
def __exit__(self, exc_type, exc_val, exc_tb):
"""Context manager exit - restore original methods"""
# Restore original methods
ClientSession.get = self.original_get
ClientSession.post = self.original_post
# Convenience function for quick setup
def mock_http_client(mock_data: OrderedDict[str, str]) -> HttpMockFramework:
"""
Create and configure an HTTP mock framework
Args:
mock_data: Dictionary mapping endpoints to filenames in tests/resources
Example:
with mock_http_client({
"/api/v3/catalog": "catalog/spaces.json",
"/api/v3/sql": "sql/job_status.json"
}) as mock:
# Your test code here
pass
"""
framework = HttpMockFramework()
for endpoint, filename in mock_data.items():
framework.load_mock_data(endpoint, filename)
return framework
# Starlette Logging App Components
class LogEntry(BaseModel):
method: str
url: str
path: str
query_params: Dict[str, Any]
headers: Dict[str, Any]
response_status: Optional[int] = None
json: Optional[Dict[str, Any]] = None
model_config = ConfigDict(validate_assignment=True)
class LoggingMiddleware(BaseHTTPMiddleware):
"""Middleware to log all incoming requests to a JSON file or file object."""
def __init__(self, app, log_file: Union[str, io.TextIOWrapper, Path]):
super().__init__(app)
self.log_file = (
open(log_file, "a")
if isinstance(log_file, Path) or isinstance(log_file, str)
else log_file
)
self.lock = asyncio.Lock()
async def dispatch(self, request: Request, call_next):
# Capture request details
try:
body = await request.json()
except Exception:
body = None
log_entry = LogEntry.model_validate(
{
"method": request.method,
"url": str(request.url),
"path": request.url.path,
"query_params": dict(request.query_params),
"headers": dict(request.headers),
"json": body,
}
)
response = await call_next(request)
log_entry.response_status = response.status_code
async with self.lock:
self.log_file.write(f"{log_entry.model_dump_json()}\n")
return response
def create_catch_all_handler(mock_data: Optional[OrderedDict[str, str]] = None):
"""Create a catch-all handler that uses HttpMockFramework for responses."""
# Create HTTP framework if mock data is provided
http_framework = HttpMockFramework()
if mock_data:
for endpoint, filename in mock_data.items():
http_framework.load_mock_data(endpoint, filename)
async def handler(request: Request):
"""Handle the request using mock framework or default response."""
if http_framework:
# Use the HTTP framework to get a mock response
mock_response = http_framework._get_mock_response(
request.url.path, request.method
)
response_data = await mock_response.json()
return JSONResponse(response_data, status_code=mock_response.status)
else:
raise NotImplementedError(f"Mock data not provided for {request.url}")
return handler
def create_logging_app(
mock_data: Optional[OrderedDict[str, str]] = None,
log_file: Union[str, TextIO, Path] = "api_logs.json",
) -> Starlette:
handler = create_catch_all_handler(mock_data)
middleware = [Middleware(LoggingMiddleware, log_file=log_file)]
# Define routes - catch all paths and methods
routes = [
Route(
"/{path:path}",
handler,
methods=["GET", "POST", "PUT", "DELETE", "PATCH", "HEAD", "OPTIONS"],
)
]
return Starlette(debug=True, routes=routes, middleware=middleware)
def start_server(
runner: Coroutine,
should_exit: Callable[[bool], None],
name: str,
addtional_runners: Optional[List[Coroutine]] = None,
):
stop_event = threading.Event()
def run_server():
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
# Monitor stop event
async def monitor_stop():
await asyncio.to_thread(stop_event.wait)
log.logger(f"start_server {name}").info("Stop event set")
should_exit(True)
# monitor_task = loop.create_task(monitor_stop())
runners = [runner, monitor_stop()] + (
addtional_runners if addtional_runners else []
)
log.logger(f"start_server {name}").info(
f"Runners = {runners} ({addtional_runners})"
)
loop.run_until_complete(asyncio.gather(*runners))
loop.close()
log.logger(f"start_server {name}").info(f"Starting server {name}")
server_thread = threading.Thread(target=run_server, name=name, daemon=True)
server_thread.start()
time.sleep(0.5)
return server_thread, stop_event
def start_server_with_app(
app: Starlette,
host: str = "127.0.0.1",
port: int = 8000,
log_level="warning",
name="mcp-server",
additional_runners: Optional[List[Coroutine]] = None,
):
config = uvicorn.Config(
app=app, host=host, port=port, log_level=log_level, access_log=False
)
server = uvicorn.Server(config)
def should_exit(v: bool):
server.should_exit = v
return start_server(server.serve(), should_exit, name, additional_runners)
def start_logging_server(
mock_data: Optional[OrderedDict[str, str]] = None,
log_file: Union[str, TextIO, Path] = "api_logs.json",
host: str = "127.0.0.1",
port: int = 8000,
log_level="warning",
):
app = create_logging_app(mock_data=mock_data, log_file=log_file)
return start_server_with_app(
app,
host=host,
port=port,
log_level=log_level,
name="logging-server",
)
class ServerFixture:
def __init__(
self,
url: str,
stop_event: threading.Event,
server_thread: threading.Thread,
metrics_port: int = None,
):
self.url = url
self.stop_event = stop_event
self.server_thread = server_thread
self.metrics_port = metrics_port
def close(self):
self.stop_event.set()
self.server_thread.join(timeout=5)
log.logger("ServerFixture").info(
f"Server stopped, thread {self.server_thread.is_alive()}"
)
class LoggingServerFixture(ServerFixture):
def __init__(
self,
url: str,
log: io.StringIO,
stop_event: threading.Event,
server_thread: threading.Thread,
):
super().__init__(url, stop_event, server_thread)
self.log = log
def logs(self) -> List[LogEntry]:
return [
LogEntry.model_validate_json(line)
for line in self.log.getvalue().splitlines()
]
def create_pytest_logging_server_fixture(
mock_data: Optional[OrderedDict[str, str]] = None,
port: int = 8000,
log_level="warning",
) -> LoggingServerFixture:
log_file = io.StringIO()
thread, stop_event = start_logging_server(
mock_data=mock_data, log_file=log_file, port=port
)
# Return server URL
server_url = f"http://127.0.0.1:{port}"
return LoggingServerFixture(
url=server_url, log=log_file, stop_event=stop_event, server_thread=thread
)