From a8ffc704f05cd1f8e0050cebd6f83a337fbddffc Mon Sep 17 00:00:00 2001 From: Pulkit Sharma Date: Fri, 4 Apr 2025 01:44:30 +0530 Subject: [PATCH] add support for progress flow --- client/src/App.tsx | 36 ++-- client/src/lib/configurationTypes.ts | 16 ++ client/src/lib/constants.ts | 8 + .../hooks/__tests__/useConnection.test.tsx | 165 ++++++++++++++++++ client/src/lib/hooks/useConnection.ts | 40 +++-- client/src/utils/configUtils.ts | 12 ++ 6 files changed, 238 insertions(+), 39 deletions(-) create mode 100644 client/src/lib/hooks/__tests__/useConnection.test.tsx diff --git a/client/src/App.tsx b/client/src/App.tsx index 7564544..61dcad7 100644 --- a/client/src/App.tsx +++ b/client/src/App.tsx @@ -45,10 +45,7 @@ import Sidebar from "./components/Sidebar"; import ToolsTab from "./components/ToolsTab"; import { DEFAULT_INSPECTOR_CONFIG } from "./lib/constants"; import { InspectorConfig } from "./lib/configurationTypes"; -import { - getMCPProxyAddress, - getMCPServerRequestTimeout, -} from "./utils/configUtils"; +import { getMCPProxyAddress } from "./utils/configUtils"; import { useToast } from "@/hooks/use-toast"; const params = new URLSearchParams(window.location.search); @@ -148,7 +145,7 @@ const App = () => { serverCapabilities, mcpClient, requestHistory, - makeRequest: makeConnectionRequest, + makeRequest, sendNotification, handleCompletion, completionsSupported, @@ -161,8 +158,7 @@ const App = () => { sseUrl, env, bearerToken, - proxyServerUrl: getMCPProxyAddress(config), - requestTimeout: getMCPServerRequestTimeout(config), + config, onNotification: (notification) => { setNotifications((prev) => [...prev, notification as ServerNotification]); }, @@ -279,13 +275,13 @@ const App = () => { setErrors((prev) => ({ ...prev, [tabKey]: null })); }; - const makeRequest = async ( + const makeConnectionRequest = async ( request: ClientRequest, schema: T, tabKey?: keyof typeof errors, ) => { try { - const response = await makeConnectionRequest(request, schema); + const response = await makeRequest(request, schema); if (tabKey !== undefined) { clearError(tabKey); } @@ -303,7 +299,7 @@ const App = () => { }; const listResources = async () => { - const response = await makeRequest( + const response = await makeConnectionRequest( { method: "resources/list" as const, params: nextResourceCursor ? { cursor: nextResourceCursor } : {}, @@ -316,7 +312,7 @@ const App = () => { }; const listResourceTemplates = async () => { - const response = await makeRequest( + const response = await makeConnectionRequest( { method: "resources/templates/list" as const, params: nextResourceTemplateCursor @@ -333,7 +329,7 @@ const App = () => { }; const readResource = async (uri: string) => { - const response = await makeRequest( + const response = await makeConnectionRequest( { method: "resources/read" as const, params: { uri }, @@ -346,7 +342,7 @@ const App = () => { const subscribeToResource = async (uri: string) => { if (!resourceSubscriptions.has(uri)) { - await makeRequest( + await makeConnectionRequest( { method: "resources/subscribe" as const, params: { uri }, @@ -362,7 +358,7 @@ const App = () => { const unsubscribeFromResource = async (uri: string) => { if (resourceSubscriptions.has(uri)) { - await makeRequest( + await makeConnectionRequest( { method: "resources/unsubscribe" as const, params: { uri }, @@ -377,7 +373,7 @@ const App = () => { }; const listPrompts = async () => { - const response = await makeRequest( + const response = await makeConnectionRequest( { method: "prompts/list" as const, params: nextPromptCursor ? { cursor: nextPromptCursor } : {}, @@ -390,7 +386,7 @@ const App = () => { }; const getPrompt = async (name: string, args: Record = {}) => { - const response = await makeRequest( + const response = await makeConnectionRequest( { method: "prompts/get" as const, params: { name, arguments: args }, @@ -402,7 +398,7 @@ const App = () => { }; const listTools = async () => { - const response = await makeRequest( + const response = await makeConnectionRequest( { method: "tools/list" as const, params: nextToolCursor ? { cursor: nextToolCursor } : {}, @@ -415,7 +411,7 @@ const App = () => { }; const callTool = async (name: string, params: Record) => { - const response = await makeRequest( + const response = await makeConnectionRequest( { method: "tools/call" as const, params: { @@ -437,7 +433,7 @@ const App = () => { }; const sendLogLevelRequest = async (level: LoggingLevel) => { - await makeRequest( + await makeConnectionRequest( { method: "logging/setLevel" as const, params: { level }, @@ -654,7 +650,7 @@ const App = () => { { - void makeRequest( + void makeConnectionRequest( { method: "ping" as const, }, diff --git a/client/src/lib/configurationTypes.ts b/client/src/lib/configurationTypes.ts index df9eb29..d0c1263 100644 --- a/client/src/lib/configurationTypes.ts +++ b/client/src/lib/configurationTypes.ts @@ -15,5 +15,21 @@ export type InspectorConfig = { * Maximum time in milliseconds to wait for a response from the MCP server before timing out. */ MCP_SERVER_REQUEST_TIMEOUT: ConfigItem; + + /** + * Whether to reset the timeout on progress notifications. Useful for long-running operations that send periodic progress updates. + * Refer: https://spec.modelcontextprotocol.io/specification/2025-03-26/basic/utilities/progress/#progress-flow + */ + MCP_SERVER_REQUEST_TIMEOUT_RESET_ON_PROGRESS: ConfigItem; + + /** + * Maximum total time in milliseconds to wait for a response from the MCP server before timing out. Used in conjunction with MCP_SERVER_REQUEST_TIMEOUT_RESET_ON_PROGRESS. + * Refer: https://spec.modelcontextprotocol.io/specification/2025-03-26/basic/utilities/progress/#progress-flow + */ + MCP_SERVER_REQUEST_TIMEOUT_MAX_TOTAL_TIMEOUT: ConfigItem; + + /** + * The full address of the MCP Proxy Server, in case it is running on a non-default address. Example: http://10.1.1.22:5577 + */ MCP_PROXY_FULL_ADDRESS: ConfigItem; }; diff --git a/client/src/lib/constants.ts b/client/src/lib/constants.ts index c370b34..9caf4bc 100644 --- a/client/src/lib/constants.ts +++ b/client/src/lib/constants.ts @@ -25,6 +25,14 @@ export const DEFAULT_INSPECTOR_CONFIG: InspectorConfig = { description: "Timeout for requests to the MCP server (ms)", value: 10000, }, + MCP_SERVER_REQUEST_TIMEOUT_RESET_ON_PROGRESS: { + description: "Reset timeout on progress notifications", + value: true, + }, + MCP_SERVER_REQUEST_TIMEOUT_MAX_TOTAL_TIMEOUT: { + description: "Maximum total timeout for requests sent to the MCP server (ms)", + value: 60000, + }, MCP_PROXY_FULL_ADDRESS: { description: "Set this if you are running the MCP Inspector Proxy on a non-default address. Example: http://10.1.1.22:5577", diff --git a/client/src/lib/hooks/__tests__/useConnection.test.tsx b/client/src/lib/hooks/__tests__/useConnection.test.tsx new file mode 100644 index 0000000..7a96802 --- /dev/null +++ b/client/src/lib/hooks/__tests__/useConnection.test.tsx @@ -0,0 +1,165 @@ +import { renderHook, act } from "@testing-library/react"; +import { useConnection } from "../useConnection"; +import { z } from "zod"; +import { ClientRequest } from "@modelcontextprotocol/sdk/types.js"; +import { DEFAULT_INSPECTOR_CONFIG } from "../../constants"; + +// Mock fetch +global.fetch = jest.fn().mockResolvedValue({ + json: () => Promise.resolve({ status: "ok" }), +}); + +// Mock the SDK dependencies +const mockRequest = jest.fn().mockResolvedValue({ test: "response" }); +const mockClient = { + request: mockRequest, + notification: jest.fn(), + connect: jest.fn().mockResolvedValue(undefined), + close: jest.fn(), + getServerCapabilities: jest.fn(), + setNotificationHandler: jest.fn(), + setRequestHandler: jest.fn(), +}; + +jest.mock("@modelcontextprotocol/sdk/client/index.js", () => ({ + Client: jest.fn().mockImplementation(() => mockClient), +})); + +jest.mock("@modelcontextprotocol/sdk/client/sse.js", () => ({ + SSEClientTransport: jest.fn(), + SseError: jest.fn(), +})); + +jest.mock("@modelcontextprotocol/sdk/client/auth.js", () => ({ + auth: jest.fn().mockResolvedValue("AUTHORIZED"), +})); + +// Mock the toast hook +jest.mock("@/hooks/use-toast", () => ({ + useToast: () => ({ + toast: jest.fn(), + }), +})); + +// Mock the auth provider +jest.mock("../../auth", () => ({ + authProvider: { + tokens: jest.fn().mockResolvedValue({ access_token: "mock-token" }), + }, +})); + +describe("useConnection", () => { + const defaultProps = { + transportType: "sse" as const, + command: "", + args: "", + sseUrl: "http://localhost:8080", + env: {}, + config: DEFAULT_INSPECTOR_CONFIG, + }; + + describe("Request Configuration", () => { + beforeEach(() => { + jest.clearAllMocks(); + }); + + test("uses the default config values in makeRequest", async () => { + const { result } = renderHook(() => useConnection(defaultProps)); + + // Connect the client + await act(async () => { + await result.current.connect(); + }); + + // Wait for state update + await act(async () => { + await new Promise((resolve) => setTimeout(resolve, 0)); + }); + + const mockRequest: ClientRequest = { + method: "ping", + params: {}, + }; + + const mockSchema = z.object({ + test: z.string(), + }); + + await act(async () => { + await result.current.makeRequest(mockRequest, mockSchema); + }); + + expect(mockClient.request).toHaveBeenCalledWith( + mockRequest, + mockSchema, + expect.objectContaining({ + timeout: DEFAULT_INSPECTOR_CONFIG.MCP_SERVER_REQUEST_TIMEOUT.value, + maxTotalTimeout: + DEFAULT_INSPECTOR_CONFIG + .MCP_SERVER_REQUEST_TIMEOUT_MAX_TOTAL_TIMEOUT.value, + resetTimeoutOnProgress: + DEFAULT_INSPECTOR_CONFIG + .MCP_SERVER_REQUEST_TIMEOUT_RESET_ON_PROGRESS.value, + }), + ); + }); + + test("overrides the default config values when passed in options in makeRequest", async () => { + const { result } = renderHook(() => useConnection(defaultProps)); + + // Connect the client + await act(async () => { + await result.current.connect(); + }); + + // Wait for state update + await act(async () => { + await new Promise((resolve) => setTimeout(resolve, 0)); + }); + + const mockRequest: ClientRequest = { + method: "ping", + params: {}, + }; + + const mockSchema = z.object({ + test: z.string(), + }); + + await act(async () => { + await result.current.makeRequest(mockRequest, mockSchema, { + timeout: 1000, + maxTotalTimeout: 2000, + resetTimeoutOnProgress: false, + }); + }); + + expect(mockClient.request).toHaveBeenCalledWith( + mockRequest, + mockSchema, + expect.objectContaining({ + timeout: 1000, + maxTotalTimeout: 2000, + resetTimeoutOnProgress: false, + }), + ); + }); + }); + + test("throws error when mcpClient is not connected", async () => { + const { result } = renderHook(() => useConnection(defaultProps)); + + const mockRequest: ClientRequest = { + method: "ping", + params: {}, + }; + + const mockSchema = z.object({ + test: z.string(), + }); + + await expect( + result.current.makeRequest(mockRequest, mockSchema), + ).rejects.toThrow("MCP client not connected"); + }); +}); diff --git a/client/src/lib/hooks/useConnection.ts b/client/src/lib/hooks/useConnection.ts index bff01ce..d67c623 100644 --- a/client/src/lib/hooks/useConnection.ts +++ b/client/src/lib/hooks/useConnection.ts @@ -24,6 +24,7 @@ import { ToolListChangedNotificationSchema, PromptListChangedNotificationSchema, } from "@modelcontextprotocol/sdk/types.js"; +import { RequestOptions } from "@modelcontextprotocol/sdk/shared/protocol.js"; import { useState } from "react"; import { useToast } from "@/hooks/use-toast"; import { z } from "zod"; @@ -32,6 +33,13 @@ import { Notification, StdErrNotificationSchema } from "../notificationTypes"; import { auth } from "@modelcontextprotocol/sdk/client/auth.js"; import { authProvider } from "../auth"; import packageJson from "../../../package.json"; +import { + getMCPProxyAddress, + getMCPServerRequestMaxTotalTimeout, + resetRequestTimeoutOnProgress, +} from "@/utils/configUtils"; +import { getMCPServerRequestTimeout } from "@/utils/configUtils"; +import { InspectorConfig } from "../configurationTypes"; interface UseConnectionOptions { transportType: "stdio" | "sse"; @@ -39,9 +47,8 @@ interface UseConnectionOptions { args: string; sseUrl: string; env: Record; - proxyServerUrl: string; bearerToken?: string; - requestTimeout?: number; + config: InspectorConfig; onNotification?: (notification: Notification) => void; onStdErrNotification?: (notification: Notification) => void; // eslint-disable-next-line @typescript-eslint/no-explicit-any @@ -50,21 +57,14 @@ interface UseConnectionOptions { getRoots?: () => any[]; } -interface RequestOptions { - signal?: AbortSignal; - timeout?: number; - suppressToast?: boolean; -} - export function useConnection({ transportType, command, args, sseUrl, env, - proxyServerUrl, bearerToken, - requestTimeout, + config, onNotification, onStdErrNotification, onPendingRequest, @@ -94,7 +94,7 @@ export function useConnection({ const makeRequest = async ( request: ClientRequest, schema: T, - options?: RequestOptions, + options?: RequestOptions & { suppressToast?: boolean }, ): Promise> => { if (!mcpClient) { throw new Error("MCP client not connected"); @@ -102,23 +102,25 @@ export function useConnection({ try { const abortController = new AbortController(); - const timeoutId = setTimeout(() => { - abortController.abort("Request timed out"); - }, options?.timeout ?? requestTimeout); - let response; try { response = await mcpClient.request(request, schema, { signal: options?.signal ?? abortController.signal, + resetTimeoutOnProgress: + options?.resetTimeoutOnProgress ?? + resetRequestTimeoutOnProgress(config), + timeout: options?.timeout ?? getMCPServerRequestTimeout(config), + maxTotalTimeout: + options?.maxTotalTimeout ?? + getMCPServerRequestMaxTotalTimeout(config), }); + pushHistory(request, response); } catch (error) { const errorMessage = error instanceof Error ? error.message : String(error); pushHistory(request, { error: errorMessage }); throw error; - } finally { - clearTimeout(timeoutId); } return response; @@ -211,7 +213,7 @@ export function useConnection({ const checkProxyHealth = async () => { try { - const proxyHealthUrl = new URL(`${proxyServerUrl}/health`); + const proxyHealthUrl = new URL(`${getMCPProxyAddress(config)}/health`); const proxyHealthResponse = await fetch(proxyHealthUrl); const proxyHealth = await proxyHealthResponse.json(); if (proxyHealth?.status !== "ok") { @@ -256,7 +258,7 @@ export function useConnection({ setConnectionStatus("error-connecting-to-proxy"); return; } - const mcpProxyServerUrl = new URL(`${proxyServerUrl}/sse`); + const mcpProxyServerUrl = new URL(`${getMCPProxyAddress(config)}/sse`); mcpProxyServerUrl.searchParams.append("transportType", transportType); if (transportType === "stdio") { mcpProxyServerUrl.searchParams.append("command", command); diff --git a/client/src/utils/configUtils.ts b/client/src/utils/configUtils.ts index a6f2dd2..3295f7d 100644 --- a/client/src/utils/configUtils.ts +++ b/client/src/utils/configUtils.ts @@ -12,3 +12,15 @@ export const getMCPProxyAddress = (config: InspectorConfig): string => { export const getMCPServerRequestTimeout = (config: InspectorConfig): number => { return config.MCP_SERVER_REQUEST_TIMEOUT.value as number; }; + +export const resetRequestTimeoutOnProgress = ( + config: InspectorConfig, +): boolean => { + return config.MCP_SERVER_REQUEST_TIMEOUT_RESET_ON_PROGRESS.value as boolean; +}; + +export const getMCPServerRequestMaxTotalTimeout = ( + config: InspectorConfig, +): number => { + return config.MCP_SERVER_REQUEST_TIMEOUT_MAX_TOTAL_TIMEOUT.value as number; +};