oauth_client.py•2.74 kB
"""
Before running, specify running MCP RS server URL.
To spin up RS server locally, see
examples/servers/simple-auth/README.md
cd to the `examples/snippets` directory and run:
uv run oauth-client
"""
import asyncio
from urllib.parse import parse_qs, urlparse
from pydantic import AnyUrl
from mcp import ClientSession
from mcp.client.auth import OAuthClientProvider, TokenStorage
from mcp.client.streamable_http import streamablehttp_client
from mcp.shared.auth import OAuthClientInformationFull, OAuthClientMetadata, OAuthToken
class InMemoryTokenStorage(TokenStorage):
"""Demo In-memory token storage implementation."""
def __init__(self):
self.tokens: OAuthToken | None = None
self.client_info: OAuthClientInformationFull | None = None
async def get_tokens(self) -> OAuthToken | None:
"""Get stored tokens."""
return self.tokens
async def set_tokens(self, tokens: OAuthToken) -> None:
"""Store tokens."""
self.tokens = tokens
async def get_client_info(self) -> OAuthClientInformationFull | None:
"""Get stored client information."""
return self.client_info
async def set_client_info(self, client_info: OAuthClientInformationFull) -> None:
"""Store client information."""
self.client_info = client_info
async def handle_redirect(auth_url: str) -> None:
print(f"Visit: {auth_url}")
async def handle_callback() -> tuple[str, str | None]:
callback_url = input("Paste callback URL: ")
params = parse_qs(urlparse(callback_url).query)
return params["code"][0], params.get("state", [None])[0]
async def main():
"""Run the OAuth client example."""
oauth_auth = OAuthClientProvider(
server_url="http://localhost:8001",
client_metadata=OAuthClientMetadata(
client_name="Example MCP Client",
redirect_uris=[AnyUrl("http://localhost:3000/callback")],
grant_types=["authorization_code", "refresh_token"],
response_types=["code"],
scope="user",
),
storage=InMemoryTokenStorage(),
redirect_handler=handle_redirect,
callback_handler=handle_callback,
)
async with streamablehttp_client("http://localhost:8001/mcp", auth=oauth_auth) as (read, write, _):
async with ClientSession(read, write) as session:
await session.initialize()
tools = await session.list_tools()
print(f"Available tools: {[tool.name for tool in tools.tools]}")
resources = await session.list_resources()
print(f"Available resources: {[r.uri for r in resources.resources]}")
def run():
asyncio.run(main())
if __name__ == "__main__":
run()