Merge branch 'main' into tool-input-improvements
This commit is contained in:
@@ -1,5 +1,3 @@
|
||||
import { useDraggablePane } from "./lib/hooks/useDraggablePane";
|
||||
import { useConnection } from "./lib/hooks/useConnection";
|
||||
import {
|
||||
ClientRequest,
|
||||
CompatibilityCallToolResult,
|
||||
@@ -10,15 +8,17 @@ import {
|
||||
ListPromptsResultSchema,
|
||||
ListResourcesResultSchema,
|
||||
ListResourceTemplatesResultSchema,
|
||||
ReadResourceResultSchema,
|
||||
ListToolsResultSchema,
|
||||
ReadResourceResultSchema,
|
||||
Resource,
|
||||
ResourceTemplate,
|
||||
Root,
|
||||
ServerNotification,
|
||||
Tool,
|
||||
} from "@modelcontextprotocol/sdk/types.js";
|
||||
import { useEffect, useRef, useState } from "react";
|
||||
import React, { Suspense, useEffect, useRef, useState } from "react";
|
||||
import { useConnection } from "./lib/hooks/useConnection";
|
||||
import { useDraggablePane } from "./lib/hooks/useDraggablePane";
|
||||
|
||||
import { StdErrNotification } from "./lib/notificationTypes";
|
||||
|
||||
@@ -32,6 +32,7 @@ import {
|
||||
MessageSquare,
|
||||
} from "lucide-react";
|
||||
|
||||
import { toast } from "react-toastify";
|
||||
import { z } from "zod";
|
||||
import "./App.css";
|
||||
import ConsoleTab from "./components/ConsoleTab";
|
||||
@@ -49,6 +50,17 @@ const PROXY_PORT = params.get("proxyPort") ?? "3000";
|
||||
const PROXY_SERVER_URL = `http://localhost:${PROXY_PORT}`;
|
||||
|
||||
const App = () => {
|
||||
// Handle OAuth callback route
|
||||
if (window.location.pathname === "/oauth/callback") {
|
||||
const OAuthCallback = React.lazy(
|
||||
() => import("./components/OAuthCallback"),
|
||||
);
|
||||
return (
|
||||
<Suspense fallback={<div>Loading...</div>}>
|
||||
<OAuthCallback />
|
||||
</Suspense>
|
||||
);
|
||||
}
|
||||
const [resources, setResources] = useState<Resource[]>([]);
|
||||
const [resourceTemplates, setResourceTemplates] = useState<
|
||||
ResourceTemplate[]
|
||||
@@ -71,8 +83,14 @@ const App = () => {
|
||||
return localStorage.getItem("lastArgs") || "";
|
||||
});
|
||||
|
||||
const [sseUrl, setSseUrl] = useState<string>("http://localhost:3001/sse");
|
||||
const [transportType, setTransportType] = useState<"stdio" | "sse">("stdio");
|
||||
const [sseUrl, setSseUrl] = useState<string>(() => {
|
||||
return localStorage.getItem("lastSseUrl") || "http://localhost:3001/sse";
|
||||
});
|
||||
const [transportType, setTransportType] = useState<"stdio" | "sse">(() => {
|
||||
return (
|
||||
(localStorage.getItem("lastTransportType") as "stdio" | "sse") || "stdio"
|
||||
);
|
||||
});
|
||||
const [notifications, setNotifications] = useState<ServerNotification[]>([]);
|
||||
const [stdErrNotifications, setStdErrNotifications] = useState<
|
||||
StdErrNotification[]
|
||||
@@ -190,6 +208,31 @@ const App = () => {
|
||||
localStorage.setItem("lastArgs", args);
|
||||
}, [args]);
|
||||
|
||||
useEffect(() => {
|
||||
localStorage.setItem("lastSseUrl", sseUrl);
|
||||
}, [sseUrl]);
|
||||
|
||||
useEffect(() => {
|
||||
localStorage.setItem("lastTransportType", transportType);
|
||||
}, [transportType]);
|
||||
|
||||
// Auto-connect if serverUrl is provided in URL params (e.g. after OAuth callback)
|
||||
useEffect(() => {
|
||||
const serverUrl = params.get("serverUrl");
|
||||
if (serverUrl) {
|
||||
setSseUrl(serverUrl);
|
||||
setTransportType("sse");
|
||||
// Remove serverUrl from URL without reloading the page
|
||||
const newUrl = new URL(window.location.href);
|
||||
newUrl.searchParams.delete("serverUrl");
|
||||
window.history.replaceState({}, "", newUrl.toString());
|
||||
// Show success toast for OAuth
|
||||
toast.success("Successfully authenticated with OAuth");
|
||||
// Connect to the server
|
||||
connectMcpServer();
|
||||
}
|
||||
}, []);
|
||||
|
||||
useEffect(() => {
|
||||
fetch(`${PROXY_SERVER_URL}/config`)
|
||||
.then((response) => response.json())
|
||||
|
||||
54
client/src/components/OAuthCallback.tsx
Normal file
54
client/src/components/OAuthCallback.tsx
Normal file
@@ -0,0 +1,54 @@
|
||||
import { useEffect, useRef } from "react";
|
||||
import { handleOAuthCallback } from "../lib/auth";
|
||||
import { SESSION_KEYS } from "../lib/constants";
|
||||
|
||||
const OAuthCallback = () => {
|
||||
const hasProcessedRef = useRef(false);
|
||||
|
||||
useEffect(() => {
|
||||
const handleCallback = async () => {
|
||||
// Skip if we've already processed this callback
|
||||
if (hasProcessedRef.current) {
|
||||
return;
|
||||
}
|
||||
hasProcessedRef.current = true;
|
||||
|
||||
const params = new URLSearchParams(window.location.search);
|
||||
const code = params.get("code");
|
||||
const serverUrl = sessionStorage.getItem(SESSION_KEYS.SERVER_URL);
|
||||
|
||||
if (!code || !serverUrl) {
|
||||
console.error("Missing code or server URL");
|
||||
window.location.href = "/";
|
||||
return;
|
||||
}
|
||||
|
||||
try {
|
||||
const tokens = await handleOAuthCallback(serverUrl, code);
|
||||
// Store both access and refresh tokens
|
||||
sessionStorage.setItem(SESSION_KEYS.ACCESS_TOKEN, tokens.access_token);
|
||||
if (tokens.refresh_token) {
|
||||
sessionStorage.setItem(
|
||||
SESSION_KEYS.REFRESH_TOKEN,
|
||||
tokens.refresh_token,
|
||||
);
|
||||
}
|
||||
// Redirect back to the main app with server URL to trigger auto-connect
|
||||
window.location.href = `/?serverUrl=${encodeURIComponent(serverUrl)}`;
|
||||
} catch (error) {
|
||||
console.error("OAuth callback error:", error);
|
||||
window.location.href = "/";
|
||||
}
|
||||
};
|
||||
|
||||
void handleCallback();
|
||||
}, []);
|
||||
|
||||
return (
|
||||
<div className="flex items-center justify-center h-screen">
|
||||
<p className="text-lg text-gray-500">Processing OAuth callback...</p>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
export default OAuthCallback;
|
||||
@@ -94,11 +94,20 @@ const ToolsTab = ({
|
||||
className="max-w-full h-auto"
|
||||
/>
|
||||
)}
|
||||
{item.type === "resource" && (
|
||||
<pre className="bg-gray-50 dark:bg-gray-800 dark:text-gray-100 whitespace-pre-wrap break-words p-4 rounded text-sm overflow-auto max-h-64">
|
||||
{JSON.stringify(item.resource, null, 2)}
|
||||
</pre>
|
||||
)}
|
||||
{item.type === "resource" &&
|
||||
(item.resource?.mimeType?.startsWith("audio/") ? (
|
||||
<audio
|
||||
controls
|
||||
src={`data:${item.resource.mimeType};base64,${item.resource.blob}`}
|
||||
className="w-full"
|
||||
>
|
||||
<p>Your browser does not support audio playback</p>
|
||||
</audio>
|
||||
) : (
|
||||
<pre className="bg-gray-50 dark:bg-gray-800 dark:text-gray-100 whitespace-pre-wrap break-words p-4 rounded text-sm overflow-auto max-h-64">
|
||||
{JSON.stringify(item.resource, null, 2)}
|
||||
</pre>
|
||||
))}
|
||||
</div>
|
||||
))}
|
||||
</>
|
||||
|
||||
134
client/src/lib/auth.ts
Normal file
134
client/src/lib/auth.ts
Normal file
@@ -0,0 +1,134 @@
|
||||
import pkceChallenge from "pkce-challenge";
|
||||
import { SESSION_KEYS } from "./constants";
|
||||
import { z } from "zod";
|
||||
|
||||
export const OAuthMetadataSchema = z.object({
|
||||
authorization_endpoint: z.string(),
|
||||
token_endpoint: z.string(),
|
||||
});
|
||||
|
||||
export type OAuthMetadata = z.infer<typeof OAuthMetadataSchema>;
|
||||
|
||||
export const OAuthTokensSchema = z.object({
|
||||
access_token: z.string(),
|
||||
refresh_token: z.string().optional(),
|
||||
expires_in: z.number().optional(),
|
||||
});
|
||||
|
||||
export type OAuthTokens = z.infer<typeof OAuthTokensSchema>;
|
||||
|
||||
export async function discoverOAuthMetadata(
|
||||
serverUrl: string,
|
||||
): Promise<OAuthMetadata> {
|
||||
try {
|
||||
const url = new URL("/.well-known/oauth-authorization-server", serverUrl);
|
||||
const response = await fetch(url.toString());
|
||||
|
||||
if (response.ok) {
|
||||
const metadata = await response.json();
|
||||
const validatedMetadata = OAuthMetadataSchema.parse({
|
||||
authorization_endpoint: metadata.authorization_endpoint,
|
||||
token_endpoint: metadata.token_endpoint,
|
||||
});
|
||||
return validatedMetadata;
|
||||
}
|
||||
} catch (error) {
|
||||
console.warn("OAuth metadata discovery failed:", error);
|
||||
}
|
||||
|
||||
// Fall back to default endpoints
|
||||
const baseUrl = new URL(serverUrl);
|
||||
const defaultMetadata = {
|
||||
authorization_endpoint: new URL("/authorize", baseUrl).toString(),
|
||||
token_endpoint: new URL("/token", baseUrl).toString(),
|
||||
};
|
||||
return OAuthMetadataSchema.parse(defaultMetadata);
|
||||
}
|
||||
|
||||
export async function startOAuthFlow(serverUrl: string): Promise<string> {
|
||||
// Generate PKCE challenge
|
||||
const challenge = await pkceChallenge();
|
||||
const codeVerifier = challenge.code_verifier;
|
||||
const codeChallenge = challenge.code_challenge;
|
||||
|
||||
// Store code verifier for later use
|
||||
sessionStorage.setItem(SESSION_KEYS.CODE_VERIFIER, codeVerifier);
|
||||
|
||||
// Discover OAuth endpoints
|
||||
const metadata = await discoverOAuthMetadata(serverUrl);
|
||||
|
||||
// Build authorization URL
|
||||
const authUrl = new URL(metadata.authorization_endpoint);
|
||||
authUrl.searchParams.set("response_type", "code");
|
||||
authUrl.searchParams.set("code_challenge", codeChallenge);
|
||||
authUrl.searchParams.set("code_challenge_method", "S256");
|
||||
authUrl.searchParams.set(
|
||||
"redirect_uri",
|
||||
window.location.origin + "/oauth/callback",
|
||||
);
|
||||
|
||||
return authUrl.toString();
|
||||
}
|
||||
|
||||
export async function handleOAuthCallback(
|
||||
serverUrl: string,
|
||||
code: string,
|
||||
): Promise<OAuthTokens> {
|
||||
// Get stored code verifier
|
||||
const codeVerifier = sessionStorage.getItem(SESSION_KEYS.CODE_VERIFIER);
|
||||
if (!codeVerifier) {
|
||||
throw new Error("No code verifier found");
|
||||
}
|
||||
|
||||
// Discover OAuth endpoints
|
||||
const metadata = await discoverOAuthMetadata(serverUrl);
|
||||
// Exchange code for tokens
|
||||
const response = await fetch(metadata.token_endpoint, {
|
||||
method: "POST",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
body: JSON.stringify({
|
||||
grant_type: "authorization_code",
|
||||
code,
|
||||
code_verifier: codeVerifier,
|
||||
redirect_uri: window.location.origin + "/oauth/callback",
|
||||
}),
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
throw new Error("Token exchange failed");
|
||||
}
|
||||
|
||||
const tokens = await response.json();
|
||||
return OAuthTokensSchema.parse(tokens);
|
||||
}
|
||||
|
||||
export async function refreshAccessToken(
|
||||
serverUrl: string,
|
||||
): Promise<OAuthTokens> {
|
||||
const refreshToken = sessionStorage.getItem(SESSION_KEYS.REFRESH_TOKEN);
|
||||
if (!refreshToken) {
|
||||
throw new Error("No refresh token available");
|
||||
}
|
||||
|
||||
const metadata = await discoverOAuthMetadata(serverUrl);
|
||||
|
||||
const response = await fetch(metadata.token_endpoint, {
|
||||
method: "POST",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
body: JSON.stringify({
|
||||
grant_type: "refresh_token",
|
||||
refresh_token: refreshToken,
|
||||
}),
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
throw new Error("Token refresh failed");
|
||||
}
|
||||
|
||||
const tokens = await response.json();
|
||||
return OAuthTokensSchema.parse(tokens);
|
||||
}
|
||||
7
client/src/lib/constants.ts
Normal file
7
client/src/lib/constants.ts
Normal file
@@ -0,0 +1,7 @@
|
||||
// OAuth-related session storage keys
|
||||
export const SESSION_KEYS = {
|
||||
CODE_VERIFIER: "mcp_code_verifier",
|
||||
SERVER_URL: "mcp_server_url",
|
||||
ACCESS_TOKEN: "mcp_access_token",
|
||||
REFRESH_TOKEN: "mcp_refresh_token",
|
||||
} as const;
|
||||
@@ -1,5 +1,8 @@
|
||||
import { Client } from "@modelcontextprotocol/sdk/client/index.js";
|
||||
import { SSEClientTransport } from "@modelcontextprotocol/sdk/client/sse.js";
|
||||
import {
|
||||
SSEClientTransport,
|
||||
SseError,
|
||||
} from "@modelcontextprotocol/sdk/client/sse.js";
|
||||
import {
|
||||
ClientNotification,
|
||||
ClientRequest,
|
||||
@@ -12,8 +15,10 @@ import {
|
||||
} from "@modelcontextprotocol/sdk/types.js";
|
||||
import { useState } from "react";
|
||||
import { toast } from "react-toastify";
|
||||
import { Notification, StdErrNotificationSchema } from "../notificationTypes";
|
||||
import { z } from "zod";
|
||||
import { startOAuthFlow, refreshAccessToken } from "../auth";
|
||||
import { SESSION_KEYS } from "../constants";
|
||||
import { Notification, StdErrNotificationSchema } from "../notificationTypes";
|
||||
|
||||
const DEFAULT_REQUEST_TIMEOUT_MSEC = 10000;
|
||||
|
||||
@@ -116,7 +121,49 @@ export function useConnection({
|
||||
}
|
||||
};
|
||||
|
||||
const connect = async () => {
|
||||
const initiateOAuthFlow = async () => {
|
||||
sessionStorage.removeItem(SESSION_KEYS.ACCESS_TOKEN);
|
||||
sessionStorage.removeItem(SESSION_KEYS.REFRESH_TOKEN);
|
||||
sessionStorage.setItem(SESSION_KEYS.SERVER_URL, sseUrl);
|
||||
const redirectUrl = await startOAuthFlow(sseUrl);
|
||||
window.location.href = redirectUrl;
|
||||
};
|
||||
|
||||
const handleTokenRefresh = async () => {
|
||||
try {
|
||||
const tokens = await refreshAccessToken(sseUrl);
|
||||
sessionStorage.setItem(SESSION_KEYS.ACCESS_TOKEN, tokens.access_token);
|
||||
if (tokens.refresh_token) {
|
||||
sessionStorage.setItem(
|
||||
SESSION_KEYS.REFRESH_TOKEN,
|
||||
tokens.refresh_token,
|
||||
);
|
||||
}
|
||||
return tokens.access_token;
|
||||
} catch (error) {
|
||||
console.error("Token refresh failed:", error);
|
||||
await initiateOAuthFlow();
|
||||
throw error;
|
||||
}
|
||||
};
|
||||
|
||||
const handleAuthError = async (error: unknown) => {
|
||||
if (error instanceof SseError && error.code === 401) {
|
||||
if (sessionStorage.getItem(SESSION_KEYS.REFRESH_TOKEN)) {
|
||||
try {
|
||||
await handleTokenRefresh();
|
||||
return true;
|
||||
} catch (error) {
|
||||
console.error("Token refresh failed:", error);
|
||||
}
|
||||
} else {
|
||||
await initiateOAuthFlow();
|
||||
}
|
||||
}
|
||||
return false;
|
||||
};
|
||||
|
||||
const connect = async (_e?: unknown, retryCount: number = 0) => {
|
||||
try {
|
||||
const client = new Client<Request, Notification, Result>(
|
||||
{
|
||||
@@ -144,7 +191,20 @@ export function useConnection({
|
||||
backendUrl.searchParams.append("url", sseUrl);
|
||||
}
|
||||
|
||||
const clientTransport = new SSEClientTransport(backendUrl);
|
||||
const headers: HeadersInit = {};
|
||||
const accessToken = sessionStorage.getItem(SESSION_KEYS.ACCESS_TOKEN);
|
||||
if (accessToken) {
|
||||
headers["Authorization"] = `Bearer ${accessToken}`;
|
||||
}
|
||||
|
||||
const clientTransport = new SSEClientTransport(backendUrl, {
|
||||
eventSourceInit: {
|
||||
fetch: (url, init) => fetch(url, { ...init, headers }),
|
||||
},
|
||||
requestInit: {
|
||||
headers,
|
||||
},
|
||||
});
|
||||
|
||||
if (onNotification) {
|
||||
client.setNotificationHandler(
|
||||
@@ -160,7 +220,21 @@ export function useConnection({
|
||||
);
|
||||
}
|
||||
|
||||
await client.connect(clientTransport);
|
||||
try {
|
||||
await client.connect(clientTransport);
|
||||
} catch (error) {
|
||||
console.error("Failed to connect to MCP server:", error);
|
||||
const shouldRetry = await handleAuthError(error);
|
||||
if (shouldRetry) {
|
||||
return connect(undefined, retryCount + 1);
|
||||
}
|
||||
|
||||
if (error instanceof SseError && error.code === 401) {
|
||||
// Don't set error state if we're about to redirect for auth
|
||||
return;
|
||||
}
|
||||
throw error;
|
||||
}
|
||||
|
||||
const capabilities = client.getServerCapabilities();
|
||||
setServerCapabilities(capabilities ?? null);
|
||||
|
||||
Reference in New Issue
Block a user