Spaces:
Sleeping
Sleeping
File size: 8,615 Bytes
266d7bc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 |
import os
from collections.abc import AsyncGenerator
from typing import Any
import opik
from openai import AsyncOpenAI
from opik.integrations.openai import track_openai
from src.api.models.provider_models import ModelConfig
from src.api.services.providers.utils.messages import build_messages
from src.config import settings
from src.utils.logger_util import setup_logging
logger = setup_logging()
# -----------------------
# OpenRouter client
# -----------------------
openrouter_key = settings.openrouter.api_key
openrouter_url = settings.openrouter.api_url
async_openrouter_client = AsyncOpenAI(base_url=openrouter_url, api_key=openrouter_key)
# -----------------------
# Opik Observability
# -----------------------
os.environ["OPIK_API_KEY"] = settings.opik.api_key
os.environ["OPIK_PROJECT_NAME"] = settings.opik.project_name
async_openrouter_client = track_openai(async_openrouter_client)
# -----------------------
# Helper to build extra body for OpenRouter
# -----------------------
@opik.track(name="build_openrouter_extra")
def build_openrouter_extra(config: ModelConfig) -> dict[str, Any]:
"""Build the extra body for OpenRouter API requests based on the ModelConfig.
Args:
config (ModelConfig): The model configuration.
Returns:
dict[str, Any]: The extra body for OpenRouter API requests.
"""
body = {"provider": {"sort": config.provider_sort.value}}
if config.candidate_models:
body["models"] = list(config.candidate_models) # type: ignore
return body
# -----------------------
# Core OpenRouter functions
# -----------------------
@opik.track(name="generate_openrouter")
async def generate_openrouter(
prompt: str,
config: ModelConfig,
selected_model: str | None = None,
) -> tuple[str, str | None, str | None]:
"""Generate a response from OpenRouter for a given prompt and model configuration.
Args:
prompt (str): The input prompt.
config (ModelConfig): The model configuration.
selected_model (str | None): Optional specific model to use.
Returns:
tuple[str, str | None, str | None]: The generated response, model used, and finish reason.
"""
model_to_use = selected_model or config.primary_model
resp = await async_openrouter_client.chat.completions.create(
model=model_to_use,
messages=build_messages(prompt),
temperature=config.temperature,
max_completion_tokens=config.max_completion_tokens,
extra_body=build_openrouter_extra(config),
)
answer = resp.choices[0].message.content or ""
# Reasons: tool_calls, stop, length, content_filter, error
finish_reason = getattr(resp.choices[0], "native_finish_reason", None)
model_used = getattr(resp.choices[0], "model", None) or getattr(resp, "model", None)
logger.info(f"OpenRouter non-stream finish_reason: {finish_reason}")
if finish_reason == "length":
logger.warning("Response was truncated by token limit.")
model_used = getattr(resp.choices[0], "model", None) or getattr(resp, "model", None)
logger.info(f"OpenRouter non-stream finished. Model used: {model_used}")
return answer, model_used, finish_reason
@opik.track(name="stream_openrouter")
def stream_openrouter(
prompt: str,
config: ModelConfig,
selected_model: str | None = None,
) -> AsyncGenerator[str, None]:
"""Stream a response from OpenRouter for a given prompt and model configuration.
Args:
prompt (str): The input prompt.
config (ModelConfig): The model configuration.
selected_model (str | None): Optional specific model to use.
Returns:
AsyncGenerator[str, None]: An asynchronous generator yielding response chunks.
"""
async def gen() -> AsyncGenerator[str, None]:
"""Generate response chunks from OpenRouter.
Yields:
AsyncGenerator[str, None]: Response chunks.
"""
model_to_use = selected_model or config.primary_model
stream = await async_openrouter_client.chat.completions.create(
model=model_to_use,
messages=build_messages(prompt),
temperature=config.temperature,
max_completion_tokens=config.max_completion_tokens,
extra_body=build_openrouter_extra(config),
stream=True,
)
try:
first_chunk = await stream.__anext__()
model_used = getattr(first_chunk, "model", None)
if model_used:
yield f"__model_used__:{model_used}"
delta_text = getattr(first_chunk.choices[0].delta, "content", None)
if delta_text:
yield delta_text
except StopAsyncIteration:
return
last_finish_reason = None
async for chunk in stream:
delta_text = getattr(chunk.choices[0].delta, "content", None)
if delta_text:
yield delta_text
# Reasons: tool_calls, stop, length, content_filter, error
finish_reason = getattr(chunk.choices[0], "finish_reason", None)
if finish_reason:
last_finish_reason = finish_reason
logger.info(f"OpenRouter stream finished. Model used: {model_used}")
logger.warning(f"Final finish_reason: {last_finish_reason}")
# Yield a chunk to trigger truncation warning in UI
if last_finish_reason == "length":
yield "__truncated__"
return gen()
# ---------------------------------------
# Test Log Probs and Confidence Visualization
# ---------------------------------------
# import math
# def visualize_token_confidence(token_probs: list[tuple[str, float]]):
# """Print token probabilities as ASCII bars in the terminal."""
# for tok, prob in token_probs:
# bar_length = int(prob * 40) # scale bar to 40 chars max
# bar = "#" * bar_length
# print(f"{tok:>12}: [{bar:<40}] {prob:.2%}")
# async def generate_openrouter(
# prompt: str,
# config: ModelConfig,
# max_tokens: int | None = None) -> tuple[str, str | None, str | None]:
# """Generate a response from OpenRouter
# and log token-level statistics with confidence evolution."""
# resp = await async_openrouter_client.chat.completions.create(
# model=config.primary_model,
# messages=build_messages(prompt),
# temperature=config.temperature,
# max_completion_tokens=max_tokens or config.max_completion_tokens,
# extra_body={**build_openrouter_extra(config), "logprobs": True, "top_logprobs": 3},
# )
# choice = resp.choices[0]
# content = choice.message.content or ""
# finish_reason = getattr(choice, "native_finish_reason", None)
# model_used = getattr(choice, "model", None) or getattr(resp, "model", None)
# logger.info(f"OpenRouter non-stream finish_reason: {finish_reason}")
# if finish_reason == "length":
# logger.warning("Response was truncated by token limit.")
# # Extract logprobs
# token_logprobs_list = choice.logprobs
# tokens_logprobs = []
# token_probs = []
# if token_logprobs_list and hasattr(token_logprobs_list, "content"):
# for token_info in token_logprobs_list.content:
# tok = token_info.token
# logprob = token_info.logprob
# prob = math.exp(logprob)
# tokens_logprobs.append(logprob)
# token_probs.append((tok, prob))
# if tokens_logprobs:
# avg_logprob = sum(tokens_logprobs) / len(tokens_logprobs)
# avg_prob = math.exp(avg_logprob)
# most_confident = sorted(token_probs, key=lambda x: x[1], reverse=True)[:5]
# least_confident = sorted(token_probs, key=lambda x: x[1])[:5]
# logger.info(f"Temperature: {config.temperature}")
# logger.info(f"Max completion tokens: {config.max_completion_tokens}")
# logger.info(f"Average log probability: {avg_logprob:.4f} "
# f"(β {avg_prob:.2%} avg token prob)")"
# logger.info("Top 5 most confident tokens:")
# for tok, prob in most_confident:
# logger.info(f" '{tok}' β {prob:.2%}")
# logger.info("Top 5 least confident tokens:")
# for tok, prob in least_confident:
# logger.info(f" '{tok}' β {prob:.2%}")
# # Terminal visualization
# print("\nToken confidence evolution:")
# visualize_token_confidence(token_probs,)
# else:
# logger.warning("No logprob information found in response.")
# return content, model_used, finish_reason
|