420 lines
12 KiB
TypeScript
420 lines
12 KiB
TypeScript
import { Client } from "@modelcontextprotocol/sdk/client/index.js";
|
|
import {
|
|
SSEClientTransport,
|
|
SseError,
|
|
} from "@modelcontextprotocol/sdk/client/sse.js";
|
|
import {
|
|
ClientNotification,
|
|
ClientRequest,
|
|
CreateMessageRequestSchema,
|
|
ListRootsRequestSchema,
|
|
ResourceUpdatedNotificationSchema,
|
|
LoggingMessageNotificationSchema,
|
|
Request,
|
|
Result,
|
|
ServerCapabilities,
|
|
PromptReference,
|
|
ResourceReference,
|
|
McpError,
|
|
CompleteResultSchema,
|
|
ErrorCode,
|
|
CancelledNotificationSchema,
|
|
ResourceListChangedNotificationSchema,
|
|
ToolListChangedNotificationSchema,
|
|
PromptListChangedNotificationSchema,
|
|
Progress,
|
|
} from "@modelcontextprotocol/sdk/types.js";
|
|
import { RequestOptions } from "@modelcontextprotocol/sdk/shared/protocol.js";
|
|
import { useState } from "react";
|
|
import { useToast } from "@/hooks/use-toast";
|
|
import { z } from "zod";
|
|
import { ConnectionStatus } from "../constants";
|
|
import { Notification, StdErrNotificationSchema } from "../notificationTypes";
|
|
import { auth } from "@modelcontextprotocol/sdk/client/auth.js";
|
|
import { InspectorOAuthClientProvider } from "../auth";
|
|
import packageJson from "../../../package.json";
|
|
import {
|
|
getMCPProxyAddress,
|
|
getMCPServerRequestMaxTotalTimeout,
|
|
resetRequestTimeoutOnProgress,
|
|
} from "@/utils/configUtils";
|
|
import { getMCPServerRequestTimeout } from "@/utils/configUtils";
|
|
import { InspectorConfig } from "../configurationTypes";
|
|
|
|
interface UseConnectionOptions {
|
|
transportType: "stdio" | "sse";
|
|
command: string;
|
|
args: string;
|
|
sseUrl: string;
|
|
env: Record<string, string>;
|
|
bearerToken?: string;
|
|
headerName?: string;
|
|
config: InspectorConfig;
|
|
onNotification?: (notification: Notification) => void;
|
|
onStdErrNotification?: (notification: Notification) => void;
|
|
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
|
onPendingRequest?: (request: any, resolve: any, reject: any) => void;
|
|
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
|
getRoots?: () => any[];
|
|
}
|
|
|
|
export function useConnection({
|
|
transportType,
|
|
command,
|
|
args,
|
|
sseUrl,
|
|
env,
|
|
bearerToken,
|
|
headerName,
|
|
config,
|
|
onNotification,
|
|
onStdErrNotification,
|
|
onPendingRequest,
|
|
getRoots,
|
|
}: UseConnectionOptions) {
|
|
const [connectionStatus, setConnectionStatus] =
|
|
useState<ConnectionStatus>("disconnected");
|
|
const { toast } = useToast();
|
|
const [serverCapabilities, setServerCapabilities] =
|
|
useState<ServerCapabilities | null>(null);
|
|
const [mcpClient, setMcpClient] = useState<Client | null>(null);
|
|
const [requestHistory, setRequestHistory] = useState<
|
|
{ request: string; response?: string }[]
|
|
>([]);
|
|
const [completionsSupported, setCompletionsSupported] = useState(true);
|
|
|
|
const pushHistory = (request: object, response?: object) => {
|
|
setRequestHistory((prev) => [
|
|
...prev,
|
|
{
|
|
request: JSON.stringify(request),
|
|
response: response !== undefined ? JSON.stringify(response) : undefined,
|
|
},
|
|
]);
|
|
};
|
|
|
|
const makeRequest = async <T extends z.ZodType>(
|
|
request: ClientRequest,
|
|
schema: T,
|
|
options?: RequestOptions & { suppressToast?: boolean },
|
|
): Promise<z.output<T>> => {
|
|
if (!mcpClient) {
|
|
throw new Error("MCP client not connected");
|
|
}
|
|
try {
|
|
const abortController = new AbortController();
|
|
|
|
// prepare MCP Client request options
|
|
const mcpRequestOptions: RequestOptions = {
|
|
signal: options?.signal ?? abortController.signal,
|
|
resetTimeoutOnProgress:
|
|
options?.resetTimeoutOnProgress ??
|
|
resetRequestTimeoutOnProgress(config),
|
|
timeout: options?.timeout ?? getMCPServerRequestTimeout(config),
|
|
maxTotalTimeout:
|
|
options?.maxTotalTimeout ??
|
|
getMCPServerRequestMaxTotalTimeout(config),
|
|
};
|
|
|
|
// If progress notifications are enabled, add an onprogress hook to the MCP Client request options
|
|
// This is required by SDK to reset the timeout on progress notifications
|
|
if (mcpRequestOptions.resetTimeoutOnProgress) {
|
|
mcpRequestOptions.onprogress = (params: Progress) => {
|
|
// Add progress notification to `Server Notification` window in the UI
|
|
if (onNotification) {
|
|
onNotification({
|
|
method: "notification/progress",
|
|
params,
|
|
});
|
|
}
|
|
};
|
|
}
|
|
|
|
let response;
|
|
try {
|
|
response = await mcpClient.request(request, schema, mcpRequestOptions);
|
|
|
|
pushHistory(request, response);
|
|
} catch (error) {
|
|
const errorMessage =
|
|
error instanceof Error ? error.message : String(error);
|
|
pushHistory(request, { error: errorMessage });
|
|
throw error;
|
|
}
|
|
|
|
return response;
|
|
} catch (e: unknown) {
|
|
if (!options?.suppressToast) {
|
|
const errorString = (e as Error).message ?? String(e);
|
|
toast({
|
|
title: "Error",
|
|
description: errorString,
|
|
variant: "destructive",
|
|
});
|
|
}
|
|
throw e;
|
|
}
|
|
};
|
|
|
|
const handleCompletion = async (
|
|
ref: ResourceReference | PromptReference,
|
|
argName: string,
|
|
value: string,
|
|
signal?: AbortSignal,
|
|
): Promise<string[]> => {
|
|
if (!mcpClient || !completionsSupported) {
|
|
return [];
|
|
}
|
|
|
|
const request: ClientRequest = {
|
|
method: "completion/complete",
|
|
params: {
|
|
argument: {
|
|
name: argName,
|
|
value,
|
|
},
|
|
ref,
|
|
},
|
|
};
|
|
|
|
try {
|
|
const response = await makeRequest(request, CompleteResultSchema, {
|
|
signal,
|
|
suppressToast: true,
|
|
});
|
|
return response?.completion.values || [];
|
|
} catch (e: unknown) {
|
|
// Disable completions silently if the server doesn't support them.
|
|
// See https://github.com/modelcontextprotocol/specification/discussions/122
|
|
if (e instanceof McpError && e.code === ErrorCode.MethodNotFound) {
|
|
setCompletionsSupported(false);
|
|
return [];
|
|
}
|
|
|
|
// Unexpected errors - show toast and rethrow
|
|
toast({
|
|
title: "Error",
|
|
description: e instanceof Error ? e.message : String(e),
|
|
variant: "destructive",
|
|
});
|
|
throw e;
|
|
}
|
|
};
|
|
|
|
const sendNotification = async (notification: ClientNotification) => {
|
|
if (!mcpClient) {
|
|
const error = new Error("MCP client not connected");
|
|
toast({
|
|
title: "Error",
|
|
description: error.message,
|
|
variant: "destructive",
|
|
});
|
|
throw error;
|
|
}
|
|
|
|
try {
|
|
await mcpClient.notification(notification);
|
|
// Log successful notifications
|
|
pushHistory(notification);
|
|
} catch (e: unknown) {
|
|
if (e instanceof McpError) {
|
|
// Log MCP protocol errors
|
|
pushHistory(notification, { error: e.message });
|
|
}
|
|
toast({
|
|
title: "Error",
|
|
description: e instanceof Error ? e.message : String(e),
|
|
variant: "destructive",
|
|
});
|
|
throw e;
|
|
}
|
|
};
|
|
|
|
const checkProxyHealth = async () => {
|
|
try {
|
|
const proxyHealthUrl = new URL(`${getMCPProxyAddress(config)}/health`);
|
|
const proxyHealthResponse = await fetch(proxyHealthUrl);
|
|
const proxyHealth = await proxyHealthResponse.json();
|
|
if (proxyHealth?.status !== "ok") {
|
|
throw new Error("MCP Proxy Server is not healthy");
|
|
}
|
|
} catch (e) {
|
|
console.error("Couldn't connect to MCP Proxy Server", e);
|
|
throw e;
|
|
}
|
|
};
|
|
|
|
const handleAuthError = async (error: unknown) => {
|
|
if (error instanceof SseError && error.code === 401) {
|
|
// Create a new auth provider with the current server URL
|
|
const serverAuthProvider = new InspectorOAuthClientProvider(sseUrl);
|
|
|
|
const result = await auth(serverAuthProvider, { serverUrl: sseUrl });
|
|
return result === "AUTHORIZED";
|
|
}
|
|
|
|
return false;
|
|
};
|
|
|
|
const connect = async (_e?: unknown, retryCount: number = 0) => {
|
|
const client = new Client<Request, Notification, Result>(
|
|
{
|
|
name: "mcp-inspector",
|
|
version: packageJson.version,
|
|
},
|
|
{
|
|
capabilities: {
|
|
sampling: {},
|
|
roots: {
|
|
listChanged: true,
|
|
},
|
|
},
|
|
},
|
|
);
|
|
|
|
try {
|
|
await checkProxyHealth();
|
|
} catch {
|
|
setConnectionStatus("error-connecting-to-proxy");
|
|
return;
|
|
}
|
|
const mcpProxyServerUrl = new URL(`${getMCPProxyAddress(config)}/sse`);
|
|
mcpProxyServerUrl.searchParams.append("transportType", transportType);
|
|
if (transportType === "stdio") {
|
|
mcpProxyServerUrl.searchParams.append("command", command);
|
|
mcpProxyServerUrl.searchParams.append("args", args);
|
|
mcpProxyServerUrl.searchParams.append("env", JSON.stringify(env));
|
|
} else {
|
|
mcpProxyServerUrl.searchParams.append("url", sseUrl);
|
|
}
|
|
|
|
try {
|
|
// Inject auth manually instead of using SSEClientTransport, because we're
|
|
// proxying through the inspector server first.
|
|
const headers: HeadersInit = {};
|
|
|
|
// Create an auth provider with the current server URL
|
|
const serverAuthProvider = new InspectorOAuthClientProvider(sseUrl);
|
|
|
|
// Use manually provided bearer token if available, otherwise use OAuth tokens
|
|
const token =
|
|
bearerToken || (await serverAuthProvider.tokens())?.access_token;
|
|
if (token) {
|
|
const authHeaderName = headerName || "Authorization";
|
|
headers[authHeaderName] = `Bearer ${token}`;
|
|
}
|
|
|
|
const clientTransport = new SSEClientTransport(mcpProxyServerUrl, {
|
|
eventSourceInit: {
|
|
fetch: (url, init) => fetch(url, { ...init, headers }),
|
|
},
|
|
requestInit: {
|
|
headers,
|
|
},
|
|
});
|
|
|
|
if (onNotification) {
|
|
[
|
|
CancelledNotificationSchema,
|
|
LoggingMessageNotificationSchema,
|
|
ResourceUpdatedNotificationSchema,
|
|
ResourceListChangedNotificationSchema,
|
|
ToolListChangedNotificationSchema,
|
|
PromptListChangedNotificationSchema,
|
|
].forEach((notificationSchema) => {
|
|
client.setNotificationHandler(notificationSchema, onNotification);
|
|
});
|
|
|
|
client.fallbackNotificationHandler = (
|
|
notification: Notification,
|
|
): Promise<void> => {
|
|
onNotification(notification);
|
|
return Promise.resolve();
|
|
};
|
|
}
|
|
|
|
if (onStdErrNotification) {
|
|
client.setNotificationHandler(
|
|
StdErrNotificationSchema,
|
|
onStdErrNotification,
|
|
);
|
|
}
|
|
|
|
let capabilities;
|
|
try {
|
|
await client.connect(clientTransport);
|
|
|
|
capabilities = client.getServerCapabilities();
|
|
const initializeRequest = {
|
|
method: "initialize",
|
|
};
|
|
pushHistory(initializeRequest, {
|
|
capabilities,
|
|
serverInfo: client.getServerVersion(),
|
|
instructions: client.getInstructions(),
|
|
});
|
|
} catch (error) {
|
|
console.error(
|
|
`Failed to connect to MCP Server via the MCP Inspector Proxy: ${mcpProxyServerUrl}:`,
|
|
error,
|
|
);
|
|
const shouldRetry = await handleAuthError(error);
|
|
if (shouldRetry) {
|
|
return connect(undefined, retryCount + 1);
|
|
}
|
|
|
|
if (error instanceof SseError && error.code === 401) {
|
|
// Don't set error state if we're about to redirect for auth
|
|
return;
|
|
}
|
|
throw error;
|
|
}
|
|
setServerCapabilities(capabilities ?? null);
|
|
setCompletionsSupported(true); // Reset completions support on new connection
|
|
|
|
if (onPendingRequest) {
|
|
client.setRequestHandler(CreateMessageRequestSchema, (request) => {
|
|
return new Promise((resolve, reject) => {
|
|
onPendingRequest(request, resolve, reject);
|
|
});
|
|
});
|
|
}
|
|
|
|
if (getRoots) {
|
|
client.setRequestHandler(ListRootsRequestSchema, async () => {
|
|
return { roots: getRoots() };
|
|
});
|
|
}
|
|
|
|
setMcpClient(client);
|
|
setConnectionStatus("connected");
|
|
} catch (e) {
|
|
console.error(e);
|
|
setConnectionStatus("error");
|
|
}
|
|
};
|
|
|
|
const disconnect = async () => {
|
|
await mcpClient?.close();
|
|
const authProvider = new InspectorOAuthClientProvider(sseUrl);
|
|
authProvider.clear();
|
|
setMcpClient(null);
|
|
setConnectionStatus("disconnected");
|
|
setCompletionsSupported(false);
|
|
setServerCapabilities(null);
|
|
};
|
|
|
|
return {
|
|
connectionStatus,
|
|
serverCapabilities,
|
|
mcpClient,
|
|
requestHistory,
|
|
makeRequest,
|
|
sendNotification,
|
|
handleCompletion,
|
|
completionsSupported,
|
|
connect,
|
|
disconnect,
|
|
};
|
|
}
|