Refactor to use auth from SDK

This commit is contained in:
Justin Spahr-Summers
2025-02-11 16:39:07 +00:00
parent 6d930ecae7
commit eb6af47b21
3 changed files with 76 additions and 168 deletions

View File

@@ -1,134 +1,71 @@
import pkceChallenge from "pkce-challenge"; import { OAuthClientInformation, OAuthClientInformationSchema, OAuthClientProvider, OAuthTokens, OAuthTokensSchema } from "@modelcontextprotocol/sdk/client/auth.js";
import { SESSION_KEYS } from "./constants"; import { SESSION_KEYS } from "./constants";
import { z } from "zod";
export const OAuthMetadataSchema = z.object({ export class InspectorOAuthClientProvider implements OAuthClientProvider {
authorization_endpoint: z.string(), get redirectUrl() {
token_endpoint: z.string(), return window.location.origin + "/oauth/callback";
});
export type OAuthMetadata = z.infer<typeof OAuthMetadataSchema>;
export const OAuthTokensSchema = z.object({
access_token: z.string(),
refresh_token: z.string().optional(),
expires_in: z.number().optional(),
});
export type OAuthTokens = z.infer<typeof OAuthTokensSchema>;
export async function discoverOAuthMetadata(
serverUrl: string,
): Promise<OAuthMetadata> {
try {
const url = new URL("/.well-known/oauth-authorization-server", serverUrl);
const response = await fetch(url.toString());
if (response.ok) {
const metadata = await response.json();
const validatedMetadata = OAuthMetadataSchema.parse({
authorization_endpoint: metadata.authorization_endpoint,
token_endpoint: metadata.token_endpoint,
});
return validatedMetadata;
}
} catch (error) {
console.warn("OAuth metadata discovery failed:", error);
} }
// Fall back to default endpoints get clientMetadata() {
const baseUrl = new URL(serverUrl); return {
const defaultMetadata = { redirect_uris: [this.redirectUrl],
authorization_endpoint: new URL("/authorize", baseUrl).toString(), token_endpoint_auth_method: "none",
token_endpoint: new URL("/token", baseUrl).toString(), grant_types: ["authorization_code", "refresh_token"],
response_types: ["code"],
client_name: "MCP Inspector",
client_uri: "https://github.com/modelcontextprotocol/inspector",
}; };
return OAuthMetadataSchema.parse(defaultMetadata); }
}
export async function startOAuthFlow(serverUrl: string): Promise<string> { async clientInformation() {
// Generate PKCE challenge const value = sessionStorage.getItem(SESSION_KEYS.CLIENT_INFORMATION);
const challenge = await pkceChallenge(); if (!value) {
const codeVerifier = challenge.code_verifier; return undefined;
const codeChallenge = challenge.code_challenge; }
// Store code verifier for later use return await OAuthClientInformationSchema.parseAsync(JSON.parse(value));
sessionStorage.setItem(SESSION_KEYS.CODE_VERIFIER, codeVerifier); }
// Discover OAuth endpoints saveClientInformation(clientInformation: OAuthClientInformation) {
const metadata = await discoverOAuthMetadata(serverUrl); sessionStorage.setItem(
SESSION_KEYS.CLIENT_INFORMATION,
// Build authorization URL JSON.stringify(clientInformation),
const authUrl = new URL(metadata.authorization_endpoint);
authUrl.searchParams.set("response_type", "code");
authUrl.searchParams.set("code_challenge", codeChallenge);
authUrl.searchParams.set("code_challenge_method", "S256");
authUrl.searchParams.set(
"redirect_uri",
window.location.origin + "/oauth/callback",
); );
}
return authUrl.toString(); async tokens() {
} const tokens = sessionStorage.getItem(SESSION_KEYS.TOKENS);
if (!tokens) {
export async function handleOAuthCallback( return undefined;
serverUrl: string, }
code: string,
): Promise<OAuthTokens> { return await OAuthTokensSchema.parseAsync(JSON.parse(tokens));
// Get stored code verifier }
const codeVerifier = sessionStorage.getItem(SESSION_KEYS.CODE_VERIFIER);
if (!codeVerifier) { saveTokens(tokens: OAuthTokens) {
throw new Error("No code verifier found"); sessionStorage.setItem(
} SESSION_KEYS.TOKENS,
JSON.stringify(tokens),
// Discover OAuth endpoints );
const metadata = await discoverOAuthMetadata(serverUrl); }
// Exchange code for tokens
const response = await fetch(metadata.token_endpoint, { redirectToAuthorization(authorizationUrl: URL) {
method: "POST", window.location.href = authorizationUrl.href;
headers: { }
"Content-Type": "application/json",
}, saveCodeVerifier(codeVerifier: string) {
body: JSON.stringify({ sessionStorage.setItem(
grant_type: "authorization_code", SESSION_KEYS.CODE_VERIFIER,
code, codeVerifier,
code_verifier: codeVerifier, );
redirect_uri: window.location.origin + "/oauth/callback", }
}),
}); codeVerifier() {
const verifier = sessionStorage.getItem(SESSION_KEYS.CODE_VERIFIER);
if (!response.ok) { if (!verifier) {
throw new Error("Token exchange failed"); throw new Error("No code verifier saved for session");
} }
const tokens = await response.json(); return verifier;
return OAuthTokensSchema.parse(tokens); }
}
export async function refreshAccessToken(
serverUrl: string,
): Promise<OAuthTokens> {
const refreshToken = sessionStorage.getItem(SESSION_KEYS.REFRESH_TOKEN);
if (!refreshToken) {
throw new Error("No refresh token available");
}
const metadata = await discoverOAuthMetadata(serverUrl);
const response = await fetch(metadata.token_endpoint, {
method: "POST",
headers: {
"Content-Type": "application/json",
},
body: JSON.stringify({
grant_type: "refresh_token",
refresh_token: refreshToken,
}),
});
if (!response.ok) {
throw new Error("Token refresh failed");
}
const tokens = await response.json();
return OAuthTokensSchema.parse(tokens);
} }

View File

@@ -2,6 +2,6 @@
export const SESSION_KEYS = { export const SESSION_KEYS = {
CODE_VERIFIER: "mcp_code_verifier", CODE_VERIFIER: "mcp_code_verifier",
SERVER_URL: "mcp_server_url", SERVER_URL: "mcp_server_url",
ACCESS_TOKEN: "mcp_access_token", TOKENS: "mcp_tokens",
REFRESH_TOKEN: "mcp_refresh_token", CLIENT_INFORMATION: "mcp_client_information",
} as const; } as const;

View File

@@ -16,9 +16,10 @@ import {
import { useState } from "react"; import { useState } from "react";
import { toast } from "react-toastify"; import { toast } from "react-toastify";
import { z } from "zod"; import { z } from "zod";
import { startOAuthFlow, refreshAccessToken } from "../auth";
import { SESSION_KEYS } from "../constants"; import { SESSION_KEYS } from "../constants";
import { Notification, StdErrNotificationSchema } from "../notificationTypes"; import { Notification, StdErrNotificationSchema } from "../notificationTypes";
import { auth } from "@modelcontextprotocol/sdk/client/auth.js";
import { InspectorOAuthClientProvider } from "../auth";
const DEFAULT_REQUEST_TIMEOUT_MSEC = 10000; const DEFAULT_REQUEST_TIMEOUT_MSEC = 10000;
@@ -121,45 +122,15 @@ export function useConnection({
} }
}; };
const initiateOAuthFlow = async () => { const authProvider = new InspectorOAuthClientProvider();
sessionStorage.removeItem(SESSION_KEYS.ACCESS_TOKEN);
sessionStorage.removeItem(SESSION_KEYS.REFRESH_TOKEN);
sessionStorage.setItem(SESSION_KEYS.SERVER_URL, sseUrl);
const redirectUrl = await startOAuthFlow(sseUrl);
window.location.href = redirectUrl;
};
const handleTokenRefresh = async () => {
try {
const tokens = await refreshAccessToken(sseUrl);
sessionStorage.setItem(SESSION_KEYS.ACCESS_TOKEN, tokens.access_token);
if (tokens.refresh_token) {
sessionStorage.setItem(
SESSION_KEYS.REFRESH_TOKEN,
tokens.refresh_token,
);
}
return tokens.access_token;
} catch (error) {
console.error("Token refresh failed:", error);
await initiateOAuthFlow();
throw error;
}
};
const handleAuthError = async (error: unknown) => { const handleAuthError = async (error: unknown) => {
if (error instanceof SseError && error.code === 401) { if (error instanceof SseError && error.code === 401) {
if (sessionStorage.getItem(SESSION_KEYS.REFRESH_TOKEN)) { sessionStorage.setItem(SESSION_KEYS.SERVER_URL, sseUrl);
try {
await handleTokenRefresh(); const result = await auth(authProvider, { serverUrl: sseUrl })
return true; return result === "AUTHORIZED";
} catch (error) {
console.error("Token refresh failed:", error);
}
} else {
await initiateOAuthFlow();
}
} }
return false; return false;
}; };
@@ -192,9 +163,9 @@ export function useConnection({
} }
const headers: HeadersInit = {}; const headers: HeadersInit = {};
const accessToken = sessionStorage.getItem(SESSION_KEYS.ACCESS_TOKEN); const tokens = await authProvider.tokens();
if (accessToken) { if (tokens) {
headers["Authorization"] = `Bearer ${accessToken}`; headers["Authorization"] = `Bearer ${tokens.access_token}`;
} }
const clientTransport = new SSEClientTransport(backendUrl, { const clientTransport = new SSEClientTransport(backendUrl, {