Merge branch 'main' into main
This commit is contained in:
@@ -5,9 +5,14 @@ import {
|
||||
OAuthTokens,
|
||||
OAuthTokensSchema,
|
||||
} from "@modelcontextprotocol/sdk/shared/auth.js";
|
||||
import { SESSION_KEYS } from "./constants";
|
||||
import { SESSION_KEYS, getServerSpecificKey } from "./constants";
|
||||
|
||||
export class InspectorOAuthClientProvider implements OAuthClientProvider {
|
||||
constructor(private serverUrl: string) {
|
||||
// Save the server URL to session storage
|
||||
sessionStorage.setItem(SESSION_KEYS.SERVER_URL, serverUrl);
|
||||
}
|
||||
|
||||
class InspectorOAuthClientProvider implements OAuthClientProvider {
|
||||
get redirectUrl() {
|
||||
return window.location.origin + "/oauth/callback";
|
||||
}
|
||||
@@ -24,7 +29,11 @@ class InspectorOAuthClientProvider implements OAuthClientProvider {
|
||||
}
|
||||
|
||||
async clientInformation() {
|
||||
const value = sessionStorage.getItem(SESSION_KEYS.CLIENT_INFORMATION);
|
||||
const key = getServerSpecificKey(
|
||||
SESSION_KEYS.CLIENT_INFORMATION,
|
||||
this.serverUrl,
|
||||
);
|
||||
const value = sessionStorage.getItem(key);
|
||||
if (!value) {
|
||||
return undefined;
|
||||
}
|
||||
@@ -33,14 +42,16 @@ class InspectorOAuthClientProvider implements OAuthClientProvider {
|
||||
}
|
||||
|
||||
saveClientInformation(clientInformation: OAuthClientInformation) {
|
||||
sessionStorage.setItem(
|
||||
const key = getServerSpecificKey(
|
||||
SESSION_KEYS.CLIENT_INFORMATION,
|
||||
JSON.stringify(clientInformation),
|
||||
this.serverUrl,
|
||||
);
|
||||
sessionStorage.setItem(key, JSON.stringify(clientInformation));
|
||||
}
|
||||
|
||||
async tokens() {
|
||||
const tokens = sessionStorage.getItem(SESSION_KEYS.TOKENS);
|
||||
const key = getServerSpecificKey(SESSION_KEYS.TOKENS, this.serverUrl);
|
||||
const tokens = sessionStorage.getItem(key);
|
||||
if (!tokens) {
|
||||
return undefined;
|
||||
}
|
||||
@@ -49,7 +60,8 @@ class InspectorOAuthClientProvider implements OAuthClientProvider {
|
||||
}
|
||||
|
||||
saveTokens(tokens: OAuthTokens) {
|
||||
sessionStorage.setItem(SESSION_KEYS.TOKENS, JSON.stringify(tokens));
|
||||
const key = getServerSpecificKey(SESSION_KEYS.TOKENS, this.serverUrl);
|
||||
sessionStorage.setItem(key, JSON.stringify(tokens));
|
||||
}
|
||||
|
||||
redirectToAuthorization(authorizationUrl: URL) {
|
||||
@@ -57,17 +69,35 @@ class InspectorOAuthClientProvider implements OAuthClientProvider {
|
||||
}
|
||||
|
||||
saveCodeVerifier(codeVerifier: string) {
|
||||
sessionStorage.setItem(SESSION_KEYS.CODE_VERIFIER, codeVerifier);
|
||||
const key = getServerSpecificKey(
|
||||
SESSION_KEYS.CODE_VERIFIER,
|
||||
this.serverUrl,
|
||||
);
|
||||
sessionStorage.setItem(key, codeVerifier);
|
||||
}
|
||||
|
||||
codeVerifier() {
|
||||
const verifier = sessionStorage.getItem(SESSION_KEYS.CODE_VERIFIER);
|
||||
const key = getServerSpecificKey(
|
||||
SESSION_KEYS.CODE_VERIFIER,
|
||||
this.serverUrl,
|
||||
);
|
||||
const verifier = sessionStorage.getItem(key);
|
||||
if (!verifier) {
|
||||
throw new Error("No code verifier saved for session");
|
||||
}
|
||||
|
||||
return verifier;
|
||||
}
|
||||
}
|
||||
|
||||
export const authProvider = new InspectorOAuthClientProvider();
|
||||
clear() {
|
||||
sessionStorage.removeItem(
|
||||
getServerSpecificKey(SESSION_KEYS.CLIENT_INFORMATION, this.serverUrl),
|
||||
);
|
||||
sessionStorage.removeItem(
|
||||
getServerSpecificKey(SESSION_KEYS.TOKENS, this.serverUrl),
|
||||
);
|
||||
sessionStorage.removeItem(
|
||||
getServerSpecificKey(SESSION_KEYS.CODE_VERIFIER, this.serverUrl),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -8,6 +8,15 @@ export const SESSION_KEYS = {
|
||||
CLIENT_INFORMATION: "mcp_client_information",
|
||||
} as const;
|
||||
|
||||
// Generate server-specific session storage keys
|
||||
export const getServerSpecificKey = (
|
||||
baseKey: string,
|
||||
serverUrl?: string,
|
||||
): string => {
|
||||
if (!serverUrl) return baseKey;
|
||||
return `[${serverUrl}] ${baseKey}`;
|
||||
};
|
||||
|
||||
export type ConnectionStatus =
|
||||
| "disconnected"
|
||||
| "connected"
|
||||
|
||||
@@ -17,6 +17,8 @@ const mockClient = {
|
||||
connect: jest.fn().mockResolvedValue(undefined),
|
||||
close: jest.fn(),
|
||||
getServerCapabilities: jest.fn(),
|
||||
getServerVersion: jest.fn(),
|
||||
getInstructions: jest.fn(),
|
||||
setNotificationHandler: jest.fn(),
|
||||
setRequestHandler: jest.fn(),
|
||||
};
|
||||
@@ -43,9 +45,9 @@ jest.mock("@/hooks/use-toast", () => ({
|
||||
|
||||
// Mock the auth provider
|
||||
jest.mock("../../auth", () => ({
|
||||
authProvider: {
|
||||
InspectorOAuthClientProvider: jest.fn().mockImplementation(() => ({
|
||||
tokens: jest.fn().mockResolvedValue({ access_token: "mock-token" }),
|
||||
},
|
||||
})),
|
||||
}));
|
||||
|
||||
describe("useConnection", () => {
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import { useState, useCallback, useEffect, useRef } from "react";
|
||||
import { useState, useCallback, useEffect, useRef, useMemo } from "react";
|
||||
import {
|
||||
ResourceReference,
|
||||
PromptReference,
|
||||
@@ -15,9 +15,11 @@ function debounce<T extends (...args: any[]) => PromiseLike<void>>(
|
||||
wait: number,
|
||||
): (...args: Parameters<T>) => void {
|
||||
let timeout: ReturnType<typeof setTimeout>;
|
||||
return function (...args: Parameters<T>) {
|
||||
return (...args: Parameters<T>) => {
|
||||
clearTimeout(timeout);
|
||||
timeout = setTimeout(() => func(...args), wait);
|
||||
timeout = setTimeout(() => {
|
||||
void func(...args);
|
||||
}, wait);
|
||||
};
|
||||
}
|
||||
|
||||
@@ -58,8 +60,8 @@ export function useCompletionState(
|
||||
});
|
||||
}, [cleanup]);
|
||||
|
||||
const requestCompletions = useCallback(
|
||||
debounce(
|
||||
const requestCompletions = useMemo(() => {
|
||||
return debounce(
|
||||
async (
|
||||
ref: ResourceReference | PromptReference,
|
||||
argName: string,
|
||||
@@ -94,7 +96,7 @@ export function useCompletionState(
|
||||
loading: { ...prev.loading, [argName]: false },
|
||||
}));
|
||||
}
|
||||
} catch (err) {
|
||||
} catch {
|
||||
if (!abortController.signal.aborted) {
|
||||
setState((prev) => ({
|
||||
...prev,
|
||||
@@ -108,9 +110,8 @@ export function useCompletionState(
|
||||
}
|
||||
},
|
||||
debounceMs,
|
||||
),
|
||||
[handleCompletion, completionsSupported, cleanup, debounceMs],
|
||||
);
|
||||
);
|
||||
}, [handleCompletion, completionsSupported, cleanup, debounceMs]);
|
||||
|
||||
// Clear completions when support status changes
|
||||
useEffect(() => {
|
||||
|
||||
@@ -28,10 +28,10 @@ import { RequestOptions } from "@modelcontextprotocol/sdk/shared/protocol.js";
|
||||
import { useState } from "react";
|
||||
import { useToast } from "@/hooks/use-toast";
|
||||
import { z } from "zod";
|
||||
import { ConnectionStatus, SESSION_KEYS } from "../constants";
|
||||
import { ConnectionStatus } from "../constants";
|
||||
import { Notification, StdErrNotificationSchema } from "../notificationTypes";
|
||||
import { auth } from "@modelcontextprotocol/sdk/client/auth.js";
|
||||
import { authProvider } from "../auth";
|
||||
import { InspectorOAuthClientProvider } from "../auth";
|
||||
import packageJson from "../../../package.json";
|
||||
import {
|
||||
getMCPProxyAddress,
|
||||
@@ -48,6 +48,7 @@ interface UseConnectionOptions {
|
||||
sseUrl: string;
|
||||
env: Record<string, string>;
|
||||
bearerToken?: string;
|
||||
headerName?: string;
|
||||
config: InspectorConfig;
|
||||
onNotification?: (notification: Notification) => void;
|
||||
onStdErrNotification?: (notification: Notification) => void;
|
||||
@@ -64,6 +65,7 @@ export function useConnection({
|
||||
sseUrl,
|
||||
env,
|
||||
bearerToken,
|
||||
headerName,
|
||||
config,
|
||||
onNotification,
|
||||
onStdErrNotification,
|
||||
@@ -244,9 +246,10 @@ export function useConnection({
|
||||
|
||||
const handleAuthError = async (error: unknown) => {
|
||||
if (error instanceof SseError && error.code === 401) {
|
||||
sessionStorage.setItem(SESSION_KEYS.SERVER_URL, sseUrl);
|
||||
// Create a new auth provider with the current server URL
|
||||
const serverAuthProvider = new InspectorOAuthClientProvider(sseUrl);
|
||||
|
||||
const result = await auth(authProvider, { serverUrl: sseUrl });
|
||||
const result = await auth(serverAuthProvider, { serverUrl: sseUrl });
|
||||
return result === "AUTHORIZED";
|
||||
}
|
||||
|
||||
@@ -290,10 +293,15 @@ export function useConnection({
|
||||
// proxying through the inspector server first.
|
||||
const headers: HeadersInit = {};
|
||||
|
||||
// Create an auth provider with the current server URL
|
||||
const serverAuthProvider = new InspectorOAuthClientProvider(sseUrl);
|
||||
|
||||
// Use manually provided bearer token if available, otherwise use OAuth tokens
|
||||
const token = bearerToken || (await authProvider.tokens())?.access_token;
|
||||
const token =
|
||||
bearerToken || (await serverAuthProvider.tokens())?.access_token;
|
||||
if (token) {
|
||||
headers["Authorization"] = `Bearer ${token}`;
|
||||
const authHeaderName = headerName || "Authorization";
|
||||
headers[authHeaderName] = `Bearer ${token}`;
|
||||
}
|
||||
|
||||
const clientTransport = new SSEClientTransport(mcpProxyServerUrl, {
|
||||
@@ -332,8 +340,19 @@ export function useConnection({
|
||||
);
|
||||
}
|
||||
|
||||
let capabilities;
|
||||
try {
|
||||
await client.connect(clientTransport);
|
||||
|
||||
capabilities = client.getServerCapabilities();
|
||||
const initializeRequest = {
|
||||
method: "initialize",
|
||||
};
|
||||
pushHistory(initializeRequest, {
|
||||
capabilities,
|
||||
serverInfo: client.getServerVersion(),
|
||||
instructions: client.getInstructions(),
|
||||
});
|
||||
} catch (error) {
|
||||
console.error(
|
||||
`Failed to connect to MCP Server via the MCP Inspector Proxy: ${mcpProxyServerUrl}:`,
|
||||
@@ -350,8 +369,6 @@ export function useConnection({
|
||||
}
|
||||
throw error;
|
||||
}
|
||||
|
||||
const capabilities = client.getServerCapabilities();
|
||||
setServerCapabilities(capabilities ?? null);
|
||||
setCompletionsSupported(true); // Reset completions support on new connection
|
||||
|
||||
@@ -379,6 +396,8 @@ export function useConnection({
|
||||
|
||||
const disconnect = async () => {
|
||||
await mcpClient?.close();
|
||||
const authProvider = new InspectorOAuthClientProvider(sseUrl);
|
||||
authProvider.clear();
|
||||
setMcpClient(null);
|
||||
setConnectionStatus("disconnected");
|
||||
setCompletionsSupported(false);
|
||||
|
||||
@@ -43,7 +43,10 @@ const useTheme = (): [Theme, (mode: Theme) => void] => {
|
||||
document.documentElement.classList.toggle("dark", newTheme === "dark");
|
||||
}
|
||||
}, []);
|
||||
return useMemo(() => [theme, setThemeWithSideEffect], [theme]);
|
||||
return useMemo(
|
||||
() => [theme, setThemeWithSideEffect],
|
||||
[theme, setThemeWithSideEffect],
|
||||
);
|
||||
};
|
||||
|
||||
export default useTheme;
|
||||
|
||||
Reference in New Issue
Block a user