github-actions[bot] commited on
Commit
d02622b
·
1 Parent(s): d2af8b0

Auto-sync from demo at Wed Oct 29 11:25:28 UTC 2025

Browse files
graphgen/bases/__init__.py CHANGED
@@ -1,6 +1,6 @@
1
  from .base_generator import BaseGenerator
2
  from .base_kg_builder import BaseKGBuilder
3
- from .base_llm_client import BaseLLMClient
4
  from .base_partitioner import BasePartitioner
5
  from .base_reader import BaseReader
6
  from .base_splitter import BaseSplitter
 
1
  from .base_generator import BaseGenerator
2
  from .base_kg_builder import BaseKGBuilder
3
+ from .base_llm_wrapper import BaseLLMWrapper
4
  from .base_partitioner import BasePartitioner
5
  from .base_reader import BaseReader
6
  from .base_splitter import BaseSplitter
graphgen/bases/base_generator.py CHANGED
@@ -1,7 +1,7 @@
1
  from abc import ABC, abstractmethod
2
  from typing import Any
3
 
4
- from graphgen.bases.base_llm_client import BaseLLMClient
5
 
6
 
7
  class BaseGenerator(ABC):
@@ -9,7 +9,7 @@ class BaseGenerator(ABC):
9
  Generate QAs based on given prompts.
10
  """
11
 
12
- def __init__(self, llm_client: BaseLLMClient):
13
  self.llm_client = llm_client
14
 
15
  @staticmethod
 
1
  from abc import ABC, abstractmethod
2
  from typing import Any
3
 
4
+ from graphgen.bases.base_llm_wrapper import BaseLLMWrapper
5
 
6
 
7
  class BaseGenerator(ABC):
 
9
  Generate QAs based on given prompts.
10
  """
11
 
12
+ def __init__(self, llm_client: BaseLLMWrapper):
13
  self.llm_client = llm_client
14
 
15
  @staticmethod
graphgen/bases/base_kg_builder.py CHANGED
@@ -2,13 +2,13 @@ from abc import ABC, abstractmethod
2
  from collections import defaultdict
3
  from typing import Dict, List, Tuple
4
 
5
- from graphgen.bases.base_llm_client import BaseLLMClient
6
  from graphgen.bases.base_storage import BaseGraphStorage
7
  from graphgen.bases.datatypes import Chunk
8
 
9
 
10
  class BaseKGBuilder(ABC):
11
- def __init__(self, llm_client: BaseLLMClient):
12
  self.llm_client = llm_client
13
  self._nodes: Dict[str, List[dict]] = defaultdict(list)
14
  self._edges: Dict[Tuple[str, str], List[dict]] = defaultdict(list)
 
2
  from collections import defaultdict
3
  from typing import Dict, List, Tuple
4
 
5
+ from graphgen.bases.base_llm_wrapper import BaseLLMWrapper
6
  from graphgen.bases.base_storage import BaseGraphStorage
7
  from graphgen.bases.datatypes import Chunk
8
 
9
 
10
  class BaseKGBuilder(ABC):
11
+ def __init__(self, llm_client: BaseLLMWrapper):
12
  self.llm_client = llm_client
13
  self._nodes: Dict[str, List[dict]] = defaultdict(list)
14
  self._edges: Dict[Tuple[str, str], List[dict]] = defaultdict(list)
graphgen/bases/{base_llm_client.py → base_llm_wrapper.py} RENAMED
@@ -8,7 +8,7 @@ from graphgen.bases.base_tokenizer import BaseTokenizer
8
  from graphgen.bases.datatypes import Token
9
 
10
 
11
- class BaseLLMClient(abc.ABC):
12
  """
13
  LLM client base class, agnostic to specific backends (OpenAI / Ollama / ...).
14
  """
@@ -66,3 +66,9 @@ class BaseLLMClient(abc.ABC):
66
  think_pattern = re.compile(rf"<{think_tag}>.*?</{think_tag}>", re.DOTALL)
67
  filtered_text = think_pattern.sub("", text).strip()
68
  return filtered_text if filtered_text else text.strip()
 
 
 
 
 
 
 
8
  from graphgen.bases.datatypes import Token
9
 
10
 
11
+ class BaseLLMWrapper(abc.ABC):
12
  """
13
  LLM client base class, agnostic to specific backends (OpenAI / Ollama / ...).
14
  """
 
66
  think_pattern = re.compile(rf"<{think_tag}>.*?</{think_tag}>", re.DOTALL)
67
  filtered_text = think_pattern.sub("", text).strip()
68
  return filtered_text if filtered_text else text.strip()
69
+
70
+ def shutdown(self) -> None:
71
+ """Shutdown the LLM engine if applicable."""
72
+
73
+ def restart(self) -> None:
74
+ """Reinitialize the LLM engine if applicable."""
graphgen/graphgen.py CHANGED
@@ -1,11 +1,11 @@
1
  import asyncio
2
  import os
3
  import time
4
- from dataclasses import dataclass
5
  from typing import Dict, cast
6
 
7
  import gradio as gr
8
 
 
9
  from graphgen.bases.base_storage import StorageNameSpace
10
  from graphgen.bases.datatypes import Chunk
11
  from graphgen.models import (
@@ -20,6 +20,7 @@ from graphgen.operators import (
20
  build_text_kg,
21
  chunk_documents,
22
  generate_qas,
 
23
  judge_statement,
24
  partition_kg,
25
  quiz,
@@ -31,40 +32,28 @@ from graphgen.utils import async_to_sync_method, compute_mm_hash, logger
31
  sys_path = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
32
 
33
 
34
- @dataclass
35
  class GraphGen:
36
- unique_id: int = int(time.time())
37
- working_dir: str = os.path.join(sys_path, "cache")
38
-
39
- # llm
40
- tokenizer_instance: Tokenizer = None
41
- synthesizer_llm_client: OpenAIClient = None
42
- trainee_llm_client: OpenAIClient = None
43
-
44
- # webui
45
- progress_bar: gr.Progress = None
46
-
47
- def __post_init__(self):
48
- self.tokenizer_instance: Tokenizer = self.tokenizer_instance or Tokenizer(
 
49
  model_name=os.getenv("TOKENIZER_MODEL")
50
  )
51
 
52
- self.synthesizer_llm_client: OpenAIClient = (
53
- self.synthesizer_llm_client
54
- or OpenAIClient(
55
- model_name=os.getenv("SYNTHESIZER_MODEL"),
56
- api_key=os.getenv("SYNTHESIZER_API_KEY"),
57
- base_url=os.getenv("SYNTHESIZER_BASE_URL"),
58
- tokenizer=self.tokenizer_instance,
59
- )
60
- )
61
-
62
- self.trainee_llm_client: OpenAIClient = self.trainee_llm_client or OpenAIClient(
63
- model_name=os.getenv("TRAINEE_MODEL"),
64
- api_key=os.getenv("TRAINEE_API_KEY"),
65
- base_url=os.getenv("TRAINEE_BASE_URL"),
66
- tokenizer=self.tokenizer_instance,
67
  )
 
68
 
69
  self.full_docs_storage: JsonKVStorage = JsonKVStorage(
70
  self.working_dir, namespace="full_docs"
@@ -86,6 +75,9 @@ class GraphGen:
86
  namespace="qa",
87
  )
88
 
 
 
 
89
  @async_to_sync_method
90
  async def insert(self, read_config: Dict, split_config: Dict):
91
  """
@@ -272,6 +264,12 @@ class GraphGen:
272
  )
273
 
274
  # TODO: assert trainee_llm_client is valid before judge
 
 
 
 
 
 
275
  re_judge = quiz_and_judge_config["re_judge"]
276
  _update_relations = await judge_statement(
277
  self.trainee_llm_client,
@@ -279,9 +277,16 @@ class GraphGen:
279
  self.rephrase_storage,
280
  re_judge,
281
  )
 
282
  await self.rephrase_storage.index_done_callback()
283
  await _update_relations.index_done_callback()
284
 
 
 
 
 
 
 
285
  @async_to_sync_method
286
  async def generate(self, partition_config: Dict, generate_config: Dict):
287
  # Step 1: partition the graph
 
1
  import asyncio
2
  import os
3
  import time
 
4
  from typing import Dict, cast
5
 
6
  import gradio as gr
7
 
8
+ from graphgen.bases import BaseLLMWrapper
9
  from graphgen.bases.base_storage import StorageNameSpace
10
  from graphgen.bases.datatypes import Chunk
11
  from graphgen.models import (
 
20
  build_text_kg,
21
  chunk_documents,
22
  generate_qas,
23
+ init_llm,
24
  judge_statement,
25
  partition_kg,
26
  quiz,
 
32
  sys_path = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
33
 
34
 
 
35
  class GraphGen:
36
+ def __init__(
37
+ self,
38
+ unique_id: int = int(time.time()),
39
+ working_dir: str = os.path.join(sys_path, "cache"),
40
+ tokenizer_instance: Tokenizer = None,
41
+ synthesizer_llm_client: OpenAIClient = None,
42
+ trainee_llm_client: OpenAIClient = None,
43
+ progress_bar: gr.Progress = None,
44
+ ):
45
+ self.unique_id: int = unique_id
46
+ self.working_dir: str = working_dir
47
+
48
+ # llm
49
+ self.tokenizer_instance: Tokenizer = tokenizer_instance or Tokenizer(
50
  model_name=os.getenv("TOKENIZER_MODEL")
51
  )
52
 
53
+ self.synthesizer_llm_client: BaseLLMWrapper = (
54
+ synthesizer_llm_client or init_llm("synthesizer")
 
 
 
 
 
 
 
 
 
 
 
 
 
55
  )
56
+ self.trainee_llm_client: BaseLLMWrapper = trainee_llm_client
57
 
58
  self.full_docs_storage: JsonKVStorage = JsonKVStorage(
59
  self.working_dir, namespace="full_docs"
 
75
  namespace="qa",
76
  )
77
 
78
+ # webui
79
+ self.progress_bar: gr.Progress = progress_bar
80
+
81
  @async_to_sync_method
82
  async def insert(self, read_config: Dict, split_config: Dict):
83
  """
 
264
  )
265
 
266
  # TODO: assert trainee_llm_client is valid before judge
267
+ if not self.trainee_llm_client:
268
+ # TODO: shutdown existing synthesizer_llm_client properly
269
+ logger.info("No trainee LLM client provided, initializing a new one.")
270
+ self.synthesizer_llm_client.shutdown()
271
+ self.trainee_llm_client = init_llm("trainee")
272
+
273
  re_judge = quiz_and_judge_config["re_judge"]
274
  _update_relations = await judge_statement(
275
  self.trainee_llm_client,
 
277
  self.rephrase_storage,
278
  re_judge,
279
  )
280
+
281
  await self.rephrase_storage.index_done_callback()
282
  await _update_relations.index_done_callback()
283
 
284
+ logger.info("Shutting down trainee LLM client.")
285
+ self.trainee_llm_client.shutdown()
286
+ self.trainee_llm_client = None
287
+ logger.info("Restarting synthesizer LLM client.")
288
+ self.synthesizer_llm_client.restart()
289
+
290
  @async_to_sync_method
291
  async def generate(self, partition_config: Dict, generate_config: Dict):
292
  # Step 1: partition the graph
graphgen/models/__init__.py CHANGED
@@ -7,8 +7,7 @@ from .generator import (
7
  VQAGenerator,
8
  )
9
  from .kg_builder import LightRAGKGBuilder, MMKGBuilder
10
- from .llm.openai_client import OpenAIClient
11
- from .llm.topk_token_model import TopkTokenModel
12
  from .partitioner import (
13
  AnchorBFSPartitioner,
14
  BFSPartitioner,
 
7
  VQAGenerator,
8
  )
9
  from .kg_builder import LightRAGKGBuilder, MMKGBuilder
10
+ from .llm import HTTPClient, OllamaClient, OpenAIClient
 
11
  from .partitioner import (
12
  AnchorBFSPartitioner,
13
  BFSPartitioner,
graphgen/models/kg_builder/light_rag_kg_builder.py CHANGED
@@ -2,7 +2,7 @@ import re
2
  from collections import Counter, defaultdict
3
  from typing import Dict, List, Tuple
4
 
5
- from graphgen.bases import BaseGraphStorage, BaseKGBuilder, BaseLLMClient, Chunk
6
  from graphgen.templates import KG_EXTRACTION_PROMPT, KG_SUMMARIZATION_PROMPT
7
  from graphgen.utils import (
8
  detect_main_language,
@@ -15,7 +15,7 @@ from graphgen.utils import (
15
 
16
 
17
  class LightRAGKGBuilder(BaseKGBuilder):
18
- def __init__(self, llm_client: BaseLLMClient, max_loop: int = 3):
19
  super().__init__(llm_client)
20
  self.max_loop = max_loop
21
 
 
2
  from collections import Counter, defaultdict
3
  from typing import Dict, List, Tuple
4
 
5
+ from graphgen.bases import BaseGraphStorage, BaseKGBuilder, BaseLLMWrapper, Chunk
6
  from graphgen.templates import KG_EXTRACTION_PROMPT, KG_SUMMARIZATION_PROMPT
7
  from graphgen.utils import (
8
  detect_main_language,
 
15
 
16
 
17
  class LightRAGKGBuilder(BaseKGBuilder):
18
+ def __init__(self, llm_client: BaseLLMWrapper, max_loop: int = 3):
19
  super().__init__(llm_client)
20
  self.max_loop = max_loop
21
 
graphgen/models/llm/__init__.py CHANGED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from .api.http_client import HTTPClient
2
+ from .api.ollama_client import OllamaClient
3
+ from .api.openai_client import OpenAIClient
4
+ from .local.hf_wrapper import HuggingFaceWrapper
graphgen/models/llm/api/__init__.py ADDED
File without changes
graphgen/models/llm/api/http_client.py ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import math
3
+ from typing import Any, Dict, List, Optional
4
+
5
+ import aiohttp
6
+ from tenacity import (
7
+ retry,
8
+ retry_if_exception_type,
9
+ stop_after_attempt,
10
+ wait_exponential,
11
+ )
12
+
13
+ from graphgen.bases.base_llm_wrapper import BaseLLMWrapper
14
+ from graphgen.bases.datatypes import Token
15
+ from graphgen.models.llm.limitter import RPM, TPM
16
+
17
+
18
+ class HTTPClient(BaseLLMWrapper):
19
+ """
20
+ A generic async HTTP client for LLMs compatible with OpenAI's chat/completions format.
21
+ It uses aiohttp for making requests and includes retry logic and token usage tracking.
22
+ Usage example:
23
+ client = HTTPClient(
24
+ model_name="gpt-4o-mini",
25
+ base_url="http://localhost:8080",
26
+ api_key="your_api_key",
27
+ json_mode=True,
28
+ seed=42,
29
+ topk_per_token=5,
30
+ request_limit=True,
31
+ )
32
+
33
+ answer = await client.generate_answer("Hello, world!")
34
+ tokens = await client.generate_topk_per_token("Hello, world!")
35
+ """
36
+
37
+ _instance: Optional["HTTPClient"] = None
38
+ _lock = asyncio.Lock()
39
+
40
+ def __new__(cls, **kwargs):
41
+ if cls._instance is None:
42
+ cls._instance = super().__new__(cls)
43
+ return cls._instance
44
+
45
+ def __init__(
46
+ self,
47
+ *,
48
+ model: str,
49
+ base_url: str,
50
+ api_key: Optional[str] = None,
51
+ json_mode: bool = False,
52
+ seed: Optional[int] = None,
53
+ topk_per_token: int = 5,
54
+ request_limit: bool = False,
55
+ rpm: Optional[RPM] = None,
56
+ tpm: Optional[TPM] = None,
57
+ **kwargs: Any,
58
+ ):
59
+ # Initialize only once in the singleton pattern
60
+ if getattr(self, "_initialized", False):
61
+ return
62
+ self._initialized: bool = True
63
+ super().__init__(**kwargs)
64
+ self.model_name = model
65
+ self.base_url = base_url.rstrip("/")
66
+ self.api_key = api_key
67
+ self.json_mode = json_mode
68
+ self.seed = seed
69
+ self.topk_per_token = topk_per_token
70
+ self.request_limit = request_limit
71
+ self.rpm = rpm or RPM()
72
+ self.tpm = tpm or TPM()
73
+
74
+ self.token_usage: List[Dict[str, int]] = []
75
+ self._session: Optional[aiohttp.ClientSession] = None
76
+
77
+ @property
78
+ def session(self) -> aiohttp.ClientSession:
79
+ if self._session is None or self._session.closed:
80
+ headers = (
81
+ {"Authorization": f"Bearer {self.api_key}"} if self.api_key else {}
82
+ )
83
+ self._session = aiohttp.ClientSession(headers=headers)
84
+ return self._session
85
+
86
+ async def close(self):
87
+ if self._session and not self._session.closed:
88
+ await self._session.close()
89
+
90
+ def _build_body(self, text: str, history: List[str]) -> Dict[str, Any]:
91
+ messages = []
92
+ if self.system_prompt:
93
+ messages.append({"role": "system", "content": self.system_prompt})
94
+
95
+ # chatml format: alternating user and assistant messages
96
+ if history and isinstance(history[0], dict):
97
+ messages.extend(history)
98
+
99
+ messages.append({"role": "user", "content": text})
100
+
101
+ body = {
102
+ "model": self.model_name,
103
+ "messages": messages,
104
+ "temperature": self.temperature,
105
+ "top_p": self.top_p,
106
+ "max_tokens": self.max_tokens,
107
+ }
108
+ if self.seed:
109
+ body["seed"] = self.seed
110
+ if self.json_mode:
111
+ body["response_format"] = {"type": "json_object"}
112
+ return body
113
+
114
+ @retry(
115
+ stop=stop_after_attempt(5),
116
+ wait=wait_exponential(multiplier=1, min=4, max=10),
117
+ retry=retry_if_exception_type((aiohttp.ClientError, asyncio.TimeoutError)),
118
+ )
119
+ async def generate_answer(
120
+ self,
121
+ text: str,
122
+ history: Optional[List[str]] = None,
123
+ **extra: Any,
124
+ ) -> str:
125
+ body = self._build_body(text, history or [])
126
+ prompt_tokens = sum(
127
+ len(self.tokenizer.encode(m["content"])) for m in body["messages"]
128
+ )
129
+ est = prompt_tokens + body["max_tokens"]
130
+
131
+ if self.request_limit:
132
+ await self.rpm.wait(silent=True)
133
+ await self.tpm.wait(est, silent=True)
134
+
135
+ async with self.session.post(
136
+ f"{self.base_url}/chat/completions",
137
+ json=body,
138
+ timeout=aiohttp.ClientTimeout(total=60),
139
+ ) as resp:
140
+ resp.raise_for_status()
141
+ data = await resp.json()
142
+
143
+ msg = data["choices"][0]["message"]["content"]
144
+ if "usage" in data:
145
+ self.token_usage.append(
146
+ {
147
+ "prompt_tokens": data["usage"]["prompt_tokens"],
148
+ "completion_tokens": data["usage"]["completion_tokens"],
149
+ "total_tokens": data["usage"]["total_tokens"],
150
+ }
151
+ )
152
+ return self.filter_think_tags(msg)
153
+
154
+ @retry(
155
+ stop=stop_after_attempt(5),
156
+ wait=wait_exponential(multiplier=1, min=4, max=10),
157
+ retry=retry_if_exception_type((aiohttp.ClientError, asyncio.TimeoutError)),
158
+ )
159
+ async def generate_topk_per_token(
160
+ self,
161
+ text: str,
162
+ history: Optional[List[str]] = None,
163
+ **extra: Any,
164
+ ) -> List[Token]:
165
+ body = self._build_body(text, history or [])
166
+ body["max_tokens"] = 1
167
+ if self.topk_per_token > 0:
168
+ body["logprobs"] = True
169
+ body["top_logprobs"] = self.topk_per_token
170
+
171
+ async with self.session.post(
172
+ f"{self.base_url}/chat/completions",
173
+ json=body,
174
+ timeout=aiohttp.ClientTimeout(total=60),
175
+ ) as resp:
176
+ resp.raise_for_status()
177
+ data = await resp.json()
178
+
179
+ token_logprobs = data["choices"][0]["logprobs"]["content"]
180
+ tokens = []
181
+ for item in token_logprobs:
182
+ candidates = [
183
+ Token(t["token"], math.exp(t["logprob"])) for t in item["top_logprobs"]
184
+ ]
185
+ tokens.append(
186
+ Token(
187
+ item["token"], math.exp(item["logprob"]), top_candidates=candidates
188
+ )
189
+ )
190
+ return tokens
191
+
192
+ async def generate_inputs_prob(
193
+ self, text: str, history: Optional[List[str]] = None, **extra: Any
194
+ ) -> List[Token]:
195
+ raise NotImplementedError(
196
+ "generate_inputs_prob is not implemented in HTTPClient"
197
+ )
graphgen/models/llm/api/ollama_client.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict, List, Optional
2
+
3
+ from graphgen.bases.base_llm_wrapper import BaseLLMWrapper
4
+ from graphgen.bases.datatypes import Token
5
+ from graphgen.models.llm.limitter import RPM, TPM
6
+
7
+
8
+ class OllamaClient(BaseLLMWrapper):
9
+ """
10
+ Requires a local or remote Ollama server to be running (default port 11434).
11
+ The top_logprobs field is not yet implemented by the official API.
12
+ """
13
+
14
+ def __init__(
15
+ self,
16
+ *,
17
+ model: str = "gemma3",
18
+ base_url: str = "http://localhost:11434",
19
+ json_mode: bool = False,
20
+ seed: Optional[int] = None,
21
+ topk_per_token: int = 5,
22
+ request_limit: bool = False,
23
+ rpm: Optional[RPM] = None,
24
+ tpm: Optional[TPM] = None,
25
+ **kwargs: Any,
26
+ ):
27
+ try:
28
+ import ollama
29
+ except ImportError as e:
30
+ raise ImportError(
31
+ "Ollama SDK is not installed."
32
+ "It is required to use OllamaClient."
33
+ "Please install it with `pip install ollama`."
34
+ ) from e
35
+ super().__init__(**kwargs)
36
+ self.model_name = model
37
+ self.base_url = base_url
38
+ self.json_mode = json_mode
39
+ self.seed = seed
40
+ self.topk_per_token = topk_per_token
41
+ self.request_limit = request_limit
42
+ self.rpm = rpm or RPM()
43
+ self.tpm = tpm or TPM()
44
+ self.token_usage: List[Dict[str, int]] = []
45
+
46
+ self.client = ollama.AsyncClient(host=self.base_url)
47
+
48
+ async def generate_answer(
49
+ self,
50
+ text: str,
51
+ history: Optional[List[Dict[str, str]]] = None,
52
+ **extra: Any,
53
+ ) -> str:
54
+ messages = []
55
+ if self.system_prompt:
56
+ messages.append({"role": "system", "content": self.system_prompt})
57
+ if history:
58
+ messages.extend(history)
59
+ messages.append({"role": "user", "content": text})
60
+
61
+ options = {
62
+ "temperature": self.temperature,
63
+ "top_p": self.top_p,
64
+ "num_predict": self.max_tokens,
65
+ }
66
+ if self.seed is not None:
67
+ options["seed"] = self.seed
68
+
69
+ prompt_tokens = sum(len(self.tokenizer.encode(m["content"])) for m in messages)
70
+ est = prompt_tokens + self.max_tokens
71
+ if self.request_limit:
72
+ await self.rpm.wait(silent=True)
73
+ await self.tpm.wait(est, silent=True)
74
+
75
+ response = await self.client.chat(
76
+ model=self.model_name,
77
+ messages=messages,
78
+ format="json" if self.json_mode else "",
79
+ options=options,
80
+ stream=False,
81
+ )
82
+
83
+ usage = response.get("prompt_eval_count", 0), response.get("eval_count", 0)
84
+ self.token_usage.append(
85
+ {
86
+ "prompt_tokens": usage[0],
87
+ "completion_tokens": usage[1],
88
+ "total_tokens": sum(usage),
89
+ }
90
+ )
91
+ content = response["message"]["content"]
92
+ return self.filter_think_tags(content)
93
+
94
+ async def generate_topk_per_token(
95
+ self,
96
+ text: str,
97
+ history: Optional[List[Dict[str, str]]] = None,
98
+ **extra: Any,
99
+ ) -> List[Token]:
100
+ raise NotImplementedError("Ollama API does not support per-token top-k yet.")
101
+
102
+ async def generate_inputs_prob(
103
+ self, text: str, history: Optional[List[Dict[str, str]]] = None, **extra: Any
104
+ ) -> List[Token]:
105
+ raise NotImplementedError("Ollama API does not support per-token logprobs yet.")
graphgen/models/llm/{openai_client.py → api/openai_client.py} RENAMED
@@ -10,7 +10,7 @@ from tenacity import (
10
  wait_exponential,
11
  )
12
 
13
- from graphgen.bases.base_llm_client import BaseLLMClient
14
  from graphgen.bases.datatypes import Token
15
  from graphgen.models.llm.limitter import RPM, TPM
16
 
@@ -28,7 +28,7 @@ def get_top_response_tokens(response: openai.ChatCompletion) -> List[Token]:
28
  return tokens
29
 
30
 
31
- class OpenAIClient(BaseLLMClient):
32
  def __init__(
33
  self,
34
  *,
@@ -105,8 +105,8 @@ class OpenAIClient(BaseLLMClient):
105
  kwargs["logprobs"] = True
106
  kwargs["top_logprobs"] = self.topk_per_token
107
 
108
- # Limit max_tokens to 5 to avoid long completions
109
- kwargs["max_tokens"] = 5
110
 
111
  completion = await self.client.chat.completions.create( # pylint: disable=E1125
112
  model=self.model_name, **kwargs
 
10
  wait_exponential,
11
  )
12
 
13
+ from graphgen.bases.base_llm_wrapper import BaseLLMWrapper
14
  from graphgen.bases.datatypes import Token
15
  from graphgen.models.llm.limitter import RPM, TPM
16
 
 
28
  return tokens
29
 
30
 
31
+ class OpenAIClient(BaseLLMWrapper):
32
  def __init__(
33
  self,
34
  *,
 
105
  kwargs["logprobs"] = True
106
  kwargs["top_logprobs"] = self.topk_per_token
107
 
108
+ # Limit max_tokens to 1 to avoid long completions
109
+ kwargs["max_tokens"] = 1
110
 
111
  completion = await self.client.chat.completions.create( # pylint: disable=E1125
112
  model=self.model_name, **kwargs
graphgen/models/llm/local/__init__.py ADDED
File without changes
graphgen/models/llm/local/hf_wrapper.py ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, List, Optional
2
+
3
+ from graphgen.bases.base_llm_wrapper import BaseLLMWrapper
4
+ from graphgen.bases.datatypes import Token
5
+
6
+
7
+ class HuggingFaceWrapper(BaseLLMWrapper):
8
+ """
9
+ Async inference backend based on HuggingFace Transformers
10
+ """
11
+
12
+ def __init__(
13
+ self,
14
+ model: str,
15
+ torch_dtype="auto",
16
+ device_map="auto",
17
+ trust_remote_code=True,
18
+ temperature=0.0,
19
+ top_p=1.0,
20
+ topk=5,
21
+ **kwargs: Any,
22
+ ):
23
+ super().__init__(temperature=temperature, top_p=top_p, **kwargs)
24
+
25
+ try:
26
+ import torch
27
+ from transformers import (
28
+ AutoModelForCausalLM,
29
+ AutoTokenizer,
30
+ GenerationConfig,
31
+ )
32
+ except ImportError as exc:
33
+ raise ImportError(
34
+ "HuggingFaceWrapper requires torch, transformers and accelerate. "
35
+ "Install them with: pip install torch transformers accelerate"
36
+ ) from exc
37
+
38
+ self.torch = torch
39
+ self.AutoTokenizer = AutoTokenizer
40
+ self.AutoModelForCausalLM = AutoModelForCausalLM
41
+ self.GenerationConfig = GenerationConfig
42
+
43
+ self.tokenizer = AutoTokenizer.from_pretrained(
44
+ model, trust_remote_code=trust_remote_code
45
+ )
46
+ if self.tokenizer.pad_token is None:
47
+ self.tokenizer.pad_token = self.tokenizer.eos_token
48
+
49
+ self.model = AutoModelForCausalLM.from_pretrained(
50
+ model,
51
+ torch_dtype=torch_dtype,
52
+ device_map=device_map,
53
+ trust_remote_code=trust_remote_code,
54
+ )
55
+ self.model.eval()
56
+ self.temperature = temperature
57
+ self.top_p = top_p
58
+ self.topk = topk
59
+
60
+ @staticmethod
61
+ def _build_inputs(prompt: str, history: Optional[List[str]] = None) -> str:
62
+ msgs = history or []
63
+ lines = []
64
+ for m in msgs:
65
+ if isinstance(m, dict):
66
+ role = m.get("role", "")
67
+ content = m.get("content", "")
68
+ lines.append(f"{role}: {content}")
69
+ else:
70
+ lines.append(str(m))
71
+ lines.append(prompt)
72
+ return "\n".join(lines)
73
+
74
+ async def generate_answer(
75
+ self, text: str, history: Optional[List[str]] = None, **extra: Any
76
+ ) -> str:
77
+ full = self._build_inputs(text, history)
78
+ inputs = self.tokenizer(full, return_tensors="pt").to(self.model.device)
79
+
80
+ gen_kwargs = {
81
+ "max_new_tokens": extra.get("max_new_tokens", 512),
82
+ "do_sample": self.temperature > 0,
83
+ "temperature": self.temperature if self.temperature > 0 else 1.0,
84
+ "pad_token_id": self.tokenizer.eos_token_id,
85
+ }
86
+
87
+ # Add top_p and top_k only if temperature > 0
88
+ if self.temperature > 0:
89
+ gen_kwargs.update(top_p=self.top_p, top_k=self.topk)
90
+
91
+ gen_config = self.GenerationConfig(**gen_kwargs)
92
+
93
+ with self.torch.no_grad():
94
+ out = self.model.generate(**inputs, generation_config=gen_config)
95
+
96
+ gen = out[0, inputs.input_ids.shape[-1] :]
97
+ return self.tokenizer.decode(gen, skip_special_tokens=True)
98
+
99
+ async def generate_topk_per_token(
100
+ self, text: str, history: Optional[List[str]] = None, **extra: Any
101
+ ) -> List[Token]:
102
+ full = self._build_inputs(text, history)
103
+ inputs = self.tokenizer(full, return_tensors="pt").to(self.model.device)
104
+
105
+ with self.torch.no_grad():
106
+ out = self.model.generate(
107
+ **inputs,
108
+ max_new_tokens=1,
109
+ do_sample=False,
110
+ temperature=1.0,
111
+ return_dict_in_generate=True,
112
+ output_scores=True,
113
+ pad_token_id=self.tokenizer.eos_token_id,
114
+ )
115
+
116
+ scores = out.scores[0][0] # (vocab,)
117
+ probs = self.torch.softmax(scores, dim=-1)
118
+ top_probs, top_idx = self.torch.topk(probs, k=self.topk)
119
+
120
+ tokens = []
121
+ for p, idx in zip(top_probs.cpu().numpy(), top_idx.cpu().numpy()):
122
+ tokens.append(Token(self.tokenizer.decode([idx]), float(p)))
123
+ return tokens
124
+
125
+ async def generate_inputs_prob(
126
+ self, text: str, history: Optional[List[str]] = None, **extra: Any
127
+ ) -> List[Token]:
128
+ full = self._build_inputs(text, history)
129
+ ids = self.tokenizer.encode(full)
130
+ logprobs = []
131
+
132
+ for i in range(1, len(ids) + 1):
133
+ trunc = ids[: i - 1] + ids[i:] if i < len(ids) else ids[:-1]
134
+ inputs = self.torch.tensor([trunc]).to(self.model.device)
135
+
136
+ with self.torch.no_grad():
137
+ logits = self.model(inputs).logits[0, -1, :]
138
+ probs = self.torch.softmax(logits, dim=-1)
139
+
140
+ true_id = ids[i - 1]
141
+ logprobs.append(
142
+ Token(
143
+ self.tokenizer.decode([true_id]),
144
+ float(probs[true_id].cpu()),
145
+ )
146
+ )
147
+ return logprobs
graphgen/models/llm/local/sglang_wrapper.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import Any, Dict, List, Optional
3
+
4
+ from graphgen.bases.base_llm_wrapper import BaseLLMWrapper
5
+ from graphgen.bases.datatypes import Token
6
+
7
+
8
+ class SGLangWrapper(BaseLLMWrapper):
9
+ """
10
+ Async inference backend based on SGLang offline engine.
11
+ """
12
+
13
+ def __init__(
14
+ self,
15
+ model: str,
16
+ temperature: float = 0.0,
17
+ top_p: float = 1.0,
18
+ topk: int = 5,
19
+ **kwargs: Any,
20
+ ):
21
+ super().__init__(temperature=temperature, top_p=top_p, **kwargs)
22
+ try:
23
+ import sglang as sgl
24
+ from sglang.utils import async_stream_and_merge, stream_and_merge
25
+ except ImportError as exc:
26
+ raise ImportError(
27
+ "SGLangWrapper requires sglang. Install it with: "
28
+ "uv pip install sglang --prerelease=allow"
29
+ ) from exc
30
+
31
+ self.model_path: str = model
32
+ self.temperature = temperature
33
+ self.top_p = top_p
34
+ self.topk = topk
35
+
36
+ # Initialise the offline engine
37
+ self.engine = sgl.Engine(model_path=self.model_path)
38
+
39
+ # Keep helpers for streaming
40
+ self.async_stream_and_merge = async_stream_and_merge
41
+ self.stream_and_merge = stream_and_merge
42
+
43
+ @staticmethod
44
+ def _build_sampling_params(
45
+ temperature: float,
46
+ top_p: float,
47
+ max_tokens: int,
48
+ topk: int,
49
+ logprobs: bool = False,
50
+ ) -> Dict[str, Any]:
51
+ """Build SGLang-compatible sampling-params dict."""
52
+ params = {
53
+ "temperature": temperature,
54
+ "top_p": top_p,
55
+ "max_new_tokens": max_tokens,
56
+ }
57
+ if logprobs and topk > 0:
58
+ params["logprobs"] = topk
59
+ return params
60
+
61
+ def _prep_prompt(self, text: str, history: Optional[List[dict]] = None) -> str:
62
+ """Convert raw text (+ optional history) into a single prompt string."""
63
+ parts = []
64
+ if self.system_prompt:
65
+ parts.append(self.system_prompt)
66
+ if history:
67
+ assert len(history) % 2 == 0, "History must have even length (u/a turns)."
68
+ parts.extend([item["content"] for item in history])
69
+ parts.append(text)
70
+ return "\n".join(parts)
71
+
72
+ def _tokens_from_output(self, output: Dict[str, Any]) -> List[Token]:
73
+ tokens: List[Token] = []
74
+
75
+ meta = output.get("meta_info", {})
76
+ logprobs = meta.get("output_token_logprobs", [])
77
+ topks = meta.get("output_top_logprobs", [])
78
+
79
+ tokenizer = self.engine.tokenizer_manager.tokenizer
80
+
81
+ for idx, (lp, tid, _) in enumerate(logprobs):
82
+ prob = math.exp(lp)
83
+ tok_str = tokenizer.decode([tid])
84
+
85
+ top_candidates = []
86
+ if self.topk > 0 and idx < len(topks):
87
+ for t_lp, t_tid, _ in topks[idx][: self.topk]:
88
+ top_candidates.append(
89
+ Token(text=tokenizer.decode([t_tid]), prob=math.exp(t_lp))
90
+ )
91
+
92
+ tokens.append(Token(text=tok_str, prob=prob, top_candidates=top_candidates))
93
+
94
+ return tokens
95
+
96
+ async def generate_answer(
97
+ self,
98
+ text: str,
99
+ history: Optional[List[str]] = None,
100
+ **extra: Any,
101
+ ) -> str:
102
+ prompt = self._prep_prompt(text, history)
103
+ sampling_params = self._build_sampling_params(
104
+ temperature=self.temperature,
105
+ top_p=self.top_p,
106
+ max_tokens=self.max_tokens,
107
+ topk=0, # no logprobs needed for simple generation
108
+ )
109
+
110
+ outputs = await self.engine.async_generate([prompt], sampling_params)
111
+ return self.filter_think_tags(outputs[0]["text"])
112
+
113
+ async def generate_topk_per_token(
114
+ self,
115
+ text: str,
116
+ history: Optional[List[str]] = None,
117
+ **extra: Any,
118
+ ) -> List[Token]:
119
+ prompt = self._prep_prompt(text, history)
120
+ sampling_params = self._build_sampling_params(
121
+ temperature=self.temperature,
122
+ top_p=self.top_p,
123
+ max_tokens=1, # keep short for token-level analysis
124
+ topk=self.topk,
125
+ )
126
+
127
+ outputs = await self.engine.async_generate(
128
+ [prompt], sampling_params, return_logprob=True, top_logprobs_num=5
129
+ )
130
+ print(outputs)
131
+ return self._tokens_from_output(outputs[0])
132
+
133
+ async def generate_inputs_prob(
134
+ self, text: str, history: Optional[List[str]] = None, **extra: Any
135
+ ) -> List[Token]:
136
+ raise NotImplementedError(
137
+ "SGLangWrapper does not support per-token logprobs yet."
138
+ )
139
+
140
+ def shutdown(self) -> None:
141
+ """Gracefully shutdown the SGLang engine."""
142
+ if hasattr(self, "engine"):
143
+ self.engine.shutdown()
144
+
145
+ def restart(self) -> None:
146
+ """Restart the SGLang engine."""
147
+ self.shutdown()
148
+ self.engine = self.engine.__class__(model_path=self.model_path)
graphgen/models/llm/local/tgi_wrapper.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, List, Optional
2
+
3
+ from graphgen.bases import BaseLLMWrapper
4
+ from graphgen.bases.datatypes import Token
5
+
6
+
7
+ # TODO: implement TGIWrapper methods
8
+ class TGIWrapper(BaseLLMWrapper):
9
+ """
10
+ Async inference backend based on TGI (Text-Generation-Inference)
11
+ """
12
+
13
+ def __init__(
14
+ self,
15
+ model_url: str, # e.g. "http://localhost:8080"
16
+ temperature: float = 0.0,
17
+ top_p: float = 1.0,
18
+ topk: int = 5,
19
+ **kwargs: Any
20
+ ):
21
+ super().__init__(temperature=temperature, top_p=top_p, **kwargs)
22
+
23
+ async def generate_answer(
24
+ self, text: str, history: Optional[List[str]] = None, **extra: Any
25
+ ) -> str:
26
+ pass
27
+
28
+ async def generate_topk_per_token(
29
+ self, text: str, history: Optional[List[str]] = None, **extra: Any
30
+ ) -> List[Token]:
31
+ pass
32
+
33
+ async def generate_inputs_prob(
34
+ self, text: str, history: Optional[List[str]] = None, **extra: Any
35
+ ) -> List[Token]:
36
+ pass
graphgen/models/llm/{ollama_client.py → local/trt_wrapper.py} RENAMED
@@ -1,10 +1,15 @@
1
- # TODO: implement ollama client
2
  from typing import Any, List, Optional
3
 
4
- from graphgen.bases import BaseLLMClient, Token
 
5
 
6
 
7
- class OllamaClient(BaseLLMClient):
 
 
 
 
 
8
  async def generate_answer(
9
  self, text: str, history: Optional[List[str]] = None, **extra: Any
10
  ) -> str:
 
 
1
  from typing import Any, List, Optional
2
 
3
+ from graphgen.bases import BaseLLMWrapper
4
+ from graphgen.bases.datatypes import Token
5
 
6
 
7
+ # TODO: implement TensorRTWrapper methods
8
+ class TensorRTWrapper(BaseLLMWrapper):
9
+ """
10
+ Async inference backend based on TensorRT-LLM
11
+ """
12
+
13
  async def generate_answer(
14
  self, text: str, history: Optional[List[str]] = None, **extra: Any
15
  ) -> str:
graphgen/models/llm/local/vllm_wrapper.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, List, Optional
2
+
3
+ from graphgen.bases.base_llm_wrapper import BaseLLMWrapper
4
+ from graphgen.bases.datatypes import Token
5
+
6
+
7
+ class VLLMWrapper(BaseLLMWrapper):
8
+ """
9
+ Async inference backend based on vLLM (https://github.com/vllm-project/vllm)
10
+ """
11
+
12
+ def __init__(
13
+ self,
14
+ model: str,
15
+ tensor_parallel_size: int = 1,
16
+ gpu_memory_utilization: float = 0.9,
17
+ temperature: float = 0.0,
18
+ top_p: float = 1.0,
19
+ topk: int = 5,
20
+ **kwargs: Any,
21
+ ):
22
+ super().__init__(temperature=temperature, top_p=top_p, **kwargs)
23
+
24
+ try:
25
+ from vllm import AsyncEngineArgs, AsyncLLMEngine, SamplingParams
26
+ except ImportError as exc:
27
+ raise ImportError(
28
+ "VLLMWrapper requires vllm. Install it with: uv pip install vllm --torch-backend=auto"
29
+ ) from exc
30
+
31
+ self.SamplingParams = SamplingParams
32
+
33
+ engine_args = AsyncEngineArgs(
34
+ model=model,
35
+ tensor_parallel_size=tensor_parallel_size,
36
+ gpu_memory_utilization=gpu_memory_utilization,
37
+ trust_remote_code=kwargs.get("trust_remote_code", True),
38
+ )
39
+ self.engine = AsyncLLMEngine.from_engine_args(engine_args)
40
+
41
+ self.temperature = temperature
42
+ self.top_p = top_p
43
+ self.topk = topk
44
+
45
+ @staticmethod
46
+ def _build_inputs(prompt: str, history: Optional[List[str]] = None) -> str:
47
+ msgs = history or []
48
+ lines = []
49
+ for m in msgs:
50
+ if isinstance(m, dict):
51
+ role = m.get("role", "")
52
+ content = m.get("content", "")
53
+ lines.append(f"{role}: {content}")
54
+ else:
55
+ lines.append(str(m))
56
+ lines.append(prompt)
57
+ return "\n".join(lines)
58
+
59
+ async def generate_answer(
60
+ self, text: str, history: Optional[List[str]] = None, **extra: Any
61
+ ) -> str:
62
+ full_prompt = self._build_inputs(text, history)
63
+
64
+ sp = self.SamplingParams(
65
+ temperature=self.temperature if self.temperature > 0 else 1.0,
66
+ top_p=self.top_p if self.temperature > 0 else 1.0,
67
+ max_tokens=extra.get("max_new_tokens", 512),
68
+ )
69
+
70
+ results = []
71
+ async for req_output in self.engine.generate(
72
+ full_prompt, sp, request_id="graphgen_req"
73
+ ):
74
+ results = req_output.outputs
75
+ return results[-1].text
76
+
77
+ async def generate_topk_per_token(
78
+ self, text: str, history: Optional[List[str]] = None, **extra: Any
79
+ ) -> List[Token]:
80
+ full_prompt = self._build_inputs(text, history)
81
+
82
+ sp = self.SamplingParams(
83
+ temperature=0,
84
+ max_tokens=1,
85
+ logprobs=self.topk,
86
+ )
87
+
88
+ results = []
89
+ async for req_output in self.engine.generate(
90
+ full_prompt, sp, request_id="graphgen_topk"
91
+ ):
92
+ results = req_output.outputs
93
+ top_logprobs = results[-1].logprobs[0]
94
+
95
+ tokens = []
96
+ for _, logprob_obj in top_logprobs.items():
97
+ tok_str = logprob_obj.decoded_token
98
+ prob = float(logprob_obj.logprob.exp())
99
+ tokens.append(Token(tok_str, prob))
100
+ tokens.sort(key=lambda x: -x.prob)
101
+ return tokens
102
+
103
+ async def generate_inputs_prob(
104
+ self, text: str, history: Optional[List[str]] = None, **extra: Any
105
+ ) -> List[Token]:
106
+ full_prompt = self._build_inputs(text, history)
107
+
108
+ # vLLM 没有现成的“mask 一个 token 再算 prob”接口,
109
+ # 我们采用最直观的方式:把 prompt 一次性送进去,打开
110
+ # prompt_logprobs=True,让 vLLM 返回 *输入部分* 每个位置的
111
+ # logprob,然后挑出对应 token 的概率即可。
112
+ sp = self.SamplingParams(
113
+ temperature=0,
114
+ max_tokens=0, # 不生成新 token
115
+ prompt_logprobs=1, # 只要 top-1 就够了
116
+ )
117
+
118
+ results = []
119
+ async for req_output in self.engine.generate(
120
+ full_prompt, sp, request_id="graphgen_prob"
121
+ ):
122
+ results = req_output.outputs
123
+
124
+ # prompt_logprobs 是一个 list,长度 = prompt token 数,
125
+ # 每个元素是 dict{token_id: logprob_obj} 或 None(首个位置为 None)
126
+ prompt_logprobs = results[-1].prompt_logprobs
127
+
128
+ tokens = []
129
+ for _, logprob_dict in enumerate(prompt_logprobs):
130
+ if logprob_dict is None:
131
+ continue
132
+ # 这里每个 dict 只有 1 个 kv,因为 top-1
133
+ _, logprob_obj = next(iter(logprob_dict.items()))
134
+ tok_str = logprob_obj.decoded_token
135
+ prob = float(logprob_obj.logprob.exp())
136
+ tokens.append(Token(tok_str, prob))
137
+ return tokens
graphgen/models/llm/topk_token_model.py DELETED
@@ -1,53 +0,0 @@
1
- from abc import ABC, abstractmethod
2
- from typing import List, Optional
3
-
4
- from graphgen.bases import Token
5
-
6
-
7
- class TopkTokenModel(ABC):
8
- def __init__(
9
- self,
10
- do_sample: bool = False,
11
- temperature: float = 0,
12
- max_tokens: int = 4096,
13
- repetition_penalty: float = 1.05,
14
- num_beams: int = 1,
15
- topk: int = 50,
16
- topp: float = 0.95,
17
- topk_per_token: int = 5,
18
- ):
19
- self.do_sample = do_sample
20
- self.temperature = temperature
21
- self.max_tokens = max_tokens
22
- self.repetition_penalty = repetition_penalty
23
- self.num_beams = num_beams
24
- self.topk = topk
25
- self.topp = topp
26
- self.topk_per_token = topk_per_token
27
-
28
- @abstractmethod
29
- async def generate_topk_per_token(self, text: str) -> List[Token]:
30
- """
31
- Generate prob, text and candidates for each token of the model's output.
32
- This function is used to visualize the inference process.
33
- """
34
- raise NotImplementedError
35
-
36
- @abstractmethod
37
- async def generate_inputs_prob(
38
- self, text: str, history: Optional[List[str]] = None
39
- ) -> List[Token]:
40
- """
41
- Generate prob and text for each token of the input text.
42
- This function is used to visualize the ppl.
43
- """
44
- raise NotImplementedError
45
-
46
- @abstractmethod
47
- async def generate_answer(
48
- self, text: str, history: Optional[List[str]] = None
49
- ) -> str:
50
- """
51
- Generate answer from the model.
52
- """
53
- raise NotImplementedError
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
graphgen/operators/__init__.py CHANGED
@@ -1,5 +1,6 @@
1
  from .build_kg import build_mm_kg, build_text_kg
2
  from .generate import generate_qas
 
3
  from .judge import judge_statement
4
  from .partition import partition_kg
5
  from .quiz import quiz
 
1
  from .build_kg import build_mm_kg, build_text_kg
2
  from .generate import generate_qas
3
+ from .init import init_llm
4
  from .judge import judge_statement
5
  from .partition import partition_kg
6
  from .quiz import quiz
graphgen/operators/build_kg/build_mm_kg.py CHANGED
@@ -3,14 +3,15 @@ from typing import List
3
 
4
  import gradio as gr
5
 
 
6
  from graphgen.bases.base_storage import BaseGraphStorage
7
  from graphgen.bases.datatypes import Chunk
8
- from graphgen.models import MMKGBuilder, OpenAIClient
9
  from graphgen.utils import run_concurrent
10
 
11
 
12
  async def build_mm_kg(
13
- llm_client: OpenAIClient,
14
  kg_instance: BaseGraphStorage,
15
  chunks: List[Chunk],
16
  progress_bar: gr.Progress = None,
 
3
 
4
  import gradio as gr
5
 
6
+ from graphgen.bases import BaseLLMWrapper
7
  from graphgen.bases.base_storage import BaseGraphStorage
8
  from graphgen.bases.datatypes import Chunk
9
+ from graphgen.models import MMKGBuilder
10
  from graphgen.utils import run_concurrent
11
 
12
 
13
  async def build_mm_kg(
14
+ llm_client: BaseLLMWrapper,
15
  kg_instance: BaseGraphStorage,
16
  chunks: List[Chunk],
17
  progress_bar: gr.Progress = None,
graphgen/operators/build_kg/build_text_kg.py CHANGED
@@ -3,14 +3,15 @@ from typing import List
3
 
4
  import gradio as gr
5
 
 
6
  from graphgen.bases.base_storage import BaseGraphStorage
7
  from graphgen.bases.datatypes import Chunk
8
- from graphgen.models import LightRAGKGBuilder, OpenAIClient
9
  from graphgen.utils import run_concurrent
10
 
11
 
12
  async def build_text_kg(
13
- llm_client: OpenAIClient,
14
  kg_instance: BaseGraphStorage,
15
  chunks: List[Chunk],
16
  progress_bar: gr.Progress = None,
 
3
 
4
  import gradio as gr
5
 
6
+ from graphgen.bases import BaseLLMWrapper
7
  from graphgen.bases.base_storage import BaseGraphStorage
8
  from graphgen.bases.datatypes import Chunk
9
+ from graphgen.models import LightRAGKGBuilder
10
  from graphgen.utils import run_concurrent
11
 
12
 
13
  async def build_text_kg(
14
+ llm_client: BaseLLMWrapper,
15
  kg_instance: BaseGraphStorage,
16
  chunks: List[Chunk],
17
  progress_bar: gr.Progress = None,
graphgen/operators/generate/generate_qas.py CHANGED
@@ -1,6 +1,6 @@
1
  from typing import Any
2
 
3
- from graphgen.bases import BaseLLMClient
4
  from graphgen.models import (
5
  AggregatedGenerator,
6
  AtomicGenerator,
@@ -12,7 +12,7 @@ from graphgen.utils import logger, run_concurrent
12
 
13
 
14
  async def generate_qas(
15
- llm_client: BaseLLMClient,
16
  batches: list[
17
  tuple[
18
  list[tuple[str, dict]], list[tuple[Any, Any, dict] | tuple[Any, Any, Any]]
 
1
  from typing import Any
2
 
3
+ from graphgen.bases import BaseLLMWrapper
4
  from graphgen.models import (
5
  AggregatedGenerator,
6
  AtomicGenerator,
 
12
 
13
 
14
  async def generate_qas(
15
+ llm_client: BaseLLMWrapper,
16
  batches: list[
17
  tuple[
18
  list[tuple[str, dict]], list[tuple[Any, Any, dict] | tuple[Any, Any, Any]]
graphgen/operators/init/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .init_llm import init_llm
graphgen/operators/init/init_llm.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Any, Dict, Optional
3
+
4
+ from graphgen.bases import BaseLLMWrapper
5
+ from graphgen.models import Tokenizer
6
+
7
+
8
+ class LLMFactory:
9
+ """
10
+ A factory class to create LLM wrapper instances based on the specified backend.
11
+ Supported backends include:
12
+ - http_api: HTTPClient
13
+ - openai_api: OpenAIClient
14
+ - ollama_api: OllamaClient
15
+ - ollama: OllamaWrapper
16
+ - deepspeed: DeepSpeedWrapper
17
+ - huggingface: HuggingFaceWrapper
18
+ - tgi: TGIWrapper
19
+ - sglang: SGLangWrapper
20
+ - tensorrt: TensorRTWrapper
21
+ """
22
+
23
+ @staticmethod
24
+ def create_llm_wrapper(backend: str, config: Dict[str, Any]) -> BaseLLMWrapper:
25
+ # add tokenizer
26
+ tokenizer: Tokenizer = Tokenizer(
27
+ os.environ.get("TOKENIZER_MODEL", "cl100k_base"),
28
+ )
29
+ config["tokenizer"] = tokenizer
30
+ if backend == "http_api":
31
+ from graphgen.models.llm.api.http_client import HTTPClient
32
+
33
+ return HTTPClient(**config)
34
+ if backend == "openai_api":
35
+ from graphgen.models.llm.api.openai_client import OpenAIClient
36
+
37
+ return OpenAIClient(**config)
38
+ if backend == "ollama_api":
39
+ from graphgen.models.llm.api.ollama_client import OllamaClient
40
+
41
+ return OllamaClient(**config)
42
+ if backend == "huggingface":
43
+ from graphgen.models.llm.local.hf_wrapper import HuggingFaceWrapper
44
+
45
+ return HuggingFaceWrapper(**config)
46
+ # if backend == "sglang":
47
+ # from graphgen.models.llm.local.sglang_wrapper import SGLangWrapper
48
+ #
49
+ # return SGLangWrapper(**config)
50
+
51
+ if backend == "vllm":
52
+ from graphgen.models.llm.local.vllm_wrapper import VLLMWrapper
53
+
54
+ return VLLMWrapper(**config)
55
+
56
+ raise NotImplementedError(f"Backend {backend} is not implemented yet.")
57
+
58
+
59
+ def _load_env_group(prefix: str) -> Dict[str, Any]:
60
+ """
61
+ Collect environment variables with the given prefix into a dictionary,
62
+ stripping the prefix from the keys.
63
+ """
64
+ return {
65
+ k[len(prefix) :].lower(): v
66
+ for k, v in os.environ.items()
67
+ if k.startswith(prefix)
68
+ }
69
+
70
+
71
+ def init_llm(model_type: str) -> Optional[BaseLLMWrapper]:
72
+ if model_type == "synthesizer":
73
+ prefix = "SYNTHESIZER_"
74
+ elif model_type == "trainee":
75
+ prefix = "TRAINEE_"
76
+ else:
77
+ raise NotImplementedError(f"Model type {model_type} is not implemented yet.")
78
+ config = _load_env_group(prefix)
79
+ # if config is empty, return None
80
+ if not config:
81
+ return None
82
+ backend = config.pop("backend")
83
+ llm_wrapper = LLMFactory.create_llm_wrapper(backend, config)
84
+ return llm_wrapper
graphgen/operators/judge.py CHANGED
@@ -3,13 +3,14 @@ import math
3
 
4
  from tqdm.asyncio import tqdm as tqdm_async
5
 
6
- from graphgen.models import JsonKVStorage, NetworkXStorage, OpenAIClient
 
7
  from graphgen.templates import STATEMENT_JUDGEMENT_PROMPT
8
  from graphgen.utils import logger, yes_no_loss_entropy
9
 
10
 
11
  async def judge_statement( # pylint: disable=too-many-statements
12
- trainee_llm_client: OpenAIClient,
13
  graph_storage: NetworkXStorage,
14
  rephrase_storage: JsonKVStorage,
15
  re_judge: bool = False,
 
3
 
4
  from tqdm.asyncio import tqdm as tqdm_async
5
 
6
+ from graphgen.bases import BaseLLMWrapper
7
+ from graphgen.models import JsonKVStorage, NetworkXStorage
8
  from graphgen.templates import STATEMENT_JUDGEMENT_PROMPT
9
  from graphgen.utils import logger, yes_no_loss_entropy
10
 
11
 
12
  async def judge_statement( # pylint: disable=too-many-statements
13
+ trainee_llm_client: BaseLLMWrapper,
14
  graph_storage: NetworkXStorage,
15
  rephrase_storage: JsonKVStorage,
16
  re_judge: bool = False,
graphgen/operators/quiz.py CHANGED
@@ -3,13 +3,14 @@ from collections import defaultdict
3
 
4
  from tqdm.asyncio import tqdm as tqdm_async
5
 
6
- from graphgen.models import JsonKVStorage, NetworkXStorage, OpenAIClient
 
7
  from graphgen.templates import DESCRIPTION_REPHRASING_PROMPT
8
  from graphgen.utils import detect_main_language, logger
9
 
10
 
11
  async def quiz(
12
- synth_llm_client: OpenAIClient,
13
  graph_storage: NetworkXStorage,
14
  rephrase_storage: JsonKVStorage,
15
  max_samples: int = 1,
 
3
 
4
  from tqdm.asyncio import tqdm as tqdm_async
5
 
6
+ from graphgen.bases import BaseLLMWrapper
7
+ from graphgen.models import JsonKVStorage, NetworkXStorage
8
  from graphgen.templates import DESCRIPTION_REPHRASING_PROMPT
9
  from graphgen.utils import detect_main_language, logger
10
 
11
 
12
  async def quiz(
13
+ synth_llm_client: BaseLLMWrapper,
14
  graph_storage: NetworkXStorage,
15
  rephrase_storage: JsonKVStorage,
16
  max_samples: int = 1,
requirements.txt CHANGED
@@ -25,4 +25,4 @@ igraph
25
  python-louvain
26
 
27
  # For visualization
28
- matplotlib
 
25
  python-louvain
26
 
27
  # For visualization
28
+ matplotlib