roots.pyā¢2.47 kB
import inspect
from collections.abc import Awaitable, Callable
from typing import TypeAlias
import mcp.types
import pydantic
from mcp import ClientSession
from mcp.client.session import ListRootsFnT
from mcp.shared.context import LifespanContextT, RequestContext
RootsList: TypeAlias = list[str] | list[mcp.types.Root] | list[str | mcp.types.Root]
RootsHandler: TypeAlias = (
Callable[[RequestContext[ClientSession, LifespanContextT]], RootsList]
| Callable[[RequestContext[ClientSession, LifespanContextT]], Awaitable[RootsList]]
)
def convert_roots_list(roots: RootsList) -> list[mcp.types.Root]:
roots_list = []
for r in roots:
if isinstance(r, mcp.types.Root):
roots_list.append(r)
elif isinstance(r, pydantic.FileUrl):
roots_list.append(mcp.types.Root(uri=r))
elif isinstance(r, str):
roots_list.append(mcp.types.Root(uri=pydantic.FileUrl(r)))
else:
raise ValueError(f"Invalid root: {r}")
return roots_list
def create_roots_callback(
handler: RootsList | RootsHandler,
) -> ListRootsFnT:
if isinstance(handler, list):
return _create_roots_callback_from_roots(handler)
elif inspect.isfunction(handler):
return _create_roots_callback_from_fn(handler)
else:
raise ValueError(f"Invalid roots handler: {handler}")
def _create_roots_callback_from_roots(
roots: RootsList,
) -> ListRootsFnT:
roots = convert_roots_list(roots)
async def _roots_callback(
context: RequestContext[ClientSession, LifespanContextT],
) -> mcp.types.ListRootsResult:
return mcp.types.ListRootsResult(roots=roots)
return _roots_callback
def _create_roots_callback_from_fn(
fn: Callable[[RequestContext[ClientSession, LifespanContextT]], RootsList]
| Callable[[RequestContext[ClientSession, LifespanContextT]], Awaitable[RootsList]],
) -> ListRootsFnT:
async def _roots_callback(
context: RequestContext[ClientSession, LifespanContextT],
) -> mcp.types.ListRootsResult | mcp.types.ErrorData:
try:
roots = fn(context)
if inspect.isawaitable(roots):
roots = await roots
return mcp.types.ListRootsResult(roots=convert_roots_list(roots))
except Exception as e:
return mcp.types.ErrorData(
code=mcp.types.INTERNAL_ERROR,
message=str(e),
)
return _roots_callback