File size: 4,866 Bytes
266d7bc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
from collections.abc import AsyncGenerator, Callable

import opik

from src.api.models.api_models import SearchResult
from src.api.models.provider_models import MODEL_REGISTRY
from src.api.services.providers.huggingface_service import generate_huggingface, stream_huggingface
from src.api.services.providers.openai_service import generate_openai, stream_openai
from src.api.services.providers.openrouter_service import generate_openrouter, stream_openrouter
from src.api.services.providers.utils.evaluation_metrics import evaluate_metrics
from src.api.services.providers.utils.prompts import build_research_prompt
from src.utils.logger_util import setup_logging

logger = setup_logging()


# -----------------------
# Non-streaming answer generator
# -----------------------
@opik.track(name="generate_answer")
async def generate_answer(
    query: str,
    contexts: list[SearchResult],
    provider: str = "openrouter",
    selected_model: str | None = None,
) -> dict:
    """Generate a non-streaming answer using the specified LLM provider.

    Args:
        query (str): The user's research query.
        contexts (list[SearchResult]): List of context documents with metadata.
        provider (str): The LLM provider to use ("openai", "openrouter", "huggingface").

    Returns:
        dict: {"answer": str, "sources": list[str], "model": Optional[str]}

    """
    prompt = build_research_prompt(contexts, query=query)
    model_used: str | None = None
    finish_reason: str | None = None

    provider_lower = provider.lower()

    config = MODEL_REGISTRY.get_config(provider_lower)

    if provider_lower == "openai":
        answer, model_used = await generate_openai(prompt, config=config)
    elif provider_lower == "openrouter":
        try:
            answer, model_used, finish_reason = await generate_openrouter(
                prompt, config=config, selected_model=selected_model
            )
            metrics_results = await evaluate_metrics(answer, prompt)
            logger.info(f"G-Eval Faithfulness → {metrics_results}")
        except Exception as e:
            logger.error(f"Error occurred while generating answer from {provider_lower}: {e}")
            raise

    elif provider_lower == "huggingface":
        answer, model_used = await generate_huggingface(prompt, config=config)
    else:
        raise ValueError(f"Unknown provider: {provider}")

    return {
        "answer": answer,
        "sources": [r.url for r in contexts],
        "model": model_used,
        "finish_reason": finish_reason,
    }


# -----------------------
# Streaming answer generator
# -----------------------
@opik.track(name="get_streaming_function")
def get_streaming_function(
    provider: str,
    query: str,
    contexts: list[SearchResult],
    selected_model: str | None = None,
) -> Callable[[], AsyncGenerator[str, None]]:
    """Get a streaming function for the specified LLM provider.

    Args:
        provider (str): The LLM provider to use ("openai", "openrouter", "huggingface").
        query (str): The user's research query.
        contexts (list[SearchResult]): List of context documents with metadata.

    Returns:
        Callable[[], AsyncGenerator[str, None]]: A function that returns an async generator yielding
        response chunks.

    """
    prompt = build_research_prompt(contexts, query=query)
    provider_lower = provider.lower()
    config = MODEL_REGISTRY.get_config(provider_lower)
    logger.info(f"Using model config: {config}")

    async def stream_gen() -> AsyncGenerator[str, None]:
        """Asynchronous generator that streams response chunks from the specified provider.

        Yields:
            str: The next chunk of the response.

        """
        buffer = []  # collect all chunks here

        if provider_lower == "openai":
            async for chunk in stream_openai(prompt, config=config):
                buffer.append(chunk)
                yield chunk

        elif provider_lower == "openrouter":
            try:
                async for chunk in stream_openrouter(
                    prompt, config=config, selected_model=selected_model
                ):
                    buffer.append(chunk)
                    yield chunk

                full_output = "".join(buffer)
                metrics_results = await evaluate_metrics(full_output, prompt)
                logger.info(f"Metrics results: {metrics_results}")

            except Exception as e:
                logger.error(f"Error occurred while streaming from {provider}: {e}")
                yield "__error__"

        elif provider_lower == "huggingface":
            async for chunk in stream_huggingface(prompt, config=config):
                buffer.append(chunk)
                yield chunk

        else:
            raise ValueError(f"Unknown provider: {provider}")

    return stream_gen