import type { ConversationEntityMembers } from "$lib/state/conversations.svelte"; import type { ChatCompletionInputMessage, GenerationParameters, PipelineType } from "@huggingface/tasks"; import { HF_HUB_INFERENCE_PROXY_TEMPLATE, type InferenceSnippet, type ModelDataMinimal, type SnippetInferenceProvider, getModelInputSnippet, stringifyGenerationConfig, stringifyMessages, } from "@huggingface/tasks"; export const snippetBasic = ( model: ModelDataMinimal, accessToken: string, provider: SnippetInferenceProvider ): InferenceSnippet[] => { if (provider !== "hf-inference") { return []; } return [ { client: "curl", content: `\ curl https://router.huggingface.co/hf-inference/models/${model.id} \\ -X POST \\ -d '{"inputs": ${getModelInputSnippet(model, true)}}' \\ -H 'Content-Type: application/json' \\ -H 'Authorization: Bearer ${accessToken || `{API_TOKEN}`}'`, }, ]; }; export const snippetTextGeneration = ( model: ModelDataMinimal, accessToken: string, provider: SnippetInferenceProvider, providerModelId?: string, opts?: { streaming?: boolean; messages?: ChatCompletionInputMessage[]; temperature?: GenerationParameters["temperature"]; max_tokens?: GenerationParameters["max_tokens"]; top_p?: GenerationParameters["top_p"]; structured_output?: ConversationEntityMembers["structuredOutput"]; } ): InferenceSnippet[] => { if (model.tags.includes("conversational")) { const baseUrl = provider === "hf-inference" ? `https://router.huggingface.co/hf-inference/models/${model.id}/v1/chat/completions` : HF_HUB_INFERENCE_PROXY_TEMPLATE.replace("{{PROVIDER}}", provider) + "/v1/chat/completions"; const modelId = providerModelId ?? model.id; // Conversational model detected, so we display a code snippet that features the Messages API const streaming = opts?.streaming ?? true; const exampleMessages = getModelInputSnippet(model) as ChatCompletionInputMessage[]; const messages = opts?.messages ?? exampleMessages; const config = { ...(opts?.temperature ? { temperature: opts.temperature } : undefined), max_tokens: opts?.max_tokens ?? 500, ...(opts?.top_p ? { top_p: opts.top_p } : undefined), ...(opts?.structured_output?.enabled ? { response_format: JSON.stringify( { type: "json_schema", json_schema: JSON.parse(opts.structured_output.schema ?? ""), }, null, 6 ), } : undefined), }; return [ { client: "curl", content: `curl '${baseUrl}' \\ -H 'Authorization: Bearer ${accessToken || `{API_TOKEN}`}' \\ -H 'Content-Type: application/json' \\ --data '{ "model": "${modelId}", "messages": ${stringifyMessages(messages, { indent: "\t", attributeKeyQuotes: true, customContentEscaper: str => str.replace(/'/g, "'\\''"), })}, ${stringifyGenerationConfig(config, { indent: "\n ", attributeKeyQuotes: true, attributeValueConnector: ": ", })} "stream": ${!!streaming} }'`, }, ]; } else { return snippetBasic(model, accessToken, provider); } }; export const snippetZeroShotClassification = ( model: ModelDataMinimal, accessToken: string, provider: SnippetInferenceProvider ): InferenceSnippet[] => { if (provider !== "hf-inference") { return []; } return [ { client: "curl", content: `curl https://router.huggingface.co/hf-inference/models/${model.id} \\ -X POST \\ -d '{"inputs": ${getModelInputSnippet(model, true)}, "parameters": {"candidate_labels": ["refund", "legal", "faq"]}}' \\ -H 'Content-Type: application/json' \\ -H 'Authorization: Bearer ${accessToken || `{API_TOKEN}`}'`, }, ]; }; export const snippetFile = ( model: ModelDataMinimal, accessToken: string, provider: SnippetInferenceProvider ): InferenceSnippet[] => { if (provider !== "hf-inference") { return []; } return [ { client: "curl", content: `curl https://router.huggingface.co/hf-inference/models/${model.id} \\ -X POST \\ --data-binary '@${getModelInputSnippet(model, true, true)}' \\ -H 'Authorization: Bearer ${accessToken || `{API_TOKEN}`}'`, }, ]; }; export const curlSnippets: Partial< Record< PipelineType, ( model: ModelDataMinimal, accessToken: string, provider: SnippetInferenceProvider, providerModelId?: string, opts?: Record ) => InferenceSnippet[] > > = { // Same order as in tasks/src/pipelines.ts "text-classification": snippetBasic, "token-classification": snippetBasic, "table-question-answering": snippetBasic, "question-answering": snippetBasic, "zero-shot-classification": snippetZeroShotClassification, "translation": snippetBasic, "summarization": snippetBasic, "feature-extraction": snippetBasic, "text-generation": snippetTextGeneration, "image-text-to-text": snippetTextGeneration, "text2text-generation": snippetBasic, "fill-mask": snippetBasic, "sentence-similarity": snippetBasic, "automatic-speech-recognition": snippetFile, "text-to-image": snippetBasic, "text-to-speech": snippetBasic, "text-to-audio": snippetBasic, "audio-to-audio": snippetFile, "audio-classification": snippetFile, "image-classification": snippetFile, "image-to-text": snippetFile, "object-detection": snippetFile, "image-segmentation": snippetFile, }; export function getCurlInferenceSnippet( model: ModelDataMinimal, accessToken: string, provider: SnippetInferenceProvider, providerModelId?: string, opts?: Record ): InferenceSnippet[] { return model.pipeline_tag && model.pipeline_tag in curlSnippets ? (curlSnippets[model.pipeline_tag]?.(model, accessToken, provider, providerModelId, opts) ?? []) : []; }