GraphGen / graphgen /models /llm /api /openai_client.py
github-actions[bot]
Auto-sync from demo at Fri Nov 21 14:48:34 UTC 2025
e83bd85
raw
history blame
6.3 kB
import math
from typing import Any, Dict, List, Optional
import openai
from openai import APIConnectionError, APITimeoutError, AsyncOpenAI, AsyncAzureOpenAI, RateLimitError
from tenacity import (
retry,
retry_if_exception_type,
stop_after_attempt,
wait_exponential,
)
from graphgen.bases.base_llm_wrapper import BaseLLMWrapper
from graphgen.bases.datatypes import Token
from graphgen.models.llm.limitter import RPM, TPM
def get_top_response_tokens(response: openai.ChatCompletion) -> List[Token]:
token_logprobs = response.choices[0].logprobs.content
tokens = []
for token_prob in token_logprobs:
prob = math.exp(token_prob.logprob)
candidate_tokens = [
Token(t.token, math.exp(t.logprob)) for t in token_prob.top_logprobs
]
token = Token(token_prob.token, prob, top_candidates=candidate_tokens)
tokens.append(token)
return tokens
class OpenAIClient(BaseLLMWrapper):
def __init__(
self,
*,
model: str = "gpt-4o-mini",
api_key: Optional[str] = None,
base_url: Optional[str] = None,
api_version: Optional[str] = None,
json_mode: bool = False,
seed: Optional[int] = None,
topk_per_token: int = 5, # number of topk tokens to generate for each token
request_limit: bool = False,
rpm: Optional[RPM] = None,
tpm: Optional[TPM] = None,
backend: str = "openai_api",
**kwargs: Any,
):
super().__init__(**kwargs)
self.model = model
self.api_key = api_key
self.api_version = api_version # required for Azure OpenAI
self.base_url = base_url
self.json_mode = json_mode
self.seed = seed
self.topk_per_token = topk_per_token
self.token_usage: list = []
self.request_limit = request_limit
self.rpm = rpm or RPM()
self.tpm = tpm or TPM()
assert (
backend in ("openai_api", "azure_openai_api")
), f"Unsupported backend '{backend}'. Use 'openai_api' or 'azure_openai_api'."
self.backend = backend
self.__post_init__()
def __post_init__(self):
api_name = self.backend.replace("_", " ")
assert self.api_key is not None, f"Please provide api key to access {api_name}."
if self.backend == "openai_api":
self.client = AsyncOpenAI(
api_key=self.api_key or "dummy", base_url=self.base_url
)
elif self.backend == "azure_openai_api":
assert self.api_version is not None, f"Please provide api_version for {api_name}."
assert self.base_url is not None, f"Please provide base_url for {api_name}."
self.client = AsyncAzureOpenAI(
api_key=self.api_key,
azure_endpoint=self.base_url,
api_version=self.api_version,
azure_deployment=self.model,
)
else:
raise ValueError(f"Unsupported backend {self.backend}. Use 'openai_api' or 'azure_openai_api'.")
def _pre_generate(self, text: str, history: List[str]) -> Dict:
kwargs = {
"temperature": self.temperature,
"top_p": self.top_p,
"max_tokens": self.max_tokens,
}
if self.seed:
kwargs["seed"] = self.seed
if self.json_mode:
kwargs["response_format"] = {"type": "json_object"}
messages = []
if self.system_prompt:
messages.append({"role": "system", "content": self.system_prompt})
messages.append({"role": "user", "content": text})
if history:
assert len(history) % 2 == 0, "History should have even number of elements."
messages = history + messages
kwargs["messages"] = messages
return kwargs
@retry(
stop=stop_after_attempt(5),
wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type(
(RateLimitError, APIConnectionError, APITimeoutError)
),
)
async def generate_topk_per_token(
self,
text: str,
history: Optional[List[str]] = None,
**extra: Any,
) -> List[Token]:
kwargs = self._pre_generate(text, history)
if self.topk_per_token > 0:
kwargs["logprobs"] = True
kwargs["top_logprobs"] = self.topk_per_token
# Limit max_tokens to 1 to avoid long completions
kwargs["max_tokens"] = 1
completion = await self.client.chat.completions.create( # pylint: disable=E1125
model=self.model, **kwargs
)
tokens = get_top_response_tokens(completion)
return tokens
@retry(
stop=stop_after_attempt(5),
wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type(
(RateLimitError, APIConnectionError, APITimeoutError)
),
)
async def generate_answer(
self,
text: str,
history: Optional[List[str]] = None,
**extra: Any,
) -> str:
kwargs = self._pre_generate(text, history)
prompt_tokens = 0
for message in kwargs["messages"]:
prompt_tokens += len(self.tokenizer.encode(message["content"]))
estimated_tokens = prompt_tokens + kwargs["max_tokens"]
if self.request_limit:
await self.rpm.wait(silent=True)
await self.tpm.wait(estimated_tokens, silent=True)
completion = await self.client.chat.completions.create( # pylint: disable=E1125
model=self.model, **kwargs
)
if hasattr(completion, "usage"):
self.token_usage.append(
{
"prompt_tokens": completion.usage.prompt_tokens,
"completion_tokens": completion.usage.completion_tokens,
"total_tokens": completion.usage.total_tokens,
}
)
return self.filter_think_tags(completion.choices[0].message.content)
async def generate_inputs_prob(
self, text: str, history: Optional[List[str]] = None, **extra: Any
) -> List[Token]:
"""Generate probabilities for each token in the input."""
raise NotImplementedError