* In useConnection.ts - In the case for "stdio", let the transportOptions be the same as "sse" because it will use that transport to the proxy regardless
487 lines
14 KiB
TypeScript
487 lines
14 KiB
TypeScript
import { Client } from "@modelcontextprotocol/sdk/client/index.js";
|
|
import {
|
|
SSEClientTransport,
|
|
SseError,
|
|
SSEClientTransportOptions,
|
|
} from "@modelcontextprotocol/sdk/client/sse.js";
|
|
import {
|
|
StreamableHTTPClientTransport,
|
|
StreamableHTTPClientTransportOptions,
|
|
} from "@modelcontextprotocol/sdk/client/streamableHttp.js";
|
|
import {
|
|
ClientNotification,
|
|
ClientRequest,
|
|
CreateMessageRequestSchema,
|
|
ListRootsRequestSchema,
|
|
ResourceUpdatedNotificationSchema,
|
|
LoggingMessageNotificationSchema,
|
|
Request,
|
|
Result,
|
|
ServerCapabilities,
|
|
PromptReference,
|
|
ResourceReference,
|
|
McpError,
|
|
CompleteResultSchema,
|
|
ErrorCode,
|
|
CancelledNotificationSchema,
|
|
ResourceListChangedNotificationSchema,
|
|
ToolListChangedNotificationSchema,
|
|
PromptListChangedNotificationSchema,
|
|
Progress,
|
|
} from "@modelcontextprotocol/sdk/types.js";
|
|
import { RequestOptions } from "@modelcontextprotocol/sdk/shared/protocol.js";
|
|
import { useState } from "react";
|
|
import { useToast } from "@/hooks/use-toast";
|
|
import { z } from "zod";
|
|
import { ConnectionStatus } from "../constants";
|
|
import { Notification, StdErrNotificationSchema } from "../notificationTypes";
|
|
import { auth } from "@modelcontextprotocol/sdk/client/auth.js";
|
|
import { InspectorOAuthClientProvider } from "../auth";
|
|
import packageJson from "../../../package.json";
|
|
import {
|
|
getMCPProxyAddress,
|
|
getMCPServerRequestMaxTotalTimeout,
|
|
resetRequestTimeoutOnProgress,
|
|
} from "@/utils/configUtils";
|
|
import { getMCPServerRequestTimeout } from "@/utils/configUtils";
|
|
import { InspectorConfig } from "../configurationTypes";
|
|
|
|
interface UseConnectionOptions {
|
|
transportType: "stdio" | "sse" | "streamable-http";
|
|
command: string;
|
|
args: string;
|
|
sseUrl: string;
|
|
env: Record<string, string>;
|
|
bearerToken?: string;
|
|
headerName?: string;
|
|
config: InspectorConfig;
|
|
onNotification?: (notification: Notification) => void;
|
|
onStdErrNotification?: (notification: Notification) => void;
|
|
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
|
onPendingRequest?: (request: any, resolve: any, reject: any) => void;
|
|
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
|
getRoots?: () => any[];
|
|
}
|
|
|
|
export function useConnection({
|
|
transportType,
|
|
command,
|
|
args,
|
|
sseUrl,
|
|
env,
|
|
bearerToken,
|
|
headerName,
|
|
config,
|
|
onNotification,
|
|
onStdErrNotification,
|
|
onPendingRequest,
|
|
getRoots,
|
|
}: UseConnectionOptions) {
|
|
const [connectionStatus, setConnectionStatus] =
|
|
useState<ConnectionStatus>("disconnected");
|
|
const { toast } = useToast();
|
|
const [serverCapabilities, setServerCapabilities] =
|
|
useState<ServerCapabilities | null>(null);
|
|
const [mcpClient, setMcpClient] = useState<Client | null>(null);
|
|
const [requestHistory, setRequestHistory] = useState<
|
|
{ request: string; response?: string }[]
|
|
>([]);
|
|
const [completionsSupported, setCompletionsSupported] = useState(true);
|
|
|
|
const pushHistory = (request: object, response?: object) => {
|
|
setRequestHistory((prev) => [
|
|
...prev,
|
|
{
|
|
request: JSON.stringify(request),
|
|
response: response !== undefined ? JSON.stringify(response) : undefined,
|
|
},
|
|
]);
|
|
};
|
|
|
|
const makeRequest = async <T extends z.ZodType>(
|
|
request: ClientRequest,
|
|
schema: T,
|
|
options?: RequestOptions & { suppressToast?: boolean },
|
|
): Promise<z.output<T>> => {
|
|
if (!mcpClient) {
|
|
throw new Error("MCP client not connected");
|
|
}
|
|
try {
|
|
const abortController = new AbortController();
|
|
|
|
// prepare MCP Client request options
|
|
const mcpRequestOptions: RequestOptions = {
|
|
signal: options?.signal ?? abortController.signal,
|
|
resetTimeoutOnProgress:
|
|
options?.resetTimeoutOnProgress ??
|
|
resetRequestTimeoutOnProgress(config),
|
|
timeout: options?.timeout ?? getMCPServerRequestTimeout(config),
|
|
maxTotalTimeout:
|
|
options?.maxTotalTimeout ??
|
|
getMCPServerRequestMaxTotalTimeout(config),
|
|
};
|
|
|
|
// If progress notifications are enabled, add an onprogress hook to the MCP Client request options
|
|
// This is required by SDK to reset the timeout on progress notifications
|
|
if (mcpRequestOptions.resetTimeoutOnProgress) {
|
|
mcpRequestOptions.onprogress = (params: Progress) => {
|
|
// Add progress notification to `Server Notification` window in the UI
|
|
if (onNotification) {
|
|
onNotification({
|
|
method: "notification/progress",
|
|
params,
|
|
});
|
|
}
|
|
};
|
|
}
|
|
|
|
let response;
|
|
try {
|
|
response = await mcpClient.request(request, schema, mcpRequestOptions);
|
|
|
|
pushHistory(request, response);
|
|
} catch (error) {
|
|
const errorMessage =
|
|
error instanceof Error ? error.message : String(error);
|
|
pushHistory(request, { error: errorMessage });
|
|
throw error;
|
|
}
|
|
|
|
return response;
|
|
} catch (e: unknown) {
|
|
if (!options?.suppressToast) {
|
|
const errorString = (e as Error).message ?? String(e);
|
|
toast({
|
|
title: "Error",
|
|
description: errorString,
|
|
variant: "destructive",
|
|
});
|
|
}
|
|
throw e;
|
|
}
|
|
};
|
|
|
|
const handleCompletion = async (
|
|
ref: ResourceReference | PromptReference,
|
|
argName: string,
|
|
value: string,
|
|
signal?: AbortSignal,
|
|
): Promise<string[]> => {
|
|
if (!mcpClient || !completionsSupported) {
|
|
return [];
|
|
}
|
|
|
|
const request: ClientRequest = {
|
|
method: "completion/complete",
|
|
params: {
|
|
argument: {
|
|
name: argName,
|
|
value,
|
|
},
|
|
ref,
|
|
},
|
|
};
|
|
|
|
try {
|
|
const response = await makeRequest(request, CompleteResultSchema, {
|
|
signal,
|
|
suppressToast: true,
|
|
});
|
|
return response?.completion.values || [];
|
|
} catch (e: unknown) {
|
|
// Disable completions silently if the server doesn't support them.
|
|
// See https://github.com/modelcontextprotocol/specification/discussions/122
|
|
if (e instanceof McpError && e.code === ErrorCode.MethodNotFound) {
|
|
setCompletionsSupported(false);
|
|
return [];
|
|
}
|
|
|
|
// Unexpected errors - show toast and rethrow
|
|
toast({
|
|
title: "Error",
|
|
description: e instanceof Error ? e.message : String(e),
|
|
variant: "destructive",
|
|
});
|
|
throw e;
|
|
}
|
|
};
|
|
|
|
const sendNotification = async (notification: ClientNotification) => {
|
|
if (!mcpClient) {
|
|
const error = new Error("MCP client not connected");
|
|
toast({
|
|
title: "Error",
|
|
description: error.message,
|
|
variant: "destructive",
|
|
});
|
|
throw error;
|
|
}
|
|
|
|
try {
|
|
await mcpClient.notification(notification);
|
|
// Log successful notifications
|
|
pushHistory(notification);
|
|
} catch (e: unknown) {
|
|
if (e instanceof McpError) {
|
|
// Log MCP protocol errors
|
|
pushHistory(notification, { error: e.message });
|
|
}
|
|
toast({
|
|
title: "Error",
|
|
description: e instanceof Error ? e.message : String(e),
|
|
variant: "destructive",
|
|
});
|
|
throw e;
|
|
}
|
|
};
|
|
|
|
const checkProxyHealth = async () => {
|
|
try {
|
|
const proxyHealthUrl = new URL(`${getMCPProxyAddress(config)}/health`);
|
|
const proxyHealthResponse = await fetch(proxyHealthUrl);
|
|
const proxyHealth = await proxyHealthResponse.json();
|
|
if (proxyHealth?.status !== "ok") {
|
|
throw new Error("MCP Proxy Server is not healthy");
|
|
}
|
|
} catch (e) {
|
|
console.error("Couldn't connect to MCP Proxy Server", e);
|
|
throw e;
|
|
}
|
|
};
|
|
|
|
const handleAuthError = async (error: unknown) => {
|
|
if (error instanceof SseError && error.code === 401) {
|
|
// Create a new auth provider with the current server URL
|
|
const serverAuthProvider = new InspectorOAuthClientProvider(sseUrl);
|
|
|
|
const result = await auth(serverAuthProvider, { serverUrl: sseUrl });
|
|
return result === "AUTHORIZED";
|
|
}
|
|
|
|
return false;
|
|
};
|
|
|
|
const connect = async (_e?: unknown, retryCount: number = 0) => {
|
|
const client = new Client<Request, Notification, Result>(
|
|
{
|
|
name: "mcp-inspector",
|
|
version: packageJson.version,
|
|
},
|
|
{
|
|
capabilities: {
|
|
sampling: {},
|
|
roots: {
|
|
listChanged: true,
|
|
},
|
|
},
|
|
},
|
|
);
|
|
|
|
try {
|
|
await checkProxyHealth();
|
|
} catch {
|
|
setConnectionStatus("error-connecting-to-proxy");
|
|
return;
|
|
}
|
|
|
|
try {
|
|
// Inject auth manually instead of using SSEClientTransport, because we're
|
|
// proxying through the inspector server first.
|
|
const headers: HeadersInit = {};
|
|
|
|
// Create an auth provider with the current server URL
|
|
const serverAuthProvider = new InspectorOAuthClientProvider(sseUrl);
|
|
|
|
// Use manually provided bearer token if available, otherwise use OAuth tokens
|
|
const token =
|
|
bearerToken || (await serverAuthProvider.tokens())?.access_token;
|
|
if (token) {
|
|
const authHeaderName = headerName || "Authorization";
|
|
headers[authHeaderName] = `Bearer ${token}`;
|
|
}
|
|
|
|
// Create appropriate transport
|
|
let transportOptions:
|
|
| StreamableHTTPClientTransportOptions
|
|
| SSEClientTransportOptions;
|
|
|
|
let mcpProxyServerUrl;
|
|
switch (transportType) {
|
|
case "stdio":
|
|
mcpProxyServerUrl = new URL(`${getMCPProxyAddress(config)}/stdio`);
|
|
mcpProxyServerUrl.searchParams.append("command", command);
|
|
mcpProxyServerUrl.searchParams.append("args", args);
|
|
mcpProxyServerUrl.searchParams.append("env", JSON.stringify(env));
|
|
transportOptions = {
|
|
authProvider: serverAuthProvider,
|
|
eventSourceInit: {
|
|
fetch: (
|
|
url: string | URL | globalThis.Request,
|
|
init: RequestInit | undefined,
|
|
) => fetch(url, { ...init, headers }),
|
|
},
|
|
requestInit: {
|
|
headers,
|
|
},
|
|
};
|
|
break;
|
|
|
|
case "sse":
|
|
mcpProxyServerUrl = new URL(`${getMCPProxyAddress(config)}/sse`);
|
|
mcpProxyServerUrl.searchParams.append("url", sseUrl);
|
|
transportOptions = {
|
|
authProvider: serverAuthProvider,
|
|
eventSourceInit: {
|
|
fetch: (
|
|
url: string | URL | globalThis.Request,
|
|
init: RequestInit | undefined,
|
|
) => fetch(url, { ...init, headers }),
|
|
},
|
|
requestInit: {
|
|
headers,
|
|
},
|
|
};
|
|
break;
|
|
|
|
case "streamable-http":
|
|
mcpProxyServerUrl = new URL(`${getMCPProxyAddress(config)}/mcp`);
|
|
mcpProxyServerUrl.searchParams.append("url", sseUrl);
|
|
transportOptions = {
|
|
authProvider: serverAuthProvider,
|
|
eventSourceInit: {
|
|
fetch: (
|
|
url: string | URL | globalThis.Request,
|
|
init: RequestInit | undefined,
|
|
) => fetch(url, { ...init, headers }),
|
|
},
|
|
requestInit: {
|
|
headers,
|
|
},
|
|
// TODO these should be configurable...
|
|
reconnectionOptions: {
|
|
maxReconnectionDelay: 30000,
|
|
initialReconnectionDelay: 1000,
|
|
reconnectionDelayGrowFactor: 1.5,
|
|
maxRetries: 2,
|
|
},
|
|
};
|
|
break;
|
|
}
|
|
(mcpProxyServerUrl as URL).searchParams.append(
|
|
"transportType",
|
|
transportType,
|
|
);
|
|
|
|
const clientTransport =
|
|
transportType === "streamable-http"
|
|
? new StreamableHTTPClientTransport(mcpProxyServerUrl as URL, {
|
|
sessionId: undefined,
|
|
...transportOptions,
|
|
})
|
|
: new SSEClientTransport(mcpProxyServerUrl as URL, transportOptions);
|
|
|
|
if (onNotification) {
|
|
[
|
|
CancelledNotificationSchema,
|
|
LoggingMessageNotificationSchema,
|
|
ResourceUpdatedNotificationSchema,
|
|
ResourceListChangedNotificationSchema,
|
|
ToolListChangedNotificationSchema,
|
|
PromptListChangedNotificationSchema,
|
|
].forEach((notificationSchema) => {
|
|
client.setNotificationHandler(notificationSchema, onNotification);
|
|
});
|
|
|
|
client.fallbackNotificationHandler = (
|
|
notification: Notification,
|
|
): Promise<void> => {
|
|
onNotification(notification);
|
|
return Promise.resolve();
|
|
};
|
|
}
|
|
|
|
if (onStdErrNotification) {
|
|
client.setNotificationHandler(
|
|
StdErrNotificationSchema,
|
|
onStdErrNotification,
|
|
);
|
|
}
|
|
|
|
let capabilities;
|
|
try {
|
|
await client.connect(clientTransport);
|
|
|
|
capabilities = client.getServerCapabilities();
|
|
const initializeRequest = {
|
|
method: "initialize",
|
|
};
|
|
pushHistory(initializeRequest, {
|
|
capabilities,
|
|
serverInfo: client.getServerVersion(),
|
|
instructions: client.getInstructions(),
|
|
});
|
|
} catch (error) {
|
|
console.error(
|
|
`Failed to connect to MCP Server via the MCP Inspector Proxy: ${mcpProxyServerUrl}:`,
|
|
error,
|
|
);
|
|
const shouldRetry = await handleAuthError(error);
|
|
if (shouldRetry) {
|
|
return connect(undefined, retryCount + 1);
|
|
}
|
|
|
|
if (error instanceof SseError && error.code === 401) {
|
|
// Don't set error state if we're about to redirect for auth
|
|
return;
|
|
}
|
|
throw error;
|
|
}
|
|
setServerCapabilities(capabilities ?? null);
|
|
setCompletionsSupported(true); // Reset completions support on new connection
|
|
|
|
if (onPendingRequest) {
|
|
client.setRequestHandler(CreateMessageRequestSchema, (request) => {
|
|
return new Promise((resolve, reject) => {
|
|
onPendingRequest(request, resolve, reject);
|
|
});
|
|
});
|
|
}
|
|
|
|
if (getRoots) {
|
|
client.setRequestHandler(ListRootsRequestSchema, async () => {
|
|
return { roots: getRoots() };
|
|
});
|
|
}
|
|
|
|
setMcpClient(client);
|
|
setConnectionStatus("connected");
|
|
} catch (e) {
|
|
console.error(e);
|
|
setConnectionStatus("error");
|
|
}
|
|
};
|
|
|
|
const disconnect = async () => {
|
|
await mcpClient?.close();
|
|
const authProvider = new InspectorOAuthClientProvider(sseUrl);
|
|
authProvider.clear();
|
|
setMcpClient(null);
|
|
setConnectionStatus("disconnected");
|
|
setCompletionsSupported(false);
|
|
setServerCapabilities(null);
|
|
};
|
|
|
|
return {
|
|
connectionStatus,
|
|
serverCapabilities,
|
|
mcpClient,
|
|
requestHistory,
|
|
makeRequest,
|
|
sendNotification,
|
|
handleCompletion,
|
|
completionsSupported,
|
|
connect,
|
|
disconnect,
|
|
};
|
|
}
|