Search_Engine / src /api /services /generation_service.py
IndraneelKumar
Initial search engine commit
266d7bc
from collections.abc import AsyncGenerator, Callable
import opik
from src.api.models.api_models import SearchResult
from src.api.models.provider_models import MODEL_REGISTRY
from src.api.services.providers.huggingface_service import generate_huggingface, stream_huggingface
from src.api.services.providers.openai_service import generate_openai, stream_openai
from src.api.services.providers.openrouter_service import generate_openrouter, stream_openrouter
from src.api.services.providers.utils.evaluation_metrics import evaluate_metrics
from src.api.services.providers.utils.prompts import build_research_prompt
from src.utils.logger_util import setup_logging
logger = setup_logging()
# -----------------------
# Non-streaming answer generator
# -----------------------
@opik.track(name="generate_answer")
async def generate_answer(
query: str,
contexts: list[SearchResult],
provider: str = "openrouter",
selected_model: str | None = None,
) -> dict:
"""Generate a non-streaming answer using the specified LLM provider.
Args:
query (str): The user's research query.
contexts (list[SearchResult]): List of context documents with metadata.
provider (str): The LLM provider to use ("openai", "openrouter", "huggingface").
Returns:
dict: {"answer": str, "sources": list[str], "model": Optional[str]}
"""
prompt = build_research_prompt(contexts, query=query)
model_used: str | None = None
finish_reason: str | None = None
provider_lower = provider.lower()
config = MODEL_REGISTRY.get_config(provider_lower)
if provider_lower == "openai":
answer, model_used = await generate_openai(prompt, config=config)
elif provider_lower == "openrouter":
try:
answer, model_used, finish_reason = await generate_openrouter(
prompt, config=config, selected_model=selected_model
)
metrics_results = await evaluate_metrics(answer, prompt)
logger.info(f"G-Eval Faithfulness → {metrics_results}")
except Exception as e:
logger.error(f"Error occurred while generating answer from {provider_lower}: {e}")
raise
elif provider_lower == "huggingface":
answer, model_used = await generate_huggingface(prompt, config=config)
else:
raise ValueError(f"Unknown provider: {provider}")
return {
"answer": answer,
"sources": [r.url for r in contexts],
"model": model_used,
"finish_reason": finish_reason,
}
# -----------------------
# Streaming answer generator
# -----------------------
@opik.track(name="get_streaming_function")
def get_streaming_function(
provider: str,
query: str,
contexts: list[SearchResult],
selected_model: str | None = None,
) -> Callable[[], AsyncGenerator[str, None]]:
"""Get a streaming function for the specified LLM provider.
Args:
provider (str): The LLM provider to use ("openai", "openrouter", "huggingface").
query (str): The user's research query.
contexts (list[SearchResult]): List of context documents with metadata.
Returns:
Callable[[], AsyncGenerator[str, None]]: A function that returns an async generator yielding
response chunks.
"""
prompt = build_research_prompt(contexts, query=query)
provider_lower = provider.lower()
config = MODEL_REGISTRY.get_config(provider_lower)
logger.info(f"Using model config: {config}")
async def stream_gen() -> AsyncGenerator[str, None]:
"""Asynchronous generator that streams response chunks from the specified provider.
Yields:
str: The next chunk of the response.
"""
buffer = [] # collect all chunks here
if provider_lower == "openai":
async for chunk in stream_openai(prompt, config=config):
buffer.append(chunk)
yield chunk
elif provider_lower == "openrouter":
try:
async for chunk in stream_openrouter(
prompt, config=config, selected_model=selected_model
):
buffer.append(chunk)
yield chunk
full_output = "".join(buffer)
metrics_results = await evaluate_metrics(full_output, prompt)
logger.info(f"Metrics results: {metrics_results}")
except Exception as e:
logger.error(f"Error occurred while streaming from {provider}: {e}")
yield "__error__"
elif provider_lower == "huggingface":
async for chunk in stream_huggingface(prompt, config=config):
buffer.append(chunk)
yield chunk
else:
raise ValueError(f"Unknown provider: {provider}")
return stream_gen