ibibrahim's picture
Create demo (#1)
c675f75 verified
import React, {
useState,
useEffect,
useCallback,
useRef,
useMemo,
} from "react";
import { openDB, type IDBPDatabase } from "idb";
import { Play, Plus, RotateCcw, Wrench } from "lucide-react";
import { useLLM } from "./hooks/useLLM";
import type { Tool } from "./components/ToolItem";
import {
extractFunctionAndRenderer,
generateSchemaFromCode,
getErrorMessage,
isMobileOrTablet,
} from "./utils";
import { DB_NAME, STORE_NAME, SETTINGS_STORE_NAME } from "./constants/db";
import { DEFAULT_TOOLS, TEMPLATE } from "./tools";
import ToolResultRenderer from "./components/ToolResultRenderer";
import ToolCallIndicator from "./components/ToolCallIndicator";
import ToolItem from "./components/ToolItem";
import ResultBlock from "./components/ResultBlock";
import ExamplePrompts from "./components/ExamplePrompts";
import { LoadingScreen } from "./components/LoadingScreen";
interface RenderInfo {
call: string;
result?: any;
renderer?: string;
input?: Record<string, any>;
error?: string;
}
interface BaseMessage {
role: "system" | "user" | "assistant";
content: string;
}
interface ToolMessage {
role: "tool";
content: string;
renderInfo: RenderInfo[]; // Rich data for the UI
}
type Message = BaseMessage | ToolMessage;
async function getDB(): Promise<IDBPDatabase> {
return openDB(DB_NAME, 1, {
upgrade(db) {
if (!db.objectStoreNames.contains(STORE_NAME)) {
db.createObjectStore(STORE_NAME, {
keyPath: "id",
autoIncrement: true,
});
}
if (!db.objectStoreNames.contains(SETTINGS_STORE_NAME)) {
db.createObjectStore(SETTINGS_STORE_NAME, { keyPath: "key" });
}
},
});
}
const App: React.FC = () => {
const [messages, setMessages] = useState<Message[]>([]);
const [tools, setTools] = useState<Tool[]>([]);
const [input, setInput] = useState<string>("");
const [isGenerating, setIsGenerating] = useState<boolean>(false);
const isMobile = useMemo(isMobileOrTablet, []);
const [selectedModel, setSelectedModel] = useState<string>(
isMobile ? "350M" : "1B",
);
const [isModelDropdownOpen, setIsModelDropdownOpen] =
useState<boolean>(false);
const [isToolsPanelVisible, setIsToolsPanelVisible] =
useState<boolean>(false);
const chatContainerRef = useRef<HTMLDivElement>(null);
const debounceTimers = useRef<Record<number, NodeJS.Timeout>>({});
const toolsContainerRef = useRef<HTMLDivElement>(null);
const inputRef = useRef<HTMLInputElement>(null);
const {
isLoading,
isReady,
error,
progress,
loadModel,
generateResponse,
clearPastKeyValues,
} = useLLM(selectedModel);
const loadTools = useCallback(async (): Promise<void> => {
const db = await getDB();
const allTools: Tool[] = await db.getAll(STORE_NAME);
if (allTools.length === 0) {
const defaultTools: Tool[] = Object.entries(DEFAULT_TOOLS).map(
([name, code], id) => ({
id,
name,
code,
enabled: true,
isCollapsed: false,
}),
);
const tx = db.transaction(STORE_NAME, "readwrite");
await Promise.all(defaultTools.map((tool) => tx.store.put(tool)));
await tx.done;
setTools(defaultTools);
} else {
setTools(allTools.map((t) => ({ ...t, isCollapsed: false })));
}
}, []);
useEffect(() => {
loadTools();
}, [loadTools]);
useEffect(() => {
if (chatContainerRef.current) {
chatContainerRef.current.scrollTop =
chatContainerRef.current.scrollHeight;
}
}, [messages]);
const updateToolInDB = async (tool: Tool): Promise<void> => {
const db = await getDB();
await db.put(STORE_NAME, tool);
};
const saveToolDebounced = (tool: Tool): void => {
if (tool.id !== undefined && debounceTimers.current[tool.id]) {
clearTimeout(debounceTimers.current[tool.id]);
}
if (tool.id !== undefined) {
debounceTimers.current[tool.id] = setTimeout(() => {
updateToolInDB(tool);
}, 300);
}
};
const clearChat = useCallback(() => {
setMessages([]);
clearPastKeyValues();
}, [clearPastKeyValues]);
const addTool = async (): Promise<void> => {
const newTool: Omit<Tool, "id"> = {
name: "new_tool",
code: TEMPLATE,
enabled: true,
isCollapsed: false,
};
const db = await getDB();
const id = await db.add(STORE_NAME, newTool);
setTools((prev) => {
const updated = [...prev, { ...newTool, id: id as number }];
setTimeout(() => {
if (toolsContainerRef.current) {
toolsContainerRef.current.scrollTop =
toolsContainerRef.current.scrollHeight;
}
}, 0);
return updated;
});
clearChat();
};
const deleteTool = async (id: number): Promise<void> => {
if (debounceTimers.current[id]) {
clearTimeout(debounceTimers.current[id]);
}
const db = await getDB();
await db.delete(STORE_NAME, id);
setTools(tools.filter((tool) => tool.id !== id));
clearChat();
};
const toggleToolEnabled = (id: number): void => {
let changedTool: Tool | undefined;
const newTools = tools.map((tool) => {
if (tool.id === id) {
changedTool = { ...tool, enabled: !tool.enabled };
return changedTool;
}
return tool;
});
setTools(newTools);
if (changedTool) saveToolDebounced(changedTool);
};
const toggleToolCollapsed = (id: number): void => {
setTools(
tools.map((tool) =>
tool.id === id ? { ...tool, isCollapsed: !tool.isCollapsed } : tool,
),
);
};
const expandTool = (id: number): void => {
setTools(
tools.map((tool) =>
tool.id === id ? { ...tool, isCollapsed: false } : tool,
),
);
};
const handleToolCodeChange = (id: number, newCode: string): void => {
let changedTool: Tool | undefined;
const newTools = tools.map((tool) => {
if (tool.id === id) {
const { functionCode } = extractFunctionAndRenderer(newCode);
const schema = generateSchemaFromCode(functionCode);
changedTool = { ...tool, code: newCode, name: schema.name };
return changedTool;
}
return tool;
});
setTools(newTools);
if (changedTool) saveToolDebounced(changedTool);
};
interface ToolCallPayload {
name: string;
arguments?: Record<string, any>;
}
const extractToolCalls = (text: string): ToolCallPayload[] => {
const matches = Array.from(
text.matchAll(/<tool_call>([\s\S]*?)<\/tool_call>/g),
);
const toolCalls: ToolCallPayload[] = [];
for (const match of matches) {
try {
const parsed = JSON.parse(match[1].trim());
if (parsed && typeof parsed.name === "string") {
toolCalls.push({
name: parsed.name,
arguments: parsed.arguments ?? {},
});
}
} catch {
// ignore malformed tool call payloads
}
}
return toolCalls;
};
const executeToolCall = async (
toolCall: ToolCallPayload,
): Promise<{
serializedResult: string;
rendererCode?: string;
input: Record<string, any>;
}> => {
const toolToUse = tools.find((t) => t.name === toolCall.name && t.enabled);
if (!toolToUse)
throw new Error(`Tool '${toolCall.name}' not found or is disabled.`);
const { functionCode, rendererCode } = extractFunctionAndRenderer(
toolToUse.code,
);
const schema = generateSchemaFromCode(functionCode);
const properties = schema.parameters?.properties ?? {};
const paramNames = Object.keys(properties);
const requiredParams = schema.parameters?.required ?? [];
const callArgs = toolCall.arguments ?? {};
const finalArgs: any[] = [];
const resolvedArgs: Record<string, any> = Object.create(null);
for (const paramName of paramNames) {
const propertyConfig = properties[paramName] ?? {};
if (Object.prototype.hasOwnProperty.call(callArgs, paramName)) {
const value = callArgs[paramName];
finalArgs.push(value);
resolvedArgs[paramName] = value;
} else if (
Object.prototype.hasOwnProperty.call(propertyConfig, "default")
) {
const value = propertyConfig.default;
finalArgs.push(value);
resolvedArgs[paramName] = value;
} else if (!requiredParams.includes(paramName)) {
finalArgs.push(undefined);
resolvedArgs[paramName] = undefined;
} else {
throw new Error(`Missing required argument: ${paramName}`);
}
}
for (const extraKey of Object.keys(callArgs)) {
if (!Object.prototype.hasOwnProperty.call(resolvedArgs, extraKey)) {
resolvedArgs[extraKey] = callArgs[extraKey];
}
}
const bodyMatch = functionCode.match(/function[^{]+\{([\s\S]*)\}/);
if (!bodyMatch) {
throw new Error(
"Could not parse function body. Ensure it's a standard `function` declaration.",
);
}
const body = bodyMatch[1];
const AsyncFunction = Object.getPrototypeOf(
async function () {},
).constructor;
const func = new AsyncFunction(...paramNames, body);
const result = await func(...finalArgs);
return {
serializedResult: JSON.stringify(result),
rendererCode,
input: resolvedArgs,
};
};
const executeToolCalls = async (
toolCalls: ToolCallPayload[],
): Promise<RenderInfo[]> => {
if (toolCalls.length === 0) {
return [{ call: "", error: "No valid tool calls found." }];
}
const results: RenderInfo[] = [];
for (const toolCall of toolCalls) {
const callDisplay = `<tool_call>${JSON.stringify(toolCall)}</tool_call>`;
try {
const { serializedResult, rendererCode, input } =
await executeToolCall(toolCall);
let parsedResult: unknown;
try {
parsedResult = JSON.parse(serializedResult);
} catch {
parsedResult = serializedResult;
}
results.push({
call: callDisplay,
result: parsedResult,
renderer: rendererCode,
input,
});
} catch (error) {
results.push({
call: callDisplay,
error: getErrorMessage(error),
});
}
}
return results;
};
const handleSendMessage = async (): Promise<void> => {
if (!input.trim() || !isReady) return;
const userMessage: Message = { role: "user", content: input };
let currentMessages: Message[] = [...messages, userMessage];
setMessages(currentMessages);
setInput("");
setIsGenerating(true);
try {
const toolSchemas = tools
.filter((tool) => tool.enabled)
.map((tool) => generateSchemaFromCode(tool.code));
while (true) {
const messagesForGeneration = [...currentMessages];
setMessages([...currentMessages, { role: "assistant", content: "" }]);
let accumulatedContent = "";
const response = await generateResponse(
messagesForGeneration,
toolSchemas,
(token: string) => {
accumulatedContent += token;
setMessages((current) => {
const updated = [...current];
updated[updated.length - 1] = {
role: "assistant",
content: accumulatedContent,
};
return updated;
});
},
);
currentMessages.push({ role: "assistant", content: response });
const toolCalls = extractToolCalls(response);
if (toolCalls.length > 0) {
const toolResults = await executeToolCalls(toolCalls);
const toolMessage: ToolMessage = {
role: "tool",
content: JSON.stringify(toolResults.map((r) => r.result ?? null)),
renderInfo: toolResults,
};
currentMessages.push(toolMessage);
setMessages([...currentMessages]);
continue;
} else {
setMessages(currentMessages);
break;
}
}
} catch (error) {
const errorMessage = getErrorMessage(error);
setMessages([
...currentMessages,
{
role: "assistant",
content: `Error generating response: ${errorMessage}`,
},
]);
} finally {
setIsGenerating(false);
setTimeout(() => inputRef.current?.focus(), 0);
}
};
const loadSelectedModel = useCallback(async (): Promise<void> => {
try {
await loadModel();
} catch (error) {
console.error("Failed to load model:", error);
}
}, [selectedModel, loadModel]);
const saveSelectedModel = useCallback(async (modelId: string) => {
try {
const db = await getDB();
await db.put(SETTINGS_STORE_NAME, {
key: "selectedModelId",
value: modelId,
});
} catch (error) {
console.error("Failed to save selected model ID:", error);
}
}, []);
const loadSelectedModelId = useCallback(async (): Promise<void> => {
try {
const db = await getDB();
const stored = await db.get(SETTINGS_STORE_NAME, "selectedModelId");
if (stored && stored.value) {
setSelectedModel(stored.value);
}
} catch (error) {
console.error("Failed to load selected model ID:", error);
}
}, []);
useEffect(() => {
loadSelectedModelId();
}, [loadSelectedModelId]);
const handleModelSelect = async (modelId: string) => {
setSelectedModel(modelId);
setIsModelDropdownOpen(false);
await saveSelectedModel(modelId);
};
const handleExampleClick = async (messageText: string): Promise<void> => {
if (!isReady || isGenerating) return;
setInput(messageText);
const userMessage: Message = { role: "user", content: messageText };
const currentMessages: Message[] = [...messages, userMessage];
setMessages(currentMessages);
setInput("");
setIsGenerating(true);
try {
const toolSchemas = tools
.filter((tool) => tool.enabled)
.map((tool) => generateSchemaFromCode(tool.code));
while (true) {
const messagesForGeneration = [...currentMessages];
setMessages([...currentMessages, { role: "assistant", content: "" }]);
let accumulatedContent = "";
const response = await generateResponse(
messagesForGeneration,
toolSchemas,
(token: string) => {
accumulatedContent += token;
setMessages((current) => {
const updated = [...current];
updated[updated.length - 1] = {
role: "assistant",
content: accumulatedContent,
};
return updated;
});
},
);
currentMessages.push({ role: "assistant", content: response });
const toolCalls = extractToolCalls(response);
if (toolCalls.length > 0) {
const toolResults = await executeToolCalls(toolCalls);
const toolMessage: ToolMessage = {
role: "tool",
content: JSON.stringify(toolResults.map((r) => r.result ?? null)),
renderInfo: toolResults,
};
currentMessages.push(toolMessage);
setMessages([...currentMessages]);
continue;
} else {
setMessages(currentMessages);
break;
}
}
} catch (error) {
const errorMessage = getErrorMessage(error);
setMessages([
...currentMessages,
{
role: "assistant",
content: `Error generating response: ${errorMessage}`,
},
]);
} finally {
setIsGenerating(false);
setTimeout(() => inputRef.current?.focus(), 0);
}
};
return (
<div className="font-sans min-h-screen bg-gradient-to-br from-[#031b4e] via-[#06183d] to-[#010409] text-gray-100 text-[16px] md:text-[17px]">
{!isReady ? (
<LoadingScreen
isLoading={isLoading}
progress={progress}
error={error}
loadSelectedModel={loadSelectedModel}
selectedModelId={selectedModel}
isModelDropdownOpen={isModelDropdownOpen}
setIsModelDropdownOpen={setIsModelDropdownOpen}
handleModelSelect={handleModelSelect}
/>
) : (
<div className="flex h-screen text-gray-100 w-full gap-6 py-10 px-8">
<div className="flex-1 flex flex-col p-6 bg-white/5 backdrop-blur-lg border border-white/10 rounded-3xl shadow-[0_35px_65px_rgba(3,27,78,0.55)] min-h-0">
<div className="flex items-center justify-between mb-6">
<div className="space-y-1">
<span className="inline-flex items-center gap-2 text-xs font-semibold uppercase tracking-[0.35em] text-[#78a9ff]">
IBM Granite
</span>
<h1 className="text-3xl font-semibold text-white">
Granite-4.0 Tool Studio
</h1>
</div>
<div className="flex items-center gap-3">
<button
disabled={isGenerating}
onClick={clearChat}
className={`h-10 flex items-center px-4 rounded-full font-semibold text-sm transition-all border ${
isGenerating
? "border-white/15 bg-white/10 text-[#a6c8ff] opacity-50 cursor-not-allowed"
: "border-white/20 bg-white/8 text-[#d0e2ff] hover:border-[#78a9ff]/50 hover:bg-[#0f62fe]/15"
}`}
title="Clear chat"
>
<RotateCcw size={14} className="mr-2" /> Reset Thread
</button>
<button
onClick={() =>
setIsToolsPanelVisible((previous) => !previous)
}
className={`h-10 flex items-center px-4 rounded-full font-semibold text-sm transition-all border ${
isToolsPanelVisible
? "border-[#78a9ff]/60 bg-[#0f62fe]/25 text-white shadow-[0_10px_25px_rgba(15,98,254,0.25)]"
: "border-white/20 bg-white/8 text-[#d0e2ff] hover:border-[#78a9ff]/50 hover:bg-[#0f62fe]/15"
}`}
>
<Wrench size={16} className="mr-2" />
{isToolsPanelVisible ? "Hide Tools" : "Show Tools"}
</button>
</div>
</div>
<div
ref={chatContainerRef}
className="flex-grow bg-[#0b1e3f]/80 border border-white/10 rounded-2xl p-6 overflow-y-auto mb-6 space-y-5 shadow-inner min-h-0"
>
{messages.length === 0 && isReady ? (
<ExamplePrompts onExampleClick={handleExampleClick} />
) : (
messages.map((msg, index) => {
const key = `${msg.role}-${index}`;
if (msg.role === "user") {
return (
<div key={key} className="flex justify-end">
<div className="px-4 py-3 rounded-2xl max-w-md bg-[#0f62fe]/30 border border-[#78a9ff]/40 shadow-[0_20px_45px_rgba(10,49,140,0.25)]">
<p className="text-md text-white whitespace-pre-wrap">
{msg.content}
</p>
</div>
</div>
);
}
if (msg.role === "assistant") {
const isToolCall = msg.content.includes("<tool_call>");
if (isToolCall) {
const nextMessage = messages[index + 1];
const isCompleted = nextMessage?.role === "tool";
const hasError =
isCompleted &&
(nextMessage as ToolMessage).renderInfo.some(
(info) => !!info.error,
);
return (
<div key={key} className="flex justify-start">
<div className="px-4 py-3 rounded-2xl bg-white/8 border border-[#0f62fe]/30 shadow-[0_18px_50px_rgba(0,0,0,0.35)]">
<ToolCallIndicator
content={msg.content}
isRunning={!isCompleted}
hasError={hasError}
/>
</div>
</div>
);
}
return (
<div key={key} className="flex justify-start">
<div className="px-4 py-3 rounded-2xl max-w-md bg-white/8 border border-white/15 shadow-[0_18px_50px_rgba(0,0,0,0.35)]">
<p className="text-md text-[#d0e2ff] whitespace-pre-wrap">
{msg.content}
</p>
</div>
</div>
);
}
if (msg.role === "tool") {
const visibleToolResults = msg.renderInfo.filter(
(info) =>
info.error || (info.result != null && info.renderer),
);
if (visibleToolResults.length === 0) return null;
return (
<div key={key} className="flex justify-start">
<div className="p-4 rounded-2xl bg-white/8 border border-white/15 max-w-lg shadow-[0_18px_50px_rgba(0,0,0,0.35)]">
<div className="space-y-4">
{visibleToolResults.map((info, idx) => (
<div className="flex flex-col gap-2" key={idx}>
<div className="text-xs text-[#a6c8ff] font-mono">
{info.call}
</div>
{info.error ? (
<ResultBlock error={info.error} />
) : (
<ToolResultRenderer
result={info.result}
rendererCode={info.renderer}
input={info.input}
/>
)}
</div>
))}
</div>
</div>
</div>
);
}
return null;
})
)}
</div>
<div className="flex items-center gap-3">
<div className="flex flex-1 items-center bg-white/5 border border-white/10 rounded-2xl overflow-hidden shadow-[0_15px_45px_rgba(0,0,0,0.35)]">
<input
ref={inputRef}
type="text"
value={input}
onChange={(e) => setInput(e.target.value)}
onKeyDown={(e) =>
e.key === "Enter" &&
!isGenerating &&
isReady &&
handleSendMessage()
}
disabled={isGenerating || !isReady}
className="flex-grow bg-transparent px-5 py-3 text-lg text-white placeholder:text-[#a6c8ff]/70 focus:outline-none disabled:opacity-40"
placeholder={
isReady
? "Type your message here..."
: "Load a Granite model to enable chat"
}
/>
<button
onClick={handleSendMessage}
disabled={isGenerating || !isReady}
className="h-full px-5 py-3 bg-[#0f62fe] hover:bg-[#0043ce] disabled:bg-[#0f62fe]/40 disabled:cursor-not-allowed text-white font-semibold transition-all"
>
<Play size={28} />
</button>
</div>
</div>
</div>
{isToolsPanelVisible && (
<div className="w-full md:w-1/2 flex flex-col p-6 bg-white/5 backdrop-blur-lg border border-white/10 rounded-3xl shadow-[0_35px_65px_rgba(3,27,78,0.55)] min-h-0">
<div className="flex justify-between items-center mb-6">
<div>
<span className="text-xs font-semibold uppercase tracking-[0.25em] text-[#78a9ff]">
Tool Workspace
</span>
<h2 className="text-2xl font-semibold text-white mt-1">
Tools
</h2>
</div>
<button
onClick={addTool}
className="flex items-center bg-gradient-to-r from-[#0f62fe] to-[#4589ff] hover:brightness-110 text-white font-semibold py-2 px-4 rounded-full transition-all shadow-[0_15px_35px_rgba(15,98,254,0.35)]"
>
<Plus size={16} className="mr-2" /> Add Tool
</button>
</div>
<div
ref={toolsContainerRef}
className="flex-grow bg-[#0b1e3f]/60 border border-white/10 rounded-2xl p-4 overflow-y-auto space-y-3"
>
{tools.map((tool) => (
<ToolItem
key={tool.id}
tool={tool}
onToggleEnabled={() => toggleToolEnabled(tool.id)}
onToggleCollapsed={() => toggleToolCollapsed(tool.id)}
onExpand={() => expandTool(tool.id)}
onDelete={() => deleteTool(tool.id)}
onCodeChange={(newCode) =>
handleToolCodeChange(tool.id, newCode)
}
/>
))}
</div>
</div>
)}
</div>
)}
</div>
);
};
export default App;