import type { ConversationEntityMembers } from "$lib/state/conversations.svelte"; import type { ChatCompletionInputMessage, GenerationParameters, PipelineType, WidgetType } from "@huggingface/tasks"; import { getModelInputSnippet, openAIbaseUrl, stringifyGenerationConfig, stringifyMessages, type InferenceSnippet, type ModelDataMinimal, type SnippetInferenceProvider, } from "@huggingface/tasks"; const HFJS_METHODS: Partial> = { "text-classification": "textClassification", "token-classification": "tokenClassification", "table-question-answering": "tableQuestionAnswering", "question-answering": "questionAnswering", "translation": "translation", "summarization": "summarization", "feature-extraction": "featureExtraction", "text-generation": "textGeneration", "text2text-generation": "textGeneration", "fill-mask": "fillMask", "sentence-similarity": "sentenceSimilarity", }; export const snippetBasic = ( model: ModelDataMinimal, accessToken: string, provider: SnippetInferenceProvider ): InferenceSnippet[] => { return [ ...(model.pipeline_tag && model.pipeline_tag in HFJS_METHODS ? [ { client: "huggingface.js", content: `\ import { HfInference } from "@huggingface/inference"; const client = new HfInference("${accessToken || `{API_TOKEN}`}"); const output = await client.${HFJS_METHODS[model.pipeline_tag]}({ model: "${model.id}", inputs: ${getModelInputSnippet(model)}, provider: "${provider}", }); console.log(output); `, }, ] : []), { client: "fetch", content: `\ async function query(data) { const response = await fetch( "https://router.huggingface.co/hf-inference/models/${model.id}", { headers: { Authorization: "Bearer ${accessToken || `{API_TOKEN}`}", "Content-Type": "application/json", }, method: "POST", body: JSON.stringify(data), } ); const result = await response.json(); return result; } query({"inputs": ${getModelInputSnippet(model)}}).then((response) => { console.log(JSON.stringify(response)); });`, }, ]; }; export const snippetTextGeneration = ( model: ModelDataMinimal, accessToken: string, provider: SnippetInferenceProvider, providerModelId?: string, opts?: { streaming?: boolean; messages?: ChatCompletionInputMessage[]; temperature?: GenerationParameters["temperature"]; max_tokens?: GenerationParameters["max_tokens"]; top_p?: GenerationParameters["top_p"]; structured_output?: ConversationEntityMembers["structuredOutput"]; } ): InferenceSnippet[] => { if (model.tags.includes("conversational")) { // Conversational model detected, so we display a code snippet that features the Messages API const streaming = opts?.streaming ?? true; const exampleMessages = getModelInputSnippet(model) as ChatCompletionInputMessage[]; const messages = opts?.messages ?? exampleMessages; const messagesStr = stringifyMessages(messages, { indent: "\t" }); const config = { ...(opts?.temperature ? { temperature: opts.temperature } : undefined), max_tokens: opts?.max_tokens ?? 500, ...(opts?.top_p ? { top_p: opts.top_p } : undefined), ...(opts?.structured_output?.enabled ? { response_format: JSON.stringify( { type: "json_schema", json_schema: JSON.parse(opts.structured_output.schema ?? ""), }, null, 6 ), } : undefined), }; const configStr = stringifyGenerationConfig(config, { indent: "\n\t", attributeValueConnector: ": ", }); if (streaming) { return [ { client: "huggingface.js", content: `import { HfInference } from "@huggingface/inference"; const client = new HfInference("${accessToken || `{API_TOKEN}`}"); let out = ""; const stream = client.chatCompletionStream({ model: "${model.id}", messages: ${messagesStr}, provider: "${provider}", ${configStr} }); for await (const chunk of stream) { if (chunk.choices && chunk.choices.length > 0) { const newContent = chunk.choices[0].delta.content; out += newContent; console.log(newContent); } }`, }, { client: "openai", content: `import { OpenAI } from "openai"; const client = new OpenAI({ baseURL: "${openAIbaseUrl(provider)}", apiKey: "${accessToken || `{API_TOKEN}`}" }); data. let out = ""; const stream = await client.chat.completions.create({ model: "${providerModelId ?? model.id}", messages: ${messagesStr}, ${configStr} stream: true, }); for await (const chunk of stream) { if (chunk.choices && chunk.choices.length > 0) { const newContent = chunk.choices[0].delta.content; out += newContent; console.log(newContent); } }`, }, ]; } else { return [ { client: "huggingface.js", content: `import { HfInference } from "@huggingface/inference"; const client = new HfInference("${accessToken || `{API_TOKEN}`}"); const chatCompletion = await client.chatCompletion({ model: "${model.id}", messages: ${messagesStr}, provider: "${provider}", ${configStr} }); console.log(chatCompletion.choices[0].message); `, }, { client: "openai", content: `import { OpenAI } from "openai"; const client = new OpenAI({ baseURL: "${openAIbaseUrl(provider)}", apiKey: "${accessToken || `{API_TOKEN}`}" }); const chatCompletion = await client.chat.completions.create({ model: "${providerModelId ?? model.id}", messages: ${messagesStr}, ${configStr} }); console.log(chatCompletion.choices[0].message); `, }, ]; } } else { return snippetBasic(model, accessToken, provider); } }; export const snippetZeroShotClassification = (model: ModelDataMinimal, accessToken: string): InferenceSnippet[] => { return [ { client: "fetch", content: `async function query(data) { const response = await fetch( "https://router.huggingface.co/hf-inference/models/${model.id}", { headers: { Authorization: "Bearer ${accessToken || `{API_TOKEN}`}", "Content-Type": "application/json", }, method: "POST", body: JSON.stringify(data), } ); const result = await response.json(); return result; } query({"inputs": ${getModelInputSnippet( model )}, "parameters": {"candidate_labels": ["refund", "legal", "faq"]}}).then((response) => { console.log(JSON.stringify(response)); });`, }, ]; }; export const snippetTextToImage = ( model: ModelDataMinimal, accessToken: string, provider: SnippetInferenceProvider ): InferenceSnippet[] => { return [ { client: "huggingface.js", content: `\ import { HfInference } from "@huggingface/inference"; const client = new HfInference("${accessToken || `{API_TOKEN}`}"); const image = await client.textToImage({ model: "${model.id}", inputs: ${getModelInputSnippet(model)}, parameters: { num_inference_steps: 5 }, provider: "${provider}", }); /// Use the generated image (it's a Blob) `, }, ...(provider === "hf-inference" ? [ { client: "fetch", content: `async function query(data) { const response = await fetch( "https://router.huggingface.co/hf-inference/models/${model.id}", { headers: { Authorization: "Bearer ${accessToken || `{API_TOKEN}`}", "Content-Type": "application/json", }, method: "POST", body: JSON.stringify(data), } ); const result = await response.blob(); return result; } query({"inputs": ${getModelInputSnippet(model)}}).then((response) => { // Use image });`, }, ] : []), ]; }; export const snippetTextToVideo = ( model: ModelDataMinimal, accessToken: string, provider: SnippetInferenceProvider ): InferenceSnippet[] => { return ["fal-ai", "replicate"].includes(provider) ? [ { client: "huggingface.js", content: `\ import { HfInference } from "@huggingface/inference"; const client = new HfInference("${accessToken || `{API_TOKEN}`}"); const video = await client.textToVideo({ model: "${model.id}", provider: "${provider}", inputs: ${getModelInputSnippet(model)}, parameters: { num_inference_steps: 5 }, }); // Use the generated video (it's a Blob) `, }, ] : []; }; export const snippetTextToAudio = ( model: ModelDataMinimal, accessToken: string, provider: SnippetInferenceProvider ): InferenceSnippet[] => { if (provider !== "hf-inference") { return []; } const commonSnippet = `async function query(data) { const response = await fetch( "https://router.huggingface.co/hf-inference/models/${model.id}", { headers: { Authorization: "Bearer ${accessToken || `{API_TOKEN}`}", "Content-Type": "application/json", }, method: "POST", body: JSON.stringify(data), } );`; if (model.library_name === "transformers") { return [ { client: "fetch", content: commonSnippet + ` const result = await response.blob(); return result; } query({"inputs": ${getModelInputSnippet(model)}}).then((response) => { // Returns a byte object of the Audio wavform. Use it directly! });`, }, ]; } else { return [ { client: "fetch", content: commonSnippet + ` const result = await response.json(); return result; } query({"inputs": ${getModelInputSnippet(model)}}).then((response) => { console.log(JSON.stringify(response)); });`, }, ]; } }; export const snippetAutomaticSpeechRecognition = ( model: ModelDataMinimal, accessToken: string, provider: SnippetInferenceProvider ): InferenceSnippet[] => { return [ { client: "huggingface.js", content: `\ import { HfInference } from "@huggingface/inference"; const client = new HfInference("${accessToken || `{API_TOKEN}`}"); const data = fs.readFileSync(${getModelInputSnippet(model)}); const output = await client.automaticSpeechRecognition({ data, model: "${model.id}", provider: "${provider}", }); console.log(output); `, }, ...(provider === "hf-inference" ? snippetFile(model, accessToken, provider) : []), ]; }; export const snippetFile = ( model: ModelDataMinimal, accessToken: string, provider: SnippetInferenceProvider ): InferenceSnippet[] => { if (provider !== "hf-inference") { return []; } return [ { client: "fetch", content: `async function query(filename) { const data = fs.readFileSync(filename); const response = await fetch( "https://router.huggingface.co/hf-inference/models/${model.id}", { headers: { Authorization: "Bearer ${accessToken || `{API_TOKEN}`}", "Content-Type": "application/json", }, method: "POST", body: data, } ); const result = await response.json(); return result; } query(${getModelInputSnippet(model)}).then((response) => { console.log(JSON.stringify(response)); });`, }, ]; }; export const jsSnippets: Partial< Record< PipelineType, ( model: ModelDataMinimal, accessToken: string, provider: SnippetInferenceProvider, providerModelId?: string, opts?: Record ) => InferenceSnippet[] > > = { // Same order as in tasks/src/pipelines.ts "text-classification": snippetBasic, "token-classification": snippetBasic, "table-question-answering": snippetBasic, "question-answering": snippetBasic, "zero-shot-classification": snippetZeroShotClassification, "translation": snippetBasic, "summarization": snippetBasic, "feature-extraction": snippetBasic, "text-generation": snippetTextGeneration, "image-text-to-text": snippetTextGeneration, "text2text-generation": snippetBasic, "fill-mask": snippetBasic, "sentence-similarity": snippetBasic, "automatic-speech-recognition": snippetAutomaticSpeechRecognition, "text-to-image": snippetTextToImage, "text-to-video": snippetTextToVideo, "text-to-speech": snippetTextToAudio, "text-to-audio": snippetTextToAudio, "audio-to-audio": snippetFile, "audio-classification": snippetFile, "image-classification": snippetFile, "image-to-text": snippetFile, "object-detection": snippetFile, "image-segmentation": snippetFile, }; export function getJsInferenceSnippet( model: ModelDataMinimal, accessToken: string, provider: SnippetInferenceProvider, providerModelId?: string, opts?: Record ): InferenceSnippet[] { return model.pipeline_tag && model.pipeline_tag in jsSnippets ? (jsSnippets[model.pipeline_tag]?.(model, accessToken, provider, providerModelId, opts) ?? []) : []; }