diff --git a/client/src/lib/auth.ts b/client/src/lib/auth.ts index 592dc17..1ae0efd 100644 --- a/client/src/lib/auth.ts +++ b/client/src/lib/auth.ts @@ -1,134 +1,71 @@ -import pkceChallenge from "pkce-challenge"; +import { OAuthClientInformation, OAuthClientInformationSchema, OAuthClientProvider, OAuthTokens, OAuthTokensSchema } from "@modelcontextprotocol/sdk/client/auth.js"; import { SESSION_KEYS } from "./constants"; -import { z } from "zod"; -export const OAuthMetadataSchema = z.object({ - authorization_endpoint: z.string(), - token_endpoint: z.string(), -}); +export class InspectorOAuthClientProvider implements OAuthClientProvider { + get redirectUrl() { + return window.location.origin + "/oauth/callback"; + } -export type OAuthMetadata = z.infer; + get clientMetadata() { + return { + redirect_uris: [this.redirectUrl], + token_endpoint_auth_method: "none", + grant_types: ["authorization_code", "refresh_token"], + response_types: ["code"], + client_name: "MCP Inspector", + client_uri: "https://github.com/modelcontextprotocol/inspector", + }; + } -export const OAuthTokensSchema = z.object({ - access_token: z.string(), - refresh_token: z.string().optional(), - expires_in: z.number().optional(), -}); - -export type OAuthTokens = z.infer; - -export async function discoverOAuthMetadata( - serverUrl: string, -): Promise { - try { - const url = new URL("/.well-known/oauth-authorization-server", serverUrl); - const response = await fetch(url.toString()); - - if (response.ok) { - const metadata = await response.json(); - const validatedMetadata = OAuthMetadataSchema.parse({ - authorization_endpoint: metadata.authorization_endpoint, - token_endpoint: metadata.token_endpoint, - }); - return validatedMetadata; + async clientInformation() { + const value = sessionStorage.getItem(SESSION_KEYS.CLIENT_INFORMATION); + if (!value) { + return undefined; } - } catch (error) { - console.warn("OAuth metadata discovery failed:", error); + + return await OAuthClientInformationSchema.parseAsync(JSON.parse(value)); } - // Fall back to default endpoints - const baseUrl = new URL(serverUrl); - const defaultMetadata = { - authorization_endpoint: new URL("/authorize", baseUrl).toString(), - token_endpoint: new URL("/token", baseUrl).toString(), - }; - return OAuthMetadataSchema.parse(defaultMetadata); -} - -export async function startOAuthFlow(serverUrl: string): Promise { - // Generate PKCE challenge - const challenge = await pkceChallenge(); - const codeVerifier = challenge.code_verifier; - const codeChallenge = challenge.code_challenge; - - // Store code verifier for later use - sessionStorage.setItem(SESSION_KEYS.CODE_VERIFIER, codeVerifier); - - // Discover OAuth endpoints - const metadata = await discoverOAuthMetadata(serverUrl); - - // Build authorization URL - const authUrl = new URL(metadata.authorization_endpoint); - authUrl.searchParams.set("response_type", "code"); - authUrl.searchParams.set("code_challenge", codeChallenge); - authUrl.searchParams.set("code_challenge_method", "S256"); - authUrl.searchParams.set( - "redirect_uri", - window.location.origin + "/oauth/callback", - ); - - return authUrl.toString(); -} - -export async function handleOAuthCallback( - serverUrl: string, - code: string, -): Promise { - // Get stored code verifier - const codeVerifier = sessionStorage.getItem(SESSION_KEYS.CODE_VERIFIER); - if (!codeVerifier) { - throw new Error("No code verifier found"); - } - - // Discover OAuth endpoints - const metadata = await discoverOAuthMetadata(serverUrl); - // Exchange code for tokens - const response = await fetch(metadata.token_endpoint, { - method: "POST", - headers: { - "Content-Type": "application/json", - }, - body: JSON.stringify({ - grant_type: "authorization_code", - code, - code_verifier: codeVerifier, - redirect_uri: window.location.origin + "/oauth/callback", - }), - }); - - if (!response.ok) { - throw new Error("Token exchange failed"); - } - - const tokens = await response.json(); - return OAuthTokensSchema.parse(tokens); -} - -export async function refreshAccessToken( - serverUrl: string, -): Promise { - const refreshToken = sessionStorage.getItem(SESSION_KEYS.REFRESH_TOKEN); - if (!refreshToken) { - throw new Error("No refresh token available"); - } - - const metadata = await discoverOAuthMetadata(serverUrl); - - const response = await fetch(metadata.token_endpoint, { - method: "POST", - headers: { - "Content-Type": "application/json", - }, - body: JSON.stringify({ - grant_type: "refresh_token", - refresh_token: refreshToken, - }), - }); - - if (!response.ok) { - throw new Error("Token refresh failed"); - } - - const tokens = await response.json(); - return OAuthTokensSchema.parse(tokens); + saveClientInformation(clientInformation: OAuthClientInformation) { + sessionStorage.setItem( + SESSION_KEYS.CLIENT_INFORMATION, + JSON.stringify(clientInformation), + ); + } + + async tokens() { + const tokens = sessionStorage.getItem(SESSION_KEYS.TOKENS); + if (!tokens) { + return undefined; + } + + return await OAuthTokensSchema.parseAsync(JSON.parse(tokens)); + } + + saveTokens(tokens: OAuthTokens) { + sessionStorage.setItem( + SESSION_KEYS.TOKENS, + JSON.stringify(tokens), + ); + } + + redirectToAuthorization(authorizationUrl: URL) { + window.location.href = authorizationUrl.href; + } + + saveCodeVerifier(codeVerifier: string) { + sessionStorage.setItem( + SESSION_KEYS.CODE_VERIFIER, + codeVerifier, + ); + } + + codeVerifier() { + const verifier = sessionStorage.getItem(SESSION_KEYS.CODE_VERIFIER); + if (!verifier) { + throw new Error("No code verifier saved for session"); + } + + return verifier; + } } diff --git a/client/src/lib/constants.ts b/client/src/lib/constants.ts index 13a2370..4051bec 100644 --- a/client/src/lib/constants.ts +++ b/client/src/lib/constants.ts @@ -2,6 +2,6 @@ export const SESSION_KEYS = { CODE_VERIFIER: "mcp_code_verifier", SERVER_URL: "mcp_server_url", - ACCESS_TOKEN: "mcp_access_token", - REFRESH_TOKEN: "mcp_refresh_token", + TOKENS: "mcp_tokens", + CLIENT_INFORMATION: "mcp_client_information", } as const; diff --git a/client/src/lib/hooks/useConnection.ts b/client/src/lib/hooks/useConnection.ts index 6c42c3f..e23a532 100644 --- a/client/src/lib/hooks/useConnection.ts +++ b/client/src/lib/hooks/useConnection.ts @@ -16,9 +16,10 @@ import { import { useState } from "react"; import { toast } from "react-toastify"; import { z } from "zod"; -import { startOAuthFlow, refreshAccessToken } from "../auth"; import { SESSION_KEYS } from "../constants"; import { Notification, StdErrNotificationSchema } from "../notificationTypes"; +import { auth } from "@modelcontextprotocol/sdk/client/auth.js"; +import { InspectorOAuthClientProvider } from "../auth"; const DEFAULT_REQUEST_TIMEOUT_MSEC = 10000; @@ -121,45 +122,15 @@ export function useConnection({ } }; - const initiateOAuthFlow = async () => { - sessionStorage.removeItem(SESSION_KEYS.ACCESS_TOKEN); - sessionStorage.removeItem(SESSION_KEYS.REFRESH_TOKEN); - sessionStorage.setItem(SESSION_KEYS.SERVER_URL, sseUrl); - const redirectUrl = await startOAuthFlow(sseUrl); - window.location.href = redirectUrl; - }; - - const handleTokenRefresh = async () => { - try { - const tokens = await refreshAccessToken(sseUrl); - sessionStorage.setItem(SESSION_KEYS.ACCESS_TOKEN, tokens.access_token); - if (tokens.refresh_token) { - sessionStorage.setItem( - SESSION_KEYS.REFRESH_TOKEN, - tokens.refresh_token, - ); - } - return tokens.access_token; - } catch (error) { - console.error("Token refresh failed:", error); - await initiateOAuthFlow(); - throw error; - } - }; - + const authProvider = new InspectorOAuthClientProvider(); const handleAuthError = async (error: unknown) => { if (error instanceof SseError && error.code === 401) { - if (sessionStorage.getItem(SESSION_KEYS.REFRESH_TOKEN)) { - try { - await handleTokenRefresh(); - return true; - } catch (error) { - console.error("Token refresh failed:", error); - } - } else { - await initiateOAuthFlow(); - } + sessionStorage.setItem(SESSION_KEYS.SERVER_URL, sseUrl); + + const result = await auth(authProvider, { serverUrl: sseUrl }) + return result === "AUTHORIZED"; } + return false; }; @@ -192,9 +163,9 @@ export function useConnection({ } const headers: HeadersInit = {}; - const accessToken = sessionStorage.getItem(SESSION_KEYS.ACCESS_TOKEN); - if (accessToken) { - headers["Authorization"] = `Bearer ${accessToken}`; + const tokens = await authProvider.tokens(); + if (tokens) { + headers["Authorization"] = `Bearer ${tokens.access_token}`; } const clientTransport = new SSEClientTransport(backendUrl, {