fix: store auth tokens with server-specific keys

Changes client information and access tokens to use server-specific keys in sessionStorage. This fixes issues where changing the server URL would try to use tokens from a different server.
This commit is contained in:
Glen Maddern
2025-04-16 10:14:08 +10:00
parent f7272d8d8c
commit df0b526a41
5 changed files with 55 additions and 20 deletions

View File

@@ -1,5 +1,5 @@
import { useEffect, useRef } from "react"; import { useEffect, useRef } from "react";
import { authProvider } from "../lib/auth"; import { InspectorOAuthClientProvider } from "../lib/auth";
import { SESSION_KEYS } from "../lib/constants"; import { SESSION_KEYS } from "../lib/constants";
import { auth } from "@modelcontextprotocol/sdk/client/auth.js"; import { auth } from "@modelcontextprotocol/sdk/client/auth.js";
@@ -25,7 +25,10 @@ const OAuthCallback = () => {
} }
try { try {
const result = await auth(authProvider, { // Create an auth provider with the current server URL
const serverAuthProvider = new InspectorOAuthClientProvider(serverUrl);
const result = await auth(serverAuthProvider, {
serverUrl, serverUrl,
authorizationCode: code, authorizationCode: code,
}); });

View File

@@ -5,9 +5,14 @@ import {
OAuthTokens, OAuthTokens,
OAuthTokensSchema, OAuthTokensSchema,
} from "@modelcontextprotocol/sdk/shared/auth.js"; } from "@modelcontextprotocol/sdk/shared/auth.js";
import { SESSION_KEYS } from "./constants"; import { SESSION_KEYS, getServerSpecificKey } from "./constants";
export class InspectorOAuthClientProvider implements OAuthClientProvider {
constructor(private serverUrl: string) {
// Save the server URL to session storage
sessionStorage.setItem(SESSION_KEYS.SERVER_URL, serverUrl);
}
class InspectorOAuthClientProvider implements OAuthClientProvider {
get redirectUrl() { get redirectUrl() {
return window.location.origin + "/oauth/callback"; return window.location.origin + "/oauth/callback";
} }
@@ -24,7 +29,11 @@ class InspectorOAuthClientProvider implements OAuthClientProvider {
} }
async clientInformation() { async clientInformation() {
const value = sessionStorage.getItem(SESSION_KEYS.CLIENT_INFORMATION); const key = getServerSpecificKey(
SESSION_KEYS.CLIENT_INFORMATION,
this.serverUrl,
);
const value = sessionStorage.getItem(key);
if (!value) { if (!value) {
return undefined; return undefined;
} }
@@ -33,14 +42,16 @@ class InspectorOAuthClientProvider implements OAuthClientProvider {
} }
saveClientInformation(clientInformation: OAuthClientInformation) { saveClientInformation(clientInformation: OAuthClientInformation) {
sessionStorage.setItem( const key = getServerSpecificKey(
SESSION_KEYS.CLIENT_INFORMATION, SESSION_KEYS.CLIENT_INFORMATION,
JSON.stringify(clientInformation), this.serverUrl,
); );
sessionStorage.setItem(key, JSON.stringify(clientInformation));
} }
async tokens() { async tokens() {
const tokens = sessionStorage.getItem(SESSION_KEYS.TOKENS); const key = getServerSpecificKey(SESSION_KEYS.TOKENS, this.serverUrl);
const tokens = sessionStorage.getItem(key);
if (!tokens) { if (!tokens) {
return undefined; return undefined;
} }
@@ -49,7 +60,8 @@ class InspectorOAuthClientProvider implements OAuthClientProvider {
} }
saveTokens(tokens: OAuthTokens) { saveTokens(tokens: OAuthTokens) {
sessionStorage.setItem(SESSION_KEYS.TOKENS, JSON.stringify(tokens)); const key = getServerSpecificKey(SESSION_KEYS.TOKENS, this.serverUrl);
sessionStorage.setItem(key, JSON.stringify(tokens));
} }
redirectToAuthorization(authorizationUrl: URL) { redirectToAuthorization(authorizationUrl: URL) {
@@ -57,11 +69,19 @@ class InspectorOAuthClientProvider implements OAuthClientProvider {
} }
saveCodeVerifier(codeVerifier: string) { saveCodeVerifier(codeVerifier: string) {
sessionStorage.setItem(SESSION_KEYS.CODE_VERIFIER, codeVerifier); const key = getServerSpecificKey(
SESSION_KEYS.CODE_VERIFIER,
this.serverUrl,
);
sessionStorage.setItem(key, codeVerifier);
} }
codeVerifier() { codeVerifier() {
const verifier = sessionStorage.getItem(SESSION_KEYS.CODE_VERIFIER); const key = getServerSpecificKey(
SESSION_KEYS.CODE_VERIFIER,
this.serverUrl,
);
const verifier = sessionStorage.getItem(key);
if (!verifier) { if (!verifier) {
throw new Error("No code verifier saved for session"); throw new Error("No code verifier saved for session");
} }
@@ -69,5 +89,3 @@ class InspectorOAuthClientProvider implements OAuthClientProvider {
return verifier; return verifier;
} }
} }
export const authProvider = new InspectorOAuthClientProvider();

View File

@@ -8,6 +8,15 @@ export const SESSION_KEYS = {
CLIENT_INFORMATION: "mcp_client_information", CLIENT_INFORMATION: "mcp_client_information",
} as const; } as const;
// Generate server-specific session storage keys
export const getServerSpecificKey = (
baseKey: string,
serverUrl?: string,
): string => {
if (!serverUrl) return baseKey;
return `[${serverUrl}] ${baseKey}`;
};
export type ConnectionStatus = export type ConnectionStatus =
| "disconnected" | "disconnected"
| "connected" | "connected"

View File

@@ -45,9 +45,9 @@ jest.mock("@/hooks/use-toast", () => ({
// Mock the auth provider // Mock the auth provider
jest.mock("../../auth", () => ({ jest.mock("../../auth", () => ({
authProvider: { InspectorOAuthClientProvider: jest.fn().mockImplementation(() => ({
tokens: jest.fn().mockResolvedValue({ access_token: "mock-token" }), tokens: jest.fn().mockResolvedValue({ access_token: "mock-token" }),
}, })),
})); }));
describe("useConnection", () => { describe("useConnection", () => {

View File

@@ -28,10 +28,10 @@ import { RequestOptions } from "@modelcontextprotocol/sdk/shared/protocol.js";
import { useState } from "react"; import { useState } from "react";
import { useToast } from "@/hooks/use-toast"; import { useToast } from "@/hooks/use-toast";
import { z } from "zod"; import { z } from "zod";
import { ConnectionStatus, SESSION_KEYS } from "../constants"; import { ConnectionStatus } from "../constants";
import { Notification, StdErrNotificationSchema } from "../notificationTypes"; import { Notification, StdErrNotificationSchema } from "../notificationTypes";
import { auth } from "@modelcontextprotocol/sdk/client/auth.js"; import { auth } from "@modelcontextprotocol/sdk/client/auth.js";
import { authProvider } from "../auth"; import { InspectorOAuthClientProvider } from "../auth";
import packageJson from "../../../package.json"; import packageJson from "../../../package.json";
import { import {
getMCPProxyAddress, getMCPProxyAddress,
@@ -246,9 +246,10 @@ export function useConnection({
const handleAuthError = async (error: unknown) => { const handleAuthError = async (error: unknown) => {
if (error instanceof SseError && error.code === 401) { if (error instanceof SseError && error.code === 401) {
sessionStorage.setItem(SESSION_KEYS.SERVER_URL, sseUrl); // Create a new auth provider with the current server URL
const serverAuthProvider = new InspectorOAuthClientProvider(sseUrl);
const result = await auth(authProvider, { serverUrl: sseUrl }); const result = await auth(serverAuthProvider, { serverUrl: sseUrl });
return result === "AUTHORIZED"; return result === "AUTHORIZED";
} }
@@ -292,8 +293,12 @@ export function useConnection({
// proxying through the inspector server first. // proxying through the inspector server first.
const headers: HeadersInit = {}; const headers: HeadersInit = {};
// Create an auth provider with the current server URL
const serverAuthProvider = new InspectorOAuthClientProvider(sseUrl);
// Use manually provided bearer token if available, otherwise use OAuth tokens // Use manually provided bearer token if available, otherwise use OAuth tokens
const token = bearerToken || (await authProvider.tokens())?.access_token; const token =
bearerToken || (await serverAuthProvider.tokens())?.access_token;
if (token) { if (token) {
const authHeaderName = headerName || "Authorization"; const authHeaderName = headerName || "Authorization";
headers[authHeaderName] = `Bearer ${token}`; headers[authHeaderName] = `Bearer ${token}`;