diff --git a/client/src/lib/hooks/useConnection.ts b/client/src/lib/hooks/useConnection.ts index 9ceafa2..cd90d7f 100644 --- a/client/src/lib/hooks/useConnection.ts +++ b/client/src/lib/hooks/useConnection.ts @@ -144,7 +144,23 @@ export function useConnection({ } }; - const connect = async () => { + const handleAuthError = async (error: unknown) => { + if (error instanceof SseError && error.code === 401) { + if (sessionStorage.getItem(SESSION_KEYS.REFRESH_TOKEN)) { + try { + await handleTokenRefresh(); + return true; + } catch (error) { + console.error("Token refresh failed:", error); + } + } else { + await initiateOAuthFlow(); + } + } + return false; + }; + + const connect = async (_e?: unknown, retryCount: number = 0) => { try { const client = new Client( { @@ -180,28 +196,7 @@ export function useConnection({ const clientTransport = new SSEClientTransport(backendUrl, { eventSourceInit: { - fetch: async (url, init) => { - const response = await fetch(url, { ...init, headers }); - - if (response.status === 401) { - if (sessionStorage.getItem(SESSION_KEYS.REFRESH_TOKEN)) { - try { - const newAccessToken = await handleTokenRefresh(); - headers["Authorization"] = `Bearer ${newAccessToken}`; - return fetch(url, { ...init, headers }); - } catch (error) { - console.error("Token refresh failed:", error); - } - } - - if (sessionStorage.getItem(SESSION_KEYS.ACCESS_TOKEN)) { - await initiateOAuthFlow(); - return new Response(); - } - } - - return response; - }, + fetch: (url, init) => fetch(url, { ...init, headers }), }, requestInit: { headers, @@ -226,11 +221,17 @@ export function useConnection({ await client.connect(clientTransport); } catch (error) { console.error("Failed to connect to MCP server:", error); + const shouldRetry = await handleAuthError(error); + if (shouldRetry) { + return connect(undefined, retryCount + 1); + } + if (error instanceof SseError && error.code === 401) { - await initiateOAuthFlow(); + // Don't set error state if we're about to redirect for auth return; } - throw error; + setConnectionStatus("error"); + return; } const capabilities = client.getServerCapabilities();