Spaces:
Sleeping
Sleeping
| from collections.abc import AsyncGenerator, Callable | |
| import opik | |
| from src.api.models.api_models import SearchResult | |
| from src.api.models.provider_models import MODEL_REGISTRY | |
| from src.api.services.providers.huggingface_service import generate_huggingface, stream_huggingface | |
| from src.api.services.providers.openai_service import generate_openai, stream_openai | |
| from src.api.services.providers.openrouter_service import generate_openrouter, stream_openrouter | |
| from src.api.services.providers.utils.evaluation_metrics import evaluate_metrics | |
| from src.api.services.providers.utils.prompts import build_research_prompt | |
| from src.utils.logger_util import setup_logging | |
| logger = setup_logging() | |
| # ----------------------- | |
| # Non-streaming answer generator | |
| # ----------------------- | |
| async def generate_answer( | |
| query: str, | |
| contexts: list[SearchResult], | |
| provider: str = "openrouter", | |
| selected_model: str | None = None, | |
| ) -> dict: | |
| """Generate a non-streaming answer using the specified LLM provider. | |
| Args: | |
| query (str): The user's research query. | |
| contexts (list[SearchResult]): List of context documents with metadata. | |
| provider (str): The LLM provider to use ("openai", "openrouter", "huggingface"). | |
| Returns: | |
| dict: {"answer": str, "sources": list[str], "model": Optional[str]} | |
| """ | |
| prompt = build_research_prompt(contexts, query=query) | |
| model_used: str | None = None | |
| finish_reason: str | None = None | |
| provider_lower = provider.lower() | |
| config = MODEL_REGISTRY.get_config(provider_lower) | |
| if provider_lower == "openai": | |
| answer, model_used = await generate_openai(prompt, config=config) | |
| elif provider_lower == "openrouter": | |
| try: | |
| answer, model_used, finish_reason = await generate_openrouter( | |
| prompt, config=config, selected_model=selected_model | |
| ) | |
| metrics_results = await evaluate_metrics(answer, prompt) | |
| logger.info(f"G-Eval Faithfulness → {metrics_results}") | |
| except Exception as e: | |
| logger.error(f"Error occurred while generating answer from {provider_lower}: {e}") | |
| raise | |
| elif provider_lower == "huggingface": | |
| answer, model_used = await generate_huggingface(prompt, config=config) | |
| else: | |
| raise ValueError(f"Unknown provider: {provider}") | |
| return { | |
| "answer": answer, | |
| "sources": [r.url for r in contexts], | |
| "model": model_used, | |
| "finish_reason": finish_reason, | |
| } | |
| # ----------------------- | |
| # Streaming answer generator | |
| # ----------------------- | |
| def get_streaming_function( | |
| provider: str, | |
| query: str, | |
| contexts: list[SearchResult], | |
| selected_model: str | None = None, | |
| ) -> Callable[[], AsyncGenerator[str, None]]: | |
| """Get a streaming function for the specified LLM provider. | |
| Args: | |
| provider (str): The LLM provider to use ("openai", "openrouter", "huggingface"). | |
| query (str): The user's research query. | |
| contexts (list[SearchResult]): List of context documents with metadata. | |
| Returns: | |
| Callable[[], AsyncGenerator[str, None]]: A function that returns an async generator yielding | |
| response chunks. | |
| """ | |
| prompt = build_research_prompt(contexts, query=query) | |
| provider_lower = provider.lower() | |
| config = MODEL_REGISTRY.get_config(provider_lower) | |
| logger.info(f"Using model config: {config}") | |
| async def stream_gen() -> AsyncGenerator[str, None]: | |
| """Asynchronous generator that streams response chunks from the specified provider. | |
| Yields: | |
| str: The next chunk of the response. | |
| """ | |
| buffer = [] # collect all chunks here | |
| if provider_lower == "openai": | |
| async for chunk in stream_openai(prompt, config=config): | |
| buffer.append(chunk) | |
| yield chunk | |
| elif provider_lower == "openrouter": | |
| try: | |
| async for chunk in stream_openrouter( | |
| prompt, config=config, selected_model=selected_model | |
| ): | |
| buffer.append(chunk) | |
| yield chunk | |
| full_output = "".join(buffer) | |
| metrics_results = await evaluate_metrics(full_output, prompt) | |
| logger.info(f"Metrics results: {metrics_results}") | |
| except Exception as e: | |
| logger.error(f"Error occurred while streaming from {provider}: {e}") | |
| yield "__error__" | |
| elif provider_lower == "huggingface": | |
| async for chunk in stream_huggingface(prompt, config=config): | |
| buffer.append(chunk) | |
| yield chunk | |
| else: | |
| raise ValueError(f"Unknown provider: {provider}") | |
| return stream_gen | |