hantech commited on
Commit
bfc3921
·
verified ·
1 Parent(s): 55b4471

Update worker.js

Browse files
Files changed (1) hide show
  1. worker.js +62 -63
worker.js CHANGED
@@ -1,60 +1,54 @@
1
  import { pipeline, env, cos_sim } from 'https://cdn.jsdelivr.net/npm/@huggingface/transformers@3.0.0-alpha.19/dist/transformers.min.js';
2
 
3
- // Cấu hình WebGPU
4
- env.backends.onnx.wasm.proxy = false;
 
5
 
6
- // --- CẤU HÌNH MODEL (KHÔNG ĐỔI THEO YÊU CẦU) ---
7
- // 1. Embedding Model
8
- const EMBEDDING_MODEL_ID = 'onnx-community/embeddinggemma-300m-ONNX';
 
9
 
10
- // 2. LLM Model: Granite 4.0 Nano
11
- // Lưu ý: Nếu phiên bản ONNX của Granite 4.0 chưa public dưới ID này,
12
- // bạn cần trỏ tới đúng repo onnx (ví dụ ibm-granite/granite-3.0-2b-instruct nếu 4.0 chưa có onnx).
13
- // Tuy nhiên, tôi giữ nguyên tham chiếu "Granite" như yêu cầu.
14
- const LLM_MODEL_ID = 'onnx-community/granite-4.0-1b-ONNX-web'; // Kiểm tra lại tên chính xác trên HF Hub nếu lỗi
15
 
16
  let extractor = null;
17
  let generator = null;
18
- let vectorStore = []; // Lưu trữ chunks và vectors: { text: string, vector: number[] }
19
 
20
  async function initModels() {
21
  try {
22
- self.postMessage({ type: 'status', payload: "Đang tải Embedding Model (có thể lâu lần đầu)..." });
23
-
24
- // Thêm cờ { use_cache: true } dù mặc định nó đã có, để đảm bảo
25
  extractor = await pipeline('feature-extraction', EMBEDDING_MODEL_ID, {
26
  device: 'webgpu',
27
  dtype: 'fp32',
28
- use_cache: true
29
  });
 
 
 
 
30
 
31
- self.postMessage({ type: 'status', payload: "Đang tải LLM Granite 4.0 (Model nặng, vui lòng chờ)..." });
32
-
33
  generator = await pipeline('text-generation', LLM_MODEL_ID, {
34
  device: 'webgpu',
35
  dtype: 'q4',
36
  use_external_data_format: true,
37
- // Thêm progress_callback để UI không tưởng bị treo
38
- progress_callback: (data) => {
39
- if (data.status === 'progress') {
40
- // Gửi tiến độ về main thread để hiện loading bar nếu cần
41
- self.postMessage({
42
- type: 'download_progress',
43
- payload: { file: data.file, progress: data.progress }
44
- });
45
- }
46
- }
47
  });
 
48
 
49
  self.postMessage({ type: 'init_complete' });
50
  } catch (e) {
51
- console.error(e); // Log chi tiết ra console
52
- self.postMessage({ type: 'error', payload: "Lỗi tải model (Kiểm tra Console F12): " + e.message });
53
  }
54
  }
55
 
56
- // Xử chunking văn bản
57
  function chunkText(text, chunkSize = 300, overlap = 50) {
 
58
  const sentences = text.match(/[^.!?]+[.!?]+|[^.!?]+$/g) || [text];
59
  let chunks = [];
60
  let currentChunk = "";
@@ -62,49 +56,57 @@ function chunkText(text, chunkSize = 300, overlap = 50) {
62
  for (let sentence of sentences) {
63
  if ((currentChunk + sentence).length > chunkSize) {
64
  chunks.push(currentChunk.trim());
65
- currentChunk = sentence.slice(-overlap); // Overlap đơn giản
66
  } else {
67
  currentChunk += " " + sentence;
68
  }
69
  }
70
  if (currentChunk) chunks.push(currentChunk.trim());
 
 
71
  return chunks;
72
  }
73
 
74
- // Tạo embeddings cho văn bản
75
  async function ingestText(text) {
76
  const chunks = chunkText(text);
77
- vectorStore = []; // Reset store
78
-
79
- for (const chunk of chunks) {
80
- const output = await extractor(chunk, { pooling: 'mean', normalize: true });
 
 
 
81
  vectorStore.push({
82
- text: chunk,
83
  vector: output.data
84
  });
 
 
85
  }
86
- console.log(`Đã index ${vectorStore.length} đoạn văn bản.`);
 
 
87
  }
88
 
89
- // Tìm kiếm RAG
90
  async function retrieve(query) {
 
91
  const queryOutput = await extractor(query, { pooling: 'mean', normalize: true });
92
  const queryVector = queryOutput.data;
93
 
94
- // Tính Cosine Similarity
95
- const scored = vectorStore.map(item => {
96
- return {
97
- text: item.text,
98
- score: cos_sim(queryVector, item.vector)
99
- };
100
- });
101
 
102
- // Lấy top 3 đoạn liên quan nhất
103
  scored.sort((a, b) => b.score - a.score);
104
- return scored.slice(0, 3).map(i => i.text).join("\n\n");
 
 
 
 
105
  }
106
 
107
- // Xử lý tin nhắn từ Main Thread
108
  self.onmessage = async (e) => {
109
  if (!extractor || !generator) {
110
  await initModels();
@@ -115,35 +117,33 @@ self.onmessage = async (e) => {
115
  if (type === 'ingest_text') {
116
  await ingestText(payload);
117
  } else if (type === 'query') {
118
- // 1. Retrieve Context
 
 
 
119
  const context = await retrieve(payload);
120
 
121
- // 2. Tạo Prompt cho Granite
122
- // Định dạng prompt cơ bản cho instruction tuned model
123
  const prompt = `<|system|>
124
- Bạn là trợ lý AI hữu ích. Hãy trả lời câu hỏi dựa trên ngữ cảnh được cung cấp bên dưới bằng Tiếng Việt.
125
- Ngữ cảnh:
126
  ${context}
127
  <|user|>
128
  ${payload}
129
- <|assistant|>
130
- `;
131
-
132
- // 3. Generate Answer
133
  try {
134
  const output = await generator(prompt, {
135
  max_new_tokens: 256,
136
- temperature: 0.7,
137
  do_sample: true,
 
 
138
  });
139
 
140
- // Lấy phần trả lời sau tag assistant (tuỳ thuộc format model)
141
  let answer = output[0].generated_text;
142
- // Cắt bớt phần prompt nếu cần thiết
143
  if (answer.includes("<|assistant|>")) {
144
  answer = answer.split("<|assistant|>")[1];
145
  }
146
-
147
  self.postMessage({ type: 'answer', payload: answer });
148
  } catch (err) {
149
  self.postMessage({ type: 'error', payload: err.message });
@@ -151,5 +151,4 @@ ${payload}
151
  }
152
  };
153
 
154
- // Khởi tạo ngay khi worker chạy
155
  initModels();
 
1
  import { pipeline, env, cos_sim } from 'https://cdn.jsdelivr.net/npm/@huggingface/transformers@3.0.0-alpha.19/dist/transformers.min.js';
2
 
3
+ // Cấu hình
4
+ env.allowLocalModels = false;
5
+ env.useBrowserCache = true;
6
 
7
+ // Hàm log helper gửi về Main Thread
8
+ function log(msg) {
9
+ self.postMessage({ type: 'log', payload: msg });
10
+ }
11
 
12
+ const EMBEDDING_MODEL_ID = 'onnx-community/embeddinggemma-300m-ONNX';
13
+ // Dùng ID Granite 4.0 (Nếu lỗi 404 hãy thử ibm-granite/granite-3.0-2b-instruct)
14
+ const LLM_MODEL_ID = 'ibm-granite/granite-4.0-350m-instruct';
 
 
15
 
16
  let extractor = null;
17
  let generator = null;
18
+ let vectorStore = [];
19
 
20
  async function initModels() {
21
  try {
22
+ log("🚀 Bắt đầu tải Embedding Model (Gemma)...");
23
+ self.postMessage({ type: 'status', payload: "Đang tải Embedding..." });
24
+
25
  extractor = await pipeline('feature-extraction', EMBEDDING_MODEL_ID, {
26
  device: 'webgpu',
27
  dtype: 'fp32',
28
+ progress_callback: (data) => self.postMessage({ type: 'download_progress', payload: data })
29
  });
30
+ log("✅ Embedding Model đã tải xong.");
31
+
32
+ log("🚀 Bắt đầu tải LLM (Granite 4.0)...");
33
+ self.postMessage({ type: 'status', payload: "Đang tải LLM..." });
34
 
 
 
35
  generator = await pipeline('text-generation', LLM_MODEL_ID, {
36
  device: 'webgpu',
37
  dtype: 'q4',
38
  use_external_data_format: true,
39
+ progress_callback: (data) => self.postMessage({ type: 'download_progress', payload: data })
 
 
 
 
 
 
 
 
 
40
  });
41
+ log("✅ LLM đã tải xong.");
42
 
43
  self.postMessage({ type: 'init_complete' });
44
  } catch (e) {
45
+ self.postMessage({ type: 'error', payload: e.message });
 
46
  }
47
  }
48
 
49
+ // Hàm chia nhỏ văn bản (Chunking) có log
50
  function chunkText(text, chunkSize = 300, overlap = 50) {
51
+ log(`Bắt đầu chia nhỏ văn bản dài ${text.length} ký tự...`);
52
  const sentences = text.match(/[^.!?]+[.!?]+|[^.!?]+$/g) || [text];
53
  let chunks = [];
54
  let currentChunk = "";
 
56
  for (let sentence of sentences) {
57
  if ((currentChunk + sentence).length > chunkSize) {
58
  chunks.push(currentChunk.trim());
59
+ currentChunk = sentence.slice(-overlap);
60
  } else {
61
  currentChunk += " " + sentence;
62
  }
63
  }
64
  if (currentChunk) chunks.push(currentChunk.trim());
65
+
66
+ log(`-> Đã chia thành ${chunks.length} chunks.`);
67
  return chunks;
68
  }
69
 
70
+ // Tạo Embeddings
71
  async function ingestText(text) {
72
  const chunks = chunkText(text);
73
+ vectorStore = [];
74
+
75
+ log("🔄 Đang tạo vector cho từng chunk (Quá trình này tốn GPU)...");
76
+
77
+ let startTime = performance.now();
78
+ for (let i = 0; i < chunks.length; i++) {
79
+ const output = await extractor(chunks[i], { pooling: 'mean', normalize: true });
80
  vectorStore.push({
81
+ text: chunks[i],
82
  vector: output.data
83
  });
84
+ // Log mỗi 5 chunks để không spam
85
+ if ((i + 1) % 5 === 0) log(`Processed ${i + 1}/${chunks.length} chunks...`);
86
  }
87
+ let endTime = performance.now();
88
+
89
+ log(`✅ Indexing hoàn tất trong ${((endTime - startTime)/1000).toFixed(2)}s.`);
90
  }
91
 
 
92
  async function retrieve(query) {
93
+ log(`🔍 Đang tìm kiếm ngữ cảnh cho: "${query}"`);
94
  const queryOutput = await extractor(query, { pooling: 'mean', normalize: true });
95
  const queryVector = queryOutput.data;
96
 
97
+ const scored = vectorStore.map(item => ({
98
+ text: item.text,
99
+ score: cos_sim(queryVector, item.vector)
100
+ }));
 
 
 
101
 
 
102
  scored.sort((a, b) => b.score - a.score);
103
+
104
+ const top3 = scored.slice(0, 3);
105
+ log(`-> Tìm thấy 3 đoạn khớp nhất (Scores: ${top3.map(i => i.score.toFixed(2)).join(', ')})`);
106
+
107
+ return top3.map(i => i.text).join("\n\n");
108
  }
109
 
 
110
  self.onmessage = async (e) => {
111
  if (!extractor || !generator) {
112
  await initModels();
 
117
  if (type === 'ingest_text') {
118
  await ingestText(payload);
119
  } else if (type === 'query') {
120
+ if (vectorStore.length === 0) {
121
+ log("⚠️ Cảnh báo: Chưa có dữ liệu PDF nào được index.");
122
+ }
123
+
124
  const context = await retrieve(payload);
125
 
 
 
126
  const prompt = `<|system|>
127
+ Bạn là trợ lý AI hữu ích. Trả lời dựa trên ngữ cảnh sau bằng Tiếng Việt:
 
128
  ${context}
129
  <|user|>
130
  ${payload}
131
+ <|assistant|>`;
132
+
133
+ log("🤖 LLM đang suy nghĩ...");
 
134
  try {
135
  const output = await generator(prompt, {
136
  max_new_tokens: 256,
137
+ temperature: 0.6,
138
  do_sample: true,
139
+ // Log mỗi khi sinh ra token mới (cẩn thận spam nếu bật)
140
+ // callback_function: (x) => log("Generated token...")
141
  });
142
 
 
143
  let answer = output[0].generated_text;
 
144
  if (answer.includes("<|assistant|>")) {
145
  answer = answer.split("<|assistant|>")[1];
146
  }
 
147
  self.postMessage({ type: 'answer', payload: answer });
148
  } catch (err) {
149
  self.postMessage({ type: 'error', payload: err.message });
 
151
  }
152
  };
153
 
 
154
  initModels();