from __future__ import annotations
from typing import TYPE_CHECKING, ParamSpec, TypeVar
from mcp.server.auth.middleware.auth_context import (
get_access_token as _sdk_get_access_token,
)
from starlette.requests import Request
from fastmcp.server.auth import AccessToken
if TYPE_CHECKING:
from fastmcp.server.context import Context
P = ParamSpec("P")
R = TypeVar("R")
__all__ = [
"get_context",
"get_http_request",
"get_http_headers",
"get_access_token",
"AccessToken",
]
# --- Context ---
def get_context() -> Context:
from fastmcp.server.context import _current_context
context = _current_context.get()
if context is None:
raise RuntimeError("No active context found.")
return context
# --- HTTP Request ---
def get_http_request() -> Request:
from mcp.server.lowlevel.server import request_ctx
request = None
try:
request = request_ctx.get().request
except LookupError:
pass
if request is None:
raise RuntimeError("No active HTTP request found.")
return request
def get_http_headers(include_all: bool = False) -> dict[str, str]:
"""
Extract headers from the current HTTP request if available.
Never raises an exception, even if there is no active HTTP request (in which case
an empty dict is returned).
By default, strips problematic headers like `content-length` that cause issues if forwarded to downstream clients.
If `include_all` is True, all headers are returned.
"""
if include_all:
exclude_headers = set()
else:
exclude_headers = {
"host",
"content-length",
"connection",
"transfer-encoding",
"upgrade",
"te",
"keep-alive",
"expect",
"accept",
# Proxy-related headers
"proxy-authenticate",
"proxy-authorization",
"proxy-connection",
# MCP-related headers
"mcp-session-id",
}
# (just in case)
if not all(h.lower() == h for h in exclude_headers):
raise ValueError("Excluded headers must be lowercase")
headers = {}
try:
request = get_http_request()
for name, value in request.headers.items():
lower_name = name.lower()
if lower_name not in exclude_headers:
headers[lower_name] = str(value)
return headers
except RuntimeError:
return {}
# --- Access Token ---
def get_access_token() -> AccessToken | None:
"""
Get the FastMCP access token from the current context.
Returns:
The access token if an authenticated user is available, None otherwise.
"""
#
obj = _sdk_get_access_token()
if obj is None or isinstance(obj, AccessToken):
return obj
# If the object is not a FastMCP AccessToken, convert it to one if the fields are compatible
# This is a workaround for the case where the SDK returns a different type
# If it fails, it will raise a TypeError
try:
return AccessToken(**obj.model_dump())
except Exception as e:
raise TypeError(
f"Expected fastmcp.server.auth.auth.AccessToken, got {type(obj).__name__}. "
"Ensure the SDK is using the correct AccessToken type."
) from e