diff --git a/client/src/components/OAuthCallback.tsx b/client/src/components/OAuthCallback.tsx index a7439df..2a9e27a 100644 --- a/client/src/components/OAuthCallback.tsx +++ b/client/src/components/OAuthCallback.tsx @@ -24,9 +24,12 @@ const OAuthCallback = () => { } try { - const accessToken = await handleOAuthCallback(serverUrl, code); - // Store the access token for future use - sessionStorage.setItem(SESSION_KEYS.ACCESS_TOKEN, accessToken); + const tokens = await handleOAuthCallback(serverUrl, code); + // Store both access and refresh tokens + sessionStorage.setItem(SESSION_KEYS.ACCESS_TOKEN, tokens.access_token); + if (tokens.refresh_token) { + sessionStorage.setItem(SESSION_KEYS.REFRESH_TOKEN, tokens.refresh_token); + } // Redirect back to the main app with server URL to trigger auto-connect window.location.href = `/?serverUrl=${encodeURIComponent(serverUrl)}`; } catch (error) { diff --git a/client/src/lib/auth.ts b/client/src/lib/auth.ts index 0417731..7d70a31 100644 --- a/client/src/lib/auth.ts +++ b/client/src/lib/auth.ts @@ -6,6 +6,12 @@ export interface OAuthMetadata { token_endpoint: string; } +export interface OAuthTokens { + access_token: string; + refresh_token?: string; + expires_in?: number; +} + export async function discoverOAuthMetadata( serverUrl: string, ): Promise { @@ -60,7 +66,7 @@ export async function startOAuthFlow(serverUrl: string): Promise { export async function handleOAuthCallback( serverUrl: string, code: string, -): Promise { +): Promise { // Get stored code verifier const codeVerifier = sessionStorage.getItem(SESSION_KEYS.CODE_VERIFIER); if (!codeVerifier) { @@ -69,7 +75,6 @@ export async function handleOAuthCallback( // Discover OAuth endpoints const metadata = await discoverOAuthMetadata(serverUrl); - // Exchange code for tokens const response = await fetch(metadata.token_endpoint, { method: "POST", @@ -89,5 +94,32 @@ export async function handleOAuthCallback( } const data = await response.json(); - return data.access_token; + return data; +} + +export async function refreshAccessToken(serverUrl: string): Promise { + const refreshToken = sessionStorage.getItem(SESSION_KEYS.REFRESH_TOKEN); + if (!refreshToken) { + throw new Error("No refresh token available"); + } + + const metadata = await discoverOAuthMetadata(serverUrl); + + const response = await fetch(metadata.token_endpoint, { + method: "POST", + headers: { + "Content-Type": "application/json", + }, + body: JSON.stringify({ + grant_type: "refresh_token", + refresh_token: refreshToken + }), + }); + + if (!response.ok) { + throw new Error("Token refresh failed"); + } + + const data = await response.json(); + return data; } diff --git a/client/src/lib/constants.ts b/client/src/lib/constants.ts index e302b52..13a2370 100644 --- a/client/src/lib/constants.ts +++ b/client/src/lib/constants.ts @@ -3,4 +3,5 @@ export const SESSION_KEYS = { CODE_VERIFIER: "mcp_code_verifier", SERVER_URL: "mcp_server_url", ACCESS_TOKEN: "mcp_access_token", + REFRESH_TOKEN: "mcp_refresh_token", } as const; diff --git a/client/src/lib/hooks/useConnection.ts b/client/src/lib/hooks/useConnection.ts index de2d29e..58ea0a8 100644 --- a/client/src/lib/hooks/useConnection.ts +++ b/client/src/lib/hooks/useConnection.ts @@ -16,7 +16,7 @@ import { import { useState } from "react"; import { toast } from "react-toastify"; import { z } from "zod"; -import { startOAuthFlow } from "../auth"; +import { startOAuthFlow, refreshAccessToken } from "../auth"; import { SESSION_KEYS } from "../constants"; import { Notification, StdErrNotificationSchema } from "../notificationTypes"; @@ -121,6 +121,24 @@ export function useConnection({ } }; + const handleTokenRefresh = async () => { + try { + const tokens = await refreshAccessToken(sseUrl); + sessionStorage.setItem(SESSION_KEYS.ACCESS_TOKEN, tokens.access_token); + if (tokens.refresh_token) { + sessionStorage.setItem(SESSION_KEYS.REFRESH_TOKEN, tokens.refresh_token); + } + return tokens.access_token; + } catch (error) { + console.error("Token refresh failed:", error); + // Clear tokens and redirect to home + sessionStorage.removeItem(SESSION_KEYS.ACCESS_TOKEN); + sessionStorage.removeItem(SESSION_KEYS.REFRESH_TOKEN); + window.location.href = "/"; + throw error; + } + }; + const connect = async () => { try { const client = new Client( @@ -157,7 +175,19 @@ export function useConnection({ const clientTransport = new SSEClientTransport(backendUrl, { eventSourceInit: { - fetch: (url, init) => fetch(url, { ...init, headers }), + fetch: async (url, init) => { + const response = await fetch(url, { ...init, headers }); + + if (response.status === 401 && sessionStorage.getItem(SESSION_KEYS.REFRESH_TOKEN)) { + // Try to refresh the token + const newAccessToken = await handleTokenRefresh(); + headers["Authorization"] = `Bearer ${newAccessToken}`; + // Retry the request with new token + return fetch(url, { ...init, headers }); + } + + return response; + }, }, requestInit: { headers,