github-actions[bot] commited on
Commit
e83bd85
·
1 Parent(s): 1189434

Auto-sync from demo at Fri Nov 21 14:48:34 UTC 2025

Browse files
graphgen/models/llm/api/openai_client.py CHANGED
@@ -2,7 +2,7 @@ import math
2
  from typing import Any, Dict, List, Optional
3
 
4
  import openai
5
- from openai import APIConnectionError, APITimeoutError, AsyncOpenAI, RateLimitError
6
  from tenacity import (
7
  retry,
8
  retry_if_exception_type,
@@ -35,17 +35,20 @@ class OpenAIClient(BaseLLMWrapper):
35
  model: str = "gpt-4o-mini",
36
  api_key: Optional[str] = None,
37
  base_url: Optional[str] = None,
 
38
  json_mode: bool = False,
39
  seed: Optional[int] = None,
40
  topk_per_token: int = 5, # number of topk tokens to generate for each token
41
  request_limit: bool = False,
42
  rpm: Optional[RPM] = None,
43
  tpm: Optional[TPM] = None,
 
44
  **kwargs: Any,
45
  ):
46
  super().__init__(**kwargs)
47
  self.model = model
48
  self.api_key = api_key
 
49
  self.base_url = base_url
50
  self.json_mode = json_mode
51
  self.seed = seed
@@ -56,13 +59,32 @@ class OpenAIClient(BaseLLMWrapper):
56
  self.rpm = rpm or RPM()
57
  self.tpm = tpm or TPM()
58
 
 
 
 
 
 
59
  self.__post_init__()
60
 
61
  def __post_init__(self):
62
- assert self.api_key is not None, "Please provide api key to access openai api."
63
- self.client = AsyncOpenAI(
64
- api_key=self.api_key or "dummy", base_url=self.base_url
65
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
 
67
  def _pre_generate(self, text: str, history: List[str]) -> Dict:
68
  kwargs = {
 
2
  from typing import Any, Dict, List, Optional
3
 
4
  import openai
5
+ from openai import APIConnectionError, APITimeoutError, AsyncOpenAI, AsyncAzureOpenAI, RateLimitError
6
  from tenacity import (
7
  retry,
8
  retry_if_exception_type,
 
35
  model: str = "gpt-4o-mini",
36
  api_key: Optional[str] = None,
37
  base_url: Optional[str] = None,
38
+ api_version: Optional[str] = None,
39
  json_mode: bool = False,
40
  seed: Optional[int] = None,
41
  topk_per_token: int = 5, # number of topk tokens to generate for each token
42
  request_limit: bool = False,
43
  rpm: Optional[RPM] = None,
44
  tpm: Optional[TPM] = None,
45
+ backend: str = "openai_api",
46
  **kwargs: Any,
47
  ):
48
  super().__init__(**kwargs)
49
  self.model = model
50
  self.api_key = api_key
51
+ self.api_version = api_version # required for Azure OpenAI
52
  self.base_url = base_url
53
  self.json_mode = json_mode
54
  self.seed = seed
 
59
  self.rpm = rpm or RPM()
60
  self.tpm = tpm or TPM()
61
 
62
+ assert (
63
+ backend in ("openai_api", "azure_openai_api")
64
+ ), f"Unsupported backend '{backend}'. Use 'openai_api' or 'azure_openai_api'."
65
+ self.backend = backend
66
+
67
  self.__post_init__()
68
 
69
  def __post_init__(self):
70
+
71
+ api_name = self.backend.replace("_", " ")
72
+ assert self.api_key is not None, f"Please provide api key to access {api_name}."
73
+ if self.backend == "openai_api":
74
+ self.client = AsyncOpenAI(
75
+ api_key=self.api_key or "dummy", base_url=self.base_url
76
+ )
77
+ elif self.backend == "azure_openai_api":
78
+ assert self.api_version is not None, f"Please provide api_version for {api_name}."
79
+ assert self.base_url is not None, f"Please provide base_url for {api_name}."
80
+ self.client = AsyncAzureOpenAI(
81
+ api_key=self.api_key,
82
+ azure_endpoint=self.base_url,
83
+ api_version=self.api_version,
84
+ azure_deployment=self.model,
85
+ )
86
+ else:
87
+ raise ValueError(f"Unsupported backend {self.backend}. Use 'openai_api' or 'azure_openai_api'.")
88
 
89
  def _pre_generate(self, text: str, history: List[str]) -> Dict:
90
  kwargs = {
graphgen/operators/init/init_llm.py CHANGED
@@ -27,10 +27,11 @@ class LLMFactory:
27
  from graphgen.models.llm.api.http_client import HTTPClient
28
 
29
  return HTTPClient(**config)
30
- if backend == "openai_api":
31
  from graphgen.models.llm.api.openai_client import OpenAIClient
32
-
33
- return OpenAIClient(**config)
 
34
  if backend == "ollama_api":
35
  from graphgen.models.llm.api.ollama_client import OllamaClient
36
 
 
27
  from graphgen.models.llm.api.http_client import HTTPClient
28
 
29
  return HTTPClient(**config)
30
+ if backend in ("openai_api", "azure_openai_api"):
31
  from graphgen.models.llm.api.openai_client import OpenAIClient
32
+ # pass in concrete backend to the OpenAIClient so that internally we can distinguish
33
+ # between OpenAI and Azure OpenAI
34
+ return OpenAIClient(**config, backend=backend)
35
  if backend == "ollama_api":
36
  from graphgen.models.llm.api.ollama_client import OllamaClient
37