fix: store auth tokens with server-specific keys
Changes client information and access tokens to use server-specific keys in sessionStorage. This fixes issues where changing the server URL would try to use tokens from a different server.
This commit is contained in:
@@ -1,5 +1,5 @@
|
|||||||
import { useEffect, useRef } from "react";
|
import { useEffect, useRef } from "react";
|
||||||
import { authProvider } from "../lib/auth";
|
import { InspectorOAuthClientProvider } from "../lib/auth";
|
||||||
import { SESSION_KEYS } from "../lib/constants";
|
import { SESSION_KEYS } from "../lib/constants";
|
||||||
import { auth } from "@modelcontextprotocol/sdk/client/auth.js";
|
import { auth } from "@modelcontextprotocol/sdk/client/auth.js";
|
||||||
|
|
||||||
@@ -25,7 +25,10 @@ const OAuthCallback = () => {
|
|||||||
}
|
}
|
||||||
|
|
||||||
try {
|
try {
|
||||||
const result = await auth(authProvider, {
|
// Create an auth provider with the current server URL
|
||||||
|
const serverAuthProvider = new InspectorOAuthClientProvider(serverUrl);
|
||||||
|
|
||||||
|
const result = await auth(serverAuthProvider, {
|
||||||
serverUrl,
|
serverUrl,
|
||||||
authorizationCode: code,
|
authorizationCode: code,
|
||||||
});
|
});
|
||||||
|
|||||||
@@ -5,9 +5,14 @@ import {
|
|||||||
OAuthTokens,
|
OAuthTokens,
|
||||||
OAuthTokensSchema,
|
OAuthTokensSchema,
|
||||||
} from "@modelcontextprotocol/sdk/shared/auth.js";
|
} from "@modelcontextprotocol/sdk/shared/auth.js";
|
||||||
import { SESSION_KEYS } from "./constants";
|
import { SESSION_KEYS, getServerSpecificKey } from "./constants";
|
||||||
|
|
||||||
|
export class InspectorOAuthClientProvider implements OAuthClientProvider {
|
||||||
|
constructor(private serverUrl: string) {
|
||||||
|
// Save the server URL to session storage
|
||||||
|
sessionStorage.setItem(SESSION_KEYS.SERVER_URL, serverUrl);
|
||||||
|
}
|
||||||
|
|
||||||
class InspectorOAuthClientProvider implements OAuthClientProvider {
|
|
||||||
get redirectUrl() {
|
get redirectUrl() {
|
||||||
return window.location.origin + "/oauth/callback";
|
return window.location.origin + "/oauth/callback";
|
||||||
}
|
}
|
||||||
@@ -24,7 +29,11 @@ class InspectorOAuthClientProvider implements OAuthClientProvider {
|
|||||||
}
|
}
|
||||||
|
|
||||||
async clientInformation() {
|
async clientInformation() {
|
||||||
const value = sessionStorage.getItem(SESSION_KEYS.CLIENT_INFORMATION);
|
const key = getServerSpecificKey(
|
||||||
|
SESSION_KEYS.CLIENT_INFORMATION,
|
||||||
|
this.serverUrl,
|
||||||
|
);
|
||||||
|
const value = sessionStorage.getItem(key);
|
||||||
if (!value) {
|
if (!value) {
|
||||||
return undefined;
|
return undefined;
|
||||||
}
|
}
|
||||||
@@ -33,14 +42,16 @@ class InspectorOAuthClientProvider implements OAuthClientProvider {
|
|||||||
}
|
}
|
||||||
|
|
||||||
saveClientInformation(clientInformation: OAuthClientInformation) {
|
saveClientInformation(clientInformation: OAuthClientInformation) {
|
||||||
sessionStorage.setItem(
|
const key = getServerSpecificKey(
|
||||||
SESSION_KEYS.CLIENT_INFORMATION,
|
SESSION_KEYS.CLIENT_INFORMATION,
|
||||||
JSON.stringify(clientInformation),
|
this.serverUrl,
|
||||||
);
|
);
|
||||||
|
sessionStorage.setItem(key, JSON.stringify(clientInformation));
|
||||||
}
|
}
|
||||||
|
|
||||||
async tokens() {
|
async tokens() {
|
||||||
const tokens = sessionStorage.getItem(SESSION_KEYS.TOKENS);
|
const key = getServerSpecificKey(SESSION_KEYS.TOKENS, this.serverUrl);
|
||||||
|
const tokens = sessionStorage.getItem(key);
|
||||||
if (!tokens) {
|
if (!tokens) {
|
||||||
return undefined;
|
return undefined;
|
||||||
}
|
}
|
||||||
@@ -49,7 +60,8 @@ class InspectorOAuthClientProvider implements OAuthClientProvider {
|
|||||||
}
|
}
|
||||||
|
|
||||||
saveTokens(tokens: OAuthTokens) {
|
saveTokens(tokens: OAuthTokens) {
|
||||||
sessionStorage.setItem(SESSION_KEYS.TOKENS, JSON.stringify(tokens));
|
const key = getServerSpecificKey(SESSION_KEYS.TOKENS, this.serverUrl);
|
||||||
|
sessionStorage.setItem(key, JSON.stringify(tokens));
|
||||||
}
|
}
|
||||||
|
|
||||||
redirectToAuthorization(authorizationUrl: URL) {
|
redirectToAuthorization(authorizationUrl: URL) {
|
||||||
@@ -57,11 +69,19 @@ class InspectorOAuthClientProvider implements OAuthClientProvider {
|
|||||||
}
|
}
|
||||||
|
|
||||||
saveCodeVerifier(codeVerifier: string) {
|
saveCodeVerifier(codeVerifier: string) {
|
||||||
sessionStorage.setItem(SESSION_KEYS.CODE_VERIFIER, codeVerifier);
|
const key = getServerSpecificKey(
|
||||||
|
SESSION_KEYS.CODE_VERIFIER,
|
||||||
|
this.serverUrl,
|
||||||
|
);
|
||||||
|
sessionStorage.setItem(key, codeVerifier);
|
||||||
}
|
}
|
||||||
|
|
||||||
codeVerifier() {
|
codeVerifier() {
|
||||||
const verifier = sessionStorage.getItem(SESSION_KEYS.CODE_VERIFIER);
|
const key = getServerSpecificKey(
|
||||||
|
SESSION_KEYS.CODE_VERIFIER,
|
||||||
|
this.serverUrl,
|
||||||
|
);
|
||||||
|
const verifier = sessionStorage.getItem(key);
|
||||||
if (!verifier) {
|
if (!verifier) {
|
||||||
throw new Error("No code verifier saved for session");
|
throw new Error("No code verifier saved for session");
|
||||||
}
|
}
|
||||||
@@ -69,5 +89,3 @@ class InspectorOAuthClientProvider implements OAuthClientProvider {
|
|||||||
return verifier;
|
return verifier;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
export const authProvider = new InspectorOAuthClientProvider();
|
|
||||||
|
|||||||
@@ -8,6 +8,15 @@ export const SESSION_KEYS = {
|
|||||||
CLIENT_INFORMATION: "mcp_client_information",
|
CLIENT_INFORMATION: "mcp_client_information",
|
||||||
} as const;
|
} as const;
|
||||||
|
|
||||||
|
// Generate server-specific session storage keys
|
||||||
|
export const getServerSpecificKey = (
|
||||||
|
baseKey: string,
|
||||||
|
serverUrl?: string,
|
||||||
|
): string => {
|
||||||
|
if (!serverUrl) return baseKey;
|
||||||
|
return `[${serverUrl}] ${baseKey}`;
|
||||||
|
};
|
||||||
|
|
||||||
export type ConnectionStatus =
|
export type ConnectionStatus =
|
||||||
| "disconnected"
|
| "disconnected"
|
||||||
| "connected"
|
| "connected"
|
||||||
|
|||||||
@@ -45,9 +45,9 @@ jest.mock("@/hooks/use-toast", () => ({
|
|||||||
|
|
||||||
// Mock the auth provider
|
// Mock the auth provider
|
||||||
jest.mock("../../auth", () => ({
|
jest.mock("../../auth", () => ({
|
||||||
authProvider: {
|
InspectorOAuthClientProvider: jest.fn().mockImplementation(() => ({
|
||||||
tokens: jest.fn().mockResolvedValue({ access_token: "mock-token" }),
|
tokens: jest.fn().mockResolvedValue({ access_token: "mock-token" }),
|
||||||
},
|
})),
|
||||||
}));
|
}));
|
||||||
|
|
||||||
describe("useConnection", () => {
|
describe("useConnection", () => {
|
||||||
|
|||||||
@@ -28,10 +28,10 @@ import { RequestOptions } from "@modelcontextprotocol/sdk/shared/protocol.js";
|
|||||||
import { useState } from "react";
|
import { useState } from "react";
|
||||||
import { useToast } from "@/hooks/use-toast";
|
import { useToast } from "@/hooks/use-toast";
|
||||||
import { z } from "zod";
|
import { z } from "zod";
|
||||||
import { ConnectionStatus, SESSION_KEYS } from "../constants";
|
import { ConnectionStatus } from "../constants";
|
||||||
import { Notification, StdErrNotificationSchema } from "../notificationTypes";
|
import { Notification, StdErrNotificationSchema } from "../notificationTypes";
|
||||||
import { auth } from "@modelcontextprotocol/sdk/client/auth.js";
|
import { auth } from "@modelcontextprotocol/sdk/client/auth.js";
|
||||||
import { authProvider } from "../auth";
|
import { InspectorOAuthClientProvider } from "../auth";
|
||||||
import packageJson from "../../../package.json";
|
import packageJson from "../../../package.json";
|
||||||
import {
|
import {
|
||||||
getMCPProxyAddress,
|
getMCPProxyAddress,
|
||||||
@@ -246,9 +246,10 @@ export function useConnection({
|
|||||||
|
|
||||||
const handleAuthError = async (error: unknown) => {
|
const handleAuthError = async (error: unknown) => {
|
||||||
if (error instanceof SseError && error.code === 401) {
|
if (error instanceof SseError && error.code === 401) {
|
||||||
sessionStorage.setItem(SESSION_KEYS.SERVER_URL, sseUrl);
|
// Create a new auth provider with the current server URL
|
||||||
|
const serverAuthProvider = new InspectorOAuthClientProvider(sseUrl);
|
||||||
|
|
||||||
const result = await auth(authProvider, { serverUrl: sseUrl });
|
const result = await auth(serverAuthProvider, { serverUrl: sseUrl });
|
||||||
return result === "AUTHORIZED";
|
return result === "AUTHORIZED";
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -292,8 +293,12 @@ export function useConnection({
|
|||||||
// proxying through the inspector server first.
|
// proxying through the inspector server first.
|
||||||
const headers: HeadersInit = {};
|
const headers: HeadersInit = {};
|
||||||
|
|
||||||
|
// Create an auth provider with the current server URL
|
||||||
|
const serverAuthProvider = new InspectorOAuthClientProvider(sseUrl);
|
||||||
|
|
||||||
// Use manually provided bearer token if available, otherwise use OAuth tokens
|
// Use manually provided bearer token if available, otherwise use OAuth tokens
|
||||||
const token = bearerToken || (await authProvider.tokens())?.access_token;
|
const token =
|
||||||
|
bearerToken || (await serverAuthProvider.tokens())?.access_token;
|
||||||
if (token) {
|
if (token) {
|
||||||
const authHeaderName = headerName || "Authorization";
|
const authHeaderName = headerName || "Authorization";
|
||||||
headers[authHeaderName] = `Bearer ${token}`;
|
headers[authHeaderName] = `Bearer ${token}`;
|
||||||
|
|||||||
Reference in New Issue
Block a user