Spaces:
Sleeping
Sleeping
github-actions[bot]
commited on
Commit
·
d02622b
1
Parent(s):
d2af8b0
Auto-sync from demo at Wed Oct 29 11:25:28 UTC 2025
Browse files- graphgen/bases/__init__.py +1 -1
- graphgen/bases/base_generator.py +2 -2
- graphgen/bases/base_kg_builder.py +2 -2
- graphgen/bases/{base_llm_client.py → base_llm_wrapper.py} +7 -1
- graphgen/graphgen.py +35 -30
- graphgen/models/__init__.py +1 -2
- graphgen/models/kg_builder/light_rag_kg_builder.py +2 -2
- graphgen/models/llm/__init__.py +4 -0
- graphgen/models/llm/api/__init__.py +0 -0
- graphgen/models/llm/api/http_client.py +197 -0
- graphgen/models/llm/api/ollama_client.py +105 -0
- graphgen/models/llm/{openai_client.py → api/openai_client.py} +4 -4
- graphgen/models/llm/local/__init__.py +0 -0
- graphgen/models/llm/local/hf_wrapper.py +147 -0
- graphgen/models/llm/local/sglang_wrapper.py +148 -0
- graphgen/models/llm/local/tgi_wrapper.py +36 -0
- graphgen/models/llm/{ollama_client.py → local/trt_wrapper.py} +8 -3
- graphgen/models/llm/local/vllm_wrapper.py +137 -0
- graphgen/models/llm/topk_token_model.py +0 -53
- graphgen/operators/__init__.py +1 -0
- graphgen/operators/build_kg/build_mm_kg.py +3 -2
- graphgen/operators/build_kg/build_text_kg.py +3 -2
- graphgen/operators/generate/generate_qas.py +2 -2
- graphgen/operators/init/__init__.py +1 -0
- graphgen/operators/init/init_llm.py +84 -0
- graphgen/operators/judge.py +3 -2
- graphgen/operators/quiz.py +3 -2
- requirements.txt +1 -1
graphgen/bases/__init__.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
| 1 |
from .base_generator import BaseGenerator
|
| 2 |
from .base_kg_builder import BaseKGBuilder
|
| 3 |
-
from .
|
| 4 |
from .base_partitioner import BasePartitioner
|
| 5 |
from .base_reader import BaseReader
|
| 6 |
from .base_splitter import BaseSplitter
|
|
|
|
| 1 |
from .base_generator import BaseGenerator
|
| 2 |
from .base_kg_builder import BaseKGBuilder
|
| 3 |
+
from .base_llm_wrapper import BaseLLMWrapper
|
| 4 |
from .base_partitioner import BasePartitioner
|
| 5 |
from .base_reader import BaseReader
|
| 6 |
from .base_splitter import BaseSplitter
|
graphgen/bases/base_generator.py
CHANGED
|
@@ -1,7 +1,7 @@
|
|
| 1 |
from abc import ABC, abstractmethod
|
| 2 |
from typing import Any
|
| 3 |
|
| 4 |
-
from graphgen.bases.
|
| 5 |
|
| 6 |
|
| 7 |
class BaseGenerator(ABC):
|
|
@@ -9,7 +9,7 @@ class BaseGenerator(ABC):
|
|
| 9 |
Generate QAs based on given prompts.
|
| 10 |
"""
|
| 11 |
|
| 12 |
-
def __init__(self, llm_client:
|
| 13 |
self.llm_client = llm_client
|
| 14 |
|
| 15 |
@staticmethod
|
|
|
|
| 1 |
from abc import ABC, abstractmethod
|
| 2 |
from typing import Any
|
| 3 |
|
| 4 |
+
from graphgen.bases.base_llm_wrapper import BaseLLMWrapper
|
| 5 |
|
| 6 |
|
| 7 |
class BaseGenerator(ABC):
|
|
|
|
| 9 |
Generate QAs based on given prompts.
|
| 10 |
"""
|
| 11 |
|
| 12 |
+
def __init__(self, llm_client: BaseLLMWrapper):
|
| 13 |
self.llm_client = llm_client
|
| 14 |
|
| 15 |
@staticmethod
|
graphgen/bases/base_kg_builder.py
CHANGED
|
@@ -2,13 +2,13 @@ from abc import ABC, abstractmethod
|
|
| 2 |
from collections import defaultdict
|
| 3 |
from typing import Dict, List, Tuple
|
| 4 |
|
| 5 |
-
from graphgen.bases.
|
| 6 |
from graphgen.bases.base_storage import BaseGraphStorage
|
| 7 |
from graphgen.bases.datatypes import Chunk
|
| 8 |
|
| 9 |
|
| 10 |
class BaseKGBuilder(ABC):
|
| 11 |
-
def __init__(self, llm_client:
|
| 12 |
self.llm_client = llm_client
|
| 13 |
self._nodes: Dict[str, List[dict]] = defaultdict(list)
|
| 14 |
self._edges: Dict[Tuple[str, str], List[dict]] = defaultdict(list)
|
|
|
|
| 2 |
from collections import defaultdict
|
| 3 |
from typing import Dict, List, Tuple
|
| 4 |
|
| 5 |
+
from graphgen.bases.base_llm_wrapper import BaseLLMWrapper
|
| 6 |
from graphgen.bases.base_storage import BaseGraphStorage
|
| 7 |
from graphgen.bases.datatypes import Chunk
|
| 8 |
|
| 9 |
|
| 10 |
class BaseKGBuilder(ABC):
|
| 11 |
+
def __init__(self, llm_client: BaseLLMWrapper):
|
| 12 |
self.llm_client = llm_client
|
| 13 |
self._nodes: Dict[str, List[dict]] = defaultdict(list)
|
| 14 |
self._edges: Dict[Tuple[str, str], List[dict]] = defaultdict(list)
|
graphgen/bases/{base_llm_client.py → base_llm_wrapper.py}
RENAMED
|
@@ -8,7 +8,7 @@ from graphgen.bases.base_tokenizer import BaseTokenizer
|
|
| 8 |
from graphgen.bases.datatypes import Token
|
| 9 |
|
| 10 |
|
| 11 |
-
class
|
| 12 |
"""
|
| 13 |
LLM client base class, agnostic to specific backends (OpenAI / Ollama / ...).
|
| 14 |
"""
|
|
@@ -66,3 +66,9 @@ class BaseLLMClient(abc.ABC):
|
|
| 66 |
think_pattern = re.compile(rf"<{think_tag}>.*?</{think_tag}>", re.DOTALL)
|
| 67 |
filtered_text = think_pattern.sub("", text).strip()
|
| 68 |
return filtered_text if filtered_text else text.strip()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
from graphgen.bases.datatypes import Token
|
| 9 |
|
| 10 |
|
| 11 |
+
class BaseLLMWrapper(abc.ABC):
|
| 12 |
"""
|
| 13 |
LLM client base class, agnostic to specific backends (OpenAI / Ollama / ...).
|
| 14 |
"""
|
|
|
|
| 66 |
think_pattern = re.compile(rf"<{think_tag}>.*?</{think_tag}>", re.DOTALL)
|
| 67 |
filtered_text = think_pattern.sub("", text).strip()
|
| 68 |
return filtered_text if filtered_text else text.strip()
|
| 69 |
+
|
| 70 |
+
def shutdown(self) -> None:
|
| 71 |
+
"""Shutdown the LLM engine if applicable."""
|
| 72 |
+
|
| 73 |
+
def restart(self) -> None:
|
| 74 |
+
"""Reinitialize the LLM engine if applicable."""
|
graphgen/graphgen.py
CHANGED
|
@@ -1,11 +1,11 @@
|
|
| 1 |
import asyncio
|
| 2 |
import os
|
| 3 |
import time
|
| 4 |
-
from dataclasses import dataclass
|
| 5 |
from typing import Dict, cast
|
| 6 |
|
| 7 |
import gradio as gr
|
| 8 |
|
|
|
|
| 9 |
from graphgen.bases.base_storage import StorageNameSpace
|
| 10 |
from graphgen.bases.datatypes import Chunk
|
| 11 |
from graphgen.models import (
|
|
@@ -20,6 +20,7 @@ from graphgen.operators import (
|
|
| 20 |
build_text_kg,
|
| 21 |
chunk_documents,
|
| 22 |
generate_qas,
|
|
|
|
| 23 |
judge_statement,
|
| 24 |
partition_kg,
|
| 25 |
quiz,
|
|
@@ -31,40 +32,28 @@ from graphgen.utils import async_to_sync_method, compute_mm_hash, logger
|
|
| 31 |
sys_path = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
| 32 |
|
| 33 |
|
| 34 |
-
@dataclass
|
| 35 |
class GraphGen:
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
|
|
|
| 49 |
model_name=os.getenv("TOKENIZER_MODEL")
|
| 50 |
)
|
| 51 |
|
| 52 |
-
self.synthesizer_llm_client:
|
| 53 |
-
|
| 54 |
-
or OpenAIClient(
|
| 55 |
-
model_name=os.getenv("SYNTHESIZER_MODEL"),
|
| 56 |
-
api_key=os.getenv("SYNTHESIZER_API_KEY"),
|
| 57 |
-
base_url=os.getenv("SYNTHESIZER_BASE_URL"),
|
| 58 |
-
tokenizer=self.tokenizer_instance,
|
| 59 |
-
)
|
| 60 |
-
)
|
| 61 |
-
|
| 62 |
-
self.trainee_llm_client: OpenAIClient = self.trainee_llm_client or OpenAIClient(
|
| 63 |
-
model_name=os.getenv("TRAINEE_MODEL"),
|
| 64 |
-
api_key=os.getenv("TRAINEE_API_KEY"),
|
| 65 |
-
base_url=os.getenv("TRAINEE_BASE_URL"),
|
| 66 |
-
tokenizer=self.tokenizer_instance,
|
| 67 |
)
|
|
|
|
| 68 |
|
| 69 |
self.full_docs_storage: JsonKVStorage = JsonKVStorage(
|
| 70 |
self.working_dir, namespace="full_docs"
|
|
@@ -86,6 +75,9 @@ class GraphGen:
|
|
| 86 |
namespace="qa",
|
| 87 |
)
|
| 88 |
|
|
|
|
|
|
|
|
|
|
| 89 |
@async_to_sync_method
|
| 90 |
async def insert(self, read_config: Dict, split_config: Dict):
|
| 91 |
"""
|
|
@@ -272,6 +264,12 @@ class GraphGen:
|
|
| 272 |
)
|
| 273 |
|
| 274 |
# TODO: assert trainee_llm_client is valid before judge
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 275 |
re_judge = quiz_and_judge_config["re_judge"]
|
| 276 |
_update_relations = await judge_statement(
|
| 277 |
self.trainee_llm_client,
|
|
@@ -279,9 +277,16 @@ class GraphGen:
|
|
| 279 |
self.rephrase_storage,
|
| 280 |
re_judge,
|
| 281 |
)
|
|
|
|
| 282 |
await self.rephrase_storage.index_done_callback()
|
| 283 |
await _update_relations.index_done_callback()
|
| 284 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 285 |
@async_to_sync_method
|
| 286 |
async def generate(self, partition_config: Dict, generate_config: Dict):
|
| 287 |
# Step 1: partition the graph
|
|
|
|
| 1 |
import asyncio
|
| 2 |
import os
|
| 3 |
import time
|
|
|
|
| 4 |
from typing import Dict, cast
|
| 5 |
|
| 6 |
import gradio as gr
|
| 7 |
|
| 8 |
+
from graphgen.bases import BaseLLMWrapper
|
| 9 |
from graphgen.bases.base_storage import StorageNameSpace
|
| 10 |
from graphgen.bases.datatypes import Chunk
|
| 11 |
from graphgen.models import (
|
|
|
|
| 20 |
build_text_kg,
|
| 21 |
chunk_documents,
|
| 22 |
generate_qas,
|
| 23 |
+
init_llm,
|
| 24 |
judge_statement,
|
| 25 |
partition_kg,
|
| 26 |
quiz,
|
|
|
|
| 32 |
sys_path = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
| 33 |
|
| 34 |
|
|
|
|
| 35 |
class GraphGen:
|
| 36 |
+
def __init__(
|
| 37 |
+
self,
|
| 38 |
+
unique_id: int = int(time.time()),
|
| 39 |
+
working_dir: str = os.path.join(sys_path, "cache"),
|
| 40 |
+
tokenizer_instance: Tokenizer = None,
|
| 41 |
+
synthesizer_llm_client: OpenAIClient = None,
|
| 42 |
+
trainee_llm_client: OpenAIClient = None,
|
| 43 |
+
progress_bar: gr.Progress = None,
|
| 44 |
+
):
|
| 45 |
+
self.unique_id: int = unique_id
|
| 46 |
+
self.working_dir: str = working_dir
|
| 47 |
+
|
| 48 |
+
# llm
|
| 49 |
+
self.tokenizer_instance: Tokenizer = tokenizer_instance or Tokenizer(
|
| 50 |
model_name=os.getenv("TOKENIZER_MODEL")
|
| 51 |
)
|
| 52 |
|
| 53 |
+
self.synthesizer_llm_client: BaseLLMWrapper = (
|
| 54 |
+
synthesizer_llm_client or init_llm("synthesizer")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 55 |
)
|
| 56 |
+
self.trainee_llm_client: BaseLLMWrapper = trainee_llm_client
|
| 57 |
|
| 58 |
self.full_docs_storage: JsonKVStorage = JsonKVStorage(
|
| 59 |
self.working_dir, namespace="full_docs"
|
|
|
|
| 75 |
namespace="qa",
|
| 76 |
)
|
| 77 |
|
| 78 |
+
# webui
|
| 79 |
+
self.progress_bar: gr.Progress = progress_bar
|
| 80 |
+
|
| 81 |
@async_to_sync_method
|
| 82 |
async def insert(self, read_config: Dict, split_config: Dict):
|
| 83 |
"""
|
|
|
|
| 264 |
)
|
| 265 |
|
| 266 |
# TODO: assert trainee_llm_client is valid before judge
|
| 267 |
+
if not self.trainee_llm_client:
|
| 268 |
+
# TODO: shutdown existing synthesizer_llm_client properly
|
| 269 |
+
logger.info("No trainee LLM client provided, initializing a new one.")
|
| 270 |
+
self.synthesizer_llm_client.shutdown()
|
| 271 |
+
self.trainee_llm_client = init_llm("trainee")
|
| 272 |
+
|
| 273 |
re_judge = quiz_and_judge_config["re_judge"]
|
| 274 |
_update_relations = await judge_statement(
|
| 275 |
self.trainee_llm_client,
|
|
|
|
| 277 |
self.rephrase_storage,
|
| 278 |
re_judge,
|
| 279 |
)
|
| 280 |
+
|
| 281 |
await self.rephrase_storage.index_done_callback()
|
| 282 |
await _update_relations.index_done_callback()
|
| 283 |
|
| 284 |
+
logger.info("Shutting down trainee LLM client.")
|
| 285 |
+
self.trainee_llm_client.shutdown()
|
| 286 |
+
self.trainee_llm_client = None
|
| 287 |
+
logger.info("Restarting synthesizer LLM client.")
|
| 288 |
+
self.synthesizer_llm_client.restart()
|
| 289 |
+
|
| 290 |
@async_to_sync_method
|
| 291 |
async def generate(self, partition_config: Dict, generate_config: Dict):
|
| 292 |
# Step 1: partition the graph
|
graphgen/models/__init__.py
CHANGED
|
@@ -7,8 +7,7 @@ from .generator import (
|
|
| 7 |
VQAGenerator,
|
| 8 |
)
|
| 9 |
from .kg_builder import LightRAGKGBuilder, MMKGBuilder
|
| 10 |
-
from .llm
|
| 11 |
-
from .llm.topk_token_model import TopkTokenModel
|
| 12 |
from .partitioner import (
|
| 13 |
AnchorBFSPartitioner,
|
| 14 |
BFSPartitioner,
|
|
|
|
| 7 |
VQAGenerator,
|
| 8 |
)
|
| 9 |
from .kg_builder import LightRAGKGBuilder, MMKGBuilder
|
| 10 |
+
from .llm import HTTPClient, OllamaClient, OpenAIClient
|
|
|
|
| 11 |
from .partitioner import (
|
| 12 |
AnchorBFSPartitioner,
|
| 13 |
BFSPartitioner,
|
graphgen/models/kg_builder/light_rag_kg_builder.py
CHANGED
|
@@ -2,7 +2,7 @@ import re
|
|
| 2 |
from collections import Counter, defaultdict
|
| 3 |
from typing import Dict, List, Tuple
|
| 4 |
|
| 5 |
-
from graphgen.bases import BaseGraphStorage, BaseKGBuilder,
|
| 6 |
from graphgen.templates import KG_EXTRACTION_PROMPT, KG_SUMMARIZATION_PROMPT
|
| 7 |
from graphgen.utils import (
|
| 8 |
detect_main_language,
|
|
@@ -15,7 +15,7 @@ from graphgen.utils import (
|
|
| 15 |
|
| 16 |
|
| 17 |
class LightRAGKGBuilder(BaseKGBuilder):
|
| 18 |
-
def __init__(self, llm_client:
|
| 19 |
super().__init__(llm_client)
|
| 20 |
self.max_loop = max_loop
|
| 21 |
|
|
|
|
| 2 |
from collections import Counter, defaultdict
|
| 3 |
from typing import Dict, List, Tuple
|
| 4 |
|
| 5 |
+
from graphgen.bases import BaseGraphStorage, BaseKGBuilder, BaseLLMWrapper, Chunk
|
| 6 |
from graphgen.templates import KG_EXTRACTION_PROMPT, KG_SUMMARIZATION_PROMPT
|
| 7 |
from graphgen.utils import (
|
| 8 |
detect_main_language,
|
|
|
|
| 15 |
|
| 16 |
|
| 17 |
class LightRAGKGBuilder(BaseKGBuilder):
|
| 18 |
+
def __init__(self, llm_client: BaseLLMWrapper, max_loop: int = 3):
|
| 19 |
super().__init__(llm_client)
|
| 20 |
self.max_loop = max_loop
|
| 21 |
|
graphgen/models/llm/__init__.py
CHANGED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .api.http_client import HTTPClient
|
| 2 |
+
from .api.ollama_client import OllamaClient
|
| 3 |
+
from .api.openai_client import OpenAIClient
|
| 4 |
+
from .local.hf_wrapper import HuggingFaceWrapper
|
graphgen/models/llm/api/__init__.py
ADDED
|
File without changes
|
graphgen/models/llm/api/http_client.py
ADDED
|
@@ -0,0 +1,197 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import asyncio
|
| 2 |
+
import math
|
| 3 |
+
from typing import Any, Dict, List, Optional
|
| 4 |
+
|
| 5 |
+
import aiohttp
|
| 6 |
+
from tenacity import (
|
| 7 |
+
retry,
|
| 8 |
+
retry_if_exception_type,
|
| 9 |
+
stop_after_attempt,
|
| 10 |
+
wait_exponential,
|
| 11 |
+
)
|
| 12 |
+
|
| 13 |
+
from graphgen.bases.base_llm_wrapper import BaseLLMWrapper
|
| 14 |
+
from graphgen.bases.datatypes import Token
|
| 15 |
+
from graphgen.models.llm.limitter import RPM, TPM
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class HTTPClient(BaseLLMWrapper):
|
| 19 |
+
"""
|
| 20 |
+
A generic async HTTP client for LLMs compatible with OpenAI's chat/completions format.
|
| 21 |
+
It uses aiohttp for making requests and includes retry logic and token usage tracking.
|
| 22 |
+
Usage example:
|
| 23 |
+
client = HTTPClient(
|
| 24 |
+
model_name="gpt-4o-mini",
|
| 25 |
+
base_url="http://localhost:8080",
|
| 26 |
+
api_key="your_api_key",
|
| 27 |
+
json_mode=True,
|
| 28 |
+
seed=42,
|
| 29 |
+
topk_per_token=5,
|
| 30 |
+
request_limit=True,
|
| 31 |
+
)
|
| 32 |
+
|
| 33 |
+
answer = await client.generate_answer("Hello, world!")
|
| 34 |
+
tokens = await client.generate_topk_per_token("Hello, world!")
|
| 35 |
+
"""
|
| 36 |
+
|
| 37 |
+
_instance: Optional["HTTPClient"] = None
|
| 38 |
+
_lock = asyncio.Lock()
|
| 39 |
+
|
| 40 |
+
def __new__(cls, **kwargs):
|
| 41 |
+
if cls._instance is None:
|
| 42 |
+
cls._instance = super().__new__(cls)
|
| 43 |
+
return cls._instance
|
| 44 |
+
|
| 45 |
+
def __init__(
|
| 46 |
+
self,
|
| 47 |
+
*,
|
| 48 |
+
model: str,
|
| 49 |
+
base_url: str,
|
| 50 |
+
api_key: Optional[str] = None,
|
| 51 |
+
json_mode: bool = False,
|
| 52 |
+
seed: Optional[int] = None,
|
| 53 |
+
topk_per_token: int = 5,
|
| 54 |
+
request_limit: bool = False,
|
| 55 |
+
rpm: Optional[RPM] = None,
|
| 56 |
+
tpm: Optional[TPM] = None,
|
| 57 |
+
**kwargs: Any,
|
| 58 |
+
):
|
| 59 |
+
# Initialize only once in the singleton pattern
|
| 60 |
+
if getattr(self, "_initialized", False):
|
| 61 |
+
return
|
| 62 |
+
self._initialized: bool = True
|
| 63 |
+
super().__init__(**kwargs)
|
| 64 |
+
self.model_name = model
|
| 65 |
+
self.base_url = base_url.rstrip("/")
|
| 66 |
+
self.api_key = api_key
|
| 67 |
+
self.json_mode = json_mode
|
| 68 |
+
self.seed = seed
|
| 69 |
+
self.topk_per_token = topk_per_token
|
| 70 |
+
self.request_limit = request_limit
|
| 71 |
+
self.rpm = rpm or RPM()
|
| 72 |
+
self.tpm = tpm or TPM()
|
| 73 |
+
|
| 74 |
+
self.token_usage: List[Dict[str, int]] = []
|
| 75 |
+
self._session: Optional[aiohttp.ClientSession] = None
|
| 76 |
+
|
| 77 |
+
@property
|
| 78 |
+
def session(self) -> aiohttp.ClientSession:
|
| 79 |
+
if self._session is None or self._session.closed:
|
| 80 |
+
headers = (
|
| 81 |
+
{"Authorization": f"Bearer {self.api_key}"} if self.api_key else {}
|
| 82 |
+
)
|
| 83 |
+
self._session = aiohttp.ClientSession(headers=headers)
|
| 84 |
+
return self._session
|
| 85 |
+
|
| 86 |
+
async def close(self):
|
| 87 |
+
if self._session and not self._session.closed:
|
| 88 |
+
await self._session.close()
|
| 89 |
+
|
| 90 |
+
def _build_body(self, text: str, history: List[str]) -> Dict[str, Any]:
|
| 91 |
+
messages = []
|
| 92 |
+
if self.system_prompt:
|
| 93 |
+
messages.append({"role": "system", "content": self.system_prompt})
|
| 94 |
+
|
| 95 |
+
# chatml format: alternating user and assistant messages
|
| 96 |
+
if history and isinstance(history[0], dict):
|
| 97 |
+
messages.extend(history)
|
| 98 |
+
|
| 99 |
+
messages.append({"role": "user", "content": text})
|
| 100 |
+
|
| 101 |
+
body = {
|
| 102 |
+
"model": self.model_name,
|
| 103 |
+
"messages": messages,
|
| 104 |
+
"temperature": self.temperature,
|
| 105 |
+
"top_p": self.top_p,
|
| 106 |
+
"max_tokens": self.max_tokens,
|
| 107 |
+
}
|
| 108 |
+
if self.seed:
|
| 109 |
+
body["seed"] = self.seed
|
| 110 |
+
if self.json_mode:
|
| 111 |
+
body["response_format"] = {"type": "json_object"}
|
| 112 |
+
return body
|
| 113 |
+
|
| 114 |
+
@retry(
|
| 115 |
+
stop=stop_after_attempt(5),
|
| 116 |
+
wait=wait_exponential(multiplier=1, min=4, max=10),
|
| 117 |
+
retry=retry_if_exception_type((aiohttp.ClientError, asyncio.TimeoutError)),
|
| 118 |
+
)
|
| 119 |
+
async def generate_answer(
|
| 120 |
+
self,
|
| 121 |
+
text: str,
|
| 122 |
+
history: Optional[List[str]] = None,
|
| 123 |
+
**extra: Any,
|
| 124 |
+
) -> str:
|
| 125 |
+
body = self._build_body(text, history or [])
|
| 126 |
+
prompt_tokens = sum(
|
| 127 |
+
len(self.tokenizer.encode(m["content"])) for m in body["messages"]
|
| 128 |
+
)
|
| 129 |
+
est = prompt_tokens + body["max_tokens"]
|
| 130 |
+
|
| 131 |
+
if self.request_limit:
|
| 132 |
+
await self.rpm.wait(silent=True)
|
| 133 |
+
await self.tpm.wait(est, silent=True)
|
| 134 |
+
|
| 135 |
+
async with self.session.post(
|
| 136 |
+
f"{self.base_url}/chat/completions",
|
| 137 |
+
json=body,
|
| 138 |
+
timeout=aiohttp.ClientTimeout(total=60),
|
| 139 |
+
) as resp:
|
| 140 |
+
resp.raise_for_status()
|
| 141 |
+
data = await resp.json()
|
| 142 |
+
|
| 143 |
+
msg = data["choices"][0]["message"]["content"]
|
| 144 |
+
if "usage" in data:
|
| 145 |
+
self.token_usage.append(
|
| 146 |
+
{
|
| 147 |
+
"prompt_tokens": data["usage"]["prompt_tokens"],
|
| 148 |
+
"completion_tokens": data["usage"]["completion_tokens"],
|
| 149 |
+
"total_tokens": data["usage"]["total_tokens"],
|
| 150 |
+
}
|
| 151 |
+
)
|
| 152 |
+
return self.filter_think_tags(msg)
|
| 153 |
+
|
| 154 |
+
@retry(
|
| 155 |
+
stop=stop_after_attempt(5),
|
| 156 |
+
wait=wait_exponential(multiplier=1, min=4, max=10),
|
| 157 |
+
retry=retry_if_exception_type((aiohttp.ClientError, asyncio.TimeoutError)),
|
| 158 |
+
)
|
| 159 |
+
async def generate_topk_per_token(
|
| 160 |
+
self,
|
| 161 |
+
text: str,
|
| 162 |
+
history: Optional[List[str]] = None,
|
| 163 |
+
**extra: Any,
|
| 164 |
+
) -> List[Token]:
|
| 165 |
+
body = self._build_body(text, history or [])
|
| 166 |
+
body["max_tokens"] = 1
|
| 167 |
+
if self.topk_per_token > 0:
|
| 168 |
+
body["logprobs"] = True
|
| 169 |
+
body["top_logprobs"] = self.topk_per_token
|
| 170 |
+
|
| 171 |
+
async with self.session.post(
|
| 172 |
+
f"{self.base_url}/chat/completions",
|
| 173 |
+
json=body,
|
| 174 |
+
timeout=aiohttp.ClientTimeout(total=60),
|
| 175 |
+
) as resp:
|
| 176 |
+
resp.raise_for_status()
|
| 177 |
+
data = await resp.json()
|
| 178 |
+
|
| 179 |
+
token_logprobs = data["choices"][0]["logprobs"]["content"]
|
| 180 |
+
tokens = []
|
| 181 |
+
for item in token_logprobs:
|
| 182 |
+
candidates = [
|
| 183 |
+
Token(t["token"], math.exp(t["logprob"])) for t in item["top_logprobs"]
|
| 184 |
+
]
|
| 185 |
+
tokens.append(
|
| 186 |
+
Token(
|
| 187 |
+
item["token"], math.exp(item["logprob"]), top_candidates=candidates
|
| 188 |
+
)
|
| 189 |
+
)
|
| 190 |
+
return tokens
|
| 191 |
+
|
| 192 |
+
async def generate_inputs_prob(
|
| 193 |
+
self, text: str, history: Optional[List[str]] = None, **extra: Any
|
| 194 |
+
) -> List[Token]:
|
| 195 |
+
raise NotImplementedError(
|
| 196 |
+
"generate_inputs_prob is not implemented in HTTPClient"
|
| 197 |
+
)
|
graphgen/models/llm/api/ollama_client.py
ADDED
|
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Any, Dict, List, Optional
|
| 2 |
+
|
| 3 |
+
from graphgen.bases.base_llm_wrapper import BaseLLMWrapper
|
| 4 |
+
from graphgen.bases.datatypes import Token
|
| 5 |
+
from graphgen.models.llm.limitter import RPM, TPM
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class OllamaClient(BaseLLMWrapper):
|
| 9 |
+
"""
|
| 10 |
+
Requires a local or remote Ollama server to be running (default port 11434).
|
| 11 |
+
The top_logprobs field is not yet implemented by the official API.
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
def __init__(
|
| 15 |
+
self,
|
| 16 |
+
*,
|
| 17 |
+
model: str = "gemma3",
|
| 18 |
+
base_url: str = "http://localhost:11434",
|
| 19 |
+
json_mode: bool = False,
|
| 20 |
+
seed: Optional[int] = None,
|
| 21 |
+
topk_per_token: int = 5,
|
| 22 |
+
request_limit: bool = False,
|
| 23 |
+
rpm: Optional[RPM] = None,
|
| 24 |
+
tpm: Optional[TPM] = None,
|
| 25 |
+
**kwargs: Any,
|
| 26 |
+
):
|
| 27 |
+
try:
|
| 28 |
+
import ollama
|
| 29 |
+
except ImportError as e:
|
| 30 |
+
raise ImportError(
|
| 31 |
+
"Ollama SDK is not installed."
|
| 32 |
+
"It is required to use OllamaClient."
|
| 33 |
+
"Please install it with `pip install ollama`."
|
| 34 |
+
) from e
|
| 35 |
+
super().__init__(**kwargs)
|
| 36 |
+
self.model_name = model
|
| 37 |
+
self.base_url = base_url
|
| 38 |
+
self.json_mode = json_mode
|
| 39 |
+
self.seed = seed
|
| 40 |
+
self.topk_per_token = topk_per_token
|
| 41 |
+
self.request_limit = request_limit
|
| 42 |
+
self.rpm = rpm or RPM()
|
| 43 |
+
self.tpm = tpm or TPM()
|
| 44 |
+
self.token_usage: List[Dict[str, int]] = []
|
| 45 |
+
|
| 46 |
+
self.client = ollama.AsyncClient(host=self.base_url)
|
| 47 |
+
|
| 48 |
+
async def generate_answer(
|
| 49 |
+
self,
|
| 50 |
+
text: str,
|
| 51 |
+
history: Optional[List[Dict[str, str]]] = None,
|
| 52 |
+
**extra: Any,
|
| 53 |
+
) -> str:
|
| 54 |
+
messages = []
|
| 55 |
+
if self.system_prompt:
|
| 56 |
+
messages.append({"role": "system", "content": self.system_prompt})
|
| 57 |
+
if history:
|
| 58 |
+
messages.extend(history)
|
| 59 |
+
messages.append({"role": "user", "content": text})
|
| 60 |
+
|
| 61 |
+
options = {
|
| 62 |
+
"temperature": self.temperature,
|
| 63 |
+
"top_p": self.top_p,
|
| 64 |
+
"num_predict": self.max_tokens,
|
| 65 |
+
}
|
| 66 |
+
if self.seed is not None:
|
| 67 |
+
options["seed"] = self.seed
|
| 68 |
+
|
| 69 |
+
prompt_tokens = sum(len(self.tokenizer.encode(m["content"])) for m in messages)
|
| 70 |
+
est = prompt_tokens + self.max_tokens
|
| 71 |
+
if self.request_limit:
|
| 72 |
+
await self.rpm.wait(silent=True)
|
| 73 |
+
await self.tpm.wait(est, silent=True)
|
| 74 |
+
|
| 75 |
+
response = await self.client.chat(
|
| 76 |
+
model=self.model_name,
|
| 77 |
+
messages=messages,
|
| 78 |
+
format="json" if self.json_mode else "",
|
| 79 |
+
options=options,
|
| 80 |
+
stream=False,
|
| 81 |
+
)
|
| 82 |
+
|
| 83 |
+
usage = response.get("prompt_eval_count", 0), response.get("eval_count", 0)
|
| 84 |
+
self.token_usage.append(
|
| 85 |
+
{
|
| 86 |
+
"prompt_tokens": usage[0],
|
| 87 |
+
"completion_tokens": usage[1],
|
| 88 |
+
"total_tokens": sum(usage),
|
| 89 |
+
}
|
| 90 |
+
)
|
| 91 |
+
content = response["message"]["content"]
|
| 92 |
+
return self.filter_think_tags(content)
|
| 93 |
+
|
| 94 |
+
async def generate_topk_per_token(
|
| 95 |
+
self,
|
| 96 |
+
text: str,
|
| 97 |
+
history: Optional[List[Dict[str, str]]] = None,
|
| 98 |
+
**extra: Any,
|
| 99 |
+
) -> List[Token]:
|
| 100 |
+
raise NotImplementedError("Ollama API does not support per-token top-k yet.")
|
| 101 |
+
|
| 102 |
+
async def generate_inputs_prob(
|
| 103 |
+
self, text: str, history: Optional[List[Dict[str, str]]] = None, **extra: Any
|
| 104 |
+
) -> List[Token]:
|
| 105 |
+
raise NotImplementedError("Ollama API does not support per-token logprobs yet.")
|
graphgen/models/llm/{openai_client.py → api/openai_client.py}
RENAMED
|
@@ -10,7 +10,7 @@ from tenacity import (
|
|
| 10 |
wait_exponential,
|
| 11 |
)
|
| 12 |
|
| 13 |
-
from graphgen.bases.
|
| 14 |
from graphgen.bases.datatypes import Token
|
| 15 |
from graphgen.models.llm.limitter import RPM, TPM
|
| 16 |
|
|
@@ -28,7 +28,7 @@ def get_top_response_tokens(response: openai.ChatCompletion) -> List[Token]:
|
|
| 28 |
return tokens
|
| 29 |
|
| 30 |
|
| 31 |
-
class OpenAIClient(
|
| 32 |
def __init__(
|
| 33 |
self,
|
| 34 |
*,
|
|
@@ -105,8 +105,8 @@ class OpenAIClient(BaseLLMClient):
|
|
| 105 |
kwargs["logprobs"] = True
|
| 106 |
kwargs["top_logprobs"] = self.topk_per_token
|
| 107 |
|
| 108 |
-
# Limit max_tokens to
|
| 109 |
-
kwargs["max_tokens"] =
|
| 110 |
|
| 111 |
completion = await self.client.chat.completions.create( # pylint: disable=E1125
|
| 112 |
model=self.model_name, **kwargs
|
|
|
|
| 10 |
wait_exponential,
|
| 11 |
)
|
| 12 |
|
| 13 |
+
from graphgen.bases.base_llm_wrapper import BaseLLMWrapper
|
| 14 |
from graphgen.bases.datatypes import Token
|
| 15 |
from graphgen.models.llm.limitter import RPM, TPM
|
| 16 |
|
|
|
|
| 28 |
return tokens
|
| 29 |
|
| 30 |
|
| 31 |
+
class OpenAIClient(BaseLLMWrapper):
|
| 32 |
def __init__(
|
| 33 |
self,
|
| 34 |
*,
|
|
|
|
| 105 |
kwargs["logprobs"] = True
|
| 106 |
kwargs["top_logprobs"] = self.topk_per_token
|
| 107 |
|
| 108 |
+
# Limit max_tokens to 1 to avoid long completions
|
| 109 |
+
kwargs["max_tokens"] = 1
|
| 110 |
|
| 111 |
completion = await self.client.chat.completions.create( # pylint: disable=E1125
|
| 112 |
model=self.model_name, **kwargs
|
graphgen/models/llm/local/__init__.py
ADDED
|
File without changes
|
graphgen/models/llm/local/hf_wrapper.py
ADDED
|
@@ -0,0 +1,147 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Any, List, Optional
|
| 2 |
+
|
| 3 |
+
from graphgen.bases.base_llm_wrapper import BaseLLMWrapper
|
| 4 |
+
from graphgen.bases.datatypes import Token
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class HuggingFaceWrapper(BaseLLMWrapper):
|
| 8 |
+
"""
|
| 9 |
+
Async inference backend based on HuggingFace Transformers
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
def __init__(
|
| 13 |
+
self,
|
| 14 |
+
model: str,
|
| 15 |
+
torch_dtype="auto",
|
| 16 |
+
device_map="auto",
|
| 17 |
+
trust_remote_code=True,
|
| 18 |
+
temperature=0.0,
|
| 19 |
+
top_p=1.0,
|
| 20 |
+
topk=5,
|
| 21 |
+
**kwargs: Any,
|
| 22 |
+
):
|
| 23 |
+
super().__init__(temperature=temperature, top_p=top_p, **kwargs)
|
| 24 |
+
|
| 25 |
+
try:
|
| 26 |
+
import torch
|
| 27 |
+
from transformers import (
|
| 28 |
+
AutoModelForCausalLM,
|
| 29 |
+
AutoTokenizer,
|
| 30 |
+
GenerationConfig,
|
| 31 |
+
)
|
| 32 |
+
except ImportError as exc:
|
| 33 |
+
raise ImportError(
|
| 34 |
+
"HuggingFaceWrapper requires torch, transformers and accelerate. "
|
| 35 |
+
"Install them with: pip install torch transformers accelerate"
|
| 36 |
+
) from exc
|
| 37 |
+
|
| 38 |
+
self.torch = torch
|
| 39 |
+
self.AutoTokenizer = AutoTokenizer
|
| 40 |
+
self.AutoModelForCausalLM = AutoModelForCausalLM
|
| 41 |
+
self.GenerationConfig = GenerationConfig
|
| 42 |
+
|
| 43 |
+
self.tokenizer = AutoTokenizer.from_pretrained(
|
| 44 |
+
model, trust_remote_code=trust_remote_code
|
| 45 |
+
)
|
| 46 |
+
if self.tokenizer.pad_token is None:
|
| 47 |
+
self.tokenizer.pad_token = self.tokenizer.eos_token
|
| 48 |
+
|
| 49 |
+
self.model = AutoModelForCausalLM.from_pretrained(
|
| 50 |
+
model,
|
| 51 |
+
torch_dtype=torch_dtype,
|
| 52 |
+
device_map=device_map,
|
| 53 |
+
trust_remote_code=trust_remote_code,
|
| 54 |
+
)
|
| 55 |
+
self.model.eval()
|
| 56 |
+
self.temperature = temperature
|
| 57 |
+
self.top_p = top_p
|
| 58 |
+
self.topk = topk
|
| 59 |
+
|
| 60 |
+
@staticmethod
|
| 61 |
+
def _build_inputs(prompt: str, history: Optional[List[str]] = None) -> str:
|
| 62 |
+
msgs = history or []
|
| 63 |
+
lines = []
|
| 64 |
+
for m in msgs:
|
| 65 |
+
if isinstance(m, dict):
|
| 66 |
+
role = m.get("role", "")
|
| 67 |
+
content = m.get("content", "")
|
| 68 |
+
lines.append(f"{role}: {content}")
|
| 69 |
+
else:
|
| 70 |
+
lines.append(str(m))
|
| 71 |
+
lines.append(prompt)
|
| 72 |
+
return "\n".join(lines)
|
| 73 |
+
|
| 74 |
+
async def generate_answer(
|
| 75 |
+
self, text: str, history: Optional[List[str]] = None, **extra: Any
|
| 76 |
+
) -> str:
|
| 77 |
+
full = self._build_inputs(text, history)
|
| 78 |
+
inputs = self.tokenizer(full, return_tensors="pt").to(self.model.device)
|
| 79 |
+
|
| 80 |
+
gen_kwargs = {
|
| 81 |
+
"max_new_tokens": extra.get("max_new_tokens", 512),
|
| 82 |
+
"do_sample": self.temperature > 0,
|
| 83 |
+
"temperature": self.temperature if self.temperature > 0 else 1.0,
|
| 84 |
+
"pad_token_id": self.tokenizer.eos_token_id,
|
| 85 |
+
}
|
| 86 |
+
|
| 87 |
+
# Add top_p and top_k only if temperature > 0
|
| 88 |
+
if self.temperature > 0:
|
| 89 |
+
gen_kwargs.update(top_p=self.top_p, top_k=self.topk)
|
| 90 |
+
|
| 91 |
+
gen_config = self.GenerationConfig(**gen_kwargs)
|
| 92 |
+
|
| 93 |
+
with self.torch.no_grad():
|
| 94 |
+
out = self.model.generate(**inputs, generation_config=gen_config)
|
| 95 |
+
|
| 96 |
+
gen = out[0, inputs.input_ids.shape[-1] :]
|
| 97 |
+
return self.tokenizer.decode(gen, skip_special_tokens=True)
|
| 98 |
+
|
| 99 |
+
async def generate_topk_per_token(
|
| 100 |
+
self, text: str, history: Optional[List[str]] = None, **extra: Any
|
| 101 |
+
) -> List[Token]:
|
| 102 |
+
full = self._build_inputs(text, history)
|
| 103 |
+
inputs = self.tokenizer(full, return_tensors="pt").to(self.model.device)
|
| 104 |
+
|
| 105 |
+
with self.torch.no_grad():
|
| 106 |
+
out = self.model.generate(
|
| 107 |
+
**inputs,
|
| 108 |
+
max_new_tokens=1,
|
| 109 |
+
do_sample=False,
|
| 110 |
+
temperature=1.0,
|
| 111 |
+
return_dict_in_generate=True,
|
| 112 |
+
output_scores=True,
|
| 113 |
+
pad_token_id=self.tokenizer.eos_token_id,
|
| 114 |
+
)
|
| 115 |
+
|
| 116 |
+
scores = out.scores[0][0] # (vocab,)
|
| 117 |
+
probs = self.torch.softmax(scores, dim=-1)
|
| 118 |
+
top_probs, top_idx = self.torch.topk(probs, k=self.topk)
|
| 119 |
+
|
| 120 |
+
tokens = []
|
| 121 |
+
for p, idx in zip(top_probs.cpu().numpy(), top_idx.cpu().numpy()):
|
| 122 |
+
tokens.append(Token(self.tokenizer.decode([idx]), float(p)))
|
| 123 |
+
return tokens
|
| 124 |
+
|
| 125 |
+
async def generate_inputs_prob(
|
| 126 |
+
self, text: str, history: Optional[List[str]] = None, **extra: Any
|
| 127 |
+
) -> List[Token]:
|
| 128 |
+
full = self._build_inputs(text, history)
|
| 129 |
+
ids = self.tokenizer.encode(full)
|
| 130 |
+
logprobs = []
|
| 131 |
+
|
| 132 |
+
for i in range(1, len(ids) + 1):
|
| 133 |
+
trunc = ids[: i - 1] + ids[i:] if i < len(ids) else ids[:-1]
|
| 134 |
+
inputs = self.torch.tensor([trunc]).to(self.model.device)
|
| 135 |
+
|
| 136 |
+
with self.torch.no_grad():
|
| 137 |
+
logits = self.model(inputs).logits[0, -1, :]
|
| 138 |
+
probs = self.torch.softmax(logits, dim=-1)
|
| 139 |
+
|
| 140 |
+
true_id = ids[i - 1]
|
| 141 |
+
logprobs.append(
|
| 142 |
+
Token(
|
| 143 |
+
self.tokenizer.decode([true_id]),
|
| 144 |
+
float(probs[true_id].cpu()),
|
| 145 |
+
)
|
| 146 |
+
)
|
| 147 |
+
return logprobs
|
graphgen/models/llm/local/sglang_wrapper.py
ADDED
|
@@ -0,0 +1,148 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
from typing import Any, Dict, List, Optional
|
| 3 |
+
|
| 4 |
+
from graphgen.bases.base_llm_wrapper import BaseLLMWrapper
|
| 5 |
+
from graphgen.bases.datatypes import Token
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class SGLangWrapper(BaseLLMWrapper):
|
| 9 |
+
"""
|
| 10 |
+
Async inference backend based on SGLang offline engine.
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
def __init__(
|
| 14 |
+
self,
|
| 15 |
+
model: str,
|
| 16 |
+
temperature: float = 0.0,
|
| 17 |
+
top_p: float = 1.0,
|
| 18 |
+
topk: int = 5,
|
| 19 |
+
**kwargs: Any,
|
| 20 |
+
):
|
| 21 |
+
super().__init__(temperature=temperature, top_p=top_p, **kwargs)
|
| 22 |
+
try:
|
| 23 |
+
import sglang as sgl
|
| 24 |
+
from sglang.utils import async_stream_and_merge, stream_and_merge
|
| 25 |
+
except ImportError as exc:
|
| 26 |
+
raise ImportError(
|
| 27 |
+
"SGLangWrapper requires sglang. Install it with: "
|
| 28 |
+
"uv pip install sglang --prerelease=allow"
|
| 29 |
+
) from exc
|
| 30 |
+
|
| 31 |
+
self.model_path: str = model
|
| 32 |
+
self.temperature = temperature
|
| 33 |
+
self.top_p = top_p
|
| 34 |
+
self.topk = topk
|
| 35 |
+
|
| 36 |
+
# Initialise the offline engine
|
| 37 |
+
self.engine = sgl.Engine(model_path=self.model_path)
|
| 38 |
+
|
| 39 |
+
# Keep helpers for streaming
|
| 40 |
+
self.async_stream_and_merge = async_stream_and_merge
|
| 41 |
+
self.stream_and_merge = stream_and_merge
|
| 42 |
+
|
| 43 |
+
@staticmethod
|
| 44 |
+
def _build_sampling_params(
|
| 45 |
+
temperature: float,
|
| 46 |
+
top_p: float,
|
| 47 |
+
max_tokens: int,
|
| 48 |
+
topk: int,
|
| 49 |
+
logprobs: bool = False,
|
| 50 |
+
) -> Dict[str, Any]:
|
| 51 |
+
"""Build SGLang-compatible sampling-params dict."""
|
| 52 |
+
params = {
|
| 53 |
+
"temperature": temperature,
|
| 54 |
+
"top_p": top_p,
|
| 55 |
+
"max_new_tokens": max_tokens,
|
| 56 |
+
}
|
| 57 |
+
if logprobs and topk > 0:
|
| 58 |
+
params["logprobs"] = topk
|
| 59 |
+
return params
|
| 60 |
+
|
| 61 |
+
def _prep_prompt(self, text: str, history: Optional[List[dict]] = None) -> str:
|
| 62 |
+
"""Convert raw text (+ optional history) into a single prompt string."""
|
| 63 |
+
parts = []
|
| 64 |
+
if self.system_prompt:
|
| 65 |
+
parts.append(self.system_prompt)
|
| 66 |
+
if history:
|
| 67 |
+
assert len(history) % 2 == 0, "History must have even length (u/a turns)."
|
| 68 |
+
parts.extend([item["content"] for item in history])
|
| 69 |
+
parts.append(text)
|
| 70 |
+
return "\n".join(parts)
|
| 71 |
+
|
| 72 |
+
def _tokens_from_output(self, output: Dict[str, Any]) -> List[Token]:
|
| 73 |
+
tokens: List[Token] = []
|
| 74 |
+
|
| 75 |
+
meta = output.get("meta_info", {})
|
| 76 |
+
logprobs = meta.get("output_token_logprobs", [])
|
| 77 |
+
topks = meta.get("output_top_logprobs", [])
|
| 78 |
+
|
| 79 |
+
tokenizer = self.engine.tokenizer_manager.tokenizer
|
| 80 |
+
|
| 81 |
+
for idx, (lp, tid, _) in enumerate(logprobs):
|
| 82 |
+
prob = math.exp(lp)
|
| 83 |
+
tok_str = tokenizer.decode([tid])
|
| 84 |
+
|
| 85 |
+
top_candidates = []
|
| 86 |
+
if self.topk > 0 and idx < len(topks):
|
| 87 |
+
for t_lp, t_tid, _ in topks[idx][: self.topk]:
|
| 88 |
+
top_candidates.append(
|
| 89 |
+
Token(text=tokenizer.decode([t_tid]), prob=math.exp(t_lp))
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
tokens.append(Token(text=tok_str, prob=prob, top_candidates=top_candidates))
|
| 93 |
+
|
| 94 |
+
return tokens
|
| 95 |
+
|
| 96 |
+
async def generate_answer(
|
| 97 |
+
self,
|
| 98 |
+
text: str,
|
| 99 |
+
history: Optional[List[str]] = None,
|
| 100 |
+
**extra: Any,
|
| 101 |
+
) -> str:
|
| 102 |
+
prompt = self._prep_prompt(text, history)
|
| 103 |
+
sampling_params = self._build_sampling_params(
|
| 104 |
+
temperature=self.temperature,
|
| 105 |
+
top_p=self.top_p,
|
| 106 |
+
max_tokens=self.max_tokens,
|
| 107 |
+
topk=0, # no logprobs needed for simple generation
|
| 108 |
+
)
|
| 109 |
+
|
| 110 |
+
outputs = await self.engine.async_generate([prompt], sampling_params)
|
| 111 |
+
return self.filter_think_tags(outputs[0]["text"])
|
| 112 |
+
|
| 113 |
+
async def generate_topk_per_token(
|
| 114 |
+
self,
|
| 115 |
+
text: str,
|
| 116 |
+
history: Optional[List[str]] = None,
|
| 117 |
+
**extra: Any,
|
| 118 |
+
) -> List[Token]:
|
| 119 |
+
prompt = self._prep_prompt(text, history)
|
| 120 |
+
sampling_params = self._build_sampling_params(
|
| 121 |
+
temperature=self.temperature,
|
| 122 |
+
top_p=self.top_p,
|
| 123 |
+
max_tokens=1, # keep short for token-level analysis
|
| 124 |
+
topk=self.topk,
|
| 125 |
+
)
|
| 126 |
+
|
| 127 |
+
outputs = await self.engine.async_generate(
|
| 128 |
+
[prompt], sampling_params, return_logprob=True, top_logprobs_num=5
|
| 129 |
+
)
|
| 130 |
+
print(outputs)
|
| 131 |
+
return self._tokens_from_output(outputs[0])
|
| 132 |
+
|
| 133 |
+
async def generate_inputs_prob(
|
| 134 |
+
self, text: str, history: Optional[List[str]] = None, **extra: Any
|
| 135 |
+
) -> List[Token]:
|
| 136 |
+
raise NotImplementedError(
|
| 137 |
+
"SGLangWrapper does not support per-token logprobs yet."
|
| 138 |
+
)
|
| 139 |
+
|
| 140 |
+
def shutdown(self) -> None:
|
| 141 |
+
"""Gracefully shutdown the SGLang engine."""
|
| 142 |
+
if hasattr(self, "engine"):
|
| 143 |
+
self.engine.shutdown()
|
| 144 |
+
|
| 145 |
+
def restart(self) -> None:
|
| 146 |
+
"""Restart the SGLang engine."""
|
| 147 |
+
self.shutdown()
|
| 148 |
+
self.engine = self.engine.__class__(model_path=self.model_path)
|
graphgen/models/llm/local/tgi_wrapper.py
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Any, List, Optional
|
| 2 |
+
|
| 3 |
+
from graphgen.bases import BaseLLMWrapper
|
| 4 |
+
from graphgen.bases.datatypes import Token
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
# TODO: implement TGIWrapper methods
|
| 8 |
+
class TGIWrapper(BaseLLMWrapper):
|
| 9 |
+
"""
|
| 10 |
+
Async inference backend based on TGI (Text-Generation-Inference)
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
def __init__(
|
| 14 |
+
self,
|
| 15 |
+
model_url: str, # e.g. "http://localhost:8080"
|
| 16 |
+
temperature: float = 0.0,
|
| 17 |
+
top_p: float = 1.0,
|
| 18 |
+
topk: int = 5,
|
| 19 |
+
**kwargs: Any
|
| 20 |
+
):
|
| 21 |
+
super().__init__(temperature=temperature, top_p=top_p, **kwargs)
|
| 22 |
+
|
| 23 |
+
async def generate_answer(
|
| 24 |
+
self, text: str, history: Optional[List[str]] = None, **extra: Any
|
| 25 |
+
) -> str:
|
| 26 |
+
pass
|
| 27 |
+
|
| 28 |
+
async def generate_topk_per_token(
|
| 29 |
+
self, text: str, history: Optional[List[str]] = None, **extra: Any
|
| 30 |
+
) -> List[Token]:
|
| 31 |
+
pass
|
| 32 |
+
|
| 33 |
+
async def generate_inputs_prob(
|
| 34 |
+
self, text: str, history: Optional[List[str]] = None, **extra: Any
|
| 35 |
+
) -> List[Token]:
|
| 36 |
+
pass
|
graphgen/models/llm/{ollama_client.py → local/trt_wrapper.py}
RENAMED
|
@@ -1,10 +1,15 @@
|
|
| 1 |
-
# TODO: implement ollama client
|
| 2 |
from typing import Any, List, Optional
|
| 3 |
|
| 4 |
-
from graphgen.bases import
|
|
|
|
| 5 |
|
| 6 |
|
| 7 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
async def generate_answer(
|
| 9 |
self, text: str, history: Optional[List[str]] = None, **extra: Any
|
| 10 |
) -> str:
|
|
|
|
|
|
|
| 1 |
from typing import Any, List, Optional
|
| 2 |
|
| 3 |
+
from graphgen.bases import BaseLLMWrapper
|
| 4 |
+
from graphgen.bases.datatypes import Token
|
| 5 |
|
| 6 |
|
| 7 |
+
# TODO: implement TensorRTWrapper methods
|
| 8 |
+
class TensorRTWrapper(BaseLLMWrapper):
|
| 9 |
+
"""
|
| 10 |
+
Async inference backend based on TensorRT-LLM
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
async def generate_answer(
|
| 14 |
self, text: str, history: Optional[List[str]] = None, **extra: Any
|
| 15 |
) -> str:
|
graphgen/models/llm/local/vllm_wrapper.py
ADDED
|
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Any, List, Optional
|
| 2 |
+
|
| 3 |
+
from graphgen.bases.base_llm_wrapper import BaseLLMWrapper
|
| 4 |
+
from graphgen.bases.datatypes import Token
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class VLLMWrapper(BaseLLMWrapper):
|
| 8 |
+
"""
|
| 9 |
+
Async inference backend based on vLLM (https://github.com/vllm-project/vllm)
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
def __init__(
|
| 13 |
+
self,
|
| 14 |
+
model: str,
|
| 15 |
+
tensor_parallel_size: int = 1,
|
| 16 |
+
gpu_memory_utilization: float = 0.9,
|
| 17 |
+
temperature: float = 0.0,
|
| 18 |
+
top_p: float = 1.0,
|
| 19 |
+
topk: int = 5,
|
| 20 |
+
**kwargs: Any,
|
| 21 |
+
):
|
| 22 |
+
super().__init__(temperature=temperature, top_p=top_p, **kwargs)
|
| 23 |
+
|
| 24 |
+
try:
|
| 25 |
+
from vllm import AsyncEngineArgs, AsyncLLMEngine, SamplingParams
|
| 26 |
+
except ImportError as exc:
|
| 27 |
+
raise ImportError(
|
| 28 |
+
"VLLMWrapper requires vllm. Install it with: uv pip install vllm --torch-backend=auto"
|
| 29 |
+
) from exc
|
| 30 |
+
|
| 31 |
+
self.SamplingParams = SamplingParams
|
| 32 |
+
|
| 33 |
+
engine_args = AsyncEngineArgs(
|
| 34 |
+
model=model,
|
| 35 |
+
tensor_parallel_size=tensor_parallel_size,
|
| 36 |
+
gpu_memory_utilization=gpu_memory_utilization,
|
| 37 |
+
trust_remote_code=kwargs.get("trust_remote_code", True),
|
| 38 |
+
)
|
| 39 |
+
self.engine = AsyncLLMEngine.from_engine_args(engine_args)
|
| 40 |
+
|
| 41 |
+
self.temperature = temperature
|
| 42 |
+
self.top_p = top_p
|
| 43 |
+
self.topk = topk
|
| 44 |
+
|
| 45 |
+
@staticmethod
|
| 46 |
+
def _build_inputs(prompt: str, history: Optional[List[str]] = None) -> str:
|
| 47 |
+
msgs = history or []
|
| 48 |
+
lines = []
|
| 49 |
+
for m in msgs:
|
| 50 |
+
if isinstance(m, dict):
|
| 51 |
+
role = m.get("role", "")
|
| 52 |
+
content = m.get("content", "")
|
| 53 |
+
lines.append(f"{role}: {content}")
|
| 54 |
+
else:
|
| 55 |
+
lines.append(str(m))
|
| 56 |
+
lines.append(prompt)
|
| 57 |
+
return "\n".join(lines)
|
| 58 |
+
|
| 59 |
+
async def generate_answer(
|
| 60 |
+
self, text: str, history: Optional[List[str]] = None, **extra: Any
|
| 61 |
+
) -> str:
|
| 62 |
+
full_prompt = self._build_inputs(text, history)
|
| 63 |
+
|
| 64 |
+
sp = self.SamplingParams(
|
| 65 |
+
temperature=self.temperature if self.temperature > 0 else 1.0,
|
| 66 |
+
top_p=self.top_p if self.temperature > 0 else 1.0,
|
| 67 |
+
max_tokens=extra.get("max_new_tokens", 512),
|
| 68 |
+
)
|
| 69 |
+
|
| 70 |
+
results = []
|
| 71 |
+
async for req_output in self.engine.generate(
|
| 72 |
+
full_prompt, sp, request_id="graphgen_req"
|
| 73 |
+
):
|
| 74 |
+
results = req_output.outputs
|
| 75 |
+
return results[-1].text
|
| 76 |
+
|
| 77 |
+
async def generate_topk_per_token(
|
| 78 |
+
self, text: str, history: Optional[List[str]] = None, **extra: Any
|
| 79 |
+
) -> List[Token]:
|
| 80 |
+
full_prompt = self._build_inputs(text, history)
|
| 81 |
+
|
| 82 |
+
sp = self.SamplingParams(
|
| 83 |
+
temperature=0,
|
| 84 |
+
max_tokens=1,
|
| 85 |
+
logprobs=self.topk,
|
| 86 |
+
)
|
| 87 |
+
|
| 88 |
+
results = []
|
| 89 |
+
async for req_output in self.engine.generate(
|
| 90 |
+
full_prompt, sp, request_id="graphgen_topk"
|
| 91 |
+
):
|
| 92 |
+
results = req_output.outputs
|
| 93 |
+
top_logprobs = results[-1].logprobs[0]
|
| 94 |
+
|
| 95 |
+
tokens = []
|
| 96 |
+
for _, logprob_obj in top_logprobs.items():
|
| 97 |
+
tok_str = logprob_obj.decoded_token
|
| 98 |
+
prob = float(logprob_obj.logprob.exp())
|
| 99 |
+
tokens.append(Token(tok_str, prob))
|
| 100 |
+
tokens.sort(key=lambda x: -x.prob)
|
| 101 |
+
return tokens
|
| 102 |
+
|
| 103 |
+
async def generate_inputs_prob(
|
| 104 |
+
self, text: str, history: Optional[List[str]] = None, **extra: Any
|
| 105 |
+
) -> List[Token]:
|
| 106 |
+
full_prompt = self._build_inputs(text, history)
|
| 107 |
+
|
| 108 |
+
# vLLM 没有现成的“mask 一个 token 再算 prob”接口,
|
| 109 |
+
# 我们采用最直观的方式:把 prompt 一次性送进去,打开
|
| 110 |
+
# prompt_logprobs=True,让 vLLM 返回 *输入部分* 每个位置的
|
| 111 |
+
# logprob,然后挑出对应 token 的概率即可。
|
| 112 |
+
sp = self.SamplingParams(
|
| 113 |
+
temperature=0,
|
| 114 |
+
max_tokens=0, # 不生成新 token
|
| 115 |
+
prompt_logprobs=1, # 只要 top-1 就够了
|
| 116 |
+
)
|
| 117 |
+
|
| 118 |
+
results = []
|
| 119 |
+
async for req_output in self.engine.generate(
|
| 120 |
+
full_prompt, sp, request_id="graphgen_prob"
|
| 121 |
+
):
|
| 122 |
+
results = req_output.outputs
|
| 123 |
+
|
| 124 |
+
# prompt_logprobs 是一个 list,长度 = prompt token 数,
|
| 125 |
+
# 每个元素是 dict{token_id: logprob_obj} 或 None(首个位置为 None)
|
| 126 |
+
prompt_logprobs = results[-1].prompt_logprobs
|
| 127 |
+
|
| 128 |
+
tokens = []
|
| 129 |
+
for _, logprob_dict in enumerate(prompt_logprobs):
|
| 130 |
+
if logprob_dict is None:
|
| 131 |
+
continue
|
| 132 |
+
# 这里每个 dict 只有 1 个 kv,因为 top-1
|
| 133 |
+
_, logprob_obj = next(iter(logprob_dict.items()))
|
| 134 |
+
tok_str = logprob_obj.decoded_token
|
| 135 |
+
prob = float(logprob_obj.logprob.exp())
|
| 136 |
+
tokens.append(Token(tok_str, prob))
|
| 137 |
+
return tokens
|
graphgen/models/llm/topk_token_model.py
DELETED
|
@@ -1,53 +0,0 @@
|
|
| 1 |
-
from abc import ABC, abstractmethod
|
| 2 |
-
from typing import List, Optional
|
| 3 |
-
|
| 4 |
-
from graphgen.bases import Token
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
class TopkTokenModel(ABC):
|
| 8 |
-
def __init__(
|
| 9 |
-
self,
|
| 10 |
-
do_sample: bool = False,
|
| 11 |
-
temperature: float = 0,
|
| 12 |
-
max_tokens: int = 4096,
|
| 13 |
-
repetition_penalty: float = 1.05,
|
| 14 |
-
num_beams: int = 1,
|
| 15 |
-
topk: int = 50,
|
| 16 |
-
topp: float = 0.95,
|
| 17 |
-
topk_per_token: int = 5,
|
| 18 |
-
):
|
| 19 |
-
self.do_sample = do_sample
|
| 20 |
-
self.temperature = temperature
|
| 21 |
-
self.max_tokens = max_tokens
|
| 22 |
-
self.repetition_penalty = repetition_penalty
|
| 23 |
-
self.num_beams = num_beams
|
| 24 |
-
self.topk = topk
|
| 25 |
-
self.topp = topp
|
| 26 |
-
self.topk_per_token = topk_per_token
|
| 27 |
-
|
| 28 |
-
@abstractmethod
|
| 29 |
-
async def generate_topk_per_token(self, text: str) -> List[Token]:
|
| 30 |
-
"""
|
| 31 |
-
Generate prob, text and candidates for each token of the model's output.
|
| 32 |
-
This function is used to visualize the inference process.
|
| 33 |
-
"""
|
| 34 |
-
raise NotImplementedError
|
| 35 |
-
|
| 36 |
-
@abstractmethod
|
| 37 |
-
async def generate_inputs_prob(
|
| 38 |
-
self, text: str, history: Optional[List[str]] = None
|
| 39 |
-
) -> List[Token]:
|
| 40 |
-
"""
|
| 41 |
-
Generate prob and text for each token of the input text.
|
| 42 |
-
This function is used to visualize the ppl.
|
| 43 |
-
"""
|
| 44 |
-
raise NotImplementedError
|
| 45 |
-
|
| 46 |
-
@abstractmethod
|
| 47 |
-
async def generate_answer(
|
| 48 |
-
self, text: str, history: Optional[List[str]] = None
|
| 49 |
-
) -> str:
|
| 50 |
-
"""
|
| 51 |
-
Generate answer from the model.
|
| 52 |
-
"""
|
| 53 |
-
raise NotImplementedError
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
graphgen/operators/__init__.py
CHANGED
|
@@ -1,5 +1,6 @@
|
|
| 1 |
from .build_kg import build_mm_kg, build_text_kg
|
| 2 |
from .generate import generate_qas
|
|
|
|
| 3 |
from .judge import judge_statement
|
| 4 |
from .partition import partition_kg
|
| 5 |
from .quiz import quiz
|
|
|
|
| 1 |
from .build_kg import build_mm_kg, build_text_kg
|
| 2 |
from .generate import generate_qas
|
| 3 |
+
from .init import init_llm
|
| 4 |
from .judge import judge_statement
|
| 5 |
from .partition import partition_kg
|
| 6 |
from .quiz import quiz
|
graphgen/operators/build_kg/build_mm_kg.py
CHANGED
|
@@ -3,14 +3,15 @@ from typing import List
|
|
| 3 |
|
| 4 |
import gradio as gr
|
| 5 |
|
|
|
|
| 6 |
from graphgen.bases.base_storage import BaseGraphStorage
|
| 7 |
from graphgen.bases.datatypes import Chunk
|
| 8 |
-
from graphgen.models import MMKGBuilder
|
| 9 |
from graphgen.utils import run_concurrent
|
| 10 |
|
| 11 |
|
| 12 |
async def build_mm_kg(
|
| 13 |
-
llm_client:
|
| 14 |
kg_instance: BaseGraphStorage,
|
| 15 |
chunks: List[Chunk],
|
| 16 |
progress_bar: gr.Progress = None,
|
|
|
|
| 3 |
|
| 4 |
import gradio as gr
|
| 5 |
|
| 6 |
+
from graphgen.bases import BaseLLMWrapper
|
| 7 |
from graphgen.bases.base_storage import BaseGraphStorage
|
| 8 |
from graphgen.bases.datatypes import Chunk
|
| 9 |
+
from graphgen.models import MMKGBuilder
|
| 10 |
from graphgen.utils import run_concurrent
|
| 11 |
|
| 12 |
|
| 13 |
async def build_mm_kg(
|
| 14 |
+
llm_client: BaseLLMWrapper,
|
| 15 |
kg_instance: BaseGraphStorage,
|
| 16 |
chunks: List[Chunk],
|
| 17 |
progress_bar: gr.Progress = None,
|
graphgen/operators/build_kg/build_text_kg.py
CHANGED
|
@@ -3,14 +3,15 @@ from typing import List
|
|
| 3 |
|
| 4 |
import gradio as gr
|
| 5 |
|
|
|
|
| 6 |
from graphgen.bases.base_storage import BaseGraphStorage
|
| 7 |
from graphgen.bases.datatypes import Chunk
|
| 8 |
-
from graphgen.models import LightRAGKGBuilder
|
| 9 |
from graphgen.utils import run_concurrent
|
| 10 |
|
| 11 |
|
| 12 |
async def build_text_kg(
|
| 13 |
-
llm_client:
|
| 14 |
kg_instance: BaseGraphStorage,
|
| 15 |
chunks: List[Chunk],
|
| 16 |
progress_bar: gr.Progress = None,
|
|
|
|
| 3 |
|
| 4 |
import gradio as gr
|
| 5 |
|
| 6 |
+
from graphgen.bases import BaseLLMWrapper
|
| 7 |
from graphgen.bases.base_storage import BaseGraphStorage
|
| 8 |
from graphgen.bases.datatypes import Chunk
|
| 9 |
+
from graphgen.models import LightRAGKGBuilder
|
| 10 |
from graphgen.utils import run_concurrent
|
| 11 |
|
| 12 |
|
| 13 |
async def build_text_kg(
|
| 14 |
+
llm_client: BaseLLMWrapper,
|
| 15 |
kg_instance: BaseGraphStorage,
|
| 16 |
chunks: List[Chunk],
|
| 17 |
progress_bar: gr.Progress = None,
|
graphgen/operators/generate/generate_qas.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
| 1 |
from typing import Any
|
| 2 |
|
| 3 |
-
from graphgen.bases import
|
| 4 |
from graphgen.models import (
|
| 5 |
AggregatedGenerator,
|
| 6 |
AtomicGenerator,
|
|
@@ -12,7 +12,7 @@ from graphgen.utils import logger, run_concurrent
|
|
| 12 |
|
| 13 |
|
| 14 |
async def generate_qas(
|
| 15 |
-
llm_client:
|
| 16 |
batches: list[
|
| 17 |
tuple[
|
| 18 |
list[tuple[str, dict]], list[tuple[Any, Any, dict] | tuple[Any, Any, Any]]
|
|
|
|
| 1 |
from typing import Any
|
| 2 |
|
| 3 |
+
from graphgen.bases import BaseLLMWrapper
|
| 4 |
from graphgen.models import (
|
| 5 |
AggregatedGenerator,
|
| 6 |
AtomicGenerator,
|
|
|
|
| 12 |
|
| 13 |
|
| 14 |
async def generate_qas(
|
| 15 |
+
llm_client: BaseLLMWrapper,
|
| 16 |
batches: list[
|
| 17 |
tuple[
|
| 18 |
list[tuple[str, dict]], list[tuple[Any, Any, dict] | tuple[Any, Any, Any]]
|
graphgen/operators/init/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from .init_llm import init_llm
|
graphgen/operators/init/init_llm.py
ADDED
|
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from typing import Any, Dict, Optional
|
| 3 |
+
|
| 4 |
+
from graphgen.bases import BaseLLMWrapper
|
| 5 |
+
from graphgen.models import Tokenizer
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class LLMFactory:
|
| 9 |
+
"""
|
| 10 |
+
A factory class to create LLM wrapper instances based on the specified backend.
|
| 11 |
+
Supported backends include:
|
| 12 |
+
- http_api: HTTPClient
|
| 13 |
+
- openai_api: OpenAIClient
|
| 14 |
+
- ollama_api: OllamaClient
|
| 15 |
+
- ollama: OllamaWrapper
|
| 16 |
+
- deepspeed: DeepSpeedWrapper
|
| 17 |
+
- huggingface: HuggingFaceWrapper
|
| 18 |
+
- tgi: TGIWrapper
|
| 19 |
+
- sglang: SGLangWrapper
|
| 20 |
+
- tensorrt: TensorRTWrapper
|
| 21 |
+
"""
|
| 22 |
+
|
| 23 |
+
@staticmethod
|
| 24 |
+
def create_llm_wrapper(backend: str, config: Dict[str, Any]) -> BaseLLMWrapper:
|
| 25 |
+
# add tokenizer
|
| 26 |
+
tokenizer: Tokenizer = Tokenizer(
|
| 27 |
+
os.environ.get("TOKENIZER_MODEL", "cl100k_base"),
|
| 28 |
+
)
|
| 29 |
+
config["tokenizer"] = tokenizer
|
| 30 |
+
if backend == "http_api":
|
| 31 |
+
from graphgen.models.llm.api.http_client import HTTPClient
|
| 32 |
+
|
| 33 |
+
return HTTPClient(**config)
|
| 34 |
+
if backend == "openai_api":
|
| 35 |
+
from graphgen.models.llm.api.openai_client import OpenAIClient
|
| 36 |
+
|
| 37 |
+
return OpenAIClient(**config)
|
| 38 |
+
if backend == "ollama_api":
|
| 39 |
+
from graphgen.models.llm.api.ollama_client import OllamaClient
|
| 40 |
+
|
| 41 |
+
return OllamaClient(**config)
|
| 42 |
+
if backend == "huggingface":
|
| 43 |
+
from graphgen.models.llm.local.hf_wrapper import HuggingFaceWrapper
|
| 44 |
+
|
| 45 |
+
return HuggingFaceWrapper(**config)
|
| 46 |
+
# if backend == "sglang":
|
| 47 |
+
# from graphgen.models.llm.local.sglang_wrapper import SGLangWrapper
|
| 48 |
+
#
|
| 49 |
+
# return SGLangWrapper(**config)
|
| 50 |
+
|
| 51 |
+
if backend == "vllm":
|
| 52 |
+
from graphgen.models.llm.local.vllm_wrapper import VLLMWrapper
|
| 53 |
+
|
| 54 |
+
return VLLMWrapper(**config)
|
| 55 |
+
|
| 56 |
+
raise NotImplementedError(f"Backend {backend} is not implemented yet.")
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def _load_env_group(prefix: str) -> Dict[str, Any]:
|
| 60 |
+
"""
|
| 61 |
+
Collect environment variables with the given prefix into a dictionary,
|
| 62 |
+
stripping the prefix from the keys.
|
| 63 |
+
"""
|
| 64 |
+
return {
|
| 65 |
+
k[len(prefix) :].lower(): v
|
| 66 |
+
for k, v in os.environ.items()
|
| 67 |
+
if k.startswith(prefix)
|
| 68 |
+
}
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def init_llm(model_type: str) -> Optional[BaseLLMWrapper]:
|
| 72 |
+
if model_type == "synthesizer":
|
| 73 |
+
prefix = "SYNTHESIZER_"
|
| 74 |
+
elif model_type == "trainee":
|
| 75 |
+
prefix = "TRAINEE_"
|
| 76 |
+
else:
|
| 77 |
+
raise NotImplementedError(f"Model type {model_type} is not implemented yet.")
|
| 78 |
+
config = _load_env_group(prefix)
|
| 79 |
+
# if config is empty, return None
|
| 80 |
+
if not config:
|
| 81 |
+
return None
|
| 82 |
+
backend = config.pop("backend")
|
| 83 |
+
llm_wrapper = LLMFactory.create_llm_wrapper(backend, config)
|
| 84 |
+
return llm_wrapper
|
graphgen/operators/judge.py
CHANGED
|
@@ -3,13 +3,14 @@ import math
|
|
| 3 |
|
| 4 |
from tqdm.asyncio import tqdm as tqdm_async
|
| 5 |
|
| 6 |
-
from graphgen.
|
|
|
|
| 7 |
from graphgen.templates import STATEMENT_JUDGEMENT_PROMPT
|
| 8 |
from graphgen.utils import logger, yes_no_loss_entropy
|
| 9 |
|
| 10 |
|
| 11 |
async def judge_statement( # pylint: disable=too-many-statements
|
| 12 |
-
trainee_llm_client:
|
| 13 |
graph_storage: NetworkXStorage,
|
| 14 |
rephrase_storage: JsonKVStorage,
|
| 15 |
re_judge: bool = False,
|
|
|
|
| 3 |
|
| 4 |
from tqdm.asyncio import tqdm as tqdm_async
|
| 5 |
|
| 6 |
+
from graphgen.bases import BaseLLMWrapper
|
| 7 |
+
from graphgen.models import JsonKVStorage, NetworkXStorage
|
| 8 |
from graphgen.templates import STATEMENT_JUDGEMENT_PROMPT
|
| 9 |
from graphgen.utils import logger, yes_no_loss_entropy
|
| 10 |
|
| 11 |
|
| 12 |
async def judge_statement( # pylint: disable=too-many-statements
|
| 13 |
+
trainee_llm_client: BaseLLMWrapper,
|
| 14 |
graph_storage: NetworkXStorage,
|
| 15 |
rephrase_storage: JsonKVStorage,
|
| 16 |
re_judge: bool = False,
|
graphgen/operators/quiz.py
CHANGED
|
@@ -3,13 +3,14 @@ from collections import defaultdict
|
|
| 3 |
|
| 4 |
from tqdm.asyncio import tqdm as tqdm_async
|
| 5 |
|
| 6 |
-
from graphgen.
|
|
|
|
| 7 |
from graphgen.templates import DESCRIPTION_REPHRASING_PROMPT
|
| 8 |
from graphgen.utils import detect_main_language, logger
|
| 9 |
|
| 10 |
|
| 11 |
async def quiz(
|
| 12 |
-
synth_llm_client:
|
| 13 |
graph_storage: NetworkXStorage,
|
| 14 |
rephrase_storage: JsonKVStorage,
|
| 15 |
max_samples: int = 1,
|
|
|
|
| 3 |
|
| 4 |
from tqdm.asyncio import tqdm as tqdm_async
|
| 5 |
|
| 6 |
+
from graphgen.bases import BaseLLMWrapper
|
| 7 |
+
from graphgen.models import JsonKVStorage, NetworkXStorage
|
| 8 |
from graphgen.templates import DESCRIPTION_REPHRASING_PROMPT
|
| 9 |
from graphgen.utils import detect_main_language, logger
|
| 10 |
|
| 11 |
|
| 12 |
async def quiz(
|
| 13 |
+
synth_llm_client: BaseLLMWrapper,
|
| 14 |
graph_storage: NetworkXStorage,
|
| 15 |
rephrase_storage: JsonKVStorage,
|
| 16 |
max_samples: int = 1,
|
requirements.txt
CHANGED
|
@@ -25,4 +25,4 @@ igraph
|
|
| 25 |
python-louvain
|
| 26 |
|
| 27 |
# For visualization
|
| 28 |
-
matplotlib
|
|
|
|
| 25 |
python-louvain
|
| 26 |
|
| 27 |
# For visualization
|
| 28 |
+
matplotlib
|