Spaces:
Running
Running
github-actions[bot]
commited on
Commit
·
8c66169
1
Parent(s):
0b9d8c7
Auto-sync from demo at Thu Oct 23 12:37:24 UTC 2025
Browse files- graphgen/bases/base_generator.py +2 -3
- graphgen/bases/base_kg_builder.py +4 -8
- graphgen/bases/base_partitioner.py +0 -2
- graphgen/bases/base_splitter.py +15 -8
- graphgen/bases/base_storage.py +0 -3
- graphgen/bases/base_tokenizer.py +2 -3
- graphgen/models/evaluator/base_evaluator.py +3 -4
- graphgen/models/evaluator/length_evaluator.py +3 -6
- graphgen/models/evaluator/mtld_evaluator.py +4 -8
- graphgen/models/generator/aggregated_generator.py +0 -2
- graphgen/models/generator/atomic_generator.py +0 -2
- graphgen/models/generator/cot_generator.py +0 -2
- graphgen/models/generator/multi_hop_generator.py +0 -2
- graphgen/models/generator/vqa_generator.py +0 -2
- graphgen/models/kg_builder/light_rag_kg_builder.py +3 -4
- graphgen/models/kg_builder/mm_kg_builder.py +1 -3
- graphgen/models/llm/topk_token_model.py +24 -12
- graphgen/models/partitioner/bfs_partitioner.py +0 -2
- graphgen/models/partitioner/dfs_partitioner.py +0 -2
- graphgen/models/partitioner/ece_partitioner.py +0 -2
- graphgen/models/partitioner/leiden_partitioner.py +0 -2
- graphgen/models/search/db/uniprot_search.py +0 -3
- graphgen/models/search/kg/wiki_search.py +0 -2
- graphgen/models/search/web/bing_search.py +2 -4
- graphgen/models/search/web/google_search.py +0 -3
- graphgen/models/storage/json_storage.py +2 -0
- graphgen/models/tokenizer/__init__.py +2 -6
- graphgen/models/tokenizer/hf_tokenizer.py +2 -3
- graphgen/models/tokenizer/tiktoken_tokenizer.py +2 -3
graphgen/bases/base_generator.py
CHANGED
|
@@ -1,17 +1,16 @@
|
|
| 1 |
from abc import ABC, abstractmethod
|
| 2 |
-
from dataclasses import dataclass
|
| 3 |
from typing import Any
|
| 4 |
|
| 5 |
from graphgen.bases.base_llm_client import BaseLLMClient
|
| 6 |
|
| 7 |
|
| 8 |
-
@dataclass
|
| 9 |
class BaseGenerator(ABC):
|
| 10 |
"""
|
| 11 |
Generate QAs based on given prompts.
|
| 12 |
"""
|
| 13 |
|
| 14 |
-
llm_client: BaseLLMClient
|
|
|
|
| 15 |
|
| 16 |
@staticmethod
|
| 17 |
@abstractmethod
|
|
|
|
| 1 |
from abc import ABC, abstractmethod
|
|
|
|
| 2 |
from typing import Any
|
| 3 |
|
| 4 |
from graphgen.bases.base_llm_client import BaseLLMClient
|
| 5 |
|
| 6 |
|
|
|
|
| 7 |
class BaseGenerator(ABC):
|
| 8 |
"""
|
| 9 |
Generate QAs based on given prompts.
|
| 10 |
"""
|
| 11 |
|
| 12 |
+
def __init__(self, llm_client: BaseLLMClient):
|
| 13 |
+
self.llm_client = llm_client
|
| 14 |
|
| 15 |
@staticmethod
|
| 16 |
@abstractmethod
|
graphgen/bases/base_kg_builder.py
CHANGED
|
@@ -1,6 +1,5 @@
|
|
| 1 |
from abc import ABC, abstractmethod
|
| 2 |
from collections import defaultdict
|
| 3 |
-
from dataclasses import dataclass, field
|
| 4 |
from typing import Dict, List, Tuple
|
| 5 |
|
| 6 |
from graphgen.bases.base_llm_client import BaseLLMClient
|
|
@@ -8,14 +7,11 @@ from graphgen.bases.base_storage import BaseGraphStorage
|
|
| 8 |
from graphgen.bases.datatypes import Chunk
|
| 9 |
|
| 10 |
|
| 11 |
-
@dataclass
|
| 12 |
class BaseKGBuilder(ABC):
|
| 13 |
-
llm_client: BaseLLMClient
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
default_factory=lambda: defaultdict(list)
|
| 18 |
-
)
|
| 19 |
|
| 20 |
@abstractmethod
|
| 21 |
async def extract(
|
|
|
|
| 1 |
from abc import ABC, abstractmethod
|
| 2 |
from collections import defaultdict
|
|
|
|
| 3 |
from typing import Dict, List, Tuple
|
| 4 |
|
| 5 |
from graphgen.bases.base_llm_client import BaseLLMClient
|
|
|
|
| 7 |
from graphgen.bases.datatypes import Chunk
|
| 8 |
|
| 9 |
|
|
|
|
| 10 |
class BaseKGBuilder(ABC):
|
| 11 |
+
def __init__(self, llm_client: BaseLLMClient):
|
| 12 |
+
self.llm_client = llm_client
|
| 13 |
+
self._nodes: Dict[str, List[dict]] = defaultdict(list)
|
| 14 |
+
self._edges: Dict[Tuple[str, str], List[dict]] = defaultdict(list)
|
|
|
|
|
|
|
| 15 |
|
| 16 |
@abstractmethod
|
| 17 |
async def extract(
|
graphgen/bases/base_partitioner.py
CHANGED
|
@@ -1,12 +1,10 @@
|
|
| 1 |
from abc import ABC, abstractmethod
|
| 2 |
-
from dataclasses import dataclass
|
| 3 |
from typing import Any, List
|
| 4 |
|
| 5 |
from graphgen.bases.base_storage import BaseGraphStorage
|
| 6 |
from graphgen.bases.datatypes import Community
|
| 7 |
|
| 8 |
|
| 9 |
-
@dataclass
|
| 10 |
class BasePartitioner(ABC):
|
| 11 |
@abstractmethod
|
| 12 |
async def partition(
|
|
|
|
| 1 |
from abc import ABC, abstractmethod
|
|
|
|
| 2 |
from typing import Any, List
|
| 3 |
|
| 4 |
from graphgen.bases.base_storage import BaseGraphStorage
|
| 5 |
from graphgen.bases.datatypes import Community
|
| 6 |
|
| 7 |
|
|
|
|
| 8 |
class BasePartitioner(ABC):
|
| 9 |
@abstractmethod
|
| 10 |
async def partition(
|
graphgen/bases/base_splitter.py
CHANGED
|
@@ -1,25 +1,32 @@
|
|
| 1 |
import copy
|
| 2 |
import re
|
| 3 |
from abc import ABC, abstractmethod
|
| 4 |
-
from dataclasses import dataclass
|
| 5 |
from typing import Callable, Iterable, List, Literal, Optional, Union
|
| 6 |
|
| 7 |
from graphgen.bases.datatypes import Chunk
|
| 8 |
from graphgen.utils import logger
|
| 9 |
|
| 10 |
|
| 11 |
-
@dataclass
|
| 12 |
class BaseSplitter(ABC):
|
| 13 |
"""
|
| 14 |
Abstract base class for splitting text into smaller chunks.
|
| 15 |
"""
|
| 16 |
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 23 |
|
| 24 |
@abstractmethod
|
| 25 |
def split_text(self, text: str) -> List[str]:
|
|
|
|
| 1 |
import copy
|
| 2 |
import re
|
| 3 |
from abc import ABC, abstractmethod
|
|
|
|
| 4 |
from typing import Callable, Iterable, List, Literal, Optional, Union
|
| 5 |
|
| 6 |
from graphgen.bases.datatypes import Chunk
|
| 7 |
from graphgen.utils import logger
|
| 8 |
|
| 9 |
|
|
|
|
| 10 |
class BaseSplitter(ABC):
|
| 11 |
"""
|
| 12 |
Abstract base class for splitting text into smaller chunks.
|
| 13 |
"""
|
| 14 |
|
| 15 |
+
def __init__(
|
| 16 |
+
self,
|
| 17 |
+
chunk_size: int = 1024,
|
| 18 |
+
chunk_overlap: int = 100,
|
| 19 |
+
length_function: Callable[[str], int] = len,
|
| 20 |
+
keep_separator: bool = False,
|
| 21 |
+
add_start_index: bool = False,
|
| 22 |
+
strip_whitespace: bool = True,
|
| 23 |
+
):
|
| 24 |
+
self.chunk_size = chunk_size
|
| 25 |
+
self.chunk_overlap = chunk_overlap
|
| 26 |
+
self.length_function = length_function
|
| 27 |
+
self.keep_separator = keep_separator
|
| 28 |
+
self.add_start_index = add_start_index
|
| 29 |
+
self.strip_whitespace = strip_whitespace
|
| 30 |
|
| 31 |
@abstractmethod
|
| 32 |
def split_text(self, text: str) -> List[str]:
|
graphgen/bases/base_storage.py
CHANGED
|
@@ -16,7 +16,6 @@ class StorageNameSpace:
|
|
| 16 |
"""commit the storage operations after querying"""
|
| 17 |
|
| 18 |
|
| 19 |
-
@dataclass
|
| 20 |
class BaseListStorage(Generic[T], StorageNameSpace):
|
| 21 |
async def all_items(self) -> list[T]:
|
| 22 |
raise NotImplementedError
|
|
@@ -34,7 +33,6 @@ class BaseListStorage(Generic[T], StorageNameSpace):
|
|
| 34 |
raise NotImplementedError
|
| 35 |
|
| 36 |
|
| 37 |
-
@dataclass
|
| 38 |
class BaseKVStorage(Generic[T], StorageNameSpace):
|
| 39 |
async def all_keys(self) -> list[str]:
|
| 40 |
raise NotImplementedError
|
|
@@ -58,7 +56,6 @@ class BaseKVStorage(Generic[T], StorageNameSpace):
|
|
| 58 |
raise NotImplementedError
|
| 59 |
|
| 60 |
|
| 61 |
-
@dataclass
|
| 62 |
class BaseGraphStorage(StorageNameSpace):
|
| 63 |
async def has_node(self, node_id: str) -> bool:
|
| 64 |
raise NotImplementedError
|
|
|
|
| 16 |
"""commit the storage operations after querying"""
|
| 17 |
|
| 18 |
|
|
|
|
| 19 |
class BaseListStorage(Generic[T], StorageNameSpace):
|
| 20 |
async def all_items(self) -> list[T]:
|
| 21 |
raise NotImplementedError
|
|
|
|
| 33 |
raise NotImplementedError
|
| 34 |
|
| 35 |
|
|
|
|
| 36 |
class BaseKVStorage(Generic[T], StorageNameSpace):
|
| 37 |
async def all_keys(self) -> list[str]:
|
| 38 |
raise NotImplementedError
|
|
|
|
| 56 |
raise NotImplementedError
|
| 57 |
|
| 58 |
|
|
|
|
| 59 |
class BaseGraphStorage(StorageNameSpace):
|
| 60 |
async def has_node(self, node_id: str) -> bool:
|
| 61 |
raise NotImplementedError
|
graphgen/bases/base_tokenizer.py
CHANGED
|
@@ -1,13 +1,12 @@
|
|
| 1 |
from __future__ import annotations
|
| 2 |
|
| 3 |
from abc import ABC, abstractmethod
|
| 4 |
-
from dataclasses import dataclass
|
| 5 |
from typing import List
|
| 6 |
|
| 7 |
|
| 8 |
-
@dataclass
|
| 9 |
class BaseTokenizer(ABC):
|
| 10 |
-
model_name: str = "cl100k_base"
|
|
|
|
| 11 |
|
| 12 |
@abstractmethod
|
| 13 |
def encode(self, text: str) -> List[int]:
|
|
|
|
| 1 |
from __future__ import annotations
|
| 2 |
|
| 3 |
from abc import ABC, abstractmethod
|
|
|
|
| 4 |
from typing import List
|
| 5 |
|
| 6 |
|
|
|
|
| 7 |
class BaseTokenizer(ABC):
|
| 8 |
+
def __init__(self, model_name: str = "cl100k_base"):
|
| 9 |
+
self.model_name = model_name
|
| 10 |
|
| 11 |
@abstractmethod
|
| 12 |
def encode(self, text: str) -> List[int]:
|
graphgen/models/evaluator/base_evaluator.py
CHANGED
|
@@ -1,5 +1,4 @@
|
|
| 1 |
import asyncio
|
| 2 |
-
from dataclasses import dataclass
|
| 3 |
|
| 4 |
from tqdm.asyncio import tqdm as tqdm_async
|
| 5 |
|
|
@@ -7,10 +6,10 @@ from graphgen.bases.datatypes import QAPair
|
|
| 7 |
from graphgen.utils import create_event_loop
|
| 8 |
|
| 9 |
|
| 10 |
-
@dataclass
|
| 11 |
class BaseEvaluator:
|
| 12 |
-
max_concurrent: int = 100
|
| 13 |
-
|
|
|
|
| 14 |
|
| 15 |
def evaluate(self, pairs: list[QAPair]) -> list[float]:
|
| 16 |
"""
|
|
|
|
| 1 |
import asyncio
|
|
|
|
| 2 |
|
| 3 |
from tqdm.asyncio import tqdm as tqdm_async
|
| 4 |
|
|
|
|
| 6 |
from graphgen.utils import create_event_loop
|
| 7 |
|
| 8 |
|
|
|
|
| 9 |
class BaseEvaluator:
|
| 10 |
+
def __init__(self, max_concurrent: int = 100):
|
| 11 |
+
self.max_concurrent = max_concurrent
|
| 12 |
+
self.results: list[float] = None
|
| 13 |
|
| 14 |
def evaluate(self, pairs: list[QAPair]) -> list[float]:
|
| 15 |
"""
|
graphgen/models/evaluator/length_evaluator.py
CHANGED
|
@@ -1,16 +1,13 @@
|
|
| 1 |
-
from dataclasses import dataclass
|
| 2 |
-
|
| 3 |
from graphgen.bases.datatypes import QAPair
|
| 4 |
from graphgen.models.evaluator.base_evaluator import BaseEvaluator
|
| 5 |
from graphgen.models.tokenizer import Tokenizer
|
| 6 |
from graphgen.utils import create_event_loop
|
| 7 |
|
| 8 |
|
| 9 |
-
@dataclass
|
| 10 |
class LengthEvaluator(BaseEvaluator):
|
| 11 |
-
tokenizer_name: str = "cl100k_base"
|
| 12 |
-
|
| 13 |
-
|
| 14 |
self.tokenizer = Tokenizer(model_name=self.tokenizer_name)
|
| 15 |
|
| 16 |
async def evaluate_single(self, pair: QAPair) -> float:
|
|
|
|
|
|
|
|
|
|
| 1 |
from graphgen.bases.datatypes import QAPair
|
| 2 |
from graphgen.models.evaluator.base_evaluator import BaseEvaluator
|
| 3 |
from graphgen.models.tokenizer import Tokenizer
|
| 4 |
from graphgen.utils import create_event_loop
|
| 5 |
|
| 6 |
|
|
|
|
| 7 |
class LengthEvaluator(BaseEvaluator):
|
| 8 |
+
def __init__(self, tokenizer_name: str = "cl100k_base", max_concurrent: int = 100):
|
| 9 |
+
super().__init__(max_concurrent)
|
| 10 |
+
self.tokenizer_name = tokenizer_name
|
| 11 |
self.tokenizer = Tokenizer(model_name=self.tokenizer_name)
|
| 12 |
|
| 13 |
async def evaluate_single(self, pair: QAPair) -> float:
|
graphgen/models/evaluator/mtld_evaluator.py
CHANGED
|
@@ -1,4 +1,3 @@
|
|
| 1 |
-
from dataclasses import dataclass, field
|
| 2 |
from typing import Set
|
| 3 |
|
| 4 |
from graphgen.bases.datatypes import QAPair
|
|
@@ -8,18 +7,15 @@ from graphgen.utils import NLTKHelper, create_event_loop, detect_main_language
|
|
| 8 |
nltk_helper = NLTKHelper()
|
| 9 |
|
| 10 |
|
| 11 |
-
@dataclass
|
| 12 |
class MTLDEvaluator(BaseEvaluator):
|
| 13 |
"""
|
| 14 |
衡量文本词汇多样性的指标
|
| 15 |
"""
|
| 16 |
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
default_factory=lambda: set(nltk_helper.get_stopwords("chinese"))
|
| 22 |
-
)
|
| 23 |
|
| 24 |
async def evaluate_single(self, pair: QAPair) -> float:
|
| 25 |
loop = create_event_loop()
|
|
|
|
|
|
|
| 1 |
from typing import Set
|
| 2 |
|
| 3 |
from graphgen.bases.datatypes import QAPair
|
|
|
|
| 7 |
nltk_helper = NLTKHelper()
|
| 8 |
|
| 9 |
|
|
|
|
| 10 |
class MTLDEvaluator(BaseEvaluator):
|
| 11 |
"""
|
| 12 |
衡量文本词汇多样性的指标
|
| 13 |
"""
|
| 14 |
|
| 15 |
+
def __init__(self, max_concurrent: int = 100):
|
| 16 |
+
super().__init__(max_concurrent)
|
| 17 |
+
self.stopwords_en: Set[str] = set(nltk_helper.get_stopwords("english"))
|
| 18 |
+
self.stopwords_zh: Set[str] = set(nltk_helper.get_stopwords("chinese"))
|
|
|
|
|
|
|
| 19 |
|
| 20 |
async def evaluate_single(self, pair: QAPair) -> float:
|
| 21 |
loop = create_event_loop()
|
graphgen/models/generator/aggregated_generator.py
CHANGED
|
@@ -1,4 +1,3 @@
|
|
| 1 |
-
from dataclasses import dataclass
|
| 2 |
from typing import Any
|
| 3 |
|
| 4 |
from graphgen.bases import BaseGenerator
|
|
@@ -6,7 +5,6 @@ from graphgen.templates import AGGREGATED_GENERATION_PROMPT
|
|
| 6 |
from graphgen.utils import compute_content_hash, detect_main_language, logger
|
| 7 |
|
| 8 |
|
| 9 |
-
@dataclass
|
| 10 |
class AggregatedGenerator(BaseGenerator):
|
| 11 |
"""
|
| 12 |
Aggregated Generator follows a TWO-STEP process:
|
|
|
|
|
|
|
| 1 |
from typing import Any
|
| 2 |
|
| 3 |
from graphgen.bases import BaseGenerator
|
|
|
|
| 5 |
from graphgen.utils import compute_content_hash, detect_main_language, logger
|
| 6 |
|
| 7 |
|
|
|
|
| 8 |
class AggregatedGenerator(BaseGenerator):
|
| 9 |
"""
|
| 10 |
Aggregated Generator follows a TWO-STEP process:
|
graphgen/models/generator/atomic_generator.py
CHANGED
|
@@ -1,4 +1,3 @@
|
|
| 1 |
-
from dataclasses import dataclass
|
| 2 |
from typing import Any
|
| 3 |
|
| 4 |
from graphgen.bases import BaseGenerator
|
|
@@ -6,7 +5,6 @@ from graphgen.templates import ATOMIC_GENERATION_PROMPT
|
|
| 6 |
from graphgen.utils import compute_content_hash, detect_main_language, logger
|
| 7 |
|
| 8 |
|
| 9 |
-
@dataclass
|
| 10 |
class AtomicGenerator(BaseGenerator):
|
| 11 |
@staticmethod
|
| 12 |
def build_prompt(
|
|
|
|
|
|
|
| 1 |
from typing import Any
|
| 2 |
|
| 3 |
from graphgen.bases import BaseGenerator
|
|
|
|
| 5 |
from graphgen.utils import compute_content_hash, detect_main_language, logger
|
| 6 |
|
| 7 |
|
|
|
|
| 8 |
class AtomicGenerator(BaseGenerator):
|
| 9 |
@staticmethod
|
| 10 |
def build_prompt(
|
graphgen/models/generator/cot_generator.py
CHANGED
|
@@ -1,4 +1,3 @@
|
|
| 1 |
-
from dataclasses import dataclass
|
| 2 |
from typing import Any
|
| 3 |
|
| 4 |
from graphgen.bases import BaseGenerator
|
|
@@ -6,7 +5,6 @@ from graphgen.templates import COT_GENERATION_PROMPT
|
|
| 6 |
from graphgen.utils import compute_content_hash, detect_main_language, logger
|
| 7 |
|
| 8 |
|
| 9 |
-
@dataclass
|
| 10 |
class CoTGenerator(BaseGenerator):
|
| 11 |
@staticmethod
|
| 12 |
def build_prompt(
|
|
|
|
|
|
|
| 1 |
from typing import Any
|
| 2 |
|
| 3 |
from graphgen.bases import BaseGenerator
|
|
|
|
| 5 |
from graphgen.utils import compute_content_hash, detect_main_language, logger
|
| 6 |
|
| 7 |
|
|
|
|
| 8 |
class CoTGenerator(BaseGenerator):
|
| 9 |
@staticmethod
|
| 10 |
def build_prompt(
|
graphgen/models/generator/multi_hop_generator.py
CHANGED
|
@@ -1,4 +1,3 @@
|
|
| 1 |
-
from dataclasses import dataclass
|
| 2 |
from typing import Any
|
| 3 |
|
| 4 |
from graphgen.bases import BaseGenerator
|
|
@@ -6,7 +5,6 @@ from graphgen.templates import MULTI_HOP_GENERATION_PROMPT
|
|
| 6 |
from graphgen.utils import compute_content_hash, detect_main_language, logger
|
| 7 |
|
| 8 |
|
| 9 |
-
@dataclass
|
| 10 |
class MultiHopGenerator(BaseGenerator):
|
| 11 |
@staticmethod
|
| 12 |
def build_prompt(
|
|
|
|
|
|
|
| 1 |
from typing import Any
|
| 2 |
|
| 3 |
from graphgen.bases import BaseGenerator
|
|
|
|
| 5 |
from graphgen.utils import compute_content_hash, detect_main_language, logger
|
| 6 |
|
| 7 |
|
|
|
|
| 8 |
class MultiHopGenerator(BaseGenerator):
|
| 9 |
@staticmethod
|
| 10 |
def build_prompt(
|
graphgen/models/generator/vqa_generator.py
CHANGED
|
@@ -1,4 +1,3 @@
|
|
| 1 |
-
from dataclasses import dataclass
|
| 2 |
from typing import Any
|
| 3 |
|
| 4 |
from graphgen.bases import BaseGenerator
|
|
@@ -6,7 +5,6 @@ from graphgen.templates import VQA_GENERATION_PROMPT
|
|
| 6 |
from graphgen.utils import compute_content_hash, detect_main_language, logger
|
| 7 |
|
| 8 |
|
| 9 |
-
@dataclass
|
| 10 |
class VQAGenerator(BaseGenerator):
|
| 11 |
@staticmethod
|
| 12 |
def build_prompt(
|
|
|
|
|
|
|
| 1 |
from typing import Any
|
| 2 |
|
| 3 |
from graphgen.bases import BaseGenerator
|
|
|
|
| 5 |
from graphgen.utils import compute_content_hash, detect_main_language, logger
|
| 6 |
|
| 7 |
|
|
|
|
| 8 |
class VQAGenerator(BaseGenerator):
|
| 9 |
@staticmethod
|
| 10 |
def build_prompt(
|
graphgen/models/kg_builder/light_rag_kg_builder.py
CHANGED
|
@@ -1,6 +1,5 @@
|
|
| 1 |
import re
|
| 2 |
from collections import Counter, defaultdict
|
| 3 |
-
from dataclasses import dataclass
|
| 4 |
from typing import Dict, List, Tuple
|
| 5 |
|
| 6 |
from graphgen.bases import BaseGraphStorage, BaseKGBuilder, BaseLLMClient, Chunk
|
|
@@ -15,10 +14,10 @@ from graphgen.utils import (
|
|
| 15 |
)
|
| 16 |
|
| 17 |
|
| 18 |
-
@dataclass
|
| 19 |
class LightRAGKGBuilder(BaseKGBuilder):
|
| 20 |
-
llm_client: BaseLLMClient =
|
| 21 |
-
|
|
|
|
| 22 |
|
| 23 |
async def extract(
|
| 24 |
self, chunk: Chunk
|
|
|
|
| 1 |
import re
|
| 2 |
from collections import Counter, defaultdict
|
|
|
|
| 3 |
from typing import Dict, List, Tuple
|
| 4 |
|
| 5 |
from graphgen.bases import BaseGraphStorage, BaseKGBuilder, BaseLLMClient, Chunk
|
|
|
|
| 14 |
)
|
| 15 |
|
| 16 |
|
|
|
|
| 17 |
class LightRAGKGBuilder(BaseKGBuilder):
|
| 18 |
+
def __init__(self, llm_client: BaseLLMClient, max_loop: int = 3):
|
| 19 |
+
super().__init__(llm_client)
|
| 20 |
+
self.max_loop = max_loop
|
| 21 |
|
| 22 |
async def extract(
|
| 23 |
self, chunk: Chunk
|
graphgen/models/kg_builder/mm_kg_builder.py
CHANGED
|
@@ -2,7 +2,7 @@ import re
|
|
| 2 |
from collections import defaultdict
|
| 3 |
from typing import Dict, List, Tuple
|
| 4 |
|
| 5 |
-
from graphgen.bases import
|
| 6 |
from graphgen.templates import MMKG_EXTRACTION_PROMPT
|
| 7 |
from graphgen.utils import (
|
| 8 |
detect_main_language,
|
|
@@ -16,8 +16,6 @@ from .light_rag_kg_builder import LightRAGKGBuilder
|
|
| 16 |
|
| 17 |
|
| 18 |
class MMKGBuilder(LightRAGKGBuilder):
|
| 19 |
-
llm_client: BaseLLMClient = None
|
| 20 |
-
|
| 21 |
async def extract(
|
| 22 |
self, chunk: Chunk
|
| 23 |
) -> Tuple[Dict[str, List[dict]], Dict[Tuple[str, str], List[dict]]]:
|
|
|
|
| 2 |
from collections import defaultdict
|
| 3 |
from typing import Dict, List, Tuple
|
| 4 |
|
| 5 |
+
from graphgen.bases import Chunk
|
| 6 |
from graphgen.templates import MMKG_EXTRACTION_PROMPT
|
| 7 |
from graphgen.utils import (
|
| 8 |
detect_main_language,
|
|
|
|
| 16 |
|
| 17 |
|
| 18 |
class MMKGBuilder(LightRAGKGBuilder):
|
|
|
|
|
|
|
| 19 |
async def extract(
|
| 20 |
self, chunk: Chunk
|
| 21 |
) -> Tuple[Dict[str, List[dict]], Dict[Tuple[str, str], List[dict]]]:
|
graphgen/models/llm/topk_token_model.py
CHANGED
|
@@ -1,21 +1,31 @@
|
|
| 1 |
-
from
|
| 2 |
from typing import List, Optional
|
| 3 |
|
| 4 |
from graphgen.bases import Token
|
| 5 |
|
| 6 |
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 18 |
|
|
|
|
| 19 |
async def generate_topk_per_token(self, text: str) -> List[Token]:
|
| 20 |
"""
|
| 21 |
Generate prob, text and candidates for each token of the model's output.
|
|
@@ -23,6 +33,7 @@ class TopkTokenModel:
|
|
| 23 |
"""
|
| 24 |
raise NotImplementedError
|
| 25 |
|
|
|
|
| 26 |
async def generate_inputs_prob(
|
| 27 |
self, text: str, history: Optional[List[str]] = None
|
| 28 |
) -> List[Token]:
|
|
@@ -32,6 +43,7 @@ class TopkTokenModel:
|
|
| 32 |
"""
|
| 33 |
raise NotImplementedError
|
| 34 |
|
|
|
|
| 35 |
async def generate_answer(
|
| 36 |
self, text: str, history: Optional[List[str]] = None
|
| 37 |
) -> str:
|
|
|
|
| 1 |
+
from abc import ABC, abstractmethod
|
| 2 |
from typing import List, Optional
|
| 3 |
|
| 4 |
from graphgen.bases import Token
|
| 5 |
|
| 6 |
|
| 7 |
+
class TopkTokenModel(ABC):
|
| 8 |
+
def __init__(
|
| 9 |
+
self,
|
| 10 |
+
do_sample: bool = False,
|
| 11 |
+
temperature: float = 0,
|
| 12 |
+
max_tokens: int = 4096,
|
| 13 |
+
repetition_penalty: float = 1.05,
|
| 14 |
+
num_beams: int = 1,
|
| 15 |
+
topk: int = 50,
|
| 16 |
+
topp: float = 0.95,
|
| 17 |
+
topk_per_token: int = 5,
|
| 18 |
+
):
|
| 19 |
+
self.do_sample = do_sample
|
| 20 |
+
self.temperature = temperature
|
| 21 |
+
self.max_tokens = max_tokens
|
| 22 |
+
self.repetition_penalty = repetition_penalty
|
| 23 |
+
self.num_beams = num_beams
|
| 24 |
+
self.topk = topk
|
| 25 |
+
self.topp = topp
|
| 26 |
+
self.topk_per_token = topk_per_token
|
| 27 |
|
| 28 |
+
@abstractmethod
|
| 29 |
async def generate_topk_per_token(self, text: str) -> List[Token]:
|
| 30 |
"""
|
| 31 |
Generate prob, text and candidates for each token of the model's output.
|
|
|
|
| 33 |
"""
|
| 34 |
raise NotImplementedError
|
| 35 |
|
| 36 |
+
@abstractmethod
|
| 37 |
async def generate_inputs_prob(
|
| 38 |
self, text: str, history: Optional[List[str]] = None
|
| 39 |
) -> List[Token]:
|
|
|
|
| 43 |
"""
|
| 44 |
raise NotImplementedError
|
| 45 |
|
| 46 |
+
@abstractmethod
|
| 47 |
async def generate_answer(
|
| 48 |
self, text: str, history: Optional[List[str]] = None
|
| 49 |
) -> str:
|
graphgen/models/partitioner/bfs_partitioner.py
CHANGED
|
@@ -1,6 +1,5 @@
|
|
| 1 |
import random
|
| 2 |
from collections import deque
|
| 3 |
-
from dataclasses import dataclass
|
| 4 |
from typing import Any, List
|
| 5 |
|
| 6 |
from graphgen.bases import BaseGraphStorage, BasePartitioner
|
|
@@ -10,7 +9,6 @@ NODE_UNIT: str = "n"
|
|
| 10 |
EDGE_UNIT: str = "e"
|
| 11 |
|
| 12 |
|
| 13 |
-
@dataclass
|
| 14 |
class BFSPartitioner(BasePartitioner):
|
| 15 |
"""
|
| 16 |
BFS partitioner that partitions the graph into communities of a fixed size.
|
|
|
|
| 1 |
import random
|
| 2 |
from collections import deque
|
|
|
|
| 3 |
from typing import Any, List
|
| 4 |
|
| 5 |
from graphgen.bases import BaseGraphStorage, BasePartitioner
|
|
|
|
| 9 |
EDGE_UNIT: str = "e"
|
| 10 |
|
| 11 |
|
|
|
|
| 12 |
class BFSPartitioner(BasePartitioner):
|
| 13 |
"""
|
| 14 |
BFS partitioner that partitions the graph into communities of a fixed size.
|
graphgen/models/partitioner/dfs_partitioner.py
CHANGED
|
@@ -1,5 +1,4 @@
|
|
| 1 |
import random
|
| 2 |
-
from dataclasses import dataclass
|
| 3 |
from typing import Any, List
|
| 4 |
|
| 5 |
from graphgen.bases import BaseGraphStorage, BasePartitioner
|
|
@@ -9,7 +8,6 @@ NODE_UNIT: str = "n"
|
|
| 9 |
EDGE_UNIT: str = "e"
|
| 10 |
|
| 11 |
|
| 12 |
-
@dataclass
|
| 13 |
class DFSPartitioner(BasePartitioner):
|
| 14 |
"""
|
| 15 |
DFS partitioner that partitions the graph into communities of a fixed size.
|
|
|
|
| 1 |
import random
|
|
|
|
| 2 |
from typing import Any, List
|
| 3 |
|
| 4 |
from graphgen.bases import BaseGraphStorage, BasePartitioner
|
|
|
|
| 8 |
EDGE_UNIT: str = "e"
|
| 9 |
|
| 10 |
|
|
|
|
| 11 |
class DFSPartitioner(BasePartitioner):
|
| 12 |
"""
|
| 13 |
DFS partitioner that partitions the graph into communities of a fixed size.
|
graphgen/models/partitioner/ece_partitioner.py
CHANGED
|
@@ -1,6 +1,5 @@
|
|
| 1 |
import asyncio
|
| 2 |
import random
|
| 3 |
-
from dataclasses import dataclass
|
| 4 |
from typing import Any, Dict, List, Optional, Set, Tuple
|
| 5 |
|
| 6 |
from tqdm.asyncio import tqdm as tqdm_async
|
|
@@ -13,7 +12,6 @@ NODE_UNIT: str = "n"
|
|
| 13 |
EDGE_UNIT: str = "e"
|
| 14 |
|
| 15 |
|
| 16 |
-
@dataclass
|
| 17 |
class ECEPartitioner(BFSPartitioner):
|
| 18 |
"""
|
| 19 |
ECE partitioner that partitions the graph into communities based on Expected Calibration Error (ECE).
|
|
|
|
| 1 |
import asyncio
|
| 2 |
import random
|
|
|
|
| 3 |
from typing import Any, Dict, List, Optional, Set, Tuple
|
| 4 |
|
| 5 |
from tqdm.asyncio import tqdm as tqdm_async
|
|
|
|
| 12 |
EDGE_UNIT: str = "e"
|
| 13 |
|
| 14 |
|
|
|
|
| 15 |
class ECEPartitioner(BFSPartitioner):
|
| 16 |
"""
|
| 17 |
ECE partitioner that partitions the graph into communities based on Expected Calibration Error (ECE).
|
graphgen/models/partitioner/leiden_partitioner.py
CHANGED
|
@@ -1,5 +1,4 @@
|
|
| 1 |
from collections import defaultdict
|
| 2 |
-
from dataclasses import dataclass
|
| 3 |
from typing import Any, Dict, List, Set, Tuple
|
| 4 |
|
| 5 |
import igraph as ig
|
|
@@ -9,7 +8,6 @@ from graphgen.bases import BaseGraphStorage, BasePartitioner
|
|
| 9 |
from graphgen.bases.datatypes import Community
|
| 10 |
|
| 11 |
|
| 12 |
-
@dataclass
|
| 13 |
class LeidenPartitioner(BasePartitioner):
|
| 14 |
"""
|
| 15 |
Leiden partitioner that partitions the graph into communities using the Leiden algorithm.
|
|
|
|
| 1 |
from collections import defaultdict
|
|
|
|
| 2 |
from typing import Any, Dict, List, Set, Tuple
|
| 3 |
|
| 4 |
import igraph as ig
|
|
|
|
| 8 |
from graphgen.bases.datatypes import Community
|
| 9 |
|
| 10 |
|
|
|
|
| 11 |
class LeidenPartitioner(BasePartitioner):
|
| 12 |
"""
|
| 13 |
Leiden partitioner that partitions the graph into communities using the Leiden algorithm.
|
graphgen/models/search/db/uniprot_search.py
CHANGED
|
@@ -1,5 +1,3 @@
|
|
| 1 |
-
from dataclasses import dataclass
|
| 2 |
-
|
| 3 |
import requests
|
| 4 |
from fastapi import HTTPException
|
| 5 |
|
|
@@ -8,7 +6,6 @@ from graphgen.utils import logger
|
|
| 8 |
UNIPROT_BASE = "https://rest.uniprot.org/uniprotkb/search"
|
| 9 |
|
| 10 |
|
| 11 |
-
@dataclass
|
| 12 |
class UniProtSearch:
|
| 13 |
"""
|
| 14 |
UniProt Search client to search with UniProt.
|
|
|
|
|
|
|
|
|
|
| 1 |
import requests
|
| 2 |
from fastapi import HTTPException
|
| 3 |
|
|
|
|
| 6 |
UNIPROT_BASE = "https://rest.uniprot.org/uniprotkb/search"
|
| 7 |
|
| 8 |
|
|
|
|
| 9 |
class UniProtSearch:
|
| 10 |
"""
|
| 11 |
UniProt Search client to search with UniProt.
|
graphgen/models/search/kg/wiki_search.py
CHANGED
|
@@ -1,4 +1,3 @@
|
|
| 1 |
-
from dataclasses import dataclass
|
| 2 |
from typing import List, Union
|
| 3 |
|
| 4 |
import wikipedia
|
|
@@ -7,7 +6,6 @@ from wikipedia import set_lang
|
|
| 7 |
from graphgen.utils import detect_main_language, logger
|
| 8 |
|
| 9 |
|
| 10 |
-
@dataclass
|
| 11 |
class WikiSearch:
|
| 12 |
@staticmethod
|
| 13 |
def set_language(language: str):
|
|
|
|
|
|
|
| 1 |
from typing import List, Union
|
| 2 |
|
| 3 |
import wikipedia
|
|
|
|
| 6 |
from graphgen.utils import detect_main_language, logger
|
| 7 |
|
| 8 |
|
|
|
|
| 9 |
class WikiSearch:
|
| 10 |
@staticmethod
|
| 11 |
def set_language(language: str):
|
graphgen/models/search/web/bing_search.py
CHANGED
|
@@ -1,5 +1,3 @@
|
|
| 1 |
-
from dataclasses import dataclass
|
| 2 |
-
|
| 3 |
import requests
|
| 4 |
from fastapi import HTTPException
|
| 5 |
|
|
@@ -9,13 +7,13 @@ BING_SEARCH_V7_ENDPOINT = "https://api.bing.microsoft.com/v7.0/search"
|
|
| 9 |
BING_MKT = "en-US"
|
| 10 |
|
| 11 |
|
| 12 |
-
@dataclass
|
| 13 |
class BingSearch:
|
| 14 |
"""
|
| 15 |
Bing Search client to search with Bing.
|
| 16 |
"""
|
| 17 |
|
| 18 |
-
subscription_key: str
|
|
|
|
| 19 |
|
| 20 |
def search(self, query: str, num_results: int = 1):
|
| 21 |
"""
|
|
|
|
|
|
|
|
|
|
| 1 |
import requests
|
| 2 |
from fastapi import HTTPException
|
| 3 |
|
|
|
|
| 7 |
BING_MKT = "en-US"
|
| 8 |
|
| 9 |
|
|
|
|
| 10 |
class BingSearch:
|
| 11 |
"""
|
| 12 |
Bing Search client to search with Bing.
|
| 13 |
"""
|
| 14 |
|
| 15 |
+
def __init__(self, subscription_key: str):
|
| 16 |
+
self.subscription_key = subscription_key
|
| 17 |
|
| 18 |
def search(self, query: str, num_results: int = 1):
|
| 19 |
"""
|
graphgen/models/search/web/google_search.py
CHANGED
|
@@ -1,5 +1,3 @@
|
|
| 1 |
-
from dataclasses import dataclass
|
| 2 |
-
|
| 3 |
import requests
|
| 4 |
from fastapi import HTTPException
|
| 5 |
|
|
@@ -8,7 +6,6 @@ from graphgen.utils import logger
|
|
| 8 |
GOOGLE_SEARCH_ENDPOINT = "https://customsearch.googleapis.com/customsearch/v1"
|
| 9 |
|
| 10 |
|
| 11 |
-
@dataclass
|
| 12 |
class GoogleSearch:
|
| 13 |
def __init__(self, subscription_key: str, cx: str):
|
| 14 |
"""
|
|
|
|
|
|
|
|
|
|
| 1 |
import requests
|
| 2 |
from fastapi import HTTPException
|
| 3 |
|
|
|
|
| 6 |
GOOGLE_SEARCH_ENDPOINT = "https://customsearch.googleapis.com/customsearch/v1"
|
| 7 |
|
| 8 |
|
|
|
|
| 9 |
class GoogleSearch:
|
| 10 |
def __init__(self, subscription_key: str, cx: str):
|
| 11 |
"""
|
graphgen/models/storage/json_storage.py
CHANGED
|
@@ -53,6 +53,8 @@ class JsonKVStorage(BaseKVStorage):
|
|
| 53 |
|
| 54 |
@dataclass
|
| 55 |
class JsonListStorage(BaseListStorage):
|
|
|
|
|
|
|
| 56 |
_data: list = None
|
| 57 |
|
| 58 |
def __post_init__(self):
|
|
|
|
| 53 |
|
| 54 |
@dataclass
|
| 55 |
class JsonListStorage(BaseListStorage):
|
| 56 |
+
working_dir: str = None
|
| 57 |
+
namespace: str = None
|
| 58 |
_data: list = None
|
| 59 |
|
| 60 |
def __post_init__(self):
|
graphgen/models/tokenizer/__init__.py
CHANGED
|
@@ -1,4 +1,3 @@
|
|
| 1 |
-
from dataclasses import dataclass, field
|
| 2 |
from typing import List
|
| 3 |
|
| 4 |
from graphgen.bases import BaseTokenizer
|
|
@@ -30,16 +29,13 @@ def get_tokenizer_impl(tokenizer_name: str = "cl100k_base") -> BaseTokenizer:
|
|
| 30 |
)
|
| 31 |
|
| 32 |
|
| 33 |
-
@dataclass
|
| 34 |
class Tokenizer(BaseTokenizer):
|
| 35 |
"""
|
| 36 |
Encapsulates different tokenization implementations based on the specified model name.
|
| 37 |
"""
|
| 38 |
|
| 39 |
-
model_name: str = "cl100k_base"
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
def __post_init__(self):
|
| 43 |
if not self.model_name:
|
| 44 |
raise ValueError("TOKENIZER_MODEL must be specified in the ENV variables.")
|
| 45 |
self._impl = get_tokenizer_impl(self.model_name)
|
|
|
|
|
|
|
| 1 |
from typing import List
|
| 2 |
|
| 3 |
from graphgen.bases import BaseTokenizer
|
|
|
|
| 29 |
)
|
| 30 |
|
| 31 |
|
|
|
|
| 32 |
class Tokenizer(BaseTokenizer):
|
| 33 |
"""
|
| 34 |
Encapsulates different tokenization implementations based on the specified model name.
|
| 35 |
"""
|
| 36 |
|
| 37 |
+
def __init__(self, model_name: str = "cl100k_base"):
|
| 38 |
+
super().__init__(model_name)
|
|
|
|
|
|
|
| 39 |
if not self.model_name:
|
| 40 |
raise ValueError("TOKENIZER_MODEL must be specified in the ENV variables.")
|
| 41 |
self._impl = get_tokenizer_impl(self.model_name)
|
graphgen/models/tokenizer/hf_tokenizer.py
CHANGED
|
@@ -1,4 +1,3 @@
|
|
| 1 |
-
from dataclasses import dataclass
|
| 2 |
from typing import List
|
| 3 |
|
| 4 |
from transformers import AutoTokenizer
|
|
@@ -6,9 +5,9 @@ from transformers import AutoTokenizer
|
|
| 6 |
from graphgen.bases import BaseTokenizer
|
| 7 |
|
| 8 |
|
| 9 |
-
@dataclass
|
| 10 |
class HFTokenizer(BaseTokenizer):
|
| 11 |
-
def
|
|
|
|
| 12 |
self.enc = AutoTokenizer.from_pretrained(self.model_name)
|
| 13 |
|
| 14 |
def encode(self, text: str) -> List[int]:
|
|
|
|
|
|
|
| 1 |
from typing import List
|
| 2 |
|
| 3 |
from transformers import AutoTokenizer
|
|
|
|
| 5 |
from graphgen.bases import BaseTokenizer
|
| 6 |
|
| 7 |
|
|
|
|
| 8 |
class HFTokenizer(BaseTokenizer):
|
| 9 |
+
def __init__(self, model_name: str = "cl100k_base"):
|
| 10 |
+
super().__init__(model_name)
|
| 11 |
self.enc = AutoTokenizer.from_pretrained(self.model_name)
|
| 12 |
|
| 13 |
def encode(self, text: str) -> List[int]:
|
graphgen/models/tokenizer/tiktoken_tokenizer.py
CHANGED
|
@@ -1,4 +1,3 @@
|
|
| 1 |
-
from dataclasses import dataclass
|
| 2 |
from typing import List
|
| 3 |
|
| 4 |
import tiktoken
|
|
@@ -6,9 +5,9 @@ import tiktoken
|
|
| 6 |
from graphgen.bases import BaseTokenizer
|
| 7 |
|
| 8 |
|
| 9 |
-
@dataclass
|
| 10 |
class TiktokenTokenizer(BaseTokenizer):
|
| 11 |
-
def
|
|
|
|
| 12 |
self.enc = tiktoken.get_encoding(self.model_name)
|
| 13 |
|
| 14 |
def encode(self, text: str) -> List[int]:
|
|
|
|
|
|
|
| 1 |
from typing import List
|
| 2 |
|
| 3 |
import tiktoken
|
|
|
|
| 5 |
from graphgen.bases import BaseTokenizer
|
| 6 |
|
| 7 |
|
|
|
|
| 8 |
class TiktokenTokenizer(BaseTokenizer):
|
| 9 |
+
def __init__(self, model_name: str = "cl100k_base"):
|
| 10 |
+
super().__init__(model_name)
|
| 11 |
self.enc = tiktoken.get_encoding(self.model_name)
|
| 12 |
|
| 13 |
def encode(self, text: str) -> List[int]:
|