From df0b526a414c0a937780df96e425ba2480c9e7dc Mon Sep 17 00:00:00 2001 From: Glen Maddern Date: Wed, 16 Apr 2025 10:14:08 +1000 Subject: [PATCH] fix: store auth tokens with server-specific keys Changes client information and access tokens to use server-specific keys in sessionStorage. This fixes issues where changing the server URL would try to use tokens from a different server. --- client/src/components/OAuthCallback.tsx | 7 +++- client/src/lib/auth.ts | 40 ++++++++++++++----- client/src/lib/constants.ts | 9 +++++ .../hooks/__tests__/useConnection.test.tsx | 4 +- client/src/lib/hooks/useConnection.ts | 15 ++++--- 5 files changed, 55 insertions(+), 20 deletions(-) diff --git a/client/src/components/OAuthCallback.tsx b/client/src/components/OAuthCallback.tsx index cba38c3..a1cff48 100644 --- a/client/src/components/OAuthCallback.tsx +++ b/client/src/components/OAuthCallback.tsx @@ -1,5 +1,5 @@ import { useEffect, useRef } from "react"; -import { authProvider } from "../lib/auth"; +import { InspectorOAuthClientProvider } from "../lib/auth"; import { SESSION_KEYS } from "../lib/constants"; import { auth } from "@modelcontextprotocol/sdk/client/auth.js"; @@ -25,7 +25,10 @@ const OAuthCallback = () => { } try { - const result = await auth(authProvider, { + // Create an auth provider with the current server URL + const serverAuthProvider = new InspectorOAuthClientProvider(serverUrl); + + const result = await auth(serverAuthProvider, { serverUrl, authorizationCode: code, }); diff --git a/client/src/lib/auth.ts b/client/src/lib/auth.ts index ba610bc..1c0c6b6 100644 --- a/client/src/lib/auth.ts +++ b/client/src/lib/auth.ts @@ -5,9 +5,14 @@ import { OAuthTokens, OAuthTokensSchema, } from "@modelcontextprotocol/sdk/shared/auth.js"; -import { SESSION_KEYS } from "./constants"; +import { SESSION_KEYS, getServerSpecificKey } from "./constants"; + +export class InspectorOAuthClientProvider implements OAuthClientProvider { + constructor(private serverUrl: string) { + // Save the server URL to session storage + sessionStorage.setItem(SESSION_KEYS.SERVER_URL, serverUrl); + } -class InspectorOAuthClientProvider implements OAuthClientProvider { get redirectUrl() { return window.location.origin + "/oauth/callback"; } @@ -24,7 +29,11 @@ class InspectorOAuthClientProvider implements OAuthClientProvider { } async clientInformation() { - const value = sessionStorage.getItem(SESSION_KEYS.CLIENT_INFORMATION); + const key = getServerSpecificKey( + SESSION_KEYS.CLIENT_INFORMATION, + this.serverUrl, + ); + const value = sessionStorage.getItem(key); if (!value) { return undefined; } @@ -33,14 +42,16 @@ class InspectorOAuthClientProvider implements OAuthClientProvider { } saveClientInformation(clientInformation: OAuthClientInformation) { - sessionStorage.setItem( + const key = getServerSpecificKey( SESSION_KEYS.CLIENT_INFORMATION, - JSON.stringify(clientInformation), + this.serverUrl, ); + sessionStorage.setItem(key, JSON.stringify(clientInformation)); } async tokens() { - const tokens = sessionStorage.getItem(SESSION_KEYS.TOKENS); + const key = getServerSpecificKey(SESSION_KEYS.TOKENS, this.serverUrl); + const tokens = sessionStorage.getItem(key); if (!tokens) { return undefined; } @@ -49,7 +60,8 @@ class InspectorOAuthClientProvider implements OAuthClientProvider { } saveTokens(tokens: OAuthTokens) { - sessionStorage.setItem(SESSION_KEYS.TOKENS, JSON.stringify(tokens)); + const key = getServerSpecificKey(SESSION_KEYS.TOKENS, this.serverUrl); + sessionStorage.setItem(key, JSON.stringify(tokens)); } redirectToAuthorization(authorizationUrl: URL) { @@ -57,11 +69,19 @@ class InspectorOAuthClientProvider implements OAuthClientProvider { } saveCodeVerifier(codeVerifier: string) { - sessionStorage.setItem(SESSION_KEYS.CODE_VERIFIER, codeVerifier); + const key = getServerSpecificKey( + SESSION_KEYS.CODE_VERIFIER, + this.serverUrl, + ); + sessionStorage.setItem(key, codeVerifier); } codeVerifier() { - const verifier = sessionStorage.getItem(SESSION_KEYS.CODE_VERIFIER); + const key = getServerSpecificKey( + SESSION_KEYS.CODE_VERIFIER, + this.serverUrl, + ); + const verifier = sessionStorage.getItem(key); if (!verifier) { throw new Error("No code verifier saved for session"); } @@ -69,5 +89,3 @@ class InspectorOAuthClientProvider implements OAuthClientProvider { return verifier; } } - -export const authProvider = new InspectorOAuthClientProvider(); diff --git a/client/src/lib/constants.ts b/client/src/lib/constants.ts index e7fa14c..a03239a 100644 --- a/client/src/lib/constants.ts +++ b/client/src/lib/constants.ts @@ -8,6 +8,15 @@ export const SESSION_KEYS = { CLIENT_INFORMATION: "mcp_client_information", } as const; +// Generate server-specific session storage keys +export const getServerSpecificKey = ( + baseKey: string, + serverUrl?: string, +): string => { + if (!serverUrl) return baseKey; + return `[${serverUrl}] ${baseKey}`; +}; + export type ConnectionStatus = | "disconnected" | "connected" diff --git a/client/src/lib/hooks/__tests__/useConnection.test.tsx b/client/src/lib/hooks/__tests__/useConnection.test.tsx index c1d67d7..e191d6c 100644 --- a/client/src/lib/hooks/__tests__/useConnection.test.tsx +++ b/client/src/lib/hooks/__tests__/useConnection.test.tsx @@ -45,9 +45,9 @@ jest.mock("@/hooks/use-toast", () => ({ // Mock the auth provider jest.mock("../../auth", () => ({ - authProvider: { + InspectorOAuthClientProvider: jest.fn().mockImplementation(() => ({ tokens: jest.fn().mockResolvedValue({ access_token: "mock-token" }), - }, + })), })); describe("useConnection", () => { diff --git a/client/src/lib/hooks/useConnection.ts b/client/src/lib/hooks/useConnection.ts index 485e8e3..d1e958f 100644 --- a/client/src/lib/hooks/useConnection.ts +++ b/client/src/lib/hooks/useConnection.ts @@ -28,10 +28,10 @@ import { RequestOptions } from "@modelcontextprotocol/sdk/shared/protocol.js"; import { useState } from "react"; import { useToast } from "@/hooks/use-toast"; import { z } from "zod"; -import { ConnectionStatus, SESSION_KEYS } from "../constants"; +import { ConnectionStatus } from "../constants"; import { Notification, StdErrNotificationSchema } from "../notificationTypes"; import { auth } from "@modelcontextprotocol/sdk/client/auth.js"; -import { authProvider } from "../auth"; +import { InspectorOAuthClientProvider } from "../auth"; import packageJson from "../../../package.json"; import { getMCPProxyAddress, @@ -246,9 +246,10 @@ export function useConnection({ const handleAuthError = async (error: unknown) => { if (error instanceof SseError && error.code === 401) { - sessionStorage.setItem(SESSION_KEYS.SERVER_URL, sseUrl); + // Create a new auth provider with the current server URL + const serverAuthProvider = new InspectorOAuthClientProvider(sseUrl); - const result = await auth(authProvider, { serverUrl: sseUrl }); + const result = await auth(serverAuthProvider, { serverUrl: sseUrl }); return result === "AUTHORIZED"; } @@ -292,8 +293,12 @@ export function useConnection({ // proxying through the inspector server first. const headers: HeadersInit = {}; + // Create an auth provider with the current server URL + const serverAuthProvider = new InspectorOAuthClientProvider(sseUrl); + // Use manually provided bearer token if available, otherwise use OAuth tokens - const token = bearerToken || (await authProvider.tokens())?.access_token; + const token = + bearerToken || (await serverAuthProvider.tokens())?.access_token; if (token) { const authHeaderName = headerName || "Authorization"; headers[authHeaderName] = `Bearer ${token}`;