File size: 2,081 Bytes
266d7bc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
from collections.abc import AsyncGenerator

from huggingface_hub import AsyncInferenceClient

from src.api.models.provider_models import ModelConfig
from src.api.services.providers.utils.messages import build_messages
from src.config import settings
from src.utils.logger_util import setup_logging

logger = setup_logging()

# -----------------------
# Hugging Face client
# -----------------------
hf_key = settings.hugging_face.api_key
hf_client = AsyncInferenceClient(provider="auto", api_key=hf_key)


async def generate_huggingface(prompt: str, config: ModelConfig) -> tuple[str, None]:
    """Generate a response from Hugging Face for a given prompt and model configuration.

    Args:
        prompt (str): The input prompt.
        config (ModelConfig): The model configuration.

    Returns:
        tuple[str, None]: The generated response and None for model and finish reason.

    """
    resp = await hf_client.chat.completions.create(
        model=config.primary_model,
        messages=build_messages(prompt),
        temperature=config.temperature,
        max_tokens=config.max_completion_tokens,
    )
    return resp.choices[0].message.content or "", None


def stream_huggingface(prompt: str, config: ModelConfig) -> AsyncGenerator[str, None]:
    """Stream a response from Hugging Face for a given prompt and model configuration.

    Args:
        prompt (str): The input prompt.
        config (ModelConfig): The model configuration.

    Returns:
        AsyncGenerator[str, None]: An asynchronous generator yielding response chunks.

    """

    async def gen() -> AsyncGenerator[str, None]:
        stream = await hf_client.chat.completions.create(
            model=config.primary_model,
            messages=build_messages(prompt),
            temperature=config.temperature,
            max_tokens=config.max_completion_tokens,
            stream=True,
        )
        async for chunk in stream:
            delta_text = getattr(chunk.choices[0].delta, "content", None)
            if delta_text:
                yield delta_text

    return gen()