Spaces:
Sleeping
Sleeping
File size: 4,866 Bytes
266d7bc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 |
from collections.abc import AsyncGenerator, Callable
import opik
from src.api.models.api_models import SearchResult
from src.api.models.provider_models import MODEL_REGISTRY
from src.api.services.providers.huggingface_service import generate_huggingface, stream_huggingface
from src.api.services.providers.openai_service import generate_openai, stream_openai
from src.api.services.providers.openrouter_service import generate_openrouter, stream_openrouter
from src.api.services.providers.utils.evaluation_metrics import evaluate_metrics
from src.api.services.providers.utils.prompts import build_research_prompt
from src.utils.logger_util import setup_logging
logger = setup_logging()
# -----------------------
# Non-streaming answer generator
# -----------------------
@opik.track(name="generate_answer")
async def generate_answer(
query: str,
contexts: list[SearchResult],
provider: str = "openrouter",
selected_model: str | None = None,
) -> dict:
"""Generate a non-streaming answer using the specified LLM provider.
Args:
query (str): The user's research query.
contexts (list[SearchResult]): List of context documents with metadata.
provider (str): The LLM provider to use ("openai", "openrouter", "huggingface").
Returns:
dict: {"answer": str, "sources": list[str], "model": Optional[str]}
"""
prompt = build_research_prompt(contexts, query=query)
model_used: str | None = None
finish_reason: str | None = None
provider_lower = provider.lower()
config = MODEL_REGISTRY.get_config(provider_lower)
if provider_lower == "openai":
answer, model_used = await generate_openai(prompt, config=config)
elif provider_lower == "openrouter":
try:
answer, model_used, finish_reason = await generate_openrouter(
prompt, config=config, selected_model=selected_model
)
metrics_results = await evaluate_metrics(answer, prompt)
logger.info(f"G-Eval Faithfulness → {metrics_results}")
except Exception as e:
logger.error(f"Error occurred while generating answer from {provider_lower}: {e}")
raise
elif provider_lower == "huggingface":
answer, model_used = await generate_huggingface(prompt, config=config)
else:
raise ValueError(f"Unknown provider: {provider}")
return {
"answer": answer,
"sources": [r.url for r in contexts],
"model": model_used,
"finish_reason": finish_reason,
}
# -----------------------
# Streaming answer generator
# -----------------------
@opik.track(name="get_streaming_function")
def get_streaming_function(
provider: str,
query: str,
contexts: list[SearchResult],
selected_model: str | None = None,
) -> Callable[[], AsyncGenerator[str, None]]:
"""Get a streaming function for the specified LLM provider.
Args:
provider (str): The LLM provider to use ("openai", "openrouter", "huggingface").
query (str): The user's research query.
contexts (list[SearchResult]): List of context documents with metadata.
Returns:
Callable[[], AsyncGenerator[str, None]]: A function that returns an async generator yielding
response chunks.
"""
prompt = build_research_prompt(contexts, query=query)
provider_lower = provider.lower()
config = MODEL_REGISTRY.get_config(provider_lower)
logger.info(f"Using model config: {config}")
async def stream_gen() -> AsyncGenerator[str, None]:
"""Asynchronous generator that streams response chunks from the specified provider.
Yields:
str: The next chunk of the response.
"""
buffer = [] # collect all chunks here
if provider_lower == "openai":
async for chunk in stream_openai(prompt, config=config):
buffer.append(chunk)
yield chunk
elif provider_lower == "openrouter":
try:
async for chunk in stream_openrouter(
prompt, config=config, selected_model=selected_model
):
buffer.append(chunk)
yield chunk
full_output = "".join(buffer)
metrics_results = await evaluate_metrics(full_output, prompt)
logger.info(f"Metrics results: {metrics_results}")
except Exception as e:
logger.error(f"Error occurred while streaming from {provider}: {e}")
yield "__error__"
elif provider_lower == "huggingface":
async for chunk in stream_huggingface(prompt, config=config):
buffer.append(chunk)
yield chunk
else:
raise ValueError(f"Unknown provider: {provider}")
return stream_gen
|