Refactor to use auth from SDK

This commit is contained in:
Justin Spahr-Summers
2025-02-11 16:39:07 +00:00
parent 6d930ecae7
commit eb6af47b21
3 changed files with 76 additions and 168 deletions

View File

@@ -1,134 +1,71 @@
import pkceChallenge from "pkce-challenge"; import { OAuthClientInformation, OAuthClientInformationSchema, OAuthClientProvider, OAuthTokens, OAuthTokensSchema } from "@modelcontextprotocol/sdk/client/auth.js";
import { SESSION_KEYS } from "./constants"; import { SESSION_KEYS } from "./constants";
import { z } from "zod";
export const OAuthMetadataSchema = z.object({ export class InspectorOAuthClientProvider implements OAuthClientProvider {
authorization_endpoint: z.string(), get redirectUrl() {
token_endpoint: z.string(), return window.location.origin + "/oauth/callback";
}); }
export type OAuthMetadata = z.infer<typeof OAuthMetadataSchema>; get clientMetadata() {
return {
redirect_uris: [this.redirectUrl],
token_endpoint_auth_method: "none",
grant_types: ["authorization_code", "refresh_token"],
response_types: ["code"],
client_name: "MCP Inspector",
client_uri: "https://github.com/modelcontextprotocol/inspector",
};
}
export const OAuthTokensSchema = z.object({ async clientInformation() {
access_token: z.string(), const value = sessionStorage.getItem(SESSION_KEYS.CLIENT_INFORMATION);
refresh_token: z.string().optional(), if (!value) {
expires_in: z.number().optional(), return undefined;
});
export type OAuthTokens = z.infer<typeof OAuthTokensSchema>;
export async function discoverOAuthMetadata(
serverUrl: string,
): Promise<OAuthMetadata> {
try {
const url = new URL("/.well-known/oauth-authorization-server", serverUrl);
const response = await fetch(url.toString());
if (response.ok) {
const metadata = await response.json();
const validatedMetadata = OAuthMetadataSchema.parse({
authorization_endpoint: metadata.authorization_endpoint,
token_endpoint: metadata.token_endpoint,
});
return validatedMetadata;
} }
} catch (error) {
console.warn("OAuth metadata discovery failed:", error); return await OAuthClientInformationSchema.parseAsync(JSON.parse(value));
} }
// Fall back to default endpoints saveClientInformation(clientInformation: OAuthClientInformation) {
const baseUrl = new URL(serverUrl); sessionStorage.setItem(
const defaultMetadata = { SESSION_KEYS.CLIENT_INFORMATION,
authorization_endpoint: new URL("/authorize", baseUrl).toString(), JSON.stringify(clientInformation),
token_endpoint: new URL("/token", baseUrl).toString(), );
}; }
return OAuthMetadataSchema.parse(defaultMetadata);
} async tokens() {
const tokens = sessionStorage.getItem(SESSION_KEYS.TOKENS);
export async function startOAuthFlow(serverUrl: string): Promise<string> { if (!tokens) {
// Generate PKCE challenge return undefined;
const challenge = await pkceChallenge(); }
const codeVerifier = challenge.code_verifier;
const codeChallenge = challenge.code_challenge; return await OAuthTokensSchema.parseAsync(JSON.parse(tokens));
}
// Store code verifier for later use
sessionStorage.setItem(SESSION_KEYS.CODE_VERIFIER, codeVerifier); saveTokens(tokens: OAuthTokens) {
sessionStorage.setItem(
// Discover OAuth endpoints SESSION_KEYS.TOKENS,
const metadata = await discoverOAuthMetadata(serverUrl); JSON.stringify(tokens),
);
// Build authorization URL }
const authUrl = new URL(metadata.authorization_endpoint);
authUrl.searchParams.set("response_type", "code"); redirectToAuthorization(authorizationUrl: URL) {
authUrl.searchParams.set("code_challenge", codeChallenge); window.location.href = authorizationUrl.href;
authUrl.searchParams.set("code_challenge_method", "S256"); }
authUrl.searchParams.set(
"redirect_uri", saveCodeVerifier(codeVerifier: string) {
window.location.origin + "/oauth/callback", sessionStorage.setItem(
); SESSION_KEYS.CODE_VERIFIER,
codeVerifier,
return authUrl.toString(); );
} }
export async function handleOAuthCallback( codeVerifier() {
serverUrl: string, const verifier = sessionStorage.getItem(SESSION_KEYS.CODE_VERIFIER);
code: string, if (!verifier) {
): Promise<OAuthTokens> { throw new Error("No code verifier saved for session");
// Get stored code verifier }
const codeVerifier = sessionStorage.getItem(SESSION_KEYS.CODE_VERIFIER);
if (!codeVerifier) { return verifier;
throw new Error("No code verifier found"); }
}
// Discover OAuth endpoints
const metadata = await discoverOAuthMetadata(serverUrl);
// Exchange code for tokens
const response = await fetch(metadata.token_endpoint, {
method: "POST",
headers: {
"Content-Type": "application/json",
},
body: JSON.stringify({
grant_type: "authorization_code",
code,
code_verifier: codeVerifier,
redirect_uri: window.location.origin + "/oauth/callback",
}),
});
if (!response.ok) {
throw new Error("Token exchange failed");
}
const tokens = await response.json();
return OAuthTokensSchema.parse(tokens);
}
export async function refreshAccessToken(
serverUrl: string,
): Promise<OAuthTokens> {
const refreshToken = sessionStorage.getItem(SESSION_KEYS.REFRESH_TOKEN);
if (!refreshToken) {
throw new Error("No refresh token available");
}
const metadata = await discoverOAuthMetadata(serverUrl);
const response = await fetch(metadata.token_endpoint, {
method: "POST",
headers: {
"Content-Type": "application/json",
},
body: JSON.stringify({
grant_type: "refresh_token",
refresh_token: refreshToken,
}),
});
if (!response.ok) {
throw new Error("Token refresh failed");
}
const tokens = await response.json();
return OAuthTokensSchema.parse(tokens);
} }

View File

@@ -2,6 +2,6 @@
export const SESSION_KEYS = { export const SESSION_KEYS = {
CODE_VERIFIER: "mcp_code_verifier", CODE_VERIFIER: "mcp_code_verifier",
SERVER_URL: "mcp_server_url", SERVER_URL: "mcp_server_url",
ACCESS_TOKEN: "mcp_access_token", TOKENS: "mcp_tokens",
REFRESH_TOKEN: "mcp_refresh_token", CLIENT_INFORMATION: "mcp_client_information",
} as const; } as const;

View File

@@ -16,9 +16,10 @@ import {
import { useState } from "react"; import { useState } from "react";
import { toast } from "react-toastify"; import { toast } from "react-toastify";
import { z } from "zod"; import { z } from "zod";
import { startOAuthFlow, refreshAccessToken } from "../auth";
import { SESSION_KEYS } from "../constants"; import { SESSION_KEYS } from "../constants";
import { Notification, StdErrNotificationSchema } from "../notificationTypes"; import { Notification, StdErrNotificationSchema } from "../notificationTypes";
import { auth } from "@modelcontextprotocol/sdk/client/auth.js";
import { InspectorOAuthClientProvider } from "../auth";
const DEFAULT_REQUEST_TIMEOUT_MSEC = 10000; const DEFAULT_REQUEST_TIMEOUT_MSEC = 10000;
@@ -121,45 +122,15 @@ export function useConnection({
} }
}; };
const initiateOAuthFlow = async () => { const authProvider = new InspectorOAuthClientProvider();
sessionStorage.removeItem(SESSION_KEYS.ACCESS_TOKEN);
sessionStorage.removeItem(SESSION_KEYS.REFRESH_TOKEN);
sessionStorage.setItem(SESSION_KEYS.SERVER_URL, sseUrl);
const redirectUrl = await startOAuthFlow(sseUrl);
window.location.href = redirectUrl;
};
const handleTokenRefresh = async () => {
try {
const tokens = await refreshAccessToken(sseUrl);
sessionStorage.setItem(SESSION_KEYS.ACCESS_TOKEN, tokens.access_token);
if (tokens.refresh_token) {
sessionStorage.setItem(
SESSION_KEYS.REFRESH_TOKEN,
tokens.refresh_token,
);
}
return tokens.access_token;
} catch (error) {
console.error("Token refresh failed:", error);
await initiateOAuthFlow();
throw error;
}
};
const handleAuthError = async (error: unknown) => { const handleAuthError = async (error: unknown) => {
if (error instanceof SseError && error.code === 401) { if (error instanceof SseError && error.code === 401) {
if (sessionStorage.getItem(SESSION_KEYS.REFRESH_TOKEN)) { sessionStorage.setItem(SESSION_KEYS.SERVER_URL, sseUrl);
try {
await handleTokenRefresh(); const result = await auth(authProvider, { serverUrl: sseUrl })
return true; return result === "AUTHORIZED";
} catch (error) {
console.error("Token refresh failed:", error);
}
} else {
await initiateOAuthFlow();
}
} }
return false; return false;
}; };
@@ -192,9 +163,9 @@ export function useConnection({
} }
const headers: HeadersInit = {}; const headers: HeadersInit = {};
const accessToken = sessionStorage.getItem(SESSION_KEYS.ACCESS_TOKEN); const tokens = await authProvider.tokens();
if (accessToken) { if (tokens) {
headers["Authorization"] = `Bearer ${accessToken}`; headers["Authorization"] = `Bearer ${tokens.access_token}`;
} }
const clientTransport = new SSEClientTransport(backendUrl, { const clientTransport = new SSEClientTransport(backendUrl, {