Merge pull request #28 from modelcontextprotocol/justin/sampling

Add tab and approval flow for server -> client sampling
This commit is contained in:
Justin Spahr-Summers
2024-10-28 15:24:05 +00:00
committed by GitHub
2 changed files with 130 additions and 11 deletions

View File

@@ -1,18 +1,10 @@
import { Button } from "@/components/ui/button";
import { Input } from "@/components/ui/input";
import {
Select,
SelectContent,
SelectItem,
SelectTrigger,
SelectValue,
} from "@/components/ui/select";
import { Tabs, TabsList, TabsTrigger } from "@/components/ui/tabs";
import { Client } from "@modelcontextprotocol/sdk/client/index.js";
import { SSEClientTransport } from "@modelcontextprotocol/sdk/client/sse.js";
import {
CallToolResultSchema,
ClientRequest,
CreateMessageRequestSchema,
CreateMessageResult,
EmptyResultSchema,
GetPromptResultSchema,
ListPromptsResultSchema,
@@ -24,16 +16,28 @@ import {
ServerNotification,
Tool,
} from "@modelcontextprotocol/sdk/types.js";
import { useEffect, useRef, useState } from "react";
import { Button } from "@/components/ui/button";
import { Input } from "@/components/ui/input";
import {
Select,
SelectContent,
SelectItem,
SelectTrigger,
SelectValue,
} from "@/components/ui/select";
import { Tabs, TabsList, TabsTrigger } from "@/components/ui/tabs";
import {
Bell,
Files,
Hammer,
Hash,
MessageSquare,
Play,
Send,
Terminal,
} from "lucide-react";
import { useEffect, useRef, useState } from "react";
import { AnyZodObject } from "zod";
import "./App.css";
@@ -43,6 +47,7 @@ import PingTab from "./components/PingTab";
import PromptsTab, { Prompt } from "./components/PromptsTab";
import RequestsTab from "./components/RequestsTabs";
import ResourcesTab from "./components/ResourcesTab";
import SamplingTab, { PendingRequest } from "./components/SamplingTab";
import Sidebar from "./components/Sidebar";
import ToolsTab from "./components/ToolsTab";
@@ -77,6 +82,32 @@ const App = () => {
const [mcpClient, setMcpClient] = useState<Client | null>(null);
const [notifications, setNotifications] = useState<ServerNotification[]>([]);
const [pendingSampleRequests, setPendingSampleRequests] = useState<
Array<
PendingRequest & {
resolve: (result: CreateMessageResult) => void;
reject: (error: Error) => void;
}
>
>([]);
const nextRequestId = useRef(0);
const handleApproveSampling = (id: number, result: CreateMessageResult) => {
setPendingSampleRequests((prev) => {
const request = prev.find((r) => r.id === id);
request?.resolve(result);
return prev.filter((r) => r.id !== id);
});
};
const handleRejectSampling = (id: number) => {
setPendingSampleRequests((prev) => {
const request = prev.find((r) => r.id === id);
request?.reject(new Error("Sampling request rejected"));
return prev.filter((r) => r.id !== id);
});
};
const [selectedResource, setSelectedResource] = useState<Resource | null>(
null,
);
@@ -229,6 +260,15 @@ const App = () => {
},
);
client.setRequestHandler(CreateMessageRequestSchema, (request) => {
return new Promise<CreateMessageResult>((resolve, reject) => {
setPendingSampleRequests((prev) => [
...prev,
{ id: nextRequestId.current++, request, resolve, reject },
]);
});
});
setMcpClient(client);
setConnectionStatus("connected");
} catch (e) {
@@ -314,6 +354,15 @@ const App = () => {
<Bell className="w-4 h-4 mr-2" />
Ping
</TabsTrigger>
<TabsTrigger value="sampling" className="relative">
<Hash className="w-4 h-4 mr-2" />
Sampling
{pendingSampleRequests.length > 0 && (
<span className="absolute -top-1 -right-1 bg-red-500 text-white text-xs rounded-full h-4 w-4 flex items-center justify-center">
{pendingSampleRequests.length}
</span>
)}
</TabsTrigger>
</TabsList>
<div className="w-full">
@@ -362,6 +411,11 @@ const App = () => {
);
}}
/>
<SamplingTab
pendingRequests={pendingSampleRequests}
onApprove={handleApproveSampling}
onReject={handleRejectSampling}
/>
</div>
</Tabs>
) : (

View File

@@ -0,0 +1,65 @@
import { Alert, AlertDescription } from "@/components/ui/alert";
import { Button } from "@/components/ui/button";
import { TabsContent } from "@/components/ui/tabs";
import {
CreateMessageRequest,
CreateMessageResult,
} from "@modelcontextprotocol/sdk/types.js";
export type PendingRequest = {
id: number;
request: CreateMessageRequest;
};
export type Props = {
pendingRequests: PendingRequest[];
onApprove: (id: number, result: CreateMessageResult) => void;
onReject: (id: number) => void;
};
const SamplingTab = ({ pendingRequests, onApprove, onReject }: Props) => {
const handleApprove = (id: number) => {
// For now, just return a stub response
onApprove(id, {
model: "stub-model",
stopReason: "endTurn",
role: "assistant",
content: {
type: "text",
text: "This is a stub response.",
},
});
};
return (
<TabsContent value="sampling" className="h-96">
<Alert>
<AlertDescription>
When the server requests LLM sampling, requests will appear here for
approval.
</AlertDescription>
</Alert>
<div className="mt-4 space-y-4">
<h3 className="text-lg font-semibold">Recent Requests</h3>
{pendingRequests.map((request) => (
<div key={request.id} className="p-4 border rounded-lg space-y-4">
<pre className="bg-gray-50 p-2 rounded">
{JSON.stringify(request.request, null, 2)}
</pre>
<div className="flex space-x-2">
<Button onClick={() => handleApprove(request.id)}>Approve</Button>
<Button variant="outline" onClick={() => onReject(request.id)}>
Reject
</Button>
</div>
</div>
))}
{pendingRequests.length === 0 && (
<p className="text-gray-500">No pending requests</p>
)}
</div>
</TabsContent>
);
};
export default SamplingTab;