Spaces:
Running
Running
| import math | |
| from typing import Any, Dict, List, Optional | |
| import openai | |
| from openai import APIConnectionError, APITimeoutError, AsyncOpenAI, AsyncAzureOpenAI, RateLimitError | |
| from tenacity import ( | |
| retry, | |
| retry_if_exception_type, | |
| stop_after_attempt, | |
| wait_exponential, | |
| ) | |
| from graphgen.bases.base_llm_wrapper import BaseLLMWrapper | |
| from graphgen.bases.datatypes import Token | |
| from graphgen.models.llm.limitter import RPM, TPM | |
| def get_top_response_tokens(response: openai.ChatCompletion) -> List[Token]: | |
| token_logprobs = response.choices[0].logprobs.content | |
| tokens = [] | |
| for token_prob in token_logprobs: | |
| prob = math.exp(token_prob.logprob) | |
| candidate_tokens = [ | |
| Token(t.token, math.exp(t.logprob)) for t in token_prob.top_logprobs | |
| ] | |
| token = Token(token_prob.token, prob, top_candidates=candidate_tokens) | |
| tokens.append(token) | |
| return tokens | |
| class OpenAIClient(BaseLLMWrapper): | |
| def __init__( | |
| self, | |
| *, | |
| model: str = "gpt-4o-mini", | |
| api_key: Optional[str] = None, | |
| base_url: Optional[str] = None, | |
| api_version: Optional[str] = None, | |
| json_mode: bool = False, | |
| seed: Optional[int] = None, | |
| topk_per_token: int = 5, # number of topk tokens to generate for each token | |
| request_limit: bool = False, | |
| rpm: Optional[RPM] = None, | |
| tpm: Optional[TPM] = None, | |
| backend: str = "openai_api", | |
| **kwargs: Any, | |
| ): | |
| super().__init__(**kwargs) | |
| self.model = model | |
| self.api_key = api_key | |
| self.api_version = api_version # required for Azure OpenAI | |
| self.base_url = base_url | |
| self.json_mode = json_mode | |
| self.seed = seed | |
| self.topk_per_token = topk_per_token | |
| self.token_usage: list = [] | |
| self.request_limit = request_limit | |
| self.rpm = rpm or RPM() | |
| self.tpm = tpm or TPM() | |
| assert ( | |
| backend in ("openai_api", "azure_openai_api") | |
| ), f"Unsupported backend '{backend}'. Use 'openai_api' or 'azure_openai_api'." | |
| self.backend = backend | |
| self.__post_init__() | |
| def __post_init__(self): | |
| api_name = self.backend.replace("_", " ") | |
| assert self.api_key is not None, f"Please provide api key to access {api_name}." | |
| if self.backend == "openai_api": | |
| self.client = AsyncOpenAI( | |
| api_key=self.api_key or "dummy", base_url=self.base_url | |
| ) | |
| elif self.backend == "azure_openai_api": | |
| assert self.api_version is not None, f"Please provide api_version for {api_name}." | |
| assert self.base_url is not None, f"Please provide base_url for {api_name}." | |
| self.client = AsyncAzureOpenAI( | |
| api_key=self.api_key, | |
| azure_endpoint=self.base_url, | |
| api_version=self.api_version, | |
| azure_deployment=self.model, | |
| ) | |
| else: | |
| raise ValueError(f"Unsupported backend {self.backend}. Use 'openai_api' or 'azure_openai_api'.") | |
| def _pre_generate(self, text: str, history: List[str]) -> Dict: | |
| kwargs = { | |
| "temperature": self.temperature, | |
| "top_p": self.top_p, | |
| "max_tokens": self.max_tokens, | |
| } | |
| if self.seed: | |
| kwargs["seed"] = self.seed | |
| if self.json_mode: | |
| kwargs["response_format"] = {"type": "json_object"} | |
| messages = [] | |
| if self.system_prompt: | |
| messages.append({"role": "system", "content": self.system_prompt}) | |
| messages.append({"role": "user", "content": text}) | |
| if history: | |
| assert len(history) % 2 == 0, "History should have even number of elements." | |
| messages = history + messages | |
| kwargs["messages"] = messages | |
| return kwargs | |
| async def generate_topk_per_token( | |
| self, | |
| text: str, | |
| history: Optional[List[str]] = None, | |
| **extra: Any, | |
| ) -> List[Token]: | |
| kwargs = self._pre_generate(text, history) | |
| if self.topk_per_token > 0: | |
| kwargs["logprobs"] = True | |
| kwargs["top_logprobs"] = self.topk_per_token | |
| # Limit max_tokens to 1 to avoid long completions | |
| kwargs["max_tokens"] = 1 | |
| completion = await self.client.chat.completions.create( # pylint: disable=E1125 | |
| model=self.model, **kwargs | |
| ) | |
| tokens = get_top_response_tokens(completion) | |
| return tokens | |
| async def generate_answer( | |
| self, | |
| text: str, | |
| history: Optional[List[str]] = None, | |
| **extra: Any, | |
| ) -> str: | |
| kwargs = self._pre_generate(text, history) | |
| prompt_tokens = 0 | |
| for message in kwargs["messages"]: | |
| prompt_tokens += len(self.tokenizer.encode(message["content"])) | |
| estimated_tokens = prompt_tokens + kwargs["max_tokens"] | |
| if self.request_limit: | |
| await self.rpm.wait(silent=True) | |
| await self.tpm.wait(estimated_tokens, silent=True) | |
| completion = await self.client.chat.completions.create( # pylint: disable=E1125 | |
| model=self.model, **kwargs | |
| ) | |
| if hasattr(completion, "usage"): | |
| self.token_usage.append( | |
| { | |
| "prompt_tokens": completion.usage.prompt_tokens, | |
| "completion_tokens": completion.usage.completion_tokens, | |
| "total_tokens": completion.usage.total_tokens, | |
| } | |
| ) | |
| return self.filter_think_tags(completion.choices[0].message.content) | |
| async def generate_inputs_prob( | |
| self, text: str, history: Optional[List[str]] = None, **extra: Any | |
| ) -> List[Token]: | |
| """Generate probabilities for each token in the input.""" | |
| raise NotImplementedError | |