Spaces:
Sleeping
Sleeping
github-actions[bot]
commited on
Commit
·
e83bd85
1
Parent(s):
1189434
Auto-sync from demo at Fri Nov 21 14:48:34 UTC 2025
Browse files
graphgen/models/llm/api/openai_client.py
CHANGED
|
@@ -2,7 +2,7 @@ import math
|
|
| 2 |
from typing import Any, Dict, List, Optional
|
| 3 |
|
| 4 |
import openai
|
| 5 |
-
from openai import APIConnectionError, APITimeoutError, AsyncOpenAI, RateLimitError
|
| 6 |
from tenacity import (
|
| 7 |
retry,
|
| 8 |
retry_if_exception_type,
|
|
@@ -35,17 +35,20 @@ class OpenAIClient(BaseLLMWrapper):
|
|
| 35 |
model: str = "gpt-4o-mini",
|
| 36 |
api_key: Optional[str] = None,
|
| 37 |
base_url: Optional[str] = None,
|
|
|
|
| 38 |
json_mode: bool = False,
|
| 39 |
seed: Optional[int] = None,
|
| 40 |
topk_per_token: int = 5, # number of topk tokens to generate for each token
|
| 41 |
request_limit: bool = False,
|
| 42 |
rpm: Optional[RPM] = None,
|
| 43 |
tpm: Optional[TPM] = None,
|
|
|
|
| 44 |
**kwargs: Any,
|
| 45 |
):
|
| 46 |
super().__init__(**kwargs)
|
| 47 |
self.model = model
|
| 48 |
self.api_key = api_key
|
|
|
|
| 49 |
self.base_url = base_url
|
| 50 |
self.json_mode = json_mode
|
| 51 |
self.seed = seed
|
|
@@ -56,13 +59,32 @@ class OpenAIClient(BaseLLMWrapper):
|
|
| 56 |
self.rpm = rpm or RPM()
|
| 57 |
self.tpm = tpm or TPM()
|
| 58 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 59 |
self.__post_init__()
|
| 60 |
|
| 61 |
def __post_init__(self):
|
| 62 |
-
|
| 63 |
-
self.
|
| 64 |
-
|
| 65 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 66 |
|
| 67 |
def _pre_generate(self, text: str, history: List[str]) -> Dict:
|
| 68 |
kwargs = {
|
|
|
|
| 2 |
from typing import Any, Dict, List, Optional
|
| 3 |
|
| 4 |
import openai
|
| 5 |
+
from openai import APIConnectionError, APITimeoutError, AsyncOpenAI, AsyncAzureOpenAI, RateLimitError
|
| 6 |
from tenacity import (
|
| 7 |
retry,
|
| 8 |
retry_if_exception_type,
|
|
|
|
| 35 |
model: str = "gpt-4o-mini",
|
| 36 |
api_key: Optional[str] = None,
|
| 37 |
base_url: Optional[str] = None,
|
| 38 |
+
api_version: Optional[str] = None,
|
| 39 |
json_mode: bool = False,
|
| 40 |
seed: Optional[int] = None,
|
| 41 |
topk_per_token: int = 5, # number of topk tokens to generate for each token
|
| 42 |
request_limit: bool = False,
|
| 43 |
rpm: Optional[RPM] = None,
|
| 44 |
tpm: Optional[TPM] = None,
|
| 45 |
+
backend: str = "openai_api",
|
| 46 |
**kwargs: Any,
|
| 47 |
):
|
| 48 |
super().__init__(**kwargs)
|
| 49 |
self.model = model
|
| 50 |
self.api_key = api_key
|
| 51 |
+
self.api_version = api_version # required for Azure OpenAI
|
| 52 |
self.base_url = base_url
|
| 53 |
self.json_mode = json_mode
|
| 54 |
self.seed = seed
|
|
|
|
| 59 |
self.rpm = rpm or RPM()
|
| 60 |
self.tpm = tpm or TPM()
|
| 61 |
|
| 62 |
+
assert (
|
| 63 |
+
backend in ("openai_api", "azure_openai_api")
|
| 64 |
+
), f"Unsupported backend '{backend}'. Use 'openai_api' or 'azure_openai_api'."
|
| 65 |
+
self.backend = backend
|
| 66 |
+
|
| 67 |
self.__post_init__()
|
| 68 |
|
| 69 |
def __post_init__(self):
|
| 70 |
+
|
| 71 |
+
api_name = self.backend.replace("_", " ")
|
| 72 |
+
assert self.api_key is not None, f"Please provide api key to access {api_name}."
|
| 73 |
+
if self.backend == "openai_api":
|
| 74 |
+
self.client = AsyncOpenAI(
|
| 75 |
+
api_key=self.api_key or "dummy", base_url=self.base_url
|
| 76 |
+
)
|
| 77 |
+
elif self.backend == "azure_openai_api":
|
| 78 |
+
assert self.api_version is not None, f"Please provide api_version for {api_name}."
|
| 79 |
+
assert self.base_url is not None, f"Please provide base_url for {api_name}."
|
| 80 |
+
self.client = AsyncAzureOpenAI(
|
| 81 |
+
api_key=self.api_key,
|
| 82 |
+
azure_endpoint=self.base_url,
|
| 83 |
+
api_version=self.api_version,
|
| 84 |
+
azure_deployment=self.model,
|
| 85 |
+
)
|
| 86 |
+
else:
|
| 87 |
+
raise ValueError(f"Unsupported backend {self.backend}. Use 'openai_api' or 'azure_openai_api'.")
|
| 88 |
|
| 89 |
def _pre_generate(self, text: str, history: List[str]) -> Dict:
|
| 90 |
kwargs = {
|
graphgen/operators/init/init_llm.py
CHANGED
|
@@ -27,10 +27,11 @@ class LLMFactory:
|
|
| 27 |
from graphgen.models.llm.api.http_client import HTTPClient
|
| 28 |
|
| 29 |
return HTTPClient(**config)
|
| 30 |
-
if backend
|
| 31 |
from graphgen.models.llm.api.openai_client import OpenAIClient
|
| 32 |
-
|
| 33 |
-
|
|
|
|
| 34 |
if backend == "ollama_api":
|
| 35 |
from graphgen.models.llm.api.ollama_client import OllamaClient
|
| 36 |
|
|
|
|
| 27 |
from graphgen.models.llm.api.http_client import HTTPClient
|
| 28 |
|
| 29 |
return HTTPClient(**config)
|
| 30 |
+
if backend in ("openai_api", "azure_openai_api"):
|
| 31 |
from graphgen.models.llm.api.openai_client import OpenAIClient
|
| 32 |
+
# pass in concrete backend to the OpenAIClient so that internally we can distinguish
|
| 33 |
+
# between OpenAI and Azure OpenAI
|
| 34 |
+
return OpenAIClient(**config, backend=backend)
|
| 35 |
if backend == "ollama_api":
|
| 36 |
from graphgen.models.llm.api.ollama_client import OllamaClient
|
| 37 |
|