hantech's picture
Update worker.js
bfc3921 verified
import { pipeline, env, cos_sim } from 'https://cdn.jsdelivr.net/npm/@huggingface/transformers@3.0.0-alpha.19/dist/transformers.min.js';
// Cấu hình
env.allowLocalModels = false;
env.useBrowserCache = true;
// Hàm log helper gửi về Main Thread
function log(msg) {
self.postMessage({ type: 'log', payload: msg });
}
const EMBEDDING_MODEL_ID = 'onnx-community/embeddinggemma-300m-ONNX';
// Dùng ID Granite 4.0 (Nếu lỗi 404 hãy thử ibm-granite/granite-3.0-2b-instruct)
const LLM_MODEL_ID = 'ibm-granite/granite-4.0-350m-instruct';
let extractor = null;
let generator = null;
let vectorStore = [];
async function initModels() {
try {
log("🚀 Bắt đầu tải Embedding Model (Gemma)...");
self.postMessage({ type: 'status', payload: "Đang tải Embedding..." });
extractor = await pipeline('feature-extraction', EMBEDDING_MODEL_ID, {
device: 'webgpu',
dtype: 'fp32',
progress_callback: (data) => self.postMessage({ type: 'download_progress', payload: data })
});
log("✅ Embedding Model đã tải xong.");
log("🚀 Bắt đầu tải LLM (Granite 4.0)...");
self.postMessage({ type: 'status', payload: "Đang tải LLM..." });
generator = await pipeline('text-generation', LLM_MODEL_ID, {
device: 'webgpu',
dtype: 'q4',
use_external_data_format: true,
progress_callback: (data) => self.postMessage({ type: 'download_progress', payload: data })
});
log("✅ LLM đã tải xong.");
self.postMessage({ type: 'init_complete' });
} catch (e) {
self.postMessage({ type: 'error', payload: e.message });
}
}
// Hàm chia nhỏ văn bản (Chunking) có log
function chunkText(text, chunkSize = 300, overlap = 50) {
log(`Bắt đầu chia nhỏ văn bản dài ${text.length} ký tự...`);
const sentences = text.match(/[^.!?]+[.!?]+|[^.!?]+$/g) || [text];
let chunks = [];
let currentChunk = "";
for (let sentence of sentences) {
if ((currentChunk + sentence).length > chunkSize) {
chunks.push(currentChunk.trim());
currentChunk = sentence.slice(-overlap);
} else {
currentChunk += " " + sentence;
}
}
if (currentChunk) chunks.push(currentChunk.trim());
log(`-> Đã chia thành ${chunks.length} chunks.`);
return chunks;
}
// Tạo Embeddings
async function ingestText(text) {
const chunks = chunkText(text);
vectorStore = [];
log("🔄 Đang tạo vector cho từng chunk (Quá trình này tốn GPU)...");
let startTime = performance.now();
for (let i = 0; i < chunks.length; i++) {
const output = await extractor(chunks[i], { pooling: 'mean', normalize: true });
vectorStore.push({
text: chunks[i],
vector: output.data
});
// Log mỗi 5 chunks để không spam
if ((i + 1) % 5 === 0) log(`Processed ${i + 1}/${chunks.length} chunks...`);
}
let endTime = performance.now();
log(`✅ Indexing hoàn tất trong ${((endTime - startTime)/1000).toFixed(2)}s.`);
}
async function retrieve(query) {
log(`🔍 Đang tìm kiếm ngữ cảnh cho: "${query}"`);
const queryOutput = await extractor(query, { pooling: 'mean', normalize: true });
const queryVector = queryOutput.data;
const scored = vectorStore.map(item => ({
text: item.text,
score: cos_sim(queryVector, item.vector)
}));
scored.sort((a, b) => b.score - a.score);
const top3 = scored.slice(0, 3);
log(`-> Tìm thấy 3 đoạn khớp nhất (Scores: ${top3.map(i => i.score.toFixed(2)).join(', ')})`);
return top3.map(i => i.text).join("\n\n");
}
self.onmessage = async (e) => {
if (!extractor || !generator) {
await initModels();
}
const { type, payload } = e.data;
if (type === 'ingest_text') {
await ingestText(payload);
} else if (type === 'query') {
if (vectorStore.length === 0) {
log("⚠️ Cảnh báo: Chưa có dữ liệu PDF nào được index.");
}
const context = await retrieve(payload);
const prompt = `<|system|>
Bạn là trợ lý AI hữu ích. Trả lời dựa trên ngữ cảnh sau bằng Tiếng Việt:
${context}
<|user|>
${payload}
<|assistant|>`;
log("🤖 LLM đang suy nghĩ...");
try {
const output = await generator(prompt, {
max_new_tokens: 256,
temperature: 0.6,
do_sample: true,
// Log mỗi khi sinh ra token mới (cẩn thận spam nếu bật)
// callback_function: (x) => log("Generated token...")
});
let answer = output[0].generated_text;
if (answer.includes("<|assistant|>")) {
answer = answer.split("<|assistant|>")[1];
}
self.postMessage({ type: 'answer', payload: answer });
} catch (err) {
self.postMessage({ type: 'error', payload: err.message });
}
}
};
initModels();