import { Buffer } from "node:buffer";
import type { AuthRequest, OAuthHelpers } from "@cloudflare/workers-oauth-provider";
import {
addApprovedClient,
createOAuthState,
fetchUpstreamAuthToken,
generateCSRFProtection,
getUpstreamAuthorizeUrl,
isClientApproved,
OAuthError,
type Props,
renderApprovalDialog,
validateCSRFToken,
validateOAuthState,
} from "./workers-oauth-utils";
type EnvWithOauth = Env & { OAUTH_PROVIDER: OAuthHelpers };
export async function handleAccessRequest(
request: Request,
env: EnvWithOauth,
_ctx: ExecutionContext,
) {
const { pathname, searchParams } = new URL(request.url);
if (request.method === "GET" && pathname === "/authorize") {
const oauthReqInfo = await env.OAUTH_PROVIDER.parseAuthRequest(request);
const { clientId } = oauthReqInfo;
if (!clientId) {
return new Response("Invalid request", { status: 400 });
}
// Check if client is already approved
if (await isClientApproved(request, clientId, env.COOKIE_ENCRYPTION_KEY)) {
// Skip approval dialog but still create secure state
const { stateToken } = await createOAuthState(oauthReqInfo, env.OAUTH_KV);
return redirectToAccess(request, env, stateToken);
}
// Generate CSRF protection for the approval form
const { token: csrfToken, setCookie } = generateCSRFProtection();
return renderApprovalDialog(request, {
client: await env.OAUTH_PROVIDER.lookupClient(clientId),
csrfToken,
server: {
description: "This is a demo MCP Remote Server using Access for authentication.",
logo: "https://avatars.githubusercontent.com/u/314135?s=200&v=4",
name: "Cloudflare Access MCP Server",
},
setCookie,
state: { oauthReqInfo },
});
}
if (request.method === "POST" && pathname === "/authorize") {
try {
// Read form data once at top
const formData = await request.formData();
// Validate CSRF token - pass parsed FormData
validateCSRFToken(formData, request);
// Extract state from form data
const encodedState = formData.get("state");
if (!encodedState || typeof encodedState !== "string") {
return new Response("Missing state in form data", { status: 400 });
}
let state: { oauthReqInfo?: AuthRequest };
try {
state = JSON.parse(atob(encodedState));
} catch (_e) {
return new Response("Invalid state data", { status: 400 });
}
if (!state.oauthReqInfo || !state.oauthReqInfo.clientId) {
return new Response("Invalid request", { status: 400 });
}
// Add client to approved list
const approvedClientCookie = await addApprovedClient(
request,
state.oauthReqInfo.clientId,
env.COOKIE_ENCRYPTION_KEY,
);
// Create OAuth state with CSRF protection
const { stateToken } = await createOAuthState(state.oauthReqInfo, env.OAUTH_KV);
return redirectToAccess(request, env, stateToken, {
"Set-Cookie": approvedClientCookie,
});
} catch (error: any) {
console.error("POST /authorize error:", error);
if (error instanceof OAuthError) {
return error.toResponse();
}
// Unexpected non-OAuth error
return new Response(`Internal server error: ${error.message}`, { status: 500 });
}
}
if (request.method === "GET" && pathname === "/callback") {
// Validate OAuth state (retrieves stored data from KV)
let oauthReqInfo: AuthRequest;
try {
const result = await validateOAuthState(request, env.OAUTH_KV);
oauthReqInfo = result.oauthReqInfo;
// No clearCookie variable needed since we're using KV-only state validation
} catch (error: any) {
if (error instanceof OAuthError) {
return error.toResponse();
}
// Unexpected non-OAuth error
return new Response("Internal server error", { status: 500 });
}
if (!oauthReqInfo.clientId) {
return new Response("Invalid OAuth request data", { status: 400 });
}
// Exchange the code for an access token
const [accessToken, idToken, errResponse] = await fetchUpstreamAuthToken({
client_id: env.ACCESS_CLIENT_ID,
client_secret: env.ACCESS_CLIENT_SECRET,
code: searchParams.get("code") ?? undefined,
redirect_uri: new URL("/callback", request.url).href,
upstream_url: env.ACCESS_TOKEN_URL,
});
if (errResponse) {
return errResponse;
}
const idTokenClaims = await verifyToken(env, idToken);
const user = {
email: idTokenClaims.email,
name: idTokenClaims.name,
sub: idTokenClaims.sub,
};
// Return back to the MCP client a new token
const { redirectTo } = await env.OAUTH_PROVIDER.completeAuthorization({
metadata: {
label: user.name,
},
// This will be available on this.props inside MyMCP
props: {
accessToken,
email: user.email,
login: user.sub,
name: user.name,
} as Props,
request: oauthReqInfo,
scope: oauthReqInfo.scope,
userId: user.sub,
});
return Response.redirect(redirectTo, 302);
}
return new Response("Not Found", { status: 404 });
}
async function redirectToAccess(
request: Request,
env: Env,
stateToken: string,
headers: Record<string, string> = {},
) {
return new Response(null, {
headers: {
...headers,
location: getUpstreamAuthorizeUrl({
client_id: env.ACCESS_CLIENT_ID,
redirect_uri: new URL("/callback", request.url).href,
scope: "openid email profile",
state: stateToken,
upstream_url: env.ACCESS_AUTHORIZATION_URL,
}),
},
status: 302,
});
}
/**
* Helper to get the Access public keys from the certs endpoint
*/
async function fetchAccessPublicKey(env: Env, kid: string) {
if (!env.ACCESS_JWKS_URL) {
throw new Error("access jwks url not provided");
}
// TODO: cache this
const resp = await fetch(env.ACCESS_JWKS_URL);
const keys = (await resp.json()) as {
keys: (JsonWebKey & { kid: string })[];
};
const jwk = keys.keys.filter((key) => key.kid === kid)[0];
const key = await crypto.subtle.importKey(
"jwk",
jwk,
{
hash: "SHA-256",
name: "RSASSA-PKCS1-v1_5",
},
false,
["verify"],
);
return key;
}
/**
* Parse a JWT into its respective pieces. Does not do any validation other than form checking.
*/
function parseJWT(token: string) {
const tokenParts = token.split(".");
if (tokenParts.length !== 3) {
throw new Error("token must have 3 parts");
}
return {
data: `${tokenParts[0]}.${tokenParts[1]}`,
header: JSON.parse(Buffer.from(tokenParts[0], "base64url").toString()),
payload: JSON.parse(Buffer.from(tokenParts[1], "base64url").toString()),
signature: tokenParts[2],
};
}
/**
* Validates the provided token using the Access public key set
*/
async function verifyToken(env: Env, token: string) {
const jwt = parseJWT(token);
const key = await fetchAccessPublicKey(env, jwt.header.kid);
const verified = await crypto.subtle.verify(
"RSASSA-PKCS1-v1_5",
key,
Buffer.from(jwt.signature, "base64url"),
Buffer.from(jwt.data),
);
if (!verified) {
throw new Error("failed to verify token");
}
const claims = jwt.payload;
const now = Math.floor(Date.now() / 1000);
// Validate expiration
if (claims.exp < now) {
throw new Error("expired token");
}
return claims;
}