github-actions[bot] commited on
Commit
8c66169
·
1 Parent(s): 0b9d8c7

Auto-sync from demo at Thu Oct 23 12:37:24 UTC 2025

Browse files
graphgen/bases/base_generator.py CHANGED
@@ -1,17 +1,16 @@
1
  from abc import ABC, abstractmethod
2
- from dataclasses import dataclass
3
  from typing import Any
4
 
5
  from graphgen.bases.base_llm_client import BaseLLMClient
6
 
7
 
8
- @dataclass
9
  class BaseGenerator(ABC):
10
  """
11
  Generate QAs based on given prompts.
12
  """
13
 
14
- llm_client: BaseLLMClient
 
15
 
16
  @staticmethod
17
  @abstractmethod
 
1
  from abc import ABC, abstractmethod
 
2
  from typing import Any
3
 
4
  from graphgen.bases.base_llm_client import BaseLLMClient
5
 
6
 
 
7
  class BaseGenerator(ABC):
8
  """
9
  Generate QAs based on given prompts.
10
  """
11
 
12
+ def __init__(self, llm_client: BaseLLMClient):
13
+ self.llm_client = llm_client
14
 
15
  @staticmethod
16
  @abstractmethod
graphgen/bases/base_kg_builder.py CHANGED
@@ -1,6 +1,5 @@
1
  from abc import ABC, abstractmethod
2
  from collections import defaultdict
3
- from dataclasses import dataclass, field
4
  from typing import Dict, List, Tuple
5
 
6
  from graphgen.bases.base_llm_client import BaseLLMClient
@@ -8,14 +7,11 @@ from graphgen.bases.base_storage import BaseGraphStorage
8
  from graphgen.bases.datatypes import Chunk
9
 
10
 
11
- @dataclass
12
  class BaseKGBuilder(ABC):
13
- llm_client: BaseLLMClient
14
-
15
- _nodes: Dict[str, List[dict]] = field(default_factory=lambda: defaultdict(list))
16
- _edges: Dict[Tuple[str, str], List[dict]] = field(
17
- default_factory=lambda: defaultdict(list)
18
- )
19
 
20
  @abstractmethod
21
  async def extract(
 
1
  from abc import ABC, abstractmethod
2
  from collections import defaultdict
 
3
  from typing import Dict, List, Tuple
4
 
5
  from graphgen.bases.base_llm_client import BaseLLMClient
 
7
  from graphgen.bases.datatypes import Chunk
8
 
9
 
 
10
  class BaseKGBuilder(ABC):
11
+ def __init__(self, llm_client: BaseLLMClient):
12
+ self.llm_client = llm_client
13
+ self._nodes: Dict[str, List[dict]] = defaultdict(list)
14
+ self._edges: Dict[Tuple[str, str], List[dict]] = defaultdict(list)
 
 
15
 
16
  @abstractmethod
17
  async def extract(
graphgen/bases/base_partitioner.py CHANGED
@@ -1,12 +1,10 @@
1
  from abc import ABC, abstractmethod
2
- from dataclasses import dataclass
3
  from typing import Any, List
4
 
5
  from graphgen.bases.base_storage import BaseGraphStorage
6
  from graphgen.bases.datatypes import Community
7
 
8
 
9
- @dataclass
10
  class BasePartitioner(ABC):
11
  @abstractmethod
12
  async def partition(
 
1
  from abc import ABC, abstractmethod
 
2
  from typing import Any, List
3
 
4
  from graphgen.bases.base_storage import BaseGraphStorage
5
  from graphgen.bases.datatypes import Community
6
 
7
 
 
8
  class BasePartitioner(ABC):
9
  @abstractmethod
10
  async def partition(
graphgen/bases/base_splitter.py CHANGED
@@ -1,25 +1,32 @@
1
  import copy
2
  import re
3
  from abc import ABC, abstractmethod
4
- from dataclasses import dataclass
5
  from typing import Callable, Iterable, List, Literal, Optional, Union
6
 
7
  from graphgen.bases.datatypes import Chunk
8
  from graphgen.utils import logger
9
 
10
 
11
- @dataclass
12
  class BaseSplitter(ABC):
13
  """
14
  Abstract base class for splitting text into smaller chunks.
15
  """
16
 
17
- chunk_size: int = 1024
18
- chunk_overlap: int = 100
19
- length_function: Callable[[str], int] = len
20
- keep_separator: bool = False
21
- add_start_index: bool = False
22
- strip_whitespace: bool = True
 
 
 
 
 
 
 
 
 
23
 
24
  @abstractmethod
25
  def split_text(self, text: str) -> List[str]:
 
1
  import copy
2
  import re
3
  from abc import ABC, abstractmethod
 
4
  from typing import Callable, Iterable, List, Literal, Optional, Union
5
 
6
  from graphgen.bases.datatypes import Chunk
7
  from graphgen.utils import logger
8
 
9
 
 
10
  class BaseSplitter(ABC):
11
  """
12
  Abstract base class for splitting text into smaller chunks.
13
  """
14
 
15
+ def __init__(
16
+ self,
17
+ chunk_size: int = 1024,
18
+ chunk_overlap: int = 100,
19
+ length_function: Callable[[str], int] = len,
20
+ keep_separator: bool = False,
21
+ add_start_index: bool = False,
22
+ strip_whitespace: bool = True,
23
+ ):
24
+ self.chunk_size = chunk_size
25
+ self.chunk_overlap = chunk_overlap
26
+ self.length_function = length_function
27
+ self.keep_separator = keep_separator
28
+ self.add_start_index = add_start_index
29
+ self.strip_whitespace = strip_whitespace
30
 
31
  @abstractmethod
32
  def split_text(self, text: str) -> List[str]:
graphgen/bases/base_storage.py CHANGED
@@ -16,7 +16,6 @@ class StorageNameSpace:
16
  """commit the storage operations after querying"""
17
 
18
 
19
- @dataclass
20
  class BaseListStorage(Generic[T], StorageNameSpace):
21
  async def all_items(self) -> list[T]:
22
  raise NotImplementedError
@@ -34,7 +33,6 @@ class BaseListStorage(Generic[T], StorageNameSpace):
34
  raise NotImplementedError
35
 
36
 
37
- @dataclass
38
  class BaseKVStorage(Generic[T], StorageNameSpace):
39
  async def all_keys(self) -> list[str]:
40
  raise NotImplementedError
@@ -58,7 +56,6 @@ class BaseKVStorage(Generic[T], StorageNameSpace):
58
  raise NotImplementedError
59
 
60
 
61
- @dataclass
62
  class BaseGraphStorage(StorageNameSpace):
63
  async def has_node(self, node_id: str) -> bool:
64
  raise NotImplementedError
 
16
  """commit the storage operations after querying"""
17
 
18
 
 
19
  class BaseListStorage(Generic[T], StorageNameSpace):
20
  async def all_items(self) -> list[T]:
21
  raise NotImplementedError
 
33
  raise NotImplementedError
34
 
35
 
 
36
  class BaseKVStorage(Generic[T], StorageNameSpace):
37
  async def all_keys(self) -> list[str]:
38
  raise NotImplementedError
 
56
  raise NotImplementedError
57
 
58
 
 
59
  class BaseGraphStorage(StorageNameSpace):
60
  async def has_node(self, node_id: str) -> bool:
61
  raise NotImplementedError
graphgen/bases/base_tokenizer.py CHANGED
@@ -1,13 +1,12 @@
1
  from __future__ import annotations
2
 
3
  from abc import ABC, abstractmethod
4
- from dataclasses import dataclass
5
  from typing import List
6
 
7
 
8
- @dataclass
9
  class BaseTokenizer(ABC):
10
- model_name: str = "cl100k_base"
 
11
 
12
  @abstractmethod
13
  def encode(self, text: str) -> List[int]:
 
1
  from __future__ import annotations
2
 
3
  from abc import ABC, abstractmethod
 
4
  from typing import List
5
 
6
 
 
7
  class BaseTokenizer(ABC):
8
+ def __init__(self, model_name: str = "cl100k_base"):
9
+ self.model_name = model_name
10
 
11
  @abstractmethod
12
  def encode(self, text: str) -> List[int]:
graphgen/models/evaluator/base_evaluator.py CHANGED
@@ -1,5 +1,4 @@
1
  import asyncio
2
- from dataclasses import dataclass
3
 
4
  from tqdm.asyncio import tqdm as tqdm_async
5
 
@@ -7,10 +6,10 @@ from graphgen.bases.datatypes import QAPair
7
  from graphgen.utils import create_event_loop
8
 
9
 
10
- @dataclass
11
  class BaseEvaluator:
12
- max_concurrent: int = 100
13
- results: list[float] = None
 
14
 
15
  def evaluate(self, pairs: list[QAPair]) -> list[float]:
16
  """
 
1
  import asyncio
 
2
 
3
  from tqdm.asyncio import tqdm as tqdm_async
4
 
 
6
  from graphgen.utils import create_event_loop
7
 
8
 
 
9
  class BaseEvaluator:
10
+ def __init__(self, max_concurrent: int = 100):
11
+ self.max_concurrent = max_concurrent
12
+ self.results: list[float] = None
13
 
14
  def evaluate(self, pairs: list[QAPair]) -> list[float]:
15
  """
graphgen/models/evaluator/length_evaluator.py CHANGED
@@ -1,16 +1,13 @@
1
- from dataclasses import dataclass
2
-
3
  from graphgen.bases.datatypes import QAPair
4
  from graphgen.models.evaluator.base_evaluator import BaseEvaluator
5
  from graphgen.models.tokenizer import Tokenizer
6
  from graphgen.utils import create_event_loop
7
 
8
 
9
- @dataclass
10
  class LengthEvaluator(BaseEvaluator):
11
- tokenizer_name: str = "cl100k_base"
12
-
13
- def __post_init__(self):
14
  self.tokenizer = Tokenizer(model_name=self.tokenizer_name)
15
 
16
  async def evaluate_single(self, pair: QAPair) -> float:
 
 
 
1
  from graphgen.bases.datatypes import QAPair
2
  from graphgen.models.evaluator.base_evaluator import BaseEvaluator
3
  from graphgen.models.tokenizer import Tokenizer
4
  from graphgen.utils import create_event_loop
5
 
6
 
 
7
  class LengthEvaluator(BaseEvaluator):
8
+ def __init__(self, tokenizer_name: str = "cl100k_base", max_concurrent: int = 100):
9
+ super().__init__(max_concurrent)
10
+ self.tokenizer_name = tokenizer_name
11
  self.tokenizer = Tokenizer(model_name=self.tokenizer_name)
12
 
13
  async def evaluate_single(self, pair: QAPair) -> float:
graphgen/models/evaluator/mtld_evaluator.py CHANGED
@@ -1,4 +1,3 @@
1
- from dataclasses import dataclass, field
2
  from typing import Set
3
 
4
  from graphgen.bases.datatypes import QAPair
@@ -8,18 +7,15 @@ from graphgen.utils import NLTKHelper, create_event_loop, detect_main_language
8
  nltk_helper = NLTKHelper()
9
 
10
 
11
- @dataclass
12
  class MTLDEvaluator(BaseEvaluator):
13
  """
14
  衡量文本词汇多样性的指标
15
  """
16
 
17
- stopwords_en: Set[str] = field(
18
- default_factory=lambda: set(nltk_helper.get_stopwords("english"))
19
- )
20
- stopwords_zh: Set[str] = field(
21
- default_factory=lambda: set(nltk_helper.get_stopwords("chinese"))
22
- )
23
 
24
  async def evaluate_single(self, pair: QAPair) -> float:
25
  loop = create_event_loop()
 
 
1
  from typing import Set
2
 
3
  from graphgen.bases.datatypes import QAPair
 
7
  nltk_helper = NLTKHelper()
8
 
9
 
 
10
  class MTLDEvaluator(BaseEvaluator):
11
  """
12
  衡量文本词汇多样性的指标
13
  """
14
 
15
+ def __init__(self, max_concurrent: int = 100):
16
+ super().__init__(max_concurrent)
17
+ self.stopwords_en: Set[str] = set(nltk_helper.get_stopwords("english"))
18
+ self.stopwords_zh: Set[str] = set(nltk_helper.get_stopwords("chinese"))
 
 
19
 
20
  async def evaluate_single(self, pair: QAPair) -> float:
21
  loop = create_event_loop()
graphgen/models/generator/aggregated_generator.py CHANGED
@@ -1,4 +1,3 @@
1
- from dataclasses import dataclass
2
  from typing import Any
3
 
4
  from graphgen.bases import BaseGenerator
@@ -6,7 +5,6 @@ from graphgen.templates import AGGREGATED_GENERATION_PROMPT
6
  from graphgen.utils import compute_content_hash, detect_main_language, logger
7
 
8
 
9
- @dataclass
10
  class AggregatedGenerator(BaseGenerator):
11
  """
12
  Aggregated Generator follows a TWO-STEP process:
 
 
1
  from typing import Any
2
 
3
  from graphgen.bases import BaseGenerator
 
5
  from graphgen.utils import compute_content_hash, detect_main_language, logger
6
 
7
 
 
8
  class AggregatedGenerator(BaseGenerator):
9
  """
10
  Aggregated Generator follows a TWO-STEP process:
graphgen/models/generator/atomic_generator.py CHANGED
@@ -1,4 +1,3 @@
1
- from dataclasses import dataclass
2
  from typing import Any
3
 
4
  from graphgen.bases import BaseGenerator
@@ -6,7 +5,6 @@ from graphgen.templates import ATOMIC_GENERATION_PROMPT
6
  from graphgen.utils import compute_content_hash, detect_main_language, logger
7
 
8
 
9
- @dataclass
10
  class AtomicGenerator(BaseGenerator):
11
  @staticmethod
12
  def build_prompt(
 
 
1
  from typing import Any
2
 
3
  from graphgen.bases import BaseGenerator
 
5
  from graphgen.utils import compute_content_hash, detect_main_language, logger
6
 
7
 
 
8
  class AtomicGenerator(BaseGenerator):
9
  @staticmethod
10
  def build_prompt(
graphgen/models/generator/cot_generator.py CHANGED
@@ -1,4 +1,3 @@
1
- from dataclasses import dataclass
2
  from typing import Any
3
 
4
  from graphgen.bases import BaseGenerator
@@ -6,7 +5,6 @@ from graphgen.templates import COT_GENERATION_PROMPT
6
  from graphgen.utils import compute_content_hash, detect_main_language, logger
7
 
8
 
9
- @dataclass
10
  class CoTGenerator(BaseGenerator):
11
  @staticmethod
12
  def build_prompt(
 
 
1
  from typing import Any
2
 
3
  from graphgen.bases import BaseGenerator
 
5
  from graphgen.utils import compute_content_hash, detect_main_language, logger
6
 
7
 
 
8
  class CoTGenerator(BaseGenerator):
9
  @staticmethod
10
  def build_prompt(
graphgen/models/generator/multi_hop_generator.py CHANGED
@@ -1,4 +1,3 @@
1
- from dataclasses import dataclass
2
  from typing import Any
3
 
4
  from graphgen.bases import BaseGenerator
@@ -6,7 +5,6 @@ from graphgen.templates import MULTI_HOP_GENERATION_PROMPT
6
  from graphgen.utils import compute_content_hash, detect_main_language, logger
7
 
8
 
9
- @dataclass
10
  class MultiHopGenerator(BaseGenerator):
11
  @staticmethod
12
  def build_prompt(
 
 
1
  from typing import Any
2
 
3
  from graphgen.bases import BaseGenerator
 
5
  from graphgen.utils import compute_content_hash, detect_main_language, logger
6
 
7
 
 
8
  class MultiHopGenerator(BaseGenerator):
9
  @staticmethod
10
  def build_prompt(
graphgen/models/generator/vqa_generator.py CHANGED
@@ -1,4 +1,3 @@
1
- from dataclasses import dataclass
2
  from typing import Any
3
 
4
  from graphgen.bases import BaseGenerator
@@ -6,7 +5,6 @@ from graphgen.templates import VQA_GENERATION_PROMPT
6
  from graphgen.utils import compute_content_hash, detect_main_language, logger
7
 
8
 
9
- @dataclass
10
  class VQAGenerator(BaseGenerator):
11
  @staticmethod
12
  def build_prompt(
 
 
1
  from typing import Any
2
 
3
  from graphgen.bases import BaseGenerator
 
5
  from graphgen.utils import compute_content_hash, detect_main_language, logger
6
 
7
 
 
8
  class VQAGenerator(BaseGenerator):
9
  @staticmethod
10
  def build_prompt(
graphgen/models/kg_builder/light_rag_kg_builder.py CHANGED
@@ -1,6 +1,5 @@
1
  import re
2
  from collections import Counter, defaultdict
3
- from dataclasses import dataclass
4
  from typing import Dict, List, Tuple
5
 
6
  from graphgen.bases import BaseGraphStorage, BaseKGBuilder, BaseLLMClient, Chunk
@@ -15,10 +14,10 @@ from graphgen.utils import (
15
  )
16
 
17
 
18
- @dataclass
19
  class LightRAGKGBuilder(BaseKGBuilder):
20
- llm_client: BaseLLMClient = None
21
- max_loop: int = 3
 
22
 
23
  async def extract(
24
  self, chunk: Chunk
 
1
  import re
2
  from collections import Counter, defaultdict
 
3
  from typing import Dict, List, Tuple
4
 
5
  from graphgen.bases import BaseGraphStorage, BaseKGBuilder, BaseLLMClient, Chunk
 
14
  )
15
 
16
 
 
17
  class LightRAGKGBuilder(BaseKGBuilder):
18
+ def __init__(self, llm_client: BaseLLMClient, max_loop: int = 3):
19
+ super().__init__(llm_client)
20
+ self.max_loop = max_loop
21
 
22
  async def extract(
23
  self, chunk: Chunk
graphgen/models/kg_builder/mm_kg_builder.py CHANGED
@@ -2,7 +2,7 @@ import re
2
  from collections import defaultdict
3
  from typing import Dict, List, Tuple
4
 
5
- from graphgen.bases import BaseLLMClient, Chunk
6
  from graphgen.templates import MMKG_EXTRACTION_PROMPT
7
  from graphgen.utils import (
8
  detect_main_language,
@@ -16,8 +16,6 @@ from .light_rag_kg_builder import LightRAGKGBuilder
16
 
17
 
18
  class MMKGBuilder(LightRAGKGBuilder):
19
- llm_client: BaseLLMClient = None
20
-
21
  async def extract(
22
  self, chunk: Chunk
23
  ) -> Tuple[Dict[str, List[dict]], Dict[Tuple[str, str], List[dict]]]:
 
2
  from collections import defaultdict
3
  from typing import Dict, List, Tuple
4
 
5
+ from graphgen.bases import Chunk
6
  from graphgen.templates import MMKG_EXTRACTION_PROMPT
7
  from graphgen.utils import (
8
  detect_main_language,
 
16
 
17
 
18
  class MMKGBuilder(LightRAGKGBuilder):
 
 
19
  async def extract(
20
  self, chunk: Chunk
21
  ) -> Tuple[Dict[str, List[dict]], Dict[Tuple[str, str], List[dict]]]:
graphgen/models/llm/topk_token_model.py CHANGED
@@ -1,21 +1,31 @@
1
- from dataclasses import dataclass
2
  from typing import List, Optional
3
 
4
  from graphgen.bases import Token
5
 
6
 
7
- @dataclass
8
- class TopkTokenModel:
9
- do_sample: bool = False
10
- temperature: float = 0
11
- max_tokens: int = 4096
12
- repetition_penalty: float = 1.05
13
- num_beams: int = 1
14
- topk: int = 50
15
- topp: float = 0.95
16
-
17
- topk_per_token: int = 5 # number of topk tokens to generate for each token
 
 
 
 
 
 
 
 
 
18
 
 
19
  async def generate_topk_per_token(self, text: str) -> List[Token]:
20
  """
21
  Generate prob, text and candidates for each token of the model's output.
@@ -23,6 +33,7 @@ class TopkTokenModel:
23
  """
24
  raise NotImplementedError
25
 
 
26
  async def generate_inputs_prob(
27
  self, text: str, history: Optional[List[str]] = None
28
  ) -> List[Token]:
@@ -32,6 +43,7 @@ class TopkTokenModel:
32
  """
33
  raise NotImplementedError
34
 
 
35
  async def generate_answer(
36
  self, text: str, history: Optional[List[str]] = None
37
  ) -> str:
 
1
+ from abc import ABC, abstractmethod
2
  from typing import List, Optional
3
 
4
  from graphgen.bases import Token
5
 
6
 
7
+ class TopkTokenModel(ABC):
8
+ def __init__(
9
+ self,
10
+ do_sample: bool = False,
11
+ temperature: float = 0,
12
+ max_tokens: int = 4096,
13
+ repetition_penalty: float = 1.05,
14
+ num_beams: int = 1,
15
+ topk: int = 50,
16
+ topp: float = 0.95,
17
+ topk_per_token: int = 5,
18
+ ):
19
+ self.do_sample = do_sample
20
+ self.temperature = temperature
21
+ self.max_tokens = max_tokens
22
+ self.repetition_penalty = repetition_penalty
23
+ self.num_beams = num_beams
24
+ self.topk = topk
25
+ self.topp = topp
26
+ self.topk_per_token = topk_per_token
27
 
28
+ @abstractmethod
29
  async def generate_topk_per_token(self, text: str) -> List[Token]:
30
  """
31
  Generate prob, text and candidates for each token of the model's output.
 
33
  """
34
  raise NotImplementedError
35
 
36
+ @abstractmethod
37
  async def generate_inputs_prob(
38
  self, text: str, history: Optional[List[str]] = None
39
  ) -> List[Token]:
 
43
  """
44
  raise NotImplementedError
45
 
46
+ @abstractmethod
47
  async def generate_answer(
48
  self, text: str, history: Optional[List[str]] = None
49
  ) -> str:
graphgen/models/partitioner/bfs_partitioner.py CHANGED
@@ -1,6 +1,5 @@
1
  import random
2
  from collections import deque
3
- from dataclasses import dataclass
4
  from typing import Any, List
5
 
6
  from graphgen.bases import BaseGraphStorage, BasePartitioner
@@ -10,7 +9,6 @@ NODE_UNIT: str = "n"
10
  EDGE_UNIT: str = "e"
11
 
12
 
13
- @dataclass
14
  class BFSPartitioner(BasePartitioner):
15
  """
16
  BFS partitioner that partitions the graph into communities of a fixed size.
 
1
  import random
2
  from collections import deque
 
3
  from typing import Any, List
4
 
5
  from graphgen.bases import BaseGraphStorage, BasePartitioner
 
9
  EDGE_UNIT: str = "e"
10
 
11
 
 
12
  class BFSPartitioner(BasePartitioner):
13
  """
14
  BFS partitioner that partitions the graph into communities of a fixed size.
graphgen/models/partitioner/dfs_partitioner.py CHANGED
@@ -1,5 +1,4 @@
1
  import random
2
- from dataclasses import dataclass
3
  from typing import Any, List
4
 
5
  from graphgen.bases import BaseGraphStorage, BasePartitioner
@@ -9,7 +8,6 @@ NODE_UNIT: str = "n"
9
  EDGE_UNIT: str = "e"
10
 
11
 
12
- @dataclass
13
  class DFSPartitioner(BasePartitioner):
14
  """
15
  DFS partitioner that partitions the graph into communities of a fixed size.
 
1
  import random
 
2
  from typing import Any, List
3
 
4
  from graphgen.bases import BaseGraphStorage, BasePartitioner
 
8
  EDGE_UNIT: str = "e"
9
 
10
 
 
11
  class DFSPartitioner(BasePartitioner):
12
  """
13
  DFS partitioner that partitions the graph into communities of a fixed size.
graphgen/models/partitioner/ece_partitioner.py CHANGED
@@ -1,6 +1,5 @@
1
  import asyncio
2
  import random
3
- from dataclasses import dataclass
4
  from typing import Any, Dict, List, Optional, Set, Tuple
5
 
6
  from tqdm.asyncio import tqdm as tqdm_async
@@ -13,7 +12,6 @@ NODE_UNIT: str = "n"
13
  EDGE_UNIT: str = "e"
14
 
15
 
16
- @dataclass
17
  class ECEPartitioner(BFSPartitioner):
18
  """
19
  ECE partitioner that partitions the graph into communities based on Expected Calibration Error (ECE).
 
1
  import asyncio
2
  import random
 
3
  from typing import Any, Dict, List, Optional, Set, Tuple
4
 
5
  from tqdm.asyncio import tqdm as tqdm_async
 
12
  EDGE_UNIT: str = "e"
13
 
14
 
 
15
  class ECEPartitioner(BFSPartitioner):
16
  """
17
  ECE partitioner that partitions the graph into communities based on Expected Calibration Error (ECE).
graphgen/models/partitioner/leiden_partitioner.py CHANGED
@@ -1,5 +1,4 @@
1
  from collections import defaultdict
2
- from dataclasses import dataclass
3
  from typing import Any, Dict, List, Set, Tuple
4
 
5
  import igraph as ig
@@ -9,7 +8,6 @@ from graphgen.bases import BaseGraphStorage, BasePartitioner
9
  from graphgen.bases.datatypes import Community
10
 
11
 
12
- @dataclass
13
  class LeidenPartitioner(BasePartitioner):
14
  """
15
  Leiden partitioner that partitions the graph into communities using the Leiden algorithm.
 
1
  from collections import defaultdict
 
2
  from typing import Any, Dict, List, Set, Tuple
3
 
4
  import igraph as ig
 
8
  from graphgen.bases.datatypes import Community
9
 
10
 
 
11
  class LeidenPartitioner(BasePartitioner):
12
  """
13
  Leiden partitioner that partitions the graph into communities using the Leiden algorithm.
graphgen/models/search/db/uniprot_search.py CHANGED
@@ -1,5 +1,3 @@
1
- from dataclasses import dataclass
2
-
3
  import requests
4
  from fastapi import HTTPException
5
 
@@ -8,7 +6,6 @@ from graphgen.utils import logger
8
  UNIPROT_BASE = "https://rest.uniprot.org/uniprotkb/search"
9
 
10
 
11
- @dataclass
12
  class UniProtSearch:
13
  """
14
  UniProt Search client to search with UniProt.
 
 
 
1
  import requests
2
  from fastapi import HTTPException
3
 
 
6
  UNIPROT_BASE = "https://rest.uniprot.org/uniprotkb/search"
7
 
8
 
 
9
  class UniProtSearch:
10
  """
11
  UniProt Search client to search with UniProt.
graphgen/models/search/kg/wiki_search.py CHANGED
@@ -1,4 +1,3 @@
1
- from dataclasses import dataclass
2
  from typing import List, Union
3
 
4
  import wikipedia
@@ -7,7 +6,6 @@ from wikipedia import set_lang
7
  from graphgen.utils import detect_main_language, logger
8
 
9
 
10
- @dataclass
11
  class WikiSearch:
12
  @staticmethod
13
  def set_language(language: str):
 
 
1
  from typing import List, Union
2
 
3
  import wikipedia
 
6
  from graphgen.utils import detect_main_language, logger
7
 
8
 
 
9
  class WikiSearch:
10
  @staticmethod
11
  def set_language(language: str):
graphgen/models/search/web/bing_search.py CHANGED
@@ -1,5 +1,3 @@
1
- from dataclasses import dataclass
2
-
3
  import requests
4
  from fastapi import HTTPException
5
 
@@ -9,13 +7,13 @@ BING_SEARCH_V7_ENDPOINT = "https://api.bing.microsoft.com/v7.0/search"
9
  BING_MKT = "en-US"
10
 
11
 
12
- @dataclass
13
  class BingSearch:
14
  """
15
  Bing Search client to search with Bing.
16
  """
17
 
18
- subscription_key: str
 
19
 
20
  def search(self, query: str, num_results: int = 1):
21
  """
 
 
 
1
  import requests
2
  from fastapi import HTTPException
3
 
 
7
  BING_MKT = "en-US"
8
 
9
 
 
10
  class BingSearch:
11
  """
12
  Bing Search client to search with Bing.
13
  """
14
 
15
+ def __init__(self, subscription_key: str):
16
+ self.subscription_key = subscription_key
17
 
18
  def search(self, query: str, num_results: int = 1):
19
  """
graphgen/models/search/web/google_search.py CHANGED
@@ -1,5 +1,3 @@
1
- from dataclasses import dataclass
2
-
3
  import requests
4
  from fastapi import HTTPException
5
 
@@ -8,7 +6,6 @@ from graphgen.utils import logger
8
  GOOGLE_SEARCH_ENDPOINT = "https://customsearch.googleapis.com/customsearch/v1"
9
 
10
 
11
- @dataclass
12
  class GoogleSearch:
13
  def __init__(self, subscription_key: str, cx: str):
14
  """
 
 
 
1
  import requests
2
  from fastapi import HTTPException
3
 
 
6
  GOOGLE_SEARCH_ENDPOINT = "https://customsearch.googleapis.com/customsearch/v1"
7
 
8
 
 
9
  class GoogleSearch:
10
  def __init__(self, subscription_key: str, cx: str):
11
  """
graphgen/models/storage/json_storage.py CHANGED
@@ -53,6 +53,8 @@ class JsonKVStorage(BaseKVStorage):
53
 
54
  @dataclass
55
  class JsonListStorage(BaseListStorage):
 
 
56
  _data: list = None
57
 
58
  def __post_init__(self):
 
53
 
54
  @dataclass
55
  class JsonListStorage(BaseListStorage):
56
+ working_dir: str = None
57
+ namespace: str = None
58
  _data: list = None
59
 
60
  def __post_init__(self):
graphgen/models/tokenizer/__init__.py CHANGED
@@ -1,4 +1,3 @@
1
- from dataclasses import dataclass, field
2
  from typing import List
3
 
4
  from graphgen.bases import BaseTokenizer
@@ -30,16 +29,13 @@ def get_tokenizer_impl(tokenizer_name: str = "cl100k_base") -> BaseTokenizer:
30
  )
31
 
32
 
33
- @dataclass
34
  class Tokenizer(BaseTokenizer):
35
  """
36
  Encapsulates different tokenization implementations based on the specified model name.
37
  """
38
 
39
- model_name: str = "cl100k_base"
40
- _impl: BaseTokenizer = field(init=False, repr=False)
41
-
42
- def __post_init__(self):
43
  if not self.model_name:
44
  raise ValueError("TOKENIZER_MODEL must be specified in the ENV variables.")
45
  self._impl = get_tokenizer_impl(self.model_name)
 
 
1
  from typing import List
2
 
3
  from graphgen.bases import BaseTokenizer
 
29
  )
30
 
31
 
 
32
  class Tokenizer(BaseTokenizer):
33
  """
34
  Encapsulates different tokenization implementations based on the specified model name.
35
  """
36
 
37
+ def __init__(self, model_name: str = "cl100k_base"):
38
+ super().__init__(model_name)
 
 
39
  if not self.model_name:
40
  raise ValueError("TOKENIZER_MODEL must be specified in the ENV variables.")
41
  self._impl = get_tokenizer_impl(self.model_name)
graphgen/models/tokenizer/hf_tokenizer.py CHANGED
@@ -1,4 +1,3 @@
1
- from dataclasses import dataclass
2
  from typing import List
3
 
4
  from transformers import AutoTokenizer
@@ -6,9 +5,9 @@ from transformers import AutoTokenizer
6
  from graphgen.bases import BaseTokenizer
7
 
8
 
9
- @dataclass
10
  class HFTokenizer(BaseTokenizer):
11
- def __post_init__(self):
 
12
  self.enc = AutoTokenizer.from_pretrained(self.model_name)
13
 
14
  def encode(self, text: str) -> List[int]:
 
 
1
  from typing import List
2
 
3
  from transformers import AutoTokenizer
 
5
  from graphgen.bases import BaseTokenizer
6
 
7
 
 
8
  class HFTokenizer(BaseTokenizer):
9
+ def __init__(self, model_name: str = "cl100k_base"):
10
+ super().__init__(model_name)
11
  self.enc = AutoTokenizer.from_pretrained(self.model_name)
12
 
13
  def encode(self, text: str) -> List[int]:
graphgen/models/tokenizer/tiktoken_tokenizer.py CHANGED
@@ -1,4 +1,3 @@
1
- from dataclasses import dataclass
2
  from typing import List
3
 
4
  import tiktoken
@@ -6,9 +5,9 @@ import tiktoken
6
  from graphgen.bases import BaseTokenizer
7
 
8
 
9
- @dataclass
10
  class TiktokenTokenizer(BaseTokenizer):
11
- def __post_init__(self):
 
12
  self.enc = tiktoken.get_encoding(self.model_name)
13
 
14
  def encode(self, text: str) -> List[int]:
 
 
1
  from typing import List
2
 
3
  import tiktoken
 
5
  from graphgen.bases import BaseTokenizer
6
 
7
 
 
8
  class TiktokenTokenizer(BaseTokenizer):
9
+ def __init__(self, model_name: str = "cl100k_base"):
10
+ super().__init__(model_name)
11
  self.enc = tiktoken.get_encoding(self.model_name)
12
 
13
  def encode(self, text: str) -> List[int]: