carljmosca commited on
Commit
726a6f0
·
1 Parent(s): e5b1920

Added support for local ONNX/model loading logic from worker.js

Browse files
Files changed (2) hide show
  1. src/App.jsx +53 -40
  2. src/worker.js +5 -10
src/App.jsx CHANGED
@@ -26,6 +26,7 @@ function App() {
26
  const [loadingMessage, setLoadingMessage] = useState("");
27
  const [progressItems, setProgressItems] = useState([]);
28
  const [isRunning, setIsRunning] = useState(false);
 
29
 
30
  // Inputs and outputs
31
  const [input, setInput] = useState("");
@@ -100,7 +101,6 @@ function App() {
100
  const onMessageReceived = (e) => {
101
  switch (e.data.status) {
102
  case "loading":
103
- // Model file start load: add a new progress item to the list.
104
  setStatus("loading");
105
  setLoadingMessage(e.data.data);
106
  break;
@@ -110,7 +110,6 @@ function App() {
110
  break;
111
 
112
  case "progress":
113
- // Model file progress: update one of the progress items.
114
  setProgressItems((prev) =>
115
  prev.map((item) => {
116
  if (item.file === e.data.file) {
@@ -122,53 +121,45 @@ function App() {
122
  break;
123
 
124
  case "done":
125
- // Model file loaded: remove the progress item from the list.
126
  setProgressItems((prev) =>
127
  prev.filter((item) => item.file !== e.data.file),
128
  );
129
  break;
130
 
131
  case "ready":
132
- // Pipeline ready: the worker is ready to accept messages.
133
  setStatus("ready");
134
  break;
135
 
136
  case "start":
137
- {
138
- // Start generation
139
- setMessages((prev) => [
140
- ...prev,
141
- { role: "assistant", content: "" },
142
- ]);
143
- }
144
  break;
145
 
146
  case "update":
147
- {
148
- // Generation update: update the output text.
149
- // Parse messages
150
- const { output, tps, numTokens } = e.data;
151
- setTps(tps);
152
- setNumTokens(numTokens);
153
- setMessages((prev) => {
154
- const cloned = [...prev];
155
- const last = cloned.at(-1);
156
- cloned[cloned.length - 1] = {
157
- ...last,
158
- content: last.content + output,
159
- };
160
- return cloned;
161
- });
162
- }
163
  break;
164
 
165
  case "complete":
166
- // Generation complete: re-enable the "Generate" button
167
  setIsRunning(false);
168
  break;
169
 
170
  case "error":
171
- setError(e.data.data);
 
 
172
  break;
173
  }
174
  };
@@ -266,16 +257,36 @@ function App() {
266
  </div>
267
  )}
268
 
269
- <button
270
- className="border px-4 py-2 rounded-lg bg-blue-400 text-white hover:bg-blue-500 disabled:bg-blue-100 disabled:cursor-not-allowed select-none"
271
- onClick={() => {
272
- worker.current.postMessage({ type: "load" });
273
- setStatus("loading");
274
- }}
275
- disabled={status !== null || error !== null}
276
- >
277
- Load model
278
- </button>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
279
  </div>
280
  </div>
281
  )}
@@ -376,7 +387,9 @@ function App() {
376
  }}
377
  onInput={(e) => setInput(e.target.value)}
378
  />
379
- <label className="flex items-center px-2 py-2 bg-blue-100 text-blue-700 rounded cursor-pointer hover:bg-blue-200">
 
 
380
  📎 Attach
381
  <input
382
  type="file"
 
26
  const [loadingMessage, setLoadingMessage] = useState("");
27
  const [progressItems, setProgressItems] = useState([]);
28
  const [isRunning, setIsRunning] = useState(false);
29
+ const [modelFiles, setModelFiles] = useState([]);
30
 
31
  // Inputs and outputs
32
  const [input, setInput] = useState("");
 
101
  const onMessageReceived = (e) => {
102
  switch (e.data.status) {
103
  case "loading":
 
104
  setStatus("loading");
105
  setLoadingMessage(e.data.data);
106
  break;
 
110
  break;
111
 
112
  case "progress":
 
113
  setProgressItems((prev) =>
114
  prev.map((item) => {
115
  if (item.file === e.data.file) {
 
121
  break;
122
 
123
  case "done":
 
124
  setProgressItems((prev) =>
125
  prev.filter((item) => item.file !== e.data.file),
126
  );
127
  break;
128
 
129
  case "ready":
 
130
  setStatus("ready");
131
  break;
132
 
133
  case "start":
134
+ setMessages((prev) => [
135
+ ...prev,
136
+ { role: "assistant", content: "" },
137
+ ]);
 
 
 
138
  break;
139
 
140
  case "update":
141
+ const { output, tps, numTokens } = e.data;
142
+ setTps(tps);
143
+ setNumTokens(numTokens);
144
+ setMessages((prev) => {
145
+ const cloned = [...prev];
146
+ const last = cloned.at(-1);
147
+ cloned[cloned.length - 1] = {
148
+ ...last,
149
+ content: last.content + output,
150
+ };
151
+ return cloned;
152
+ });
 
 
 
 
153
  break;
154
 
155
  case "complete":
 
156
  setIsRunning(false);
157
  break;
158
 
159
  case "error":
160
+ setError(e.data.data || "Unknown error during model loading.");
161
+ setStatus(null);
162
+ setLoadingMessage("");
163
  break;
164
  }
165
  };
 
257
  </div>
258
  )}
259
 
260
+ <div className="flex flex-col items-center gap-2">
261
+ <label className="border px-4 py-2 rounded-lg bg-blue-100 text-blue-700 cursor-pointer hover:bg-blue-200">
262
+ Select model directory
263
+ <input
264
+ type="file"
265
+ webkitdirectory="true"
266
+ directory="true"
267
+ multiple
268
+ className="hidden"
269
+ onChange={e => {
270
+ const files = Array.from(e.target.files);
271
+ setModelFiles(files);
272
+ }}
273
+ disabled={status !== null || error !== null}
274
+ />
275
+ </label>
276
+ <button
277
+ className="border px-4 py-2 rounded-lg bg-blue-400 text-white hover:bg-blue-500 disabled:bg-blue-100 disabled:cursor-not-allowed select-none"
278
+ onClick={() => {
279
+ worker.current.postMessage({ type: "load", files: modelFiles });
280
+ setStatus("loading");
281
+ }}
282
+ disabled={status !== null || error !== null}
283
+ >
284
+ {modelFiles.length > 0 ? `Load selected model directory` : `Load default model`}
285
+ </button>
286
+ {modelFiles.length > 0 && (
287
+ <span className="text-xs text-gray-600">Selected files: {modelFiles.map(f => f.name).join(", ")}</span>
288
+ )}
289
+ </div>
290
  </div>
291
  </div>
292
  )}
 
387
  }}
388
  onInput={(e) => setInput(e.target.value)}
389
  />
390
+ <label
391
+ className={`flex items-center px-2 py-2 bg-blue-100 text-blue-700 rounded cursor-pointer hover:bg-blue-200 ${status !== "ready" ? "opacity-50 cursor-not-allowed" : ""}`}
392
+ >
393
  📎 Attach
394
  <input
395
  type="file"
src/worker.js CHANGED
@@ -30,18 +30,15 @@ async function check() {
30
  */
31
  class TextGenerationPipeline {
32
  static model_id = "onnx-community/granite-4.0-micro-ONNX-web";
33
-
34
  static async getInstance(progress_callback = null) {
35
  this.tokenizer ??= AutoTokenizer.from_pretrained(this.model_id, {
36
  progress_callback,
37
  });
38
-
39
  this.model ??= AutoModelForCausalLM.from_pretrained(this.model_id, {
40
  dtype: "q4f16",
41
  device: "webgpu",
42
  progress_callback,
43
  });
44
-
45
  return Promise.all([this.tokenizer, this.model]);
46
  }
47
  }
@@ -114,18 +111,16 @@ async function generate(messages) {
114
  });
115
  }
116
 
117
- async function load() {
118
  self.postMessage({
119
  status: "loading",
120
- data: "Loading model...",
121
  });
122
 
123
  // Load the pipeline and save it for future use.
124
  const [tokenizer, model] = await TextGenerationPipeline.getInstance((x) => {
125
- // We also add a progress callback to the pipeline so that we can
126
- // track model loading.
127
  self.postMessage(x);
128
- });
129
 
130
  self.postMessage({
131
  status: "loading",
@@ -139,7 +134,7 @@ async function load() {
139
  }
140
  // Listen for messages from the main thread
141
  self.addEventListener("message", async (e) => {
142
- const { type, data } = e.data;
143
 
144
  switch (type) {
145
  case "check":
@@ -147,7 +142,7 @@ self.addEventListener("message", async (e) => {
147
  break;
148
 
149
  case "load":
150
- load();
151
  break;
152
 
153
  case "generate":
 
30
  */
31
  class TextGenerationPipeline {
32
  static model_id = "onnx-community/granite-4.0-micro-ONNX-web";
 
33
  static async getInstance(progress_callback = null) {
34
  this.tokenizer ??= AutoTokenizer.from_pretrained(this.model_id, {
35
  progress_callback,
36
  });
 
37
  this.model ??= AutoModelForCausalLM.from_pretrained(this.model_id, {
38
  dtype: "q4f16",
39
  device: "webgpu",
40
  progress_callback,
41
  });
 
42
  return Promise.all([this.tokenizer, this.model]);
43
  }
44
  }
 
111
  });
112
  }
113
 
114
+ async function load(files = null) {
115
  self.postMessage({
116
  status: "loading",
117
+ data: files && files.length > 0 ? `Loading model from selected directory` : "Loading model...",
118
  });
119
 
120
  // Load the pipeline and save it for future use.
121
  const [tokenizer, model] = await TextGenerationPipeline.getInstance((x) => {
 
 
122
  self.postMessage(x);
123
+ }, files);
124
 
125
  self.postMessage({
126
  status: "loading",
 
134
  }
135
  // Listen for messages from the main thread
136
  self.addEventListener("message", async (e) => {
137
+ const { type, data, files } = e.data;
138
 
139
  switch (type) {
140
  case "check":
 
142
  break;
143
 
144
  case "load":
145
+ await load(files || null);
146
  break;
147
 
148
  case "generate":