File size: 5,211 Bytes
671f7c3
 
bfc3921
 
 
671f7c3
bfc3921
 
 
 
671f7c3
bfc3921
 
 
671f7c3
 
 
bfc3921
671f7c3
 
 
bfc3921
 
 
671f7c3
 
209ec3a
bfc3921
671f7c3
bfc3921
 
 
 
671f7c3
 
 
209ec3a
 
bfc3921
671f7c3
bfc3921
671f7c3
 
 
bfc3921
671f7c3
 
 
bfc3921
671f7c3
bfc3921
671f7c3
 
 
 
 
 
 
bfc3921
671f7c3
 
 
 
 
bfc3921
 
671f7c3
 
 
bfc3921
671f7c3
 
bfc3921
 
 
 
 
 
 
671f7c3
bfc3921
671f7c3
 
bfc3921
 
671f7c3
bfc3921
 
 
671f7c3
 
 
bfc3921
671f7c3
 
 
bfc3921
 
 
 
671f7c3
 
bfc3921
 
 
 
 
671f7c3
 
 
 
 
 
 
 
 
 
 
 
bfc3921
 
 
 
671f7c3
 
 
bfc3921
671f7c3
 
 
bfc3921
 
 
671f7c3
 
 
bfc3921
671f7c3
bfc3921
 
671f7c3
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
import { pipeline, env, cos_sim } from 'https://cdn.jsdelivr.net/npm/@huggingface/transformers@3.0.0-alpha.19/dist/transformers.min.js';

// Cấu hình
env.allowLocalModels = false;
env.useBrowserCache = true; 

// Hàm log helper gửi về Main Thread
function log(msg) {
    self.postMessage({ type: 'log', payload: msg });
}

const EMBEDDING_MODEL_ID = 'onnx-community/embeddinggemma-300m-ONNX';
// Dùng ID Granite 4.0 (Nếu lỗi 404 hãy thử ibm-granite/granite-3.0-2b-instruct)
const LLM_MODEL_ID = 'ibm-granite/granite-4.0-350m-instruct'; 

let extractor = null;
let generator = null;
let vectorStore = [];

async function initModels() {
    try {
        log("🚀 Bắt đầu tải Embedding Model (Gemma)...");
        self.postMessage({ type: 'status', payload: "Đang tải Embedding..." });

        extractor = await pipeline('feature-extraction', EMBEDDING_MODEL_ID, {
            device: 'webgpu',
            dtype: 'fp32',
            progress_callback: (data) => self.postMessage({ type: 'download_progress', payload: data })
        });
        log("✅ Embedding Model đã tải xong.");

        log("🚀 Bắt đầu tải LLM (Granite 4.0)...");
        self.postMessage({ type: 'status', payload: "Đang tải LLM..." });

        generator = await pipeline('text-generation', LLM_MODEL_ID, {
            device: 'webgpu',
            dtype: 'q4',
            use_external_data_format: true,
            progress_callback: (data) => self.postMessage({ type: 'download_progress', payload: data })
        });
        log("✅ LLM đã tải xong.");

        self.postMessage({ type: 'init_complete' });
    } catch (e) {
        self.postMessage({ type: 'error', payload: e.message });
    }
}

// Hàm chia nhỏ văn bản (Chunking) có log
function chunkText(text, chunkSize = 300, overlap = 50) {
    log(`Bắt đầu chia nhỏ văn bản dài ${text.length} ký tự...`);
    const sentences = text.match(/[^.!?]+[.!?]+|[^.!?]+$/g) || [text];
    let chunks = [];
    let currentChunk = "";
    
    for (let sentence of sentences) {
        if ((currentChunk + sentence).length > chunkSize) {
            chunks.push(currentChunk.trim());
            currentChunk = sentence.slice(-overlap);
        } else {
            currentChunk += " " + sentence;
        }
    }
    if (currentChunk) chunks.push(currentChunk.trim());
    
    log(`-> Đã chia thành ${chunks.length} chunks.`);
    return chunks;
}

// Tạo Embeddings
async function ingestText(text) {
    const chunks = chunkText(text);
    vectorStore = []; 
    
    log("🔄 Đang tạo vector cho từng chunk (Quá trình này tốn GPU)...");
    
    let startTime = performance.now();
    for (let i = 0; i < chunks.length; i++) {
        const output = await extractor(chunks[i], { pooling: 'mean', normalize: true });
        vectorStore.push({
            text: chunks[i],
            vector: output.data
        });
        // Log mỗi 5 chunks để không spam
        if ((i + 1) % 5 === 0) log(`Processed ${i + 1}/${chunks.length} chunks...`);
    }
    let endTime = performance.now();
    
    log(`✅ Indexing hoàn tất trong ${((endTime - startTime)/1000).toFixed(2)}s.`);
}

async function retrieve(query) {
    log(`🔍 Đang tìm kiếm ngữ cảnh cho: "${query}"`);
    const queryOutput = await extractor(query, { pooling: 'mean', normalize: true });
    const queryVector = queryOutput.data;

    const scored = vectorStore.map(item => ({
        text: item.text,
        score: cos_sim(queryVector, item.vector)
    }));

    scored.sort((a, b) => b.score - a.score);
    
    const top3 = scored.slice(0, 3);
    log(`-> Tìm thấy 3 đoạn khớp nhất (Scores: ${top3.map(i => i.score.toFixed(2)).join(', ')})`);
    
    return top3.map(i => i.text).join("\n\n");
}

self.onmessage = async (e) => {
    if (!extractor || !generator) {
        await initModels();
    }

    const { type, payload } = e.data;

    if (type === 'ingest_text') {
        await ingestText(payload);
    } else if (type === 'query') {
        if (vectorStore.length === 0) {
            log("⚠️ Cảnh báo: Chưa có dữ liệu PDF nào được index.");
        }

        const context = await retrieve(payload);
        
        const prompt = `<|system|>
Bạn là trợ lý AI hữu ích. Trả lời dựa trên ngữ cảnh sau bằng Tiếng Việt:
${context}
<|user|>
${payload}
<|assistant|>`;

        log("🤖 LLM đang suy nghĩ...");
        try {
            const output = await generator(prompt, {
                max_new_tokens: 256,
                temperature: 0.6,
                do_sample: true,
                // Log mỗi khi sinh ra token mới (cẩn thận spam nếu bật)
                // callback_function: (x) => log("Generated token...") 
            });
            
            let answer = output[0].generated_text;
            if (answer.includes("<|assistant|>")) {
                answer = answer.split("<|assistant|>")[1];
            }
            self.postMessage({ type: 'answer', payload: answer });
        } catch (err) {
            self.postMessage({ type: 'error', payload: err.message });
        }
    }
};

initModels();