Spaces:
Running
Running
Commit
·
726a6f0
1
Parent(s):
e5b1920
Added support for local ONNX/model loading logic from worker.js
Browse files- src/App.jsx +53 -40
- src/worker.js +5 -10
src/App.jsx
CHANGED
|
@@ -26,6 +26,7 @@ function App() {
|
|
| 26 |
const [loadingMessage, setLoadingMessage] = useState("");
|
| 27 |
const [progressItems, setProgressItems] = useState([]);
|
| 28 |
const [isRunning, setIsRunning] = useState(false);
|
|
|
|
| 29 |
|
| 30 |
// Inputs and outputs
|
| 31 |
const [input, setInput] = useState("");
|
|
@@ -100,7 +101,6 @@ function App() {
|
|
| 100 |
const onMessageReceived = (e) => {
|
| 101 |
switch (e.data.status) {
|
| 102 |
case "loading":
|
| 103 |
-
// Model file start load: add a new progress item to the list.
|
| 104 |
setStatus("loading");
|
| 105 |
setLoadingMessage(e.data.data);
|
| 106 |
break;
|
|
@@ -110,7 +110,6 @@ function App() {
|
|
| 110 |
break;
|
| 111 |
|
| 112 |
case "progress":
|
| 113 |
-
// Model file progress: update one of the progress items.
|
| 114 |
setProgressItems((prev) =>
|
| 115 |
prev.map((item) => {
|
| 116 |
if (item.file === e.data.file) {
|
|
@@ -122,53 +121,45 @@ function App() {
|
|
| 122 |
break;
|
| 123 |
|
| 124 |
case "done":
|
| 125 |
-
// Model file loaded: remove the progress item from the list.
|
| 126 |
setProgressItems((prev) =>
|
| 127 |
prev.filter((item) => item.file !== e.data.file),
|
| 128 |
);
|
| 129 |
break;
|
| 130 |
|
| 131 |
case "ready":
|
| 132 |
-
// Pipeline ready: the worker is ready to accept messages.
|
| 133 |
setStatus("ready");
|
| 134 |
break;
|
| 135 |
|
| 136 |
case "start":
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
{ role: "assistant", content: "" },
|
| 142 |
-
]);
|
| 143 |
-
}
|
| 144 |
break;
|
| 145 |
|
| 146 |
case "update":
|
| 147 |
-
{
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
};
|
| 160 |
-
return cloned;
|
| 161 |
-
});
|
| 162 |
-
}
|
| 163 |
break;
|
| 164 |
|
| 165 |
case "complete":
|
| 166 |
-
// Generation complete: re-enable the "Generate" button
|
| 167 |
setIsRunning(false);
|
| 168 |
break;
|
| 169 |
|
| 170 |
case "error":
|
| 171 |
-
setError(e.data.data);
|
|
|
|
|
|
|
| 172 |
break;
|
| 173 |
}
|
| 174 |
};
|
|
@@ -266,16 +257,36 @@ function App() {
|
|
| 266 |
</div>
|
| 267 |
)}
|
| 268 |
|
| 269 |
-
<
|
| 270 |
-
className="border px-4 py-2 rounded-lg bg-blue-
|
| 271 |
-
|
| 272 |
-
|
| 273 |
-
|
| 274 |
-
|
| 275 |
-
|
| 276 |
-
|
| 277 |
-
|
| 278 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 279 |
</div>
|
| 280 |
</div>
|
| 281 |
)}
|
|
@@ -376,7 +387,9 @@ function App() {
|
|
| 376 |
}}
|
| 377 |
onInput={(e) => setInput(e.target.value)}
|
| 378 |
/>
|
| 379 |
-
<label
|
|
|
|
|
|
|
| 380 |
📎 Attach
|
| 381 |
<input
|
| 382 |
type="file"
|
|
|
|
| 26 |
const [loadingMessage, setLoadingMessage] = useState("");
|
| 27 |
const [progressItems, setProgressItems] = useState([]);
|
| 28 |
const [isRunning, setIsRunning] = useState(false);
|
| 29 |
+
const [modelFiles, setModelFiles] = useState([]);
|
| 30 |
|
| 31 |
// Inputs and outputs
|
| 32 |
const [input, setInput] = useState("");
|
|
|
|
| 101 |
const onMessageReceived = (e) => {
|
| 102 |
switch (e.data.status) {
|
| 103 |
case "loading":
|
|
|
|
| 104 |
setStatus("loading");
|
| 105 |
setLoadingMessage(e.data.data);
|
| 106 |
break;
|
|
|
|
| 110 |
break;
|
| 111 |
|
| 112 |
case "progress":
|
|
|
|
| 113 |
setProgressItems((prev) =>
|
| 114 |
prev.map((item) => {
|
| 115 |
if (item.file === e.data.file) {
|
|
|
|
| 121 |
break;
|
| 122 |
|
| 123 |
case "done":
|
|
|
|
| 124 |
setProgressItems((prev) =>
|
| 125 |
prev.filter((item) => item.file !== e.data.file),
|
| 126 |
);
|
| 127 |
break;
|
| 128 |
|
| 129 |
case "ready":
|
|
|
|
| 130 |
setStatus("ready");
|
| 131 |
break;
|
| 132 |
|
| 133 |
case "start":
|
| 134 |
+
setMessages((prev) => [
|
| 135 |
+
...prev,
|
| 136 |
+
{ role: "assistant", content: "" },
|
| 137 |
+
]);
|
|
|
|
|
|
|
|
|
|
| 138 |
break;
|
| 139 |
|
| 140 |
case "update":
|
| 141 |
+
const { output, tps, numTokens } = e.data;
|
| 142 |
+
setTps(tps);
|
| 143 |
+
setNumTokens(numTokens);
|
| 144 |
+
setMessages((prev) => {
|
| 145 |
+
const cloned = [...prev];
|
| 146 |
+
const last = cloned.at(-1);
|
| 147 |
+
cloned[cloned.length - 1] = {
|
| 148 |
+
...last,
|
| 149 |
+
content: last.content + output,
|
| 150 |
+
};
|
| 151 |
+
return cloned;
|
| 152 |
+
});
|
|
|
|
|
|
|
|
|
|
|
|
|
| 153 |
break;
|
| 154 |
|
| 155 |
case "complete":
|
|
|
|
| 156 |
setIsRunning(false);
|
| 157 |
break;
|
| 158 |
|
| 159 |
case "error":
|
| 160 |
+
setError(e.data.data || "Unknown error during model loading.");
|
| 161 |
+
setStatus(null);
|
| 162 |
+
setLoadingMessage("");
|
| 163 |
break;
|
| 164 |
}
|
| 165 |
};
|
|
|
|
| 257 |
</div>
|
| 258 |
)}
|
| 259 |
|
| 260 |
+
<div className="flex flex-col items-center gap-2">
|
| 261 |
+
<label className="border px-4 py-2 rounded-lg bg-blue-100 text-blue-700 cursor-pointer hover:bg-blue-200">
|
| 262 |
+
Select model directory
|
| 263 |
+
<input
|
| 264 |
+
type="file"
|
| 265 |
+
webkitdirectory="true"
|
| 266 |
+
directory="true"
|
| 267 |
+
multiple
|
| 268 |
+
className="hidden"
|
| 269 |
+
onChange={e => {
|
| 270 |
+
const files = Array.from(e.target.files);
|
| 271 |
+
setModelFiles(files);
|
| 272 |
+
}}
|
| 273 |
+
disabled={status !== null || error !== null}
|
| 274 |
+
/>
|
| 275 |
+
</label>
|
| 276 |
+
<button
|
| 277 |
+
className="border px-4 py-2 rounded-lg bg-blue-400 text-white hover:bg-blue-500 disabled:bg-blue-100 disabled:cursor-not-allowed select-none"
|
| 278 |
+
onClick={() => {
|
| 279 |
+
worker.current.postMessage({ type: "load", files: modelFiles });
|
| 280 |
+
setStatus("loading");
|
| 281 |
+
}}
|
| 282 |
+
disabled={status !== null || error !== null}
|
| 283 |
+
>
|
| 284 |
+
{modelFiles.length > 0 ? `Load selected model directory` : `Load default model`}
|
| 285 |
+
</button>
|
| 286 |
+
{modelFiles.length > 0 && (
|
| 287 |
+
<span className="text-xs text-gray-600">Selected files: {modelFiles.map(f => f.name).join(", ")}</span>
|
| 288 |
+
)}
|
| 289 |
+
</div>
|
| 290 |
</div>
|
| 291 |
</div>
|
| 292 |
)}
|
|
|
|
| 387 |
}}
|
| 388 |
onInput={(e) => setInput(e.target.value)}
|
| 389 |
/>
|
| 390 |
+
<label
|
| 391 |
+
className={`flex items-center px-2 py-2 bg-blue-100 text-blue-700 rounded cursor-pointer hover:bg-blue-200 ${status !== "ready" ? "opacity-50 cursor-not-allowed" : ""}`}
|
| 392 |
+
>
|
| 393 |
📎 Attach
|
| 394 |
<input
|
| 395 |
type="file"
|
src/worker.js
CHANGED
|
@@ -30,18 +30,15 @@ async function check() {
|
|
| 30 |
*/
|
| 31 |
class TextGenerationPipeline {
|
| 32 |
static model_id = "onnx-community/granite-4.0-micro-ONNX-web";
|
| 33 |
-
|
| 34 |
static async getInstance(progress_callback = null) {
|
| 35 |
this.tokenizer ??= AutoTokenizer.from_pretrained(this.model_id, {
|
| 36 |
progress_callback,
|
| 37 |
});
|
| 38 |
-
|
| 39 |
this.model ??= AutoModelForCausalLM.from_pretrained(this.model_id, {
|
| 40 |
dtype: "q4f16",
|
| 41 |
device: "webgpu",
|
| 42 |
progress_callback,
|
| 43 |
});
|
| 44 |
-
|
| 45 |
return Promise.all([this.tokenizer, this.model]);
|
| 46 |
}
|
| 47 |
}
|
|
@@ -114,18 +111,16 @@ async function generate(messages) {
|
|
| 114 |
});
|
| 115 |
}
|
| 116 |
|
| 117 |
-
async function load() {
|
| 118 |
self.postMessage({
|
| 119 |
status: "loading",
|
| 120 |
-
data: "Loading model...",
|
| 121 |
});
|
| 122 |
|
| 123 |
// Load the pipeline and save it for future use.
|
| 124 |
const [tokenizer, model] = await TextGenerationPipeline.getInstance((x) => {
|
| 125 |
-
// We also add a progress callback to the pipeline so that we can
|
| 126 |
-
// track model loading.
|
| 127 |
self.postMessage(x);
|
| 128 |
-
});
|
| 129 |
|
| 130 |
self.postMessage({
|
| 131 |
status: "loading",
|
|
@@ -139,7 +134,7 @@ async function load() {
|
|
| 139 |
}
|
| 140 |
// Listen for messages from the main thread
|
| 141 |
self.addEventListener("message", async (e) => {
|
| 142 |
-
const { type, data } = e.data;
|
| 143 |
|
| 144 |
switch (type) {
|
| 145 |
case "check":
|
|
@@ -147,7 +142,7 @@ self.addEventListener("message", async (e) => {
|
|
| 147 |
break;
|
| 148 |
|
| 149 |
case "load":
|
| 150 |
-
load();
|
| 151 |
break;
|
| 152 |
|
| 153 |
case "generate":
|
|
|
|
| 30 |
*/
|
| 31 |
class TextGenerationPipeline {
|
| 32 |
static model_id = "onnx-community/granite-4.0-micro-ONNX-web";
|
|
|
|
| 33 |
static async getInstance(progress_callback = null) {
|
| 34 |
this.tokenizer ??= AutoTokenizer.from_pretrained(this.model_id, {
|
| 35 |
progress_callback,
|
| 36 |
});
|
|
|
|
| 37 |
this.model ??= AutoModelForCausalLM.from_pretrained(this.model_id, {
|
| 38 |
dtype: "q4f16",
|
| 39 |
device: "webgpu",
|
| 40 |
progress_callback,
|
| 41 |
});
|
|
|
|
| 42 |
return Promise.all([this.tokenizer, this.model]);
|
| 43 |
}
|
| 44 |
}
|
|
|
|
| 111 |
});
|
| 112 |
}
|
| 113 |
|
| 114 |
+
async function load(files = null) {
|
| 115 |
self.postMessage({
|
| 116 |
status: "loading",
|
| 117 |
+
data: files && files.length > 0 ? `Loading model from selected directory` : "Loading model...",
|
| 118 |
});
|
| 119 |
|
| 120 |
// Load the pipeline and save it for future use.
|
| 121 |
const [tokenizer, model] = await TextGenerationPipeline.getInstance((x) => {
|
|
|
|
|
|
|
| 122 |
self.postMessage(x);
|
| 123 |
+
}, files);
|
| 124 |
|
| 125 |
self.postMessage({
|
| 126 |
status: "loading",
|
|
|
|
| 134 |
}
|
| 135 |
// Listen for messages from the main thread
|
| 136 |
self.addEventListener("message", async (e) => {
|
| 137 |
+
const { type, data, files } = e.data;
|
| 138 |
|
| 139 |
switch (type) {
|
| 140 |
case "check":
|
|
|
|
| 142 |
break;
|
| 143 |
|
| 144 |
case "load":
|
| 145 |
+
await load(files || null);
|
| 146 |
break;
|
| 147 |
|
| 148 |
case "generate":
|