from __future__ import annotations
import copy
import logging
import multiprocessing
import socket
import time
from collections.abc import Callable, Generator
from contextlib import contextmanager
from typing import TYPE_CHECKING, Any, Literal
from urllib.parse import parse_qs, urlparse
import httpx
import uvicorn
from fastmcp import settings
from fastmcp.client.auth.oauth import OAuth
from fastmcp.utilities.http import find_available_port
if TYPE_CHECKING:
from fastmcp.server.server import FastMCP
@contextmanager
def temporary_settings(**kwargs: Any):
"""
Temporarily override FastMCP setting values.
Args:
**kwargs: The settings to override, including nested settings.
Example:
Temporarily override a setting:
```python
import fastmcp
from fastmcp.utilities.tests import temporary_settings
with temporary_settings(log_level='DEBUG'):
assert fastmcp.settings.log_level == 'DEBUG'
assert fastmcp.settings.log_level == 'INFO'
```
"""
old_settings = copy.deepcopy(settings)
try:
# apply the new settings
for attr, value in kwargs.items():
settings.set_setting(attr, value)
yield
finally:
# restore the old settings
for attr in kwargs:
settings.set_setting(attr, old_settings.get_setting(attr))
def _run_server(mcp_server: FastMCP, transport: Literal["sse"], port: int) -> None:
# Some Starlette apps are not pickleable, so we need to create them here based on the indicated transport
if transport == "sse":
app = mcp_server.http_app(transport="sse")
else:
raise ValueError(f"Invalid transport: {transport}")
uvicorn_server = uvicorn.Server(
config=uvicorn.Config(
app=app,
host="127.0.0.1",
port=port,
log_level="error",
)
)
uvicorn_server.run()
@contextmanager
def run_server_in_process(
server_fn: Callable[..., None],
*args,
provide_host_and_port: bool = True,
**kwargs,
) -> Generator[str, None, None]:
"""
Context manager that runs a FastMCP server in a separate process and
returns the server URL. When the context manager is exited, the server process is killed.
Args:
server_fn: The function that runs a FastMCP server. FastMCP servers are
not pickleable, so we need a function that creates and runs one.
*args: Arguments to pass to the server function.
provide_host_and_port: Whether to provide the host and port to the server function as kwargs.
**kwargs: Keyword arguments to pass to the server function.
Returns:
The server URL.
"""
host = "127.0.0.1"
port = find_available_port()
if provide_host_and_port:
kwargs |= {"host": host, "port": port}
proc = multiprocessing.Process(
target=server_fn, args=args, kwargs=kwargs, daemon=True
)
proc.start()
# Wait for server to be running
max_attempts = 10
attempt = 0
while attempt < max_attempts and proc.is_alive():
try:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.connect((host, port))
break
except ConnectionRefusedError:
if attempt < 3:
time.sleep(0.01)
else:
time.sleep(0.1)
attempt += 1
else:
raise RuntimeError(f"Server failed to start after {max_attempts} attempts")
yield f"http://{host}:{port}"
proc.terminate()
proc.join(timeout=5)
if proc.is_alive():
# If it's still alive, then force kill it
proc.kill()
proc.join(timeout=2)
if proc.is_alive():
raise RuntimeError("Server process failed to terminate even after kill")
@contextmanager
def caplog_for_fastmcp(caplog):
"""Context manager to capture logs from FastMCP loggers even when propagation is disabled."""
caplog.clear()
logger = logging.getLogger("FastMCP")
logger.addHandler(caplog.handler)
try:
yield
finally:
logger.removeHandler(caplog.handler)
class HeadlessOAuth(OAuth):
"""
OAuth provider that bypasses browser interaction for testing.
This simulates the complete OAuth flow programmatically by making HTTP requests
instead of opening a browser and running a callback server. Useful for automated testing.
"""
def __init__(self, mcp_url: str, **kwargs):
"""Initialize HeadlessOAuth with stored response tracking."""
self._stored_response = None
super().__init__(mcp_url, **kwargs)
async def redirect_handler(self, authorization_url: str) -> None:
"""Make HTTP request to authorization URL and store response for callback handler."""
async with httpx.AsyncClient() as client:
response = await client.get(authorization_url, follow_redirects=False)
self._stored_response = response
async def callback_handler(self) -> tuple[str, str | None]:
"""Parse stored response and return (auth_code, state)."""
if not self._stored_response:
raise RuntimeError(
"No authorization response stored. redirect_handler must be called first."
)
response = self._stored_response
# Extract auth code from redirect location
if response.status_code == 302:
redirect_url = response.headers["location"]
parsed = urlparse(redirect_url)
query_params = parse_qs(parsed.query)
if "error" in query_params:
error = query_params["error"][0]
error_desc = query_params.get("error_description", ["Unknown error"])[0]
raise RuntimeError(
f"OAuth authorization failed: {error} - {error_desc}"
)
auth_code = query_params["code"][0]
state = query_params.get("state", [None])[0]
return auth_code, state
else:
raise RuntimeError(f"Authorization failed: {response.status_code}")