|
|
import { useEffect, useState, useRef } from "react"; |
|
|
|
|
|
import Chat from "./components/Chat"; |
|
|
import ArrowRightIcon from "./components/icons/ArrowRightIcon"; |
|
|
import StopIcon from "./components/icons/StopIcon"; |
|
|
import Progress from "./components/Progress"; |
|
|
|
|
|
const IS_WEBGPU_AVAILABLE = !!navigator.gpu; |
|
|
const STICKY_SCROLL_THRESHOLD = 120; |
|
|
const EXAMPLES = [ |
|
|
"Give me some tips to improve my time management skills.", |
|
|
"What is the difference between AI and ML?", |
|
|
"Write python code to compute the nth fibonacci number.", |
|
|
]; |
|
|
|
|
|
function App() { |
|
|
|
|
|
const worker = useRef(null); |
|
|
|
|
|
const textareaRef = useRef(null); |
|
|
const chatContainerRef = useRef(null); |
|
|
|
|
|
|
|
|
const [status, setStatus] = useState(null); |
|
|
const [error, setError] = useState(null); |
|
|
const [loadingMessage, setLoadingMessage] = useState(""); |
|
|
const [progressItems, setProgressItems] = useState([]); |
|
|
const [isRunning, setIsRunning] = useState(false); |
|
|
const [modelFiles, setModelFiles] = useState([]); |
|
|
|
|
|
|
|
|
const [input, setInput] = useState(""); |
|
|
const [messages, setMessages] = useState([]); |
|
|
const [tps, setTps] = useState(null); |
|
|
const [numTokens, setNumTokens] = useState(null); |
|
|
const [attachedFile, setAttachedFile] = useState(null); |
|
|
|
|
|
async function onEnter(message) { |
|
|
let fileText = ""; |
|
|
if (attachedFile) { |
|
|
if (attachedFile.name.endsWith(".txt")) { |
|
|
fileText = await attachedFile.text(); |
|
|
} else if (attachedFile.name.endsWith(".pdf")) { |
|
|
|
|
|
const pdfjsLib = await import("pdfjs-dist/build/pdf"); |
|
|
const workerSrc = (await import("pdfjs-dist/build/pdf.worker?url")).default; |
|
|
pdfjsLib.GlobalWorkerOptions.workerSrc = workerSrc; |
|
|
const arrayBuffer = await attachedFile.arrayBuffer(); |
|
|
const pdf = await pdfjsLib.getDocument({ data: arrayBuffer }).promise; |
|
|
let pdfText = ""; |
|
|
for (let i = 1; i <= pdf.numPages; i++) { |
|
|
const page = await pdf.getPage(i); |
|
|
const content = await page.getTextContent(); |
|
|
pdfText += content.items.map(item => item.str).join(" ") + "\n"; |
|
|
} |
|
|
fileText = pdfText; |
|
|
} |
|
|
} |
|
|
let fullPrompt = message; |
|
|
if (fileText) { |
|
|
fullPrompt += "\n\n--- File Content ---\n" + fileText; |
|
|
} |
|
|
let userMsg = { role: "user", content: fullPrompt }; |
|
|
setMessages((prev) => [...prev, userMsg]); |
|
|
setTps(null); |
|
|
setIsRunning(true); |
|
|
setInput(""); |
|
|
setAttachedFile(null); |
|
|
} |
|
|
|
|
|
function onInterrupt() { |
|
|
|
|
|
|
|
|
worker.current.postMessage({ type: "interrupt" }); |
|
|
} |
|
|
|
|
|
useEffect(() => { |
|
|
resizeInput(); |
|
|
}, [input]); |
|
|
|
|
|
function resizeInput() { |
|
|
if (!textareaRef.current) return; |
|
|
|
|
|
const target = textareaRef.current; |
|
|
target.style.height = "auto"; |
|
|
const newHeight = Math.min(Math.max(target.scrollHeight, 24), 200); |
|
|
target.style.height = `${newHeight}px`; |
|
|
} |
|
|
|
|
|
|
|
|
useEffect(() => { |
|
|
|
|
|
if (!worker.current) { |
|
|
worker.current = new Worker(new URL("./worker.js", import.meta.url), { |
|
|
type: "module", |
|
|
}); |
|
|
worker.current.postMessage({ type: "check" }); |
|
|
} |
|
|
|
|
|
|
|
|
const onMessageReceived = (e) => { |
|
|
switch (e.data.status) { |
|
|
case "loading": |
|
|
setStatus("loading"); |
|
|
setLoadingMessage(e.data.data); |
|
|
break; |
|
|
|
|
|
case "initiate": |
|
|
setProgressItems((prev) => [...prev, e.data]); |
|
|
break; |
|
|
|
|
|
case "progress": |
|
|
setProgressItems((prev) => |
|
|
prev.map((item) => { |
|
|
if (item.file === e.data.file) { |
|
|
return { ...item, ...e.data }; |
|
|
} |
|
|
return item; |
|
|
}), |
|
|
); |
|
|
break; |
|
|
|
|
|
case "done": |
|
|
setProgressItems((prev) => |
|
|
prev.filter((item) => item.file !== e.data.file), |
|
|
); |
|
|
break; |
|
|
|
|
|
case "ready": |
|
|
setStatus("ready"); |
|
|
break; |
|
|
|
|
|
case "start": |
|
|
setMessages((prev) => [ |
|
|
...prev, |
|
|
{ role: "assistant", content: "" }, |
|
|
]); |
|
|
break; |
|
|
|
|
|
case "update": |
|
|
const { output, tps, numTokens } = e.data; |
|
|
setTps(tps); |
|
|
setNumTokens(numTokens); |
|
|
setMessages((prev) => { |
|
|
const cloned = [...prev]; |
|
|
const last = cloned.at(-1); |
|
|
cloned[cloned.length - 1] = { |
|
|
...last, |
|
|
content: last.content + output, |
|
|
}; |
|
|
return cloned; |
|
|
}); |
|
|
break; |
|
|
|
|
|
case "complete": |
|
|
setIsRunning(false); |
|
|
break; |
|
|
|
|
|
case "error": |
|
|
setError(e.data.data || "Unknown error during model loading."); |
|
|
setStatus(null); |
|
|
setLoadingMessage(""); |
|
|
break; |
|
|
} |
|
|
}; |
|
|
|
|
|
const onErrorReceived = (e) => { |
|
|
console.error("Worker error:", e); |
|
|
}; |
|
|
|
|
|
|
|
|
worker.current.addEventListener("message", onMessageReceived); |
|
|
worker.current.addEventListener("error", onErrorReceived); |
|
|
|
|
|
|
|
|
return () => { |
|
|
worker.current.removeEventListener("message", onMessageReceived); |
|
|
worker.current.removeEventListener("error", onErrorReceived); |
|
|
}; |
|
|
}, []); |
|
|
|
|
|
|
|
|
useEffect(() => { |
|
|
if (messages.filter((x) => x.role === "user").length === 0) { |
|
|
|
|
|
return; |
|
|
} |
|
|
if (messages.at(-1).role === "assistant") { |
|
|
|
|
|
return; |
|
|
} |
|
|
setTps(null); |
|
|
worker.current.postMessage({ type: "generate", data: messages }); |
|
|
}, [messages, isRunning]); |
|
|
|
|
|
useEffect(() => { |
|
|
if (!chatContainerRef.current || !isRunning) return; |
|
|
const element = chatContainerRef.current; |
|
|
if ( |
|
|
element.scrollHeight - element.scrollTop - element.clientHeight < |
|
|
STICKY_SCROLL_THRESHOLD |
|
|
) { |
|
|
element.scrollTop = element.scrollHeight; |
|
|
} |
|
|
}, [messages, isRunning]); |
|
|
|
|
|
return IS_WEBGPU_AVAILABLE ? ( |
|
|
<div className="flex flex-col h-screen mx-auto items justify-end text-gray-800 dark:text-gray-200 bg-white dark:bg-gray-900"> |
|
|
{status === null && messages.length === 0 && ( |
|
|
<div className="h-full overflow-auto scrollbar-thin flex justify-center items-center flex-col relative"> |
|
|
<div className="flex flex-col items-center mb-1 max-w-[360px] text-center"> |
|
|
<img |
|
|
src="logo.png" |
|
|
width="80%" |
|
|
height="auto" |
|
|
className="block rounded-2xl mb-6" |
|
|
></img> |
|
|
<h1 className="text-4xl font-bold mb-2">Granite-4.0 WebGPU</h1> |
|
|
<h2 className="text-xl"> |
|
|
A reliable and efficient AI chatbot <br />that runs locally in your |
|
|
browser. |
|
|
</h2> |
|
|
</div> |
|
|
|
|
|
<div className="flex flex-col items-center px-4 text-lg"> |
|
|
<p className="max-w-[514px] mb-4"> |
|
|
<br /> |
|
|
You are about to load{" "} |
|
|
<a |
|
|
href="https://huggingface.co/onnx-community/granite-4.0-micro-ONNX-web" |
|
|
target="_blank" |
|
|
rel="noreferrer" |
|
|
className="font-medium underline" |
|
|
> |
|
|
Granite-4.0 Micro |
|
|
</a> |
|
|
, a 3.4B parameter long-context instruct model optimized for in-browser inference. |
|
|
Everything runs entirely in your browser with{" "} |
|
|
<a |
|
|
href="https://huggingface.co/docs/transformers.js" |
|
|
target="_blank" |
|
|
rel="noreferrer" |
|
|
className="underline" |
|
|
> |
|
|
🤗 Transformers.js |
|
|
</a>{" "} |
|
|
and ONNX Runtime Web, meaning no data is sent to a server. Once |
|
|
loaded (≈ 2.3 GB), it can even be used offline. |
|
|
</p> |
|
|
|
|
|
{error && ( |
|
|
<div className="text-red-500 text-center mb-2"> |
|
|
<p className="mb-1"> |
|
|
Unable to load model due to the following error: |
|
|
</p> |
|
|
<p className="text-sm">{error}</p> |
|
|
</div> |
|
|
)} |
|
|
|
|
|
<div className="flex flex-col items-center gap-2"> |
|
|
<label className="border px-4 py-2 rounded-lg bg-blue-100 text-blue-700 cursor-pointer hover:bg-blue-200"> |
|
|
Select model directory |
|
|
<input |
|
|
type="file" |
|
|
webkitdirectory="true" |
|
|
directory="true" |
|
|
multiple |
|
|
className="hidden" |
|
|
onChange={e => { |
|
|
const files = Array.from(e.target.files); |
|
|
setModelFiles(files); |
|
|
}} |
|
|
disabled={status !== null || error !== null} |
|
|
/> |
|
|
</label> |
|
|
<button |
|
|
className="border px-4 py-2 rounded-lg bg-blue-400 text-white hover:bg-blue-500 disabled:bg-blue-100 disabled:cursor-not-allowed select-none" |
|
|
onClick={() => { |
|
|
worker.current.postMessage({ type: "load", files: modelFiles }); |
|
|
setStatus("loading"); |
|
|
}} |
|
|
disabled={status !== null || error !== null} |
|
|
> |
|
|
{modelFiles.length > 0 ? `Load selected model directory` : `Load default model`} |
|
|
</button> |
|
|
{modelFiles.length > 0 && ( |
|
|
<span className="text-xs text-gray-600">Selected files: {modelFiles.map(f => f.name).join(", ")}</span> |
|
|
)} |
|
|
</div> |
|
|
</div> |
|
|
</div> |
|
|
)} |
|
|
{status === "loading" && ( |
|
|
<> |
|
|
<div className="w-full max-w-[500px] text-left mx-auto p-4 bottom-0 mt-auto"> |
|
|
<p className="text-center mb-1">{loadingMessage}</p> |
|
|
{progressItems.map(({ file, progress, total }, i) => ( |
|
|
<Progress |
|
|
key={i} |
|
|
text={file} |
|
|
percentage={progress} |
|
|
total={total} |
|
|
/> |
|
|
))} |
|
|
</div> |
|
|
</> |
|
|
)} |
|
|
|
|
|
{status === "ready" && ( |
|
|
<div |
|
|
ref={chatContainerRef} |
|
|
className="overflow-y-auto scrollbar-thin w-full flex flex-col items-center h-full" |
|
|
> |
|
|
<Chat messages={messages} /> |
|
|
{messages.length === 0 && ( |
|
|
<div> |
|
|
{EXAMPLES.map((msg, i) => ( |
|
|
<div |
|
|
key={i} |
|
|
className="m-1 border dark:border-gray-600 rounded-md p-2 bg-gray-100 dark:bg-gray-700 cursor-pointer" |
|
|
onClick={() => onEnter(msg)} |
|
|
> |
|
|
{msg} |
|
|
</div> |
|
|
))} |
|
|
</div> |
|
|
)} |
|
|
<p className="text-center text-sm min-h-6 text-gray-500 dark:text-gray-300"> |
|
|
{tps && messages.length > 0 && ( |
|
|
<> |
|
|
{!isRunning && ( |
|
|
<span> |
|
|
Generated {numTokens} tokens in{" "} |
|
|
{(numTokens / tps).toFixed(2)} seconds ( |
|
|
</span> |
|
|
)} |
|
|
{ |
|
|
<> |
|
|
<span className="font-medium text-center mr-1 text-black dark:text-white"> |
|
|
{tps.toFixed(2)} |
|
|
</span> |
|
|
<span className="text-gray-500 dark:text-gray-300"> |
|
|
tokens/second |
|
|
</span> |
|
|
</> |
|
|
} |
|
|
{!isRunning && ( |
|
|
<> |
|
|
<span className="mr-1">).</span> |
|
|
<span |
|
|
className="underline cursor-pointer" |
|
|
onClick={() => { |
|
|
worker.current.postMessage({ type: "reset" }); |
|
|
setMessages([]); |
|
|
}} |
|
|
> |
|
|
Reset |
|
|
</span> |
|
|
</> |
|
|
)} |
|
|
</> |
|
|
)} |
|
|
</p> |
|
|
</div> |
|
|
)} |
|
|
|
|
|
<div className="mt-2 border dark:bg-gray-700 rounded-lg w-[600px] max-w-[80%] max-h-[200px] mx-auto relative mb-3 flex items-center gap-2"> |
|
|
<textarea |
|
|
ref={textareaRef} |
|
|
className="scrollbar-thin w-[420px] dark:bg-gray-700 px-3 py-4 rounded-lg bg-transparent border-none outline-none text-gray-800 disabled:text-gray-400 dark:text-gray-200 placeholder-gray-500 dark:placeholder-gray-400 disabled:placeholder-gray-200 resize-none disabled:cursor-not-allowed" |
|
|
placeholder="Type your message..." |
|
|
type="text" |
|
|
rows={1} |
|
|
value={input} |
|
|
disabled={status !== "ready"} |
|
|
title={status === "ready" ? "Model is ready" : "Model not loaded yet"} |
|
|
onKeyDown={(e) => { |
|
|
if ( |
|
|
input.length > 0 && |
|
|
!isRunning && |
|
|
e.key === "Enter" && |
|
|
!e.shiftKey |
|
|
) { |
|
|
e.preventDefault(); // Prevent default behavior of Enter key |
|
|
onEnter(input); |
|
|
} |
|
|
}} |
|
|
onInput={(e) => setInput(e.target.value)} |
|
|
/> |
|
|
<label |
|
|
className={`flex items-center px-2 py-2 bg-blue-100 text-blue-700 rounded cursor-pointer hover:bg-blue-200 ${status !== "ready" ? "opacity-50 cursor-not-allowed" : ""}`} |
|
|
> |
|
|
📎 Attach |
|
|
<input |
|
|
type="file" |
|
|
accept=".txt,.pdf" |
|
|
className="hidden" |
|
|
onChange={e => { |
|
|
const file = e.target.files[0]; |
|
|
if (file && ![".txt", ".pdf"].some(ext => file.name.toLowerCase().endsWith(ext))) { |
|
|
alert("Only .txt and .pdf files are allowed."); |
|
|
e.target.value = ""; |
|
|
return; |
|
|
} |
|
|
setAttachedFile(file); |
|
|
}} |
|
|
disabled={status !== "ready"} |
|
|
/> |
|
|
</label> |
|
|
{isRunning ? ( |
|
|
<div className="cursor-pointer" onClick={onInterrupt}> |
|
|
<StopIcon className="h-8 w-8 p-1 rounded-md text-gray-800 dark:text-gray-100 absolute right-3 bottom-3" /> |
|
|
</div> |
|
|
) : input.length > 0 || attachedFile ? ( |
|
|
<div className="cursor-pointer" onClick={() => onEnter(input)}> |
|
|
<ArrowRightIcon |
|
|
className={`h-8 w-8 p-1 bg-gray-800 dark:bg-gray-100 text-white dark:text-black rounded-md absolute right-3 bottom-3`} |
|
|
/> |
|
|
</div> |
|
|
) : ( |
|
|
<div> |
|
|
<ArrowRightIcon |
|
|
className={`h-8 w-8 p-1 bg-gray-200 dark:bg-gray-600 text-gray-50 dark:text-gray-800 rounded-md absolute right-3 bottom-3`} |
|
|
/> |
|
|
</div> |
|
|
)} |
|
|
{attachedFile && ( |
|
|
<span className="ml-2 text-xs text-gray-600">{attachedFile.name}</span> |
|
|
)} |
|
|
</div> |
|
|
|
|
|
<p className="text-xs text-gray-400 text-center mb-3"> |
|
|
Disclaimer: Generated content may be inaccurate or false. |
|
|
</p> |
|
|
</div> |
|
|
) : ( |
|
|
<div className="fixed w-screen h-screen bg-black z-10 bg-opacity-[92%] text-white text-2xl font-semibold flex justify-center items-center text-center"> |
|
|
WebGPU is not supported |
|
|
<br /> |
|
|
by this browser :( |
|
|
</div> |
|
|
); |
|
|
} |
|
|
|
|
|
export default App; |
|
|
|