Spaces:
Paused
Paused
| import { isWebGPUAvailable } from "./webGpu"; | |
| import { | |
| updateSearchResults, | |
| updateResponse, | |
| getSearchResults, | |
| getQuery, | |
| updateSearchPromise, | |
| getSearchPromise, | |
| updateTextGenerationState, | |
| updateSearchState, | |
| updateModelLoadingProgress, | |
| getTextGenerationState, | |
| getSettings, | |
| listenToSettingsChanges, | |
| } from "./pubSub"; | |
| import { search } from "./search"; | |
| import { addLogEntry } from "./logEntries"; | |
| import { getSystemPrompt } from "./systemPrompt"; | |
| import prettyMilliseconds from "pretty-ms"; | |
| import { getOpenAiClient } from "./openai"; | |
| import { getSearchTokenHash } from "./searchTokenHash"; | |
| import { | |
| ChatCompletionCreateParamsStreaming, | |
| ChatCompletionMessageParam, | |
| } from "openai/resources/chat/completions.mjs"; | |
| import gptTokenizer from "gpt-tokenizer"; | |
| import { ChatOptions } from "@mlc-ai/web-llm"; | |
| import { defaultSettings } from "./settings"; | |
| export async function searchAndRespond() { | |
| if (getQuery() === "") return; | |
| document.title = getQuery(); | |
| updateResponse(""); | |
| updateSearchResults({ textResults: [], imageResults: [] }); | |
| updateSearchPromise(startSearch(getQuery())); | |
| if (!getSettings().enableAiResponse) return; | |
| const responseGenerationStartTime = new Date().getTime(); | |
| updateTextGenerationState("loadingModel"); | |
| try { | |
| const settings = getSettings(); | |
| if (settings.inferenceType === "openai") { | |
| await generateTextWithOpenAI(); | |
| } else if (settings.inferenceType === "internal") { | |
| await generateTextWithInternalApi(); | |
| } else { | |
| await canDownloadModels(); | |
| updateTextGenerationState("loadingModel"); | |
| try { | |
| if (!isWebGPUAvailable) throw Error("WebGPU is not available."); | |
| if (!settings.enableWebGpu) throw Error("WebGPU is disabled."); | |
| await generateTextWithWebLlm(); | |
| } catch (error) { | |
| addLogEntry(`Skipping text generation with WebLLM: ${error}`); | |
| addLogEntry(`Starting text generation with Wllama`); | |
| await generateTextWithWllama(); | |
| } | |
| } | |
| if (getTextGenerationState() !== "interrupted") { | |
| updateTextGenerationState("completed"); | |
| } | |
| } catch (error) { | |
| addLogEntry(`Error generating text: ${error}`); | |
| updateTextGenerationState("failed"); | |
| } | |
| addLogEntry( | |
| `Response generation took ${prettyMilliseconds( | |
| new Date().getTime() - responseGenerationStartTime, | |
| { verbose: true }, | |
| )}`, | |
| ); | |
| } | |
| async function generateTextWithOpenAI() { | |
| const settings = getSettings(); | |
| const openai = getOpenAiClient({ | |
| baseURL: settings.openAiApiBaseUrl, | |
| apiKey: settings.openAiApiKey, | |
| }); | |
| await canStartResponding(); | |
| updateTextGenerationState("preparingToGenerate"); | |
| const completion = await openai.chat.completions.create({ | |
| ...getDefaultChatCompletionCreateParamsStreaming(), | |
| model: settings.openAiApiModel, | |
| messages: [ | |
| { | |
| role: "user", | |
| content: getSystemPrompt(getFormattedSearchResults(true)), | |
| }, | |
| { role: "assistant", content: "Ok!" }, | |
| { role: "user", content: getQuery() }, | |
| ], | |
| }); | |
| let streamedMessage = ""; | |
| for await (const chunk of completion) { | |
| const deltaContent = chunk.choices[0].delta.content; | |
| if (deltaContent) streamedMessage += deltaContent; | |
| if (getTextGenerationState() === "interrupted") { | |
| completion.controller.abort(); | |
| } else if (getTextGenerationState() !== "generating") { | |
| updateTextGenerationState("generating"); | |
| } | |
| updateResponseRateLimited(streamedMessage); | |
| } | |
| updateResponse(streamedMessage); | |
| } | |
| async function generateTextWithInternalApi() { | |
| await canStartResponding(); | |
| updateTextGenerationState("preparingToGenerate"); | |
| const inferenceUrl = new URL("/inference", self.location.origin); | |
| const tokenPrefix = "Bearer "; | |
| const token = await getSearchTokenHash(); | |
| const response = await fetch(inferenceUrl.toString(), { | |
| method: "POST", | |
| headers: { | |
| "Content-Type": "application/json", | |
| Authorization: `${tokenPrefix}${token}`, | |
| }, | |
| body: JSON.stringify({ | |
| ...getDefaultChatCompletionCreateParamsStreaming(), | |
| messages: [ | |
| { | |
| role: "user", | |
| content: getSystemPrompt(getFormattedSearchResults(true)), | |
| }, | |
| { role: "assistant", content: "Ok!" }, | |
| { role: "user", content: getQuery() }, | |
| ], | |
| } as ChatCompletionCreateParamsStreaming), | |
| }); | |
| if (!response.ok || !response.body) { | |
| throw new Error(`HTTP error! status: ${response.status}`); | |
| } | |
| const reader = response.body.getReader(); | |
| const decoder = new TextDecoder("utf-8"); | |
| let streamedMessage = ""; | |
| while (true) { | |
| const { done, value } = await reader.read(); | |
| if (done) break; | |
| const chunk = decoder.decode(value); | |
| const lines = chunk.split("\n"); | |
| const parsedLines = lines | |
| .map((line) => line.replace(/^data: /, "").trim()) | |
| .filter((line) => line !== "" && line !== "[DONE]") | |
| .map((line) => JSON.parse(line)); | |
| for (const parsedLine of parsedLines) { | |
| const deltaContent = parsedLine.choices[0].delta.content; | |
| if (deltaContent) streamedMessage += deltaContent; | |
| if (getTextGenerationState() === "interrupted") { | |
| reader.cancel(); | |
| } else if (getTextGenerationState() !== "generating") { | |
| updateTextGenerationState("generating"); | |
| } | |
| updateResponseRateLimited(streamedMessage); | |
| } | |
| } | |
| updateResponse(streamedMessage); | |
| } | |
| async function generateTextWithWebLlm() { | |
| const { CreateWebWorkerMLCEngine, CreateMLCEngine, hasModelInCache } = | |
| await import("@mlc-ai/web-llm"); | |
| type InitProgressCallback = import("@mlc-ai/web-llm").InitProgressCallback; | |
| type MLCEngineConfig = import("@mlc-ai/web-llm").MLCEngineConfig; | |
| type ChatOptions = import("@mlc-ai/web-llm").ChatOptions; | |
| const selectedModelId = getSettings().webLlmModelId; | |
| addLogEntry(`Selected WebLLM model: ${selectedModelId}`); | |
| const isModelCached = await hasModelInCache(selectedModelId); | |
| let initProgressCallback: InitProgressCallback | undefined; | |
| if (isModelCached) { | |
| updateTextGenerationState("preparingToGenerate"); | |
| } else { | |
| initProgressCallback = (report) => { | |
| updateModelLoadingProgress(Math.round(report.progress * 100)); | |
| }; | |
| } | |
| const engineConfig: MLCEngineConfig = { | |
| initProgressCallback, | |
| logLevel: "SILENT", | |
| }; | |
| const chatOptions: ChatOptions = { | |
| repetition_penalty: defaultSettings.inferenceRepeatPenalty, | |
| }; | |
| const engine = Worker | |
| ? await CreateWebWorkerMLCEngine( | |
| new Worker(new URL("./webLlmWorker.ts", import.meta.url), { | |
| type: "module", | |
| }), | |
| selectedModelId, | |
| engineConfig, | |
| chatOptions, | |
| ) | |
| : await CreateMLCEngine(selectedModelId, engineConfig, chatOptions); | |
| if (getSettings().enableAiResponse) { | |
| await canStartResponding(); | |
| updateTextGenerationState("preparingToGenerate"); | |
| const completion = await engine.chat.completions.create({ | |
| ...getDefaultChatCompletionCreateParamsStreaming(), | |
| messages: [ | |
| { | |
| role: "user", | |
| content: getSystemPrompt(getFormattedSearchResults(true)), | |
| }, | |
| { role: "assistant", content: "Ok!" }, | |
| { role: "user", content: getQuery() }, | |
| ], | |
| }); | |
| let streamedMessage = ""; | |
| for await (const chunk of completion) { | |
| const deltaContent = chunk.choices[0].delta.content; | |
| if (deltaContent) streamedMessage += deltaContent; | |
| if (getTextGenerationState() === "interrupted") { | |
| await engine.interruptGenerate(); | |
| } else if (getTextGenerationState() !== "generating") { | |
| updateTextGenerationState("generating"); | |
| } | |
| updateResponseRateLimited(streamedMessage); | |
| } | |
| updateResponse(streamedMessage); | |
| } | |
| addLogEntry( | |
| `WebLLM finished generating the response. Stats: ${await engine.runtimeStatsText()}`, | |
| ); | |
| engine.unload(); | |
| } | |
| async function generateTextWithWllama() { | |
| const { initializeWllama, wllamaModels } = await import("./wllama"); | |
| let loadingPercentage = 0; | |
| const model = wllamaModels[getSettings().wllamaModelId]; | |
| const wllama = await initializeWllama(model.url, { | |
| wllama: { | |
| suppressNativeLog: true, | |
| }, | |
| model: { | |
| n_threads: getSettings().cpuThreads, | |
| n_ctx: model.contextSize, | |
| cache_type_k: model.cacheType, | |
| embeddings: false, | |
| allowOffline: true, | |
| progressCallback: ({ loaded, total }) => { | |
| const progressPercentage = Math.round((loaded / total) * 100); | |
| if (loadingPercentage !== progressPercentage) { | |
| loadingPercentage = progressPercentage; | |
| updateModelLoadingProgress(progressPercentage); | |
| } | |
| }, | |
| }, | |
| }); | |
| if (getSettings().enableAiResponse) { | |
| await canStartResponding(); | |
| updateTextGenerationState("preparingToGenerate"); | |
| const prompt = await model.buildPrompt( | |
| wllama, | |
| getQuery(), | |
| getFormattedSearchResults(model.shouldIncludeUrlsOnPrompt), | |
| ); | |
| let streamedMessage = ""; | |
| await wllama.createCompletion(prompt, { | |
| stopTokens: model.stopTokens, | |
| sampling: model.sampling, | |
| onNewToken: (_token, _piece, currentText, { abortSignal }) => { | |
| if (getTextGenerationState() === "interrupted") { | |
| abortSignal(); | |
| } else if (getTextGenerationState() !== "generating") { | |
| updateTextGenerationState("generating"); | |
| } | |
| if (model.stopStrings) { | |
| for (const stopString of model.stopStrings) { | |
| if ( | |
| currentText.slice(-(stopString.length * 2)).includes(stopString) | |
| ) { | |
| abortSignal(); | |
| currentText = currentText.slice(0, -stopString.length); | |
| break; | |
| } | |
| } | |
| } | |
| streamedMessage = currentText; | |
| updateResponseRateLimited(streamedMessage); | |
| }, | |
| }); | |
| updateResponse(streamedMessage); | |
| } | |
| await wllama.exit(); | |
| } | |
| function getFormattedSearchResults(shouldIncludeUrl: boolean) { | |
| const searchResults = getSearchResults().textResults.slice( | |
| 0, | |
| getSettings().searchResultsToConsider, | |
| ); | |
| if (searchResults.length === 0) return "None."; | |
| if (shouldIncludeUrl) { | |
| return searchResults | |
| .map( | |
| ([title, snippet, url], index) => | |
| `${index + 1}. [${title}](${url}) | ${snippet}`, | |
| ) | |
| .join("\n"); | |
| } | |
| return searchResults | |
| .map(([title, snippet]) => `- ${title} | ${snippet}`) | |
| .join("\n"); | |
| } | |
| async function getKeywords(text: string, limit?: number) { | |
| return (await import("keyword-extractor")).default | |
| .extract(text, { language: "english" }) | |
| .slice(0, limit); | |
| } | |
| async function startSearch(query: string) { | |
| updateSearchState("running"); | |
| let searchResults = await search( | |
| query.length > 2000 ? (await getKeywords(query, 20)).join(" ") : query, | |
| 30, | |
| ); | |
| if (searchResults.textResults.length === 0) { | |
| const queryKeywords = await getKeywords(query, 10); | |
| searchResults = await search(queryKeywords.join(" "), 30); | |
| } | |
| updateSearchState( | |
| searchResults.textResults.length === 0 ? "failed" : "completed", | |
| ); | |
| updateSearchResults(searchResults); | |
| return searchResults; | |
| } | |
| async function canStartResponding() { | |
| if (getSettings().searchResultsToConsider > 0) { | |
| updateTextGenerationState("awaitingSearchResults"); | |
| await getSearchPromise(); | |
| } | |
| } | |
| function updateResponseRateLimited(text: string) { | |
| const currentTime = Date.now(); | |
| if ( | |
| currentTime - updateResponseRateLimited.lastUpdateTime >= | |
| updateResponseRateLimited.updateInterval | |
| ) { | |
| updateResponse(text); | |
| updateResponseRateLimited.lastUpdateTime = currentTime; | |
| } | |
| } | |
| updateResponseRateLimited.lastUpdateTime = 0; | |
| updateResponseRateLimited.updateInterval = 1000 / 12; | |
| class ChatGenerationError extends Error { | |
| constructor(message: string) { | |
| super(message); | |
| this.name = "ChatGenerationError"; | |
| } | |
| } | |
| async function generateChatWithOpenAI( | |
| messages: ChatMessage[], | |
| onUpdate: (partialResponse: string) => void, | |
| ) { | |
| const settings = getSettings(); | |
| const openai = getOpenAiClient({ | |
| baseURL: settings.openAiApiBaseUrl, | |
| apiKey: settings.openAiApiKey, | |
| }); | |
| const completion = await openai.chat.completions.create({ | |
| ...getDefaultChatCompletionCreateParamsStreaming(), | |
| model: settings.openAiApiModel, | |
| messages: messages as ChatCompletionMessageParam[], | |
| }); | |
| let streamedMessage = ""; | |
| for await (const chunk of completion) { | |
| const deltaContent = chunk.choices[0].delta.content; | |
| if (deltaContent) { | |
| streamedMessage += deltaContent; | |
| onUpdate(streamedMessage); | |
| } | |
| if (getTextGenerationState() === "interrupted") { | |
| completion.controller.abort(); | |
| throw new ChatGenerationError("Chat generation interrupted"); | |
| } | |
| } | |
| return streamedMessage; | |
| } | |
| async function generateChatWithInternalApi( | |
| messages: ChatMessage[], | |
| onUpdate: (partialResponse: string) => void, | |
| ) { | |
| const inferenceUrl = new URL("/inference", self.location.origin); | |
| const tokenPrefix = "Bearer "; | |
| const token = await getSearchTokenHash(); | |
| const response = await fetch(inferenceUrl.toString(), { | |
| method: "POST", | |
| headers: { | |
| "Content-Type": "application/json", | |
| Authorization: `${tokenPrefix}${token}`, | |
| }, | |
| body: JSON.stringify({ | |
| ...getDefaultChatCompletionCreateParamsStreaming(), | |
| messages, | |
| } as ChatCompletionCreateParamsStreaming), | |
| }); | |
| if (!response.ok || !response.body) { | |
| throw new Error(`HTTP error! status: ${response.status}`); | |
| } | |
| const reader = response.body.getReader(); | |
| const decoder = new TextDecoder("utf-8"); | |
| let streamedMessage = ""; | |
| while (true) { | |
| const { done, value } = await reader.read(); | |
| if (done) break; | |
| const chunk = decoder.decode(value); | |
| const lines = chunk.split("\n"); | |
| const parsedLines = lines | |
| .map((line) => line.replace(/^data: /, "").trim()) | |
| .filter((line) => line !== "" && line !== "[DONE]") | |
| .map((line) => JSON.parse(line)); | |
| for (const parsedLine of parsedLines) { | |
| const deltaContent = parsedLine.choices[0].delta.content; | |
| if (deltaContent) { | |
| streamedMessage += deltaContent; | |
| onUpdate(streamedMessage); | |
| } | |
| if (getTextGenerationState() === "interrupted") { | |
| reader.cancel(); | |
| throw new ChatGenerationError("Chat generation interrupted"); | |
| } | |
| } | |
| } | |
| return streamedMessage; | |
| } | |
| async function generateChatWithWebLlm( | |
| messages: ChatMessage[], | |
| onUpdate: (partialResponse: string) => void, | |
| ) { | |
| const { CreateWebWorkerMLCEngine, CreateMLCEngine } = await import( | |
| "@mlc-ai/web-llm" | |
| ); | |
| type MLCEngineConfig = import("@mlc-ai/web-llm").MLCEngineConfig; | |
| type ChatCompletionMessageParam = | |
| import("@mlc-ai/web-llm").ChatCompletionMessageParam; | |
| const selectedModelId = getSettings().webLlmModelId; | |
| addLogEntry(`Selected WebLLM model for chat: ${selectedModelId}`); | |
| const engineConfig: MLCEngineConfig = { | |
| logLevel: "SILENT", | |
| }; | |
| const chatOptions: ChatOptions = { | |
| repetition_penalty: defaultSettings.inferenceRepeatPenalty, | |
| }; | |
| const engine = Worker | |
| ? await CreateWebWorkerMLCEngine( | |
| new Worker(new URL("./webLlmWorker.ts", import.meta.url), { | |
| type: "module", | |
| }), | |
| selectedModelId, | |
| engineConfig, | |
| chatOptions, | |
| ) | |
| : await CreateMLCEngine(selectedModelId, engineConfig, chatOptions); | |
| const completion = await engine.chat.completions.create({ | |
| ...getDefaultChatCompletionCreateParamsStreaming(), | |
| messages: messages as ChatCompletionMessageParam[], | |
| }); | |
| let streamedMessage = ""; | |
| for await (const chunk of completion) { | |
| const deltaContent = chunk.choices[0].delta.content; | |
| if (deltaContent) { | |
| streamedMessage += deltaContent; | |
| onUpdate(streamedMessage); | |
| } | |
| if (getTextGenerationState() === "interrupted") { | |
| await engine.interruptGenerate(); | |
| throw new ChatGenerationError("Chat generation interrupted"); | |
| } | |
| } | |
| addLogEntry( | |
| `WebLLM finished generating the chat response. Stats: ${await engine.runtimeStatsText()}`, | |
| ); | |
| engine.unload(); | |
| return streamedMessage; | |
| } | |
| async function generateChatWithWllama( | |
| messages: ChatMessage[], | |
| onUpdate: (partialResponse: string) => void, | |
| ) { | |
| const { initializeWllama, wllamaModels } = await import("./wllama"); | |
| const model = wllamaModels[getSettings().wllamaModelId]; | |
| const wllama = await initializeWllama(model.url, { | |
| wllama: { | |
| suppressNativeLog: true, | |
| }, | |
| model: { | |
| n_threads: getSettings().cpuThreads, | |
| n_ctx: model.contextSize, | |
| cache_type_k: model.cacheType, | |
| embeddings: false, | |
| allowOffline: true, | |
| }, | |
| }); | |
| const prompt = await model.buildPrompt( | |
| wllama, | |
| messages[messages.length - 1].content, | |
| getFormattedSearchResults(model.shouldIncludeUrlsOnPrompt), | |
| ); | |
| let streamedMessage = ""; | |
| await wllama.createCompletion(prompt, { | |
| stopTokens: model.stopTokens, | |
| sampling: model.sampling, | |
| onNewToken: (_token, _piece, currentText, { abortSignal }) => { | |
| if (getTextGenerationState() === "interrupted") { | |
| abortSignal(); | |
| throw new ChatGenerationError("Chat generation interrupted"); | |
| } | |
| if (model.stopStrings) { | |
| for (const stopString of model.stopStrings) { | |
| if ( | |
| currentText.slice(-(stopString.length * 2)).includes(stopString) | |
| ) { | |
| abortSignal(); | |
| currentText = currentText.slice(0, -stopString.length); | |
| break; | |
| } | |
| } | |
| } | |
| streamedMessage = currentText; | |
| onUpdate(streamedMessage); | |
| }, | |
| }); | |
| await wllama.exit(); | |
| return streamedMessage; | |
| } | |
| export async function generateChatResponse( | |
| newMessages: ChatMessage[], | |
| onUpdate: (partialResponse: string) => void, | |
| ) { | |
| const settings = getSettings(); | |
| let response = ""; | |
| try { | |
| const allMessages = [ | |
| { | |
| role: "user", | |
| content: getSystemPrompt(getFormattedSearchResults(true)), | |
| }, | |
| { role: "assistant", content: "Ok!" }, | |
| ...newMessages, | |
| ]; | |
| const lastMessagesReversed: ChatMessage[] = []; | |
| let totalTokens = 0; | |
| for (const message of allMessages.reverse()) { | |
| const newTotalTokens = | |
| totalTokens + gptTokenizer.encode(message.content).length; | |
| if (newTotalTokens > 1280) break; | |
| totalTokens = newTotalTokens; | |
| lastMessagesReversed.push(message); | |
| } | |
| const lastMessages = lastMessagesReversed.reverse(); | |
| if (settings.inferenceType === "openai") { | |
| response = await generateChatWithOpenAI(lastMessages, onUpdate); | |
| } else if (settings.inferenceType === "internal") { | |
| response = await generateChatWithInternalApi(lastMessages, onUpdate); | |
| } else { | |
| if (isWebGPUAvailable && settings.enableWebGpu) { | |
| response = await generateChatWithWebLlm(lastMessages, onUpdate); | |
| } else { | |
| response = await generateChatWithWllama(lastMessages, onUpdate); | |
| } | |
| } | |
| } catch (error) { | |
| if (error instanceof ChatGenerationError) { | |
| addLogEntry(`Chat generation interrupted: ${error.message}`); | |
| } else { | |
| addLogEntry(`Error generating chat response: ${error}`); | |
| } | |
| throw error; | |
| } | |
| return response; | |
| } | |
| export interface ChatMessage { | |
| role: "user" | "assistant" | string; | |
| content: string; | |
| } | |
| function canDownloadModels(): Promise<void> { | |
| return new Promise((resolve) => { | |
| if (getSettings().allowAiModelDownload) { | |
| resolve(); | |
| } else { | |
| updateTextGenerationState("awaitingModelDownloadAllowance"); | |
| listenToSettingsChanges((settings) => { | |
| if (settings.allowAiModelDownload) { | |
| resolve(); | |
| } | |
| }); | |
| } | |
| }); | |
| } | |
| function getDefaultChatCompletionCreateParamsStreaming() { | |
| return { | |
| stream: true, | |
| max_tokens: 2048, | |
| temperature: defaultSettings.inferenceTemperature, | |
| top_p: defaultSettings.inferenceTopP, | |
| frequency_penalty: defaultSettings.inferenceFrequencyPenalty, | |
| presence_penalty: defaultSettings.inferencePresencePenalty, | |
| } as const; | |
| } | |