diff --git a/client/src/lib/hooks/useConnection.ts b/client/src/lib/hooks/useConnection.ts index abbeb7c..2577a41 100644 --- a/client/src/lib/hooks/useConnection.ts +++ b/client/src/lib/hooks/useConnection.ts @@ -278,15 +278,26 @@ export function useConnection({ setConnectionStatus("error-connecting-to-proxy"); return; } - const mcpProxyServerUrl = new URL(`${getMCPProxyAddress(config)}/sse`); - mcpProxyServerUrl.searchParams.append("transportType", transportType); - if (transportType === "stdio") { - mcpProxyServerUrl.searchParams.append("command", command); - mcpProxyServerUrl.searchParams.append("args", args); - mcpProxyServerUrl.searchParams.append("env", JSON.stringify(env)); - } else { - mcpProxyServerUrl.searchParams.append("url", sseUrl); + let mcpProxyServerUrl; + switch (transportType) { + case "stdio": + mcpProxyServerUrl = new URL(`${getMCPProxyAddress(config)}/stdio`); + mcpProxyServerUrl.searchParams.append("command", command); + mcpProxyServerUrl.searchParams.append("args", args); + mcpProxyServerUrl.searchParams.append("env", JSON.stringify(env)); + break; + case "sse": + mcpProxyServerUrl = new URL(`${getMCPProxyAddress(config)}/sse`); + mcpProxyServerUrl.searchParams.append("url", sseUrl); + break; + + case "streamable-http": + mcpProxyServerUrl = new URL(`${getMCPProxyAddress(config)}/mcp`); + mcpProxyServerUrl.searchParams.append("url", sseUrl); + break; } + (mcpProxyServerUrl as URL).searchParams.append("transportType", transportType); + try { // Inject auth manually instead of using SSEClientTransport, because we're @@ -304,7 +315,7 @@ export function useConnection({ headers[authHeaderName] = `Bearer ${token}`; } - const clientTransport = new SSEClientTransport(mcpProxyServerUrl, { + const clientTransport = new SSEClientTransport(mcpProxyServerUrl as URL, { eventSourceInit: { fetch: (url, init) => fetch(url, { ...init, headers }), }, diff --git a/server/src/index.ts b/server/src/index.ts index e966910..08a88e7 100644 --- a/server/src/index.ts +++ b/server/src/index.ts @@ -97,7 +97,9 @@ const createTransport = async (req: express.Request): Promise => { console.log("Connected to SSE transport"); return transport; } else if (transportType === "streamable-http") { - const headers: HeadersInit = {}; + const headers: HeadersInit = { + Accept: "text/event-stream, application/json" + }; for (const key of STREAMABLE_HTTP_HEADERS_PASSTHROUGH) { if (req.headers[key] === undefined) { @@ -127,9 +129,79 @@ const createTransport = async (req: express.Request): Promise => { let backingServerTransport: Transport | undefined; -app.get("/sse", async (req, res) => { + +app.get("/mcp", async (req, res) => { try { - console.log("New SSE connection"); + console.log("New streamable-http connection"); + + try { + await backingServerTransport?.close(); + backingServerTransport = await createTransport(req); + } catch (error) { + if (error instanceof SseError && error.code === 401) { + console.error( + "Received 401 Unauthorized from MCP server:", + error.message, + ); + res.status(401).json(error); + return; + } + + throw error; + } + + console.log("Connected MCP client to backing server transport"); + + const webAppTransport = new SSEServerTransport("/mcp", res); + webAppTransports.push(webAppTransport); + console.log("Created web app transport"); + + await webAppTransport.start(); + + if (backingServerTransport instanceof StdioClientTransport) { + backingServerTransport.stderr!.on("data", (chunk) => { + webAppTransport.send({ + jsonrpc: "2.0", + method: "notifications/stderr", + params: { + content: chunk.toString(), + }, + }); + }); + } + + mcpProxy({ + transportToClient: webAppTransport, + transportToServer: backingServerTransport, + }); + + console.log("Set up MCP proxy"); + } catch (error) { + console.error("Error in /sse route:", error); + res.status(500).json(error); + } +}); + +app.post("/mcp", async (req, res) => { + try { + const sessionId = req.query.sessionId; + console.log(`Received message for sessionId ${sessionId}`); + + const transport = webAppTransports.find((t) => t.sessionId === sessionId); + if (!transport) { + res.status(404).end("Session not found"); + return; + } + await transport.handlePostMessage(req, res); + } catch (error) { + console.error("Error in /mcp route:", error); + res.status(500).json(error); + } +}); + +app.get("/stdio", async (req, res) => { + try { + console.log("New connection"); try { await backingServerTransport?.close(); @@ -150,15 +222,12 @@ app.get("/sse", async (req, res) => { console.log("Connected MCP client to backing server transport"); const webAppTransport = new SSEServerTransport("/message", res); - console.log("Created web app transport"); - webAppTransports.push(webAppTransport); + console.log("Created web app transport"); await webAppTransport.start(); - - if (backingServerTransport instanceof StdioClientTransport) { - backingServerTransport.stderr!.on("data", (chunk) => { + (backingServerTransport as StdioClientTransport).stderr!.on("data", (chunk) => { webAppTransport.send({ jsonrpc: "2.0", method: "notifications/stderr", @@ -167,8 +236,48 @@ app.get("/sse", async (req, res) => { }, }); }); + + mcpProxy({ + transportToClient: webAppTransport, + transportToServer: backingServerTransport, + }); + + console.log("Set up MCP proxy"); + } catch (error) { + console.error("Error in /stdio route:", error); + res.status(500).json(error); + } +}); + +app.get("/sse", async (req, res) => { + try { + console.log("New SSE connection. NOTE: The sse transport is deprecated and has been replaced by streamable-http"); + + try { + await backingServerTransport?.close(); + backingServerTransport = await createTransport(req); + } catch (error) { + if (error instanceof SseError && error.code === 401) { + console.error( + "Received 401 Unauthorized from MCP server:", + error.message, + ); + res.status(401).json(error); + return; + } + + throw error; } + console.log("Connected MCP client to backing server transport"); + + const webAppTransport = new SSEServerTransport("/message", res); + webAppTransports.push(webAppTransport); + + console.log("Created web app transport"); + + await webAppTransport.start(); + mcpProxy({ transportToClient: webAppTransport, transportToServer: backingServerTransport,