Spaces:
Running
Running
github-actions[bot]
commited on
Commit
·
817f16e
1
Parent(s):
3a3b216
Auto-sync from demo at Tue Sep 30 07:59:12 UTC 2025
Browse files- app.py +40 -31
- graphgen/configs/aggregated_config.yaml +15 -13
- graphgen/configs/atomic_config.yaml +15 -13
- graphgen/configs/cot_config.yaml +11 -8
- graphgen/configs/multi_hop_config.yaml +15 -13
- graphgen/generate.py +26 -25
- graphgen/graphgen.py +56 -76
- graphgen/models/__init__.py +0 -1
- graphgen/models/strategy/__init__.py +0 -0
- graphgen/models/strategy/travserse_strategy.py +0 -28
- graphgen/models/tokenizer/__init__.py +2 -0
- graphgen/operators/build_kg/split_kg.py +16 -15
- graphgen/operators/traverse_graph.py +8 -14
- webui/app.py +40 -31
app.py
CHANGED
|
@@ -39,27 +39,32 @@ def init_graph_gen(config: dict, env: dict) -> GraphGen:
|
|
| 39 |
set_logger(log_file, if_stream=True)
|
| 40 |
os.environ.update({k: str(v) for k, v in env.items()})
|
| 41 |
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
graph_gen.synthesizer_llm_client = OpenAIClient(
|
| 45 |
model_name=env.get("SYNTHESIZER_MODEL", ""),
|
| 46 |
base_url=env.get("SYNTHESIZER_BASE_URL", ""),
|
| 47 |
api_key=env.get("SYNTHESIZER_API_KEY", ""),
|
| 48 |
request_limit=True,
|
| 49 |
rpm=RPM(env.get("RPM", 1000)),
|
| 50 |
tpm=TPM(env.get("TPM", 50000)),
|
|
|
|
| 51 |
)
|
| 52 |
-
|
| 53 |
-
graph_gen.trainee_llm_client = OpenAIClient(
|
| 54 |
model_name=env.get("TRAINEE_MODEL", ""),
|
| 55 |
base_url=env.get("TRAINEE_BASE_URL", ""),
|
| 56 |
api_key=env.get("TRAINEE_API_KEY", ""),
|
| 57 |
request_limit=True,
|
| 58 |
rpm=RPM(env.get("RPM", 1000)),
|
| 59 |
tpm=TPM(env.get("TPM", 50000)),
|
|
|
|
| 60 |
)
|
| 61 |
|
| 62 |
-
graph_gen
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 63 |
|
| 64 |
return graph_gen
|
| 65 |
|
|
@@ -78,27 +83,32 @@ def run_graphgen(params: WebuiParams, progress=gr.Progress()):
|
|
| 78 |
"chunk_size": params.chunk_size,
|
| 79 |
"chunk_overlap": params.chunk_overlap,
|
| 80 |
},
|
| 81 |
-
"output_data_type": params.output_data_type,
|
| 82 |
-
"output_data_format": params.output_data_format,
|
| 83 |
-
"tokenizer": params.tokenizer,
|
| 84 |
"search": {"enabled": False},
|
| 85 |
-
"
|
| 86 |
"enabled": params.if_trainee_model,
|
| 87 |
"quiz_samples": params.quiz_samples,
|
| 88 |
},
|
| 89 |
-
"
|
| 90 |
-
"
|
| 91 |
-
"
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 98 |
},
|
| 99 |
}
|
| 100 |
|
| 101 |
env = {
|
|
|
|
| 102 |
"SYNTHESIZER_BASE_URL": params.synthesizer_url,
|
| 103 |
"SYNTHESIZER_MODEL": params.synthesizer_model,
|
| 104 |
"TRAINEE_BASE_URL": params.trainee_url,
|
|
@@ -128,19 +138,18 @@ def run_graphgen(params: WebuiParams, progress=gr.Progress()):
|
|
| 128 |
|
| 129 |
try:
|
| 130 |
# Process the data
|
| 131 |
-
graph_gen.insert()
|
| 132 |
|
| 133 |
if config["if_trainee_model"]:
|
| 134 |
-
#
|
| 135 |
-
graph_gen.
|
| 136 |
-
|
| 137 |
-
# Judge statements
|
| 138 |
-
graph_gen.judge()
|
| 139 |
else:
|
| 140 |
-
|
| 141 |
|
| 142 |
-
|
| 143 |
-
|
|
|
|
|
|
|
| 144 |
|
| 145 |
# Save output
|
| 146 |
output_data = graph_gen.qa_storage.data
|
|
@@ -249,6 +258,9 @@ with gr.Blocks(title="GraphGen Demo", theme=gr.themes.Glass(), css=css) as demo:
|
|
| 249 |
)
|
| 250 |
|
| 251 |
with gr.Accordion(label=_("Model Config"), open=False):
|
|
|
|
|
|
|
|
|
|
| 252 |
synthesizer_url = gr.Textbox(
|
| 253 |
label="Synthesizer URL",
|
| 254 |
value="https://api.siliconflow.cn/v1",
|
|
@@ -300,9 +312,6 @@ with gr.Blocks(title="GraphGen Demo", theme=gr.themes.Glass(), css=css) as demo:
|
|
| 300 |
step=100,
|
| 301 |
interactive=True,
|
| 302 |
)
|
| 303 |
-
tokenizer = gr.Textbox(
|
| 304 |
-
label="Tokenizer", value="cl100k_base", interactive=True
|
| 305 |
-
)
|
| 306 |
output_data_type = gr.Radio(
|
| 307 |
choices=["atomic", "multi_hop", "aggregated"],
|
| 308 |
label="Output Data Type",
|
|
|
|
| 39 |
set_logger(log_file, if_stream=True)
|
| 40 |
os.environ.update({k: str(v) for k, v in env.items()})
|
| 41 |
|
| 42 |
+
tokenizer_instance = Tokenizer(config.get("tokenizer", "cl100k_base"))
|
| 43 |
+
synthesizer_llm_client = OpenAIClient(
|
|
|
|
| 44 |
model_name=env.get("SYNTHESIZER_MODEL", ""),
|
| 45 |
base_url=env.get("SYNTHESIZER_BASE_URL", ""),
|
| 46 |
api_key=env.get("SYNTHESIZER_API_KEY", ""),
|
| 47 |
request_limit=True,
|
| 48 |
rpm=RPM(env.get("RPM", 1000)),
|
| 49 |
tpm=TPM(env.get("TPM", 50000)),
|
| 50 |
+
tokenizer=tokenizer_instance,
|
| 51 |
)
|
| 52 |
+
trainee_llm_client = OpenAIClient(
|
|
|
|
| 53 |
model_name=env.get("TRAINEE_MODEL", ""),
|
| 54 |
base_url=env.get("TRAINEE_BASE_URL", ""),
|
| 55 |
api_key=env.get("TRAINEE_API_KEY", ""),
|
| 56 |
request_limit=True,
|
| 57 |
rpm=RPM(env.get("RPM", 1000)),
|
| 58 |
tpm=TPM(env.get("TPM", 50000)),
|
| 59 |
+
tokenizer=tokenizer_instance,
|
| 60 |
)
|
| 61 |
|
| 62 |
+
graph_gen = GraphGen(
|
| 63 |
+
working_dir=working_dir,
|
| 64 |
+
tokenizer_instance=tokenizer_instance,
|
| 65 |
+
synthesizer_llm_client=synthesizer_llm_client,
|
| 66 |
+
trainee_llm_client=trainee_llm_client,
|
| 67 |
+
)
|
| 68 |
|
| 69 |
return graph_gen
|
| 70 |
|
|
|
|
| 83 |
"chunk_size": params.chunk_size,
|
| 84 |
"chunk_overlap": params.chunk_overlap,
|
| 85 |
},
|
|
|
|
|
|
|
|
|
|
| 86 |
"search": {"enabled": False},
|
| 87 |
+
"quiz_and_judge": {
|
| 88 |
"enabled": params.if_trainee_model,
|
| 89 |
"quiz_samples": params.quiz_samples,
|
| 90 |
},
|
| 91 |
+
"partition": {
|
| 92 |
+
"method": "ece",
|
| 93 |
+
"method_params": {
|
| 94 |
+
"bidirectional": params.bidirectional,
|
| 95 |
+
"expand_method": params.expand_method,
|
| 96 |
+
"max_extra_edges": params.max_extra_edges,
|
| 97 |
+
"max_tokens": params.max_tokens,
|
| 98 |
+
"max_depth": params.max_depth,
|
| 99 |
+
"edge_sampling": params.edge_sampling,
|
| 100 |
+
"isolated_node_strategy": params.isolated_node_strategy,
|
| 101 |
+
"loss_strategy": params.loss_strategy,
|
| 102 |
+
},
|
| 103 |
+
},
|
| 104 |
+
"generate": {
|
| 105 |
+
"mode": params.output_data_type,
|
| 106 |
+
"data_format": params.output_data_format,
|
| 107 |
},
|
| 108 |
}
|
| 109 |
|
| 110 |
env = {
|
| 111 |
+
"TOKENIZER_MODEL": params.tokenizer,
|
| 112 |
"SYNTHESIZER_BASE_URL": params.synthesizer_url,
|
| 113 |
"SYNTHESIZER_MODEL": params.synthesizer_model,
|
| 114 |
"TRAINEE_BASE_URL": params.trainee_url,
|
|
|
|
| 138 |
|
| 139 |
try:
|
| 140 |
# Process the data
|
| 141 |
+
graph_gen.insert(read_config=config["read"], split_config=config["split"])
|
| 142 |
|
| 143 |
if config["if_trainee_model"]:
|
| 144 |
+
# Quiz and Judge
|
| 145 |
+
graph_gen.quiz_and_judge(quiz_and_judge_config=config["quiz_and_judge"])
|
|
|
|
|
|
|
|
|
|
| 146 |
else:
|
| 147 |
+
config["partition"]["method_params"]["edge_sampling"] = "random"
|
| 148 |
|
| 149 |
+
graph_gen.generate(
|
| 150 |
+
partition_config=config["partition"],
|
| 151 |
+
generate_config=config["generate"],
|
| 152 |
+
)
|
| 153 |
|
| 154 |
# Save output
|
| 155 |
output_data = graph_gen.qa_storage.data
|
|
|
|
| 258 |
)
|
| 259 |
|
| 260 |
with gr.Accordion(label=_("Model Config"), open=False):
|
| 261 |
+
tokenizer = gr.Textbox(
|
| 262 |
+
label="Tokenizer", value="cl100k_base", interactive=True
|
| 263 |
+
)
|
| 264 |
synthesizer_url = gr.Textbox(
|
| 265 |
label="Synthesizer URL",
|
| 266 |
value="https://api.siliconflow.cn/v1",
|
|
|
|
| 312 |
step=100,
|
| 313 |
interactive=True,
|
| 314 |
)
|
|
|
|
|
|
|
|
|
|
| 315 |
output_data_type = gr.Radio(
|
| 316 |
choices=["atomic", "multi_hop", "aggregated"],
|
| 317 |
label="Output Data Type",
|
graphgen/configs/aggregated_config.yaml
CHANGED
|
@@ -6,19 +6,21 @@ split:
|
|
| 6 |
search: # web search configuration
|
| 7 |
enabled: false # whether to enable web search
|
| 8 |
search_types: ["google"] # search engine types, support: google, bing, uniprot, wikipedia
|
| 9 |
-
|
| 10 |
-
output_data_format: ChatML # Alpaca, Sharegpt, ChatML
|
| 11 |
-
tokenizer: cl100k_base # tokenizer for counting tokens, support tiktoken tokenizer names and local tokenizer path
|
| 12 |
-
quiz_and_judge_strategy: # quiz and test whether the LLM masters the knowledge points
|
| 13 |
enabled: true
|
| 14 |
quiz_samples: 2 # number of quiz samples to generate
|
| 15 |
re_judge: false # whether to re-judge the existing quiz samples
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
search: # web search configuration
|
| 7 |
enabled: false # whether to enable web search
|
| 8 |
search_types: ["google"] # search engine types, support: google, bing, uniprot, wikipedia
|
| 9 |
+
quiz_and_judge: # quiz and test whether the LLM masters the knowledge points
|
|
|
|
|
|
|
|
|
|
| 10 |
enabled: true
|
| 11 |
quiz_samples: 2 # number of quiz samples to generate
|
| 12 |
re_judge: false # whether to re-judge the existing quiz samples
|
| 13 |
+
partition: # graph partition configuration
|
| 14 |
+
method: ece # ece is a custom partition method based on comprehension loss
|
| 15 |
+
method_params:
|
| 16 |
+
bidirectional: true # whether to traverse the graph in both directions
|
| 17 |
+
edge_sampling: max_loss # edge sampling strategy, support: random, max_loss, min_loss
|
| 18 |
+
expand_method: max_width # expand method, support: max_width, max_depth
|
| 19 |
+
isolated_node_strategy: ignore # strategy for isolated nodes, support: ignore, add
|
| 20 |
+
max_depth: 5 # maximum depth for graph traversal
|
| 21 |
+
max_extra_edges: 20 # max edges per direction (if expand_method="max_width")
|
| 22 |
+
max_tokens: 256 # restricts input length (if expand_method="max_tokens")
|
| 23 |
+
loss_strategy: only_edge # defines loss computation focus, support: only_edge, both
|
| 24 |
+
generate:
|
| 25 |
+
mode: aggregated # atomic, aggregated, multi_hop, cot
|
| 26 |
+
data_format: ChatML # Alpaca, Sharegpt, ChatML
|
graphgen/configs/atomic_config.yaml
CHANGED
|
@@ -6,19 +6,21 @@ split:
|
|
| 6 |
search: # web search configuration
|
| 7 |
enabled: false # whether to enable web search
|
| 8 |
search_types: ["google"] # search engine types, support: google, bing, uniprot, wikipedia
|
| 9 |
-
|
| 10 |
-
output_data_format: Alpaca # Alpaca, Sharegpt, ChatML
|
| 11 |
-
tokenizer: cl100k_base # tokenizer for counting tokens, support tiktoken tokenizer names and local tokenizer path
|
| 12 |
-
quiz_and_judge_strategy: # quiz and test whether the LLM masters the knowledge points
|
| 13 |
enabled: true
|
| 14 |
quiz_samples: 2 # number of quiz samples to generate
|
| 15 |
re_judge: false # whether to re-judge the existing quiz samples
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
search: # web search configuration
|
| 7 |
enabled: false # whether to enable web search
|
| 8 |
search_types: ["google"] # search engine types, support: google, bing, uniprot, wikipedia
|
| 9 |
+
quiz_and_judge: # quiz and test whether the LLM masters the knowledge points
|
|
|
|
|
|
|
|
|
|
| 10 |
enabled: true
|
| 11 |
quiz_samples: 2 # number of quiz samples to generate
|
| 12 |
re_judge: false # whether to re-judge the existing quiz samples
|
| 13 |
+
partition: # graph partition configuration
|
| 14 |
+
method: ece # ece is a custom partition method based on comprehension loss
|
| 15 |
+
method_params:
|
| 16 |
+
bidirectional: true # whether to traverse the graph in both directions
|
| 17 |
+
edge_sampling: max_loss # edge sampling strategy, support: random, max_loss, min_loss
|
| 18 |
+
expand_method: max_width # expand method, support: max_width, max_depth
|
| 19 |
+
isolated_node_strategy: ignore # strategy for isolated nodes, support: ignore, add
|
| 20 |
+
max_depth: 3 # maximum depth for graph traversal
|
| 21 |
+
max_extra_edges: 5 # max edges per direction (if expand_method="max_width")
|
| 22 |
+
max_tokens: 256 # restricts input length (if expand_method="max_tokens")
|
| 23 |
+
loss_strategy: only_edge # defines loss computation focus, support: only_edge, both
|
| 24 |
+
generate:
|
| 25 |
+
mode: atomic # atomic, aggregated, multi_hop, cot
|
| 26 |
+
data_format: Alpaca # Alpaca, Sharegpt, ChatML
|
graphgen/configs/cot_config.yaml
CHANGED
|
@@ -6,11 +6,14 @@ split:
|
|
| 6 |
search: # web search configuration
|
| 7 |
enabled: false # whether to enable web search
|
| 8 |
search_types: ["google"] # search engine types, support: google, bing, uniprot, wikipedia
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
search: # web search configuration
|
| 7 |
enabled: false # whether to enable web search
|
| 8 |
search_types: ["google"] # search engine types, support: google, bing, uniprot, wikipedia
|
| 9 |
+
quiz_and_judge: # quiz and test whether the LLM masters the knowledge points
|
| 10 |
+
enabled: false
|
| 11 |
+
partition: # graph partition configuration
|
| 12 |
+
method: leiden # leiden is a community detection algorithm
|
| 13 |
+
method_params:
|
| 14 |
+
max_size: 20 # Maximum size of communities
|
| 15 |
+
use_lcc: false
|
| 16 |
+
random_seed: 42
|
| 17 |
+
generate:
|
| 18 |
+
mode: cot # atomic, aggregated, multi_hop, cot
|
| 19 |
+
data_format: Sharegpt # Alpaca, Sharegpt, ChatML
|
graphgen/configs/multi_hop_config.yaml
CHANGED
|
@@ -6,19 +6,21 @@ split:
|
|
| 6 |
search: # web search configuration
|
| 7 |
enabled: false # whether to enable web search
|
| 8 |
search_types: ["google"] # search engine types, support: google, bing, uniprot, wikipedia
|
| 9 |
-
|
| 10 |
-
output_data_format: ChatML # Alpaca, Sharegpt, ChatML
|
| 11 |
-
tokenizer: cl100k_base # tokenizer for counting tokens, support tiktoken tokenizer names and local tokenizer path
|
| 12 |
-
quiz_and_judge_strategy: # quiz and test whether the LLM masters the knowledge points
|
| 13 |
enabled: false
|
| 14 |
quiz_samples: 2 # number of quiz samples to generate
|
| 15 |
re_judge: false # whether to re-judge the existing quiz samples
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
search: # web search configuration
|
| 7 |
enabled: false # whether to enable web search
|
| 8 |
search_types: ["google"] # search engine types, support: google, bing, uniprot, wikipedia
|
| 9 |
+
quiz_and_judge: # quiz and test whether the LLM masters the knowledge points
|
|
|
|
|
|
|
|
|
|
| 10 |
enabled: false
|
| 11 |
quiz_samples: 2 # number of quiz samples to generate
|
| 12 |
re_judge: false # whether to re-judge the existing quiz samples
|
| 13 |
+
partition: # graph partition configuration
|
| 14 |
+
method: ece # ece is a custom partition method based on comprehension loss
|
| 15 |
+
method_params:
|
| 16 |
+
bidirectional: true # whether to traverse the graph in both directions
|
| 17 |
+
edge_sampling: max_loss # edge sampling strategy, support: random, max_loss, min_loss
|
| 18 |
+
expand_method: max_width # expand method, support: max_width, max_depth
|
| 19 |
+
isolated_node_strategy: ignore # strategy for isolated nodes, support: ignore, add
|
| 20 |
+
max_depth: 1 # maximum depth for graph traversal
|
| 21 |
+
max_extra_edges: 2 # max edges per direction (if expand_method="max_width")
|
| 22 |
+
max_tokens: 256 # restricts input length (if expand_method="max_tokens")
|
| 23 |
+
loss_strategy: only_edge # defines loss computation focus, support: only_edge, both
|
| 24 |
+
generate:
|
| 25 |
+
mode: multi_hop # strategy for generating multi-hop QA pairs
|
| 26 |
+
data_format: ChatML # Alpaca, Sharegpt, ChatML
|
graphgen/generate.py
CHANGED
|
@@ -6,8 +6,8 @@ from importlib.resources import files
|
|
| 6 |
import yaml
|
| 7 |
from dotenv import load_dotenv
|
| 8 |
|
| 9 |
-
from .graphgen import GraphGen
|
| 10 |
-
from .utils import logger, set_logger
|
| 11 |
|
| 12 |
sys_path = os.path.abspath(os.path.dirname(__file__))
|
| 13 |
|
|
@@ -50,50 +50,51 @@ def main():
|
|
| 50 |
with open(args.config_file, "r", encoding="utf-8") as f:
|
| 51 |
config = yaml.load(f, Loader=yaml.FullLoader)
|
| 52 |
|
| 53 |
-
|
| 54 |
unique_id = int(time.time())
|
| 55 |
|
| 56 |
-
output_path = os.path.join(
|
| 57 |
-
working_dir, "data", "graphgen", f"{unique_id}_{output_data_type}"
|
| 58 |
-
)
|
| 59 |
set_working_dir(output_path)
|
| 60 |
|
| 61 |
set_logger(
|
| 62 |
-
os.path.join(output_path, f"{unique_id}.log"),
|
| 63 |
if_stream=True,
|
| 64 |
)
|
| 65 |
logger.info(
|
| 66 |
"GraphGen with unique ID %s logging to %s",
|
| 67 |
unique_id,
|
| 68 |
-
os.path.join(
|
| 69 |
-
working_dir, "logs", f"{unique_id}_graphgen_{output_data_type}.log"
|
| 70 |
-
),
|
| 71 |
)
|
| 72 |
|
| 73 |
-
graph_gen = GraphGen(
|
| 74 |
|
| 75 |
-
graph_gen.insert()
|
| 76 |
|
| 77 |
-
|
| 78 |
-
graph_gen.search()
|
| 79 |
|
| 80 |
# Use pipeline according to the output data type
|
| 81 |
-
if
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
graph_gen.quiz()
|
| 86 |
-
graph_gen.judge()
|
| 87 |
else:
|
| 88 |
logger.warning(
|
| 89 |
"Quiz and Judge strategy is disabled. Edge sampling falls back to random."
|
| 90 |
)
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
|
|
|
|
|
|
|
|
|
| 95 |
else:
|
| 96 |
-
raise ValueError(f"Unsupported output data type: {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 97 |
|
| 98 |
save_config(os.path.join(output_path, "config.yaml"), config)
|
| 99 |
logger.info("GraphGen completed successfully. Data saved to %s", output_path)
|
|
|
|
| 6 |
import yaml
|
| 7 |
from dotenv import load_dotenv
|
| 8 |
|
| 9 |
+
from graphgen.graphgen import GraphGen
|
| 10 |
+
from graphgen.utils import logger, set_logger
|
| 11 |
|
| 12 |
sys_path = os.path.abspath(os.path.dirname(__file__))
|
| 13 |
|
|
|
|
| 50 |
with open(args.config_file, "r", encoding="utf-8") as f:
|
| 51 |
config = yaml.load(f, Loader=yaml.FullLoader)
|
| 52 |
|
| 53 |
+
mode = config["generate"]["mode"]
|
| 54 |
unique_id = int(time.time())
|
| 55 |
|
| 56 |
+
output_path = os.path.join(working_dir, "data", "graphgen", f"{unique_id}")
|
|
|
|
|
|
|
| 57 |
set_working_dir(output_path)
|
| 58 |
|
| 59 |
set_logger(
|
| 60 |
+
os.path.join(output_path, f"{unique_id}_{mode}.log"),
|
| 61 |
if_stream=True,
|
| 62 |
)
|
| 63 |
logger.info(
|
| 64 |
"GraphGen with unique ID %s logging to %s",
|
| 65 |
unique_id,
|
| 66 |
+
os.path.join(working_dir, f"{unique_id}_{mode}.log"),
|
|
|
|
|
|
|
| 67 |
)
|
| 68 |
|
| 69 |
+
graph_gen = GraphGen(unique_id=unique_id, working_dir=working_dir)
|
| 70 |
|
| 71 |
+
graph_gen.insert(read_config=config["read"], split_config=config["split"])
|
| 72 |
|
| 73 |
+
graph_gen.search(search_config=config["search"])
|
|
|
|
| 74 |
|
| 75 |
# Use pipeline according to the output data type
|
| 76 |
+
if mode in ["atomic", "aggregated", "multi_hop"]:
|
| 77 |
+
logger.info("Generation mode set to '%s'. Start generation.", mode)
|
| 78 |
+
if "quiz_and_judge" in config and config["quiz_and_judge"]["enabled"]:
|
| 79 |
+
graph_gen.quiz_and_judge(quiz_and_judge_config=config["quiz_and_judge"])
|
|
|
|
|
|
|
| 80 |
else:
|
| 81 |
logger.warning(
|
| 82 |
"Quiz and Judge strategy is disabled. Edge sampling falls back to random."
|
| 83 |
)
|
| 84 |
+
assert (
|
| 85 |
+
config["partition"]["method"] == "ece"
|
| 86 |
+
and "ece_params" in config["partition"]
|
| 87 |
+
), "Only ECE partition with edge sampling is supported."
|
| 88 |
+
config["partition"]["method_params"]["edge_sampling"] = "random"
|
| 89 |
+
elif mode == "cot":
|
| 90 |
+
logger.info("Generation mode set to 'cot'. Start generation.")
|
| 91 |
else:
|
| 92 |
+
raise ValueError(f"Unsupported output data type: {mode}")
|
| 93 |
+
|
| 94 |
+
graph_gen.generate(
|
| 95 |
+
partition_config=config["partition"],
|
| 96 |
+
generate_config=config["generate"],
|
| 97 |
+
)
|
| 98 |
|
| 99 |
save_config(os.path.join(output_path, "config.yaml"), config)
|
| 100 |
logger.info("GraphGen completed successfully. Data saved to %s", output_path)
|
graphgen/graphgen.py
CHANGED
|
@@ -1,7 +1,7 @@
|
|
| 1 |
import asyncio
|
| 2 |
import os
|
| 3 |
import time
|
| 4 |
-
from dataclasses import dataclass
|
| 5 |
from typing import Dict, cast
|
| 6 |
|
| 7 |
import gradio as gr
|
|
@@ -14,7 +14,6 @@ from graphgen.models import (
|
|
| 14 |
NetworkXStorage,
|
| 15 |
OpenAIClient,
|
| 16 |
Tokenizer,
|
| 17 |
-
TraverseStrategy,
|
| 18 |
)
|
| 19 |
from graphgen.operators import (
|
| 20 |
chunk_documents,
|
|
@@ -42,46 +41,36 @@ sys_path = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
|
| 42 |
class GraphGen:
|
| 43 |
unique_id: int = int(time.time())
|
| 44 |
working_dir: str = os.path.join(sys_path, "cache")
|
| 45 |
-
config: Dict = field(default_factory=dict)
|
| 46 |
|
| 47 |
# llm
|
| 48 |
tokenizer_instance: Tokenizer = None
|
| 49 |
synthesizer_llm_client: OpenAIClient = None
|
| 50 |
trainee_llm_client: OpenAIClient = None
|
| 51 |
|
| 52 |
-
# search
|
| 53 |
-
search_config: dict = field(
|
| 54 |
-
default_factory=lambda: {"enabled": False, "search_types": ["wikipedia"]}
|
| 55 |
-
)
|
| 56 |
-
|
| 57 |
-
# traversal
|
| 58 |
-
traverse_strategy: TraverseStrategy = None
|
| 59 |
-
|
| 60 |
# webui
|
| 61 |
progress_bar: gr.Progress = None
|
| 62 |
|
| 63 |
def __post_init__(self):
|
| 64 |
-
self.tokenizer_instance: Tokenizer = Tokenizer(
|
| 65 |
-
model_name=
|
| 66 |
)
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 72 |
)
|
| 73 |
-
|
|
|
|
| 74 |
model_name=os.getenv("TRAINEE_MODEL"),
|
| 75 |
api_key=os.getenv("TRAINEE_API_KEY"),
|
| 76 |
base_url=os.getenv("TRAINEE_BASE_URL"),
|
| 77 |
tokenizer=self.tokenizer_instance,
|
| 78 |
)
|
| 79 |
-
self.search_config = self.config["search"]
|
| 80 |
-
|
| 81 |
-
if "traverse_strategy" in self.config:
|
| 82 |
-
self.traverse_strategy = TraverseStrategy(
|
| 83 |
-
**self.config["traverse_strategy"]
|
| 84 |
-
)
|
| 85 |
|
| 86 |
self.full_docs_storage: JsonKVStorage = JsonKVStorage(
|
| 87 |
self.working_dir, namespace="full_docs"
|
|
@@ -99,24 +88,17 @@ class GraphGen:
|
|
| 99 |
self.working_dir, namespace="rephrase"
|
| 100 |
)
|
| 101 |
self.qa_storage: JsonListStorage = JsonListStorage(
|
| 102 |
-
os.path.join(
|
| 103 |
-
self.working_dir,
|
| 104 |
-
"data",
|
| 105 |
-
"graphgen",
|
| 106 |
-
f"{self.unique_id}_{self.config['output_data_type']}",
|
| 107 |
-
),
|
| 108 |
namespace="qa",
|
| 109 |
)
|
| 110 |
|
| 111 |
@async_to_sync_method
|
| 112 |
-
async def insert(self):
|
| 113 |
"""
|
| 114 |
insert chunks into the graph
|
| 115 |
"""
|
| 116 |
-
input_file = self.config["read"]["input_file"]
|
| 117 |
-
|
| 118 |
# Step 1: Read files
|
| 119 |
-
data = read_files(input_file)
|
| 120 |
if len(data) == 0:
|
| 121 |
logger.warning("No data to process")
|
| 122 |
return
|
|
@@ -141,8 +123,8 @@ class GraphGen:
|
|
| 141 |
|
| 142 |
inserting_chunks = await chunk_documents(
|
| 143 |
new_docs,
|
| 144 |
-
|
| 145 |
-
|
| 146 |
self.tokenizer_instance,
|
| 147 |
self.progress_bar,
|
| 148 |
)
|
|
@@ -178,6 +160,7 @@ class GraphGen:
|
|
| 178 |
return
|
| 179 |
|
| 180 |
await self._insert_done()
|
|
|
|
| 181 |
|
| 182 |
async def _insert_done(self):
|
| 183 |
tasks = []
|
|
@@ -193,14 +176,12 @@ class GraphGen:
|
|
| 193 |
await asyncio.gather(*tasks)
|
| 194 |
|
| 195 |
@async_to_sync_method
|
| 196 |
-
async def search(self):
|
| 197 |
logger.info(
|
| 198 |
-
"Search is %s", "enabled" if
|
| 199 |
)
|
| 200 |
-
if
|
| 201 |
-
logger.info(
|
| 202 |
-
"[Search] %s ...", ", ".join(self.search_config["search_types"])
|
| 203 |
-
)
|
| 204 |
all_nodes = await self.graph_storage.get_all_nodes()
|
| 205 |
all_nodes_names = [node[0] for node in all_nodes]
|
| 206 |
new_search_entities = await self.full_docs_storage.filter_keys(
|
|
@@ -210,7 +191,7 @@ class GraphGen:
|
|
| 210 |
"[Search] Found %d entities to search", len(new_search_entities)
|
| 211 |
)
|
| 212 |
_add_search_data = await search_all(
|
| 213 |
-
search_types=
|
| 214 |
search_entities=new_search_entities,
|
| 215 |
)
|
| 216 |
if _add_search_data:
|
|
@@ -230,78 +211,77 @@ class GraphGen:
|
|
| 230 |
await self.insert()
|
| 231 |
|
| 232 |
@async_to_sync_method
|
| 233 |
-
async def
|
| 234 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 235 |
await quiz(
|
| 236 |
self.synthesizer_llm_client,
|
| 237 |
self.graph_storage,
|
| 238 |
self.rephrase_storage,
|
| 239 |
max_samples,
|
| 240 |
)
|
| 241 |
-
await self.rephrase_storage.index_done_callback()
|
| 242 |
|
| 243 |
-
|
| 244 |
-
|
| 245 |
-
re_judge = self.config["quiz_and_judge_strategy"]["re_judge"]
|
| 246 |
_update_relations = await judge_statement(
|
| 247 |
self.trainee_llm_client,
|
| 248 |
self.graph_storage,
|
| 249 |
self.rephrase_storage,
|
| 250 |
re_judge,
|
| 251 |
)
|
|
|
|
| 252 |
await _update_relations.index_done_callback()
|
| 253 |
|
| 254 |
@async_to_sync_method
|
| 255 |
-
async def
|
| 256 |
-
|
| 257 |
-
|
| 258 |
-
|
|
|
|
| 259 |
results = await traverse_graph_for_atomic(
|
| 260 |
self.synthesizer_llm_client,
|
| 261 |
self.tokenizer_instance,
|
| 262 |
self.graph_storage,
|
| 263 |
-
|
| 264 |
self.text_chunks_storage,
|
| 265 |
self.progress_bar,
|
| 266 |
)
|
| 267 |
-
elif
|
| 268 |
results = await traverse_graph_for_multi_hop(
|
| 269 |
self.synthesizer_llm_client,
|
| 270 |
self.tokenizer_instance,
|
| 271 |
self.graph_storage,
|
| 272 |
-
|
| 273 |
self.text_chunks_storage,
|
| 274 |
self.progress_bar,
|
| 275 |
)
|
| 276 |
-
elif
|
| 277 |
results = await traverse_graph_for_aggregated(
|
| 278 |
self.synthesizer_llm_client,
|
| 279 |
self.tokenizer_instance,
|
| 280 |
self.graph_storage,
|
| 281 |
-
|
| 282 |
self.text_chunks_storage,
|
| 283 |
self.progress_bar,
|
| 284 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 285 |
else:
|
| 286 |
-
raise ValueError(f"Unknown
|
| 287 |
-
|
| 288 |
-
|
| 289 |
-
results, output_data_format=self.config["output_data_format"]
|
| 290 |
-
)
|
| 291 |
-
|
| 292 |
-
await self.qa_storage.upsert(results)
|
| 293 |
-
await self.qa_storage.index_done_callback()
|
| 294 |
-
|
| 295 |
-
@async_to_sync_method
|
| 296 |
-
async def generate_reasoning(self, method_params):
|
| 297 |
-
results = await generate_cot(
|
| 298 |
-
self.graph_storage,
|
| 299 |
-
self.synthesizer_llm_client,
|
| 300 |
-
method_params=method_params,
|
| 301 |
-
)
|
| 302 |
|
|
|
|
| 303 |
results = format_generation_results(
|
| 304 |
-
results, output_data_format=
|
| 305 |
)
|
| 306 |
|
| 307 |
await self.qa_storage.upsert(results)
|
|
|
|
| 1 |
import asyncio
|
| 2 |
import os
|
| 3 |
import time
|
| 4 |
+
from dataclasses import dataclass
|
| 5 |
from typing import Dict, cast
|
| 6 |
|
| 7 |
import gradio as gr
|
|
|
|
| 14 |
NetworkXStorage,
|
| 15 |
OpenAIClient,
|
| 16 |
Tokenizer,
|
|
|
|
| 17 |
)
|
| 18 |
from graphgen.operators import (
|
| 19 |
chunk_documents,
|
|
|
|
| 41 |
class GraphGen:
|
| 42 |
unique_id: int = int(time.time())
|
| 43 |
working_dir: str = os.path.join(sys_path, "cache")
|
|
|
|
| 44 |
|
| 45 |
# llm
|
| 46 |
tokenizer_instance: Tokenizer = None
|
| 47 |
synthesizer_llm_client: OpenAIClient = None
|
| 48 |
trainee_llm_client: OpenAIClient = None
|
| 49 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 50 |
# webui
|
| 51 |
progress_bar: gr.Progress = None
|
| 52 |
|
| 53 |
def __post_init__(self):
|
| 54 |
+
self.tokenizer_instance: Tokenizer = self.tokenizer_instance or Tokenizer(
|
| 55 |
+
model_name=os.getenv("TOKENIZER_MODEL")
|
| 56 |
)
|
| 57 |
+
|
| 58 |
+
self.synthesizer_llm_client: OpenAIClient = (
|
| 59 |
+
self.synthesizer_llm_client
|
| 60 |
+
or OpenAIClient(
|
| 61 |
+
model_name=os.getenv("SYNTHESIZER_MODEL"),
|
| 62 |
+
api_key=os.getenv("SYNTHESIZER_API_KEY"),
|
| 63 |
+
base_url=os.getenv("SYNTHESIZER_BASE_URL"),
|
| 64 |
+
tokenizer=self.tokenizer_instance,
|
| 65 |
+
)
|
| 66 |
)
|
| 67 |
+
|
| 68 |
+
self.trainee_llm_client: OpenAIClient = self.trainee_llm_client or OpenAIClient(
|
| 69 |
model_name=os.getenv("TRAINEE_MODEL"),
|
| 70 |
api_key=os.getenv("TRAINEE_API_KEY"),
|
| 71 |
base_url=os.getenv("TRAINEE_BASE_URL"),
|
| 72 |
tokenizer=self.tokenizer_instance,
|
| 73 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 74 |
|
| 75 |
self.full_docs_storage: JsonKVStorage = JsonKVStorage(
|
| 76 |
self.working_dir, namespace="full_docs"
|
|
|
|
| 88 |
self.working_dir, namespace="rephrase"
|
| 89 |
)
|
| 90 |
self.qa_storage: JsonListStorage = JsonListStorage(
|
| 91 |
+
os.path.join(self.working_dir, "data", "graphgen", f"{self.unique_id}"),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 92 |
namespace="qa",
|
| 93 |
)
|
| 94 |
|
| 95 |
@async_to_sync_method
|
| 96 |
+
async def insert(self, read_config: Dict, split_config: Dict):
|
| 97 |
"""
|
| 98 |
insert chunks into the graph
|
| 99 |
"""
|
|
|
|
|
|
|
| 100 |
# Step 1: Read files
|
| 101 |
+
data = read_files(read_config["input_file"])
|
| 102 |
if len(data) == 0:
|
| 103 |
logger.warning("No data to process")
|
| 104 |
return
|
|
|
|
| 123 |
|
| 124 |
inserting_chunks = await chunk_documents(
|
| 125 |
new_docs,
|
| 126 |
+
split_config["chunk_size"],
|
| 127 |
+
split_config["chunk_overlap"],
|
| 128 |
self.tokenizer_instance,
|
| 129 |
self.progress_bar,
|
| 130 |
)
|
|
|
|
| 160 |
return
|
| 161 |
|
| 162 |
await self._insert_done()
|
| 163 |
+
return _add_entities_and_relations
|
| 164 |
|
| 165 |
async def _insert_done(self):
|
| 166 |
tasks = []
|
|
|
|
| 176 |
await asyncio.gather(*tasks)
|
| 177 |
|
| 178 |
@async_to_sync_method
|
| 179 |
+
async def search(self, search_config: Dict):
|
| 180 |
logger.info(
|
| 181 |
+
"Search is %s", "enabled" if search_config["enabled"] else "disabled"
|
| 182 |
)
|
| 183 |
+
if search_config["enabled"]:
|
| 184 |
+
logger.info("[Search] %s ...", ", ".join(search_config["search_types"]))
|
|
|
|
|
|
|
| 185 |
all_nodes = await self.graph_storage.get_all_nodes()
|
| 186 |
all_nodes_names = [node[0] for node in all_nodes]
|
| 187 |
new_search_entities = await self.full_docs_storage.filter_keys(
|
|
|
|
| 191 |
"[Search] Found %d entities to search", len(new_search_entities)
|
| 192 |
)
|
| 193 |
_add_search_data = await search_all(
|
| 194 |
+
search_types=search_config["search_types"],
|
| 195 |
search_entities=new_search_entities,
|
| 196 |
)
|
| 197 |
if _add_search_data:
|
|
|
|
| 211 |
await self.insert()
|
| 212 |
|
| 213 |
@async_to_sync_method
|
| 214 |
+
async def quiz_and_judge(self, quiz_and_judge_config: Dict):
|
| 215 |
+
if quiz_and_judge_config is None or not quiz_and_judge_config.get(
|
| 216 |
+
"enabled", False
|
| 217 |
+
):
|
| 218 |
+
logger.warning("Quiz and Judge is not used in this pipeline.")
|
| 219 |
+
return
|
| 220 |
+
max_samples = quiz_and_judge_config["quiz_samples"]
|
| 221 |
await quiz(
|
| 222 |
self.synthesizer_llm_client,
|
| 223 |
self.graph_storage,
|
| 224 |
self.rephrase_storage,
|
| 225 |
max_samples,
|
| 226 |
)
|
|
|
|
| 227 |
|
| 228 |
+
# TODO: assert trainee_llm_client is valid before judge
|
| 229 |
+
re_judge = quiz_and_judge_config["re_judge"]
|
|
|
|
| 230 |
_update_relations = await judge_statement(
|
| 231 |
self.trainee_llm_client,
|
| 232 |
self.graph_storage,
|
| 233 |
self.rephrase_storage,
|
| 234 |
re_judge,
|
| 235 |
)
|
| 236 |
+
await self.rephrase_storage.index_done_callback()
|
| 237 |
await _update_relations.index_done_callback()
|
| 238 |
|
| 239 |
@async_to_sync_method
|
| 240 |
+
async def generate(self, partition_config: Dict, generate_config: Dict):
|
| 241 |
+
# Step 1: partition the graph
|
| 242 |
+
# TODO: implement graph partitioning, e.g. Partitioner().partition(self.graph_storage)
|
| 243 |
+
mode = generate_config["mode"]
|
| 244 |
+
if mode == "atomic":
|
| 245 |
results = await traverse_graph_for_atomic(
|
| 246 |
self.synthesizer_llm_client,
|
| 247 |
self.tokenizer_instance,
|
| 248 |
self.graph_storage,
|
| 249 |
+
partition_config["method_params"],
|
| 250 |
self.text_chunks_storage,
|
| 251 |
self.progress_bar,
|
| 252 |
)
|
| 253 |
+
elif mode == "multi_hop":
|
| 254 |
results = await traverse_graph_for_multi_hop(
|
| 255 |
self.synthesizer_llm_client,
|
| 256 |
self.tokenizer_instance,
|
| 257 |
self.graph_storage,
|
| 258 |
+
partition_config["method_params"],
|
| 259 |
self.text_chunks_storage,
|
| 260 |
self.progress_bar,
|
| 261 |
)
|
| 262 |
+
elif mode == "aggregated":
|
| 263 |
results = await traverse_graph_for_aggregated(
|
| 264 |
self.synthesizer_llm_client,
|
| 265 |
self.tokenizer_instance,
|
| 266 |
self.graph_storage,
|
| 267 |
+
partition_config["method_params"],
|
| 268 |
self.text_chunks_storage,
|
| 269 |
self.progress_bar,
|
| 270 |
)
|
| 271 |
+
elif mode == "cot":
|
| 272 |
+
results = await generate_cot(
|
| 273 |
+
self.graph_storage,
|
| 274 |
+
self.synthesizer_llm_client,
|
| 275 |
+
method_params=partition_config["method_params"],
|
| 276 |
+
)
|
| 277 |
else:
|
| 278 |
+
raise ValueError(f"Unknown generation mode: {mode}")
|
| 279 |
+
# Step 2: generate QA pairs
|
| 280 |
+
# TODO
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 281 |
|
| 282 |
+
# Step 3: format
|
| 283 |
results = format_generation_results(
|
| 284 |
+
results, output_data_format=generate_config["data_format"]
|
| 285 |
)
|
| 286 |
|
| 287 |
await self.qa_storage.upsert(results)
|
graphgen/models/__init__.py
CHANGED
|
@@ -13,5 +13,4 @@ from .search.web.google_search import GoogleSearch
|
|
| 13 |
from .splitter import ChineseRecursiveTextSplitter, RecursiveCharacterSplitter
|
| 14 |
from .storage.json_storage import JsonKVStorage, JsonListStorage
|
| 15 |
from .storage.networkx_storage import NetworkXStorage
|
| 16 |
-
from .strategy.travserse_strategy import TraverseStrategy
|
| 17 |
from .tokenizer import Tokenizer
|
|
|
|
| 13 |
from .splitter import ChineseRecursiveTextSplitter, RecursiveCharacterSplitter
|
| 14 |
from .storage.json_storage import JsonKVStorage, JsonListStorage
|
| 15 |
from .storage.networkx_storage import NetworkXStorage
|
|
|
|
| 16 |
from .tokenizer import Tokenizer
|
graphgen/models/strategy/__init__.py
DELETED
|
File without changes
|
graphgen/models/strategy/travserse_strategy.py
DELETED
|
@@ -1,28 +0,0 @@
|
|
| 1 |
-
from dataclasses import dataclass, fields
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
@dataclass
|
| 5 |
-
class TraverseStrategy:
|
| 6 |
-
# 生成的QA形式:原子、多跳、聚合型
|
| 7 |
-
qa_form: str = "atomic" # "atomic" or "multi_hop" or "aggregated"
|
| 8 |
-
# 最大边数和最大token数方法中选择一个生效
|
| 9 |
-
expand_method: str = "max_tokens" # "max_width" or "max_tokens"
|
| 10 |
-
# 单向拓展还是双向拓展
|
| 11 |
-
bidirectional: bool = True
|
| 12 |
-
# 每个方向拓展的最大边数
|
| 13 |
-
max_extra_edges: int = 5
|
| 14 |
-
# 最长token数
|
| 15 |
-
max_tokens: int = 256
|
| 16 |
-
# 每个方向拓展的最大深度
|
| 17 |
-
max_depth: int = 2
|
| 18 |
-
# 同一层中选边的策略(如果是双向拓展,同一层指的是两边连接的边的集合)
|
| 19 |
-
edge_sampling: str = "max_loss" # "max_loss" or "min_loss" or "random"
|
| 20 |
-
# 孤立节点的处理策略
|
| 21 |
-
isolated_node_strategy: str = "add" # "add" or "ignore"
|
| 22 |
-
loss_strategy: str = "only_edge" # only_edge, both
|
| 23 |
-
|
| 24 |
-
def to_yaml(self):
|
| 25 |
-
strategy_dict = {}
|
| 26 |
-
for f in fields(self):
|
| 27 |
-
strategy_dict[f.name] = getattr(self, f.name)
|
| 28 |
-
return {"traverse_strategy": strategy_dict}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
graphgen/models/tokenizer/__init__.py
CHANGED
|
@@ -39,6 +39,8 @@ class Tokenizer(BaseTokenizer):
|
|
| 39 |
_impl: BaseTokenizer = field(init=False, repr=False)
|
| 40 |
|
| 41 |
def __post_init__(self):
|
|
|
|
|
|
|
| 42 |
self._impl = get_tokenizer_impl(self.model_name)
|
| 43 |
|
| 44 |
def encode(self, text: str) -> List[int]:
|
|
|
|
| 39 |
_impl: BaseTokenizer = field(init=False, repr=False)
|
| 40 |
|
| 41 |
def __post_init__(self):
|
| 42 |
+
if not self.model_name:
|
| 43 |
+
raise ValueError("TOKENIZER_MODEL must be specified in the ENV variables.")
|
| 44 |
self._impl = get_tokenizer_impl(self.model_name)
|
| 45 |
|
| 46 |
def encode(self, text: str) -> List[int]:
|
graphgen/operators/build_kg/split_kg.py
CHANGED
|
@@ -1,9 +1,10 @@
|
|
| 1 |
import random
|
| 2 |
from collections import defaultdict
|
|
|
|
| 3 |
|
| 4 |
from tqdm.asyncio import tqdm as tqdm_async
|
| 5 |
|
| 6 |
-
from graphgen.models import NetworkXStorage
|
| 7 |
from graphgen.utils import logger
|
| 8 |
|
| 9 |
|
|
@@ -247,9 +248,9 @@ async def get_batches_with_strategy( # pylint: disable=too-many-branches
|
|
| 247 |
nodes: list,
|
| 248 |
edges: list,
|
| 249 |
graph_storage: NetworkXStorage,
|
| 250 |
-
traverse_strategy:
|
| 251 |
):
|
| 252 |
-
expand_method = traverse_strategy
|
| 253 |
if expand_method == "max_width":
|
| 254 |
logger.info("Using max width strategy")
|
| 255 |
elif expand_method == "max_tokens":
|
|
@@ -257,8 +258,8 @@ async def get_batches_with_strategy( # pylint: disable=too-many-branches
|
|
| 257 |
else:
|
| 258 |
raise ValueError(f"Invalid expand method: {expand_method}")
|
| 259 |
|
| 260 |
-
max_depth = traverse_strategy
|
| 261 |
-
edge_sampling = traverse_strategy
|
| 262 |
|
| 263 |
# 构建临接矩阵
|
| 264 |
edge_adj_list = defaultdict(list)
|
|
@@ -275,16 +276,16 @@ async def get_batches_with_strategy( # pylint: disable=too-many-branches
|
|
| 275 |
for i, (node_name, _) in enumerate(nodes):
|
| 276 |
node_dict[node_name] = i
|
| 277 |
|
| 278 |
-
if traverse_strategy
|
| 279 |
er_tuples = [
|
| 280 |
([nodes[node_dict[edge[0]]], nodes[node_dict[edge[1]]]], edge)
|
| 281 |
for edge in edges
|
| 282 |
]
|
| 283 |
edges = _sort_tuples(er_tuples, edge_sampling)
|
| 284 |
-
elif traverse_strategy
|
| 285 |
edges = _sort_edges(edges, edge_sampling)
|
| 286 |
else:
|
| 287 |
-
raise ValueError(f"Invalid loss strategy: {traverse_strategy
|
| 288 |
|
| 289 |
for i, (src, tgt, _) in enumerate(edges):
|
| 290 |
edge_adj_list[src].append(i)
|
|
@@ -315,10 +316,10 @@ async def get_batches_with_strategy( # pylint: disable=too-many-branches
|
|
| 315 |
nodes,
|
| 316 |
edge,
|
| 317 |
max_depth,
|
| 318 |
-
traverse_strategy
|
| 319 |
-
traverse_strategy
|
| 320 |
edge_sampling,
|
| 321 |
-
traverse_strategy
|
| 322 |
)
|
| 323 |
else:
|
| 324 |
level_n_edges = _get_level_n_edges_by_max_tokens(
|
|
@@ -328,10 +329,10 @@ async def get_batches_with_strategy( # pylint: disable=too-many-branches
|
|
| 328 |
nodes,
|
| 329 |
edge,
|
| 330 |
max_depth,
|
| 331 |
-
traverse_strategy
|
| 332 |
-
traverse_strategy
|
| 333 |
edge_sampling,
|
| 334 |
-
traverse_strategy
|
| 335 |
)
|
| 336 |
|
| 337 |
for _edge in level_n_edges:
|
|
@@ -352,7 +353,7 @@ async def get_batches_with_strategy( # pylint: disable=too-many-branches
|
|
| 352 |
logger.info("Processing batches: %d", len(processing_batches))
|
| 353 |
|
| 354 |
# isolate nodes
|
| 355 |
-
isolated_node_strategy = traverse_strategy
|
| 356 |
if isolated_node_strategy == "add":
|
| 357 |
processing_batches = await _add_isolated_nodes(
|
| 358 |
nodes, processing_batches, graph_storage
|
|
|
|
| 1 |
import random
|
| 2 |
from collections import defaultdict
|
| 3 |
+
from typing import Dict
|
| 4 |
|
| 5 |
from tqdm.asyncio import tqdm as tqdm_async
|
| 6 |
|
| 7 |
+
from graphgen.models import NetworkXStorage
|
| 8 |
from graphgen.utils import logger
|
| 9 |
|
| 10 |
|
|
|
|
| 248 |
nodes: list,
|
| 249 |
edges: list,
|
| 250 |
graph_storage: NetworkXStorage,
|
| 251 |
+
traverse_strategy: Dict,
|
| 252 |
):
|
| 253 |
+
expand_method = traverse_strategy["expand_method"]
|
| 254 |
if expand_method == "max_width":
|
| 255 |
logger.info("Using max width strategy")
|
| 256 |
elif expand_method == "max_tokens":
|
|
|
|
| 258 |
else:
|
| 259 |
raise ValueError(f"Invalid expand method: {expand_method}")
|
| 260 |
|
| 261 |
+
max_depth = traverse_strategy["max_depth"]
|
| 262 |
+
edge_sampling = traverse_strategy["edge_sampling"]
|
| 263 |
|
| 264 |
# 构建临接矩阵
|
| 265 |
edge_adj_list = defaultdict(list)
|
|
|
|
| 276 |
for i, (node_name, _) in enumerate(nodes):
|
| 277 |
node_dict[node_name] = i
|
| 278 |
|
| 279 |
+
if traverse_strategy["loss_strategy"] == "both":
|
| 280 |
er_tuples = [
|
| 281 |
([nodes[node_dict[edge[0]]], nodes[node_dict[edge[1]]]], edge)
|
| 282 |
for edge in edges
|
| 283 |
]
|
| 284 |
edges = _sort_tuples(er_tuples, edge_sampling)
|
| 285 |
+
elif traverse_strategy["loss_strategy"] == "only_edge":
|
| 286 |
edges = _sort_edges(edges, edge_sampling)
|
| 287 |
else:
|
| 288 |
+
raise ValueError(f"Invalid loss strategy: {traverse_strategy['loss_strategy']}")
|
| 289 |
|
| 290 |
for i, (src, tgt, _) in enumerate(edges):
|
| 291 |
edge_adj_list[src].append(i)
|
|
|
|
| 316 |
nodes,
|
| 317 |
edge,
|
| 318 |
max_depth,
|
| 319 |
+
traverse_strategy["bidirectional"],
|
| 320 |
+
traverse_strategy["max_extra_edges"],
|
| 321 |
edge_sampling,
|
| 322 |
+
traverse_strategy["loss_strategy"],
|
| 323 |
)
|
| 324 |
else:
|
| 325 |
level_n_edges = _get_level_n_edges_by_max_tokens(
|
|
|
|
| 329 |
nodes,
|
| 330 |
edge,
|
| 331 |
max_depth,
|
| 332 |
+
traverse_strategy["bidirectional"],
|
| 333 |
+
traverse_strategy["max_tokens"],
|
| 334 |
edge_sampling,
|
| 335 |
+
traverse_strategy["loss_strategy"],
|
| 336 |
)
|
| 337 |
|
| 338 |
for _edge in level_n_edges:
|
|
|
|
| 353 |
logger.info("Processing batches: %d", len(processing_batches))
|
| 354 |
|
| 355 |
# isolate nodes
|
| 356 |
+
isolated_node_strategy = traverse_strategy["isolated_node_strategy"]
|
| 357 |
if isolated_node_strategy == "add":
|
| 358 |
processing_batches = await _add_isolated_nodes(
|
| 359 |
nodes, processing_batches, graph_storage
|
graphgen/operators/traverse_graph.py
CHANGED
|
@@ -1,15 +1,10 @@
|
|
| 1 |
import asyncio
|
|
|
|
| 2 |
|
| 3 |
import gradio as gr
|
| 4 |
from tqdm.asyncio import tqdm as tqdm_async
|
| 5 |
|
| 6 |
-
from graphgen.models import
|
| 7 |
-
JsonKVStorage,
|
| 8 |
-
NetworkXStorage,
|
| 9 |
-
OpenAIClient,
|
| 10 |
-
Tokenizer,
|
| 11 |
-
TraverseStrategy,
|
| 12 |
-
)
|
| 13 |
from graphgen.operators.build_kg.split_kg import get_batches_with_strategy
|
| 14 |
from graphgen.templates import (
|
| 15 |
ANSWER_REPHRASING_PROMPT,
|
|
@@ -164,7 +159,7 @@ async def traverse_graph_for_aggregated(
|
|
| 164 |
llm_client: OpenAIClient,
|
| 165 |
tokenizer: Tokenizer,
|
| 166 |
graph_storage: NetworkXStorage,
|
| 167 |
-
traverse_strategy:
|
| 168 |
text_chunks_storage: JsonKVStorage,
|
| 169 |
progress_bar: gr.Progress = None,
|
| 170 |
max_concurrent: int = 1000,
|
|
@@ -240,7 +235,7 @@ async def traverse_graph_for_aggregated(
|
|
| 240 |
"question": question,
|
| 241 |
"answer": context,
|
| 242 |
"loss": get_average_loss(
|
| 243 |
-
_process_batch, traverse_strategy
|
| 244 |
),
|
| 245 |
}
|
| 246 |
}
|
|
@@ -272,7 +267,7 @@ async def traverse_graph_for_aggregated(
|
|
| 272 |
"question": qa["question"],
|
| 273 |
"answer": qa["answer"],
|
| 274 |
"loss": get_average_loss(
|
| 275 |
-
_process_batch, traverse_strategy
|
| 276 |
),
|
| 277 |
}
|
| 278 |
return final_results
|
|
@@ -313,7 +308,7 @@ async def traverse_graph_for_atomic(
|
|
| 313 |
llm_client: OpenAIClient,
|
| 314 |
tokenizer: Tokenizer,
|
| 315 |
graph_storage: NetworkXStorage,
|
| 316 |
-
traverse_strategy:
|
| 317 |
text_chunks_storage: JsonKVStorage,
|
| 318 |
progress_bar: gr.Progress = None,
|
| 319 |
max_concurrent: int = 1000,
|
|
@@ -331,7 +326,6 @@ async def traverse_graph_for_atomic(
|
|
| 331 |
:return: question and answer
|
| 332 |
"""
|
| 333 |
|
| 334 |
-
assert traverse_strategy.qa_form == "atomic"
|
| 335 |
semaphore = asyncio.Semaphore(max_concurrent)
|
| 336 |
|
| 337 |
def _parse_qa(qa: str) -> tuple:
|
|
@@ -429,7 +423,7 @@ async def traverse_graph_for_multi_hop(
|
|
| 429 |
llm_client: OpenAIClient,
|
| 430 |
tokenizer: Tokenizer,
|
| 431 |
graph_storage: NetworkXStorage,
|
| 432 |
-
traverse_strategy:
|
| 433 |
text_chunks_storage: JsonKVStorage,
|
| 434 |
progress_bar: gr.Progress = None,
|
| 435 |
max_concurrent: int = 1000,
|
|
@@ -517,7 +511,7 @@ async def traverse_graph_for_multi_hop(
|
|
| 517 |
"question": question,
|
| 518 |
"answer": answer,
|
| 519 |
"loss": get_average_loss(
|
| 520 |
-
_process_batch, traverse_strategy
|
| 521 |
),
|
| 522 |
}
|
| 523 |
}
|
|
|
|
| 1 |
import asyncio
|
| 2 |
+
from typing import Dict
|
| 3 |
|
| 4 |
import gradio as gr
|
| 5 |
from tqdm.asyncio import tqdm as tqdm_async
|
| 6 |
|
| 7 |
+
from graphgen.models import JsonKVStorage, NetworkXStorage, OpenAIClient, Tokenizer
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
from graphgen.operators.build_kg.split_kg import get_batches_with_strategy
|
| 9 |
from graphgen.templates import (
|
| 10 |
ANSWER_REPHRASING_PROMPT,
|
|
|
|
| 159 |
llm_client: OpenAIClient,
|
| 160 |
tokenizer: Tokenizer,
|
| 161 |
graph_storage: NetworkXStorage,
|
| 162 |
+
traverse_strategy: Dict,
|
| 163 |
text_chunks_storage: JsonKVStorage,
|
| 164 |
progress_bar: gr.Progress = None,
|
| 165 |
max_concurrent: int = 1000,
|
|
|
|
| 235 |
"question": question,
|
| 236 |
"answer": context,
|
| 237 |
"loss": get_average_loss(
|
| 238 |
+
_process_batch, traverse_strategy["loss_strategy"]
|
| 239 |
),
|
| 240 |
}
|
| 241 |
}
|
|
|
|
| 267 |
"question": qa["question"],
|
| 268 |
"answer": qa["answer"],
|
| 269 |
"loss": get_average_loss(
|
| 270 |
+
_process_batch, traverse_strategy["loss_strategy"]
|
| 271 |
),
|
| 272 |
}
|
| 273 |
return final_results
|
|
|
|
| 308 |
llm_client: OpenAIClient,
|
| 309 |
tokenizer: Tokenizer,
|
| 310 |
graph_storage: NetworkXStorage,
|
| 311 |
+
traverse_strategy: Dict,
|
| 312 |
text_chunks_storage: JsonKVStorage,
|
| 313 |
progress_bar: gr.Progress = None,
|
| 314 |
max_concurrent: int = 1000,
|
|
|
|
| 326 |
:return: question and answer
|
| 327 |
"""
|
| 328 |
|
|
|
|
| 329 |
semaphore = asyncio.Semaphore(max_concurrent)
|
| 330 |
|
| 331 |
def _parse_qa(qa: str) -> tuple:
|
|
|
|
| 423 |
llm_client: OpenAIClient,
|
| 424 |
tokenizer: Tokenizer,
|
| 425 |
graph_storage: NetworkXStorage,
|
| 426 |
+
traverse_strategy: Dict,
|
| 427 |
text_chunks_storage: JsonKVStorage,
|
| 428 |
progress_bar: gr.Progress = None,
|
| 429 |
max_concurrent: int = 1000,
|
|
|
|
| 511 |
"question": question,
|
| 512 |
"answer": answer,
|
| 513 |
"loss": get_average_loss(
|
| 514 |
+
_process_batch, traverse_strategy["loss_strategy"]
|
| 515 |
),
|
| 516 |
}
|
| 517 |
}
|
webui/app.py
CHANGED
|
@@ -39,27 +39,32 @@ def init_graph_gen(config: dict, env: dict) -> GraphGen:
|
|
| 39 |
set_logger(log_file, if_stream=True)
|
| 40 |
os.environ.update({k: str(v) for k, v in env.items()})
|
| 41 |
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
graph_gen.synthesizer_llm_client = OpenAIClient(
|
| 45 |
model_name=env.get("SYNTHESIZER_MODEL", ""),
|
| 46 |
base_url=env.get("SYNTHESIZER_BASE_URL", ""),
|
| 47 |
api_key=env.get("SYNTHESIZER_API_KEY", ""),
|
| 48 |
request_limit=True,
|
| 49 |
rpm=RPM(env.get("RPM", 1000)),
|
| 50 |
tpm=TPM(env.get("TPM", 50000)),
|
|
|
|
| 51 |
)
|
| 52 |
-
|
| 53 |
-
graph_gen.trainee_llm_client = OpenAIClient(
|
| 54 |
model_name=env.get("TRAINEE_MODEL", ""),
|
| 55 |
base_url=env.get("TRAINEE_BASE_URL", ""),
|
| 56 |
api_key=env.get("TRAINEE_API_KEY", ""),
|
| 57 |
request_limit=True,
|
| 58 |
rpm=RPM(env.get("RPM", 1000)),
|
| 59 |
tpm=TPM(env.get("TPM", 50000)),
|
|
|
|
| 60 |
)
|
| 61 |
|
| 62 |
-
graph_gen
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 63 |
|
| 64 |
return graph_gen
|
| 65 |
|
|
@@ -78,27 +83,32 @@ def run_graphgen(params: WebuiParams, progress=gr.Progress()):
|
|
| 78 |
"chunk_size": params.chunk_size,
|
| 79 |
"chunk_overlap": params.chunk_overlap,
|
| 80 |
},
|
| 81 |
-
"output_data_type": params.output_data_type,
|
| 82 |
-
"output_data_format": params.output_data_format,
|
| 83 |
-
"tokenizer": params.tokenizer,
|
| 84 |
"search": {"enabled": False},
|
| 85 |
-
"
|
| 86 |
"enabled": params.if_trainee_model,
|
| 87 |
"quiz_samples": params.quiz_samples,
|
| 88 |
},
|
| 89 |
-
"
|
| 90 |
-
"
|
| 91 |
-
"
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 98 |
},
|
| 99 |
}
|
| 100 |
|
| 101 |
env = {
|
|
|
|
| 102 |
"SYNTHESIZER_BASE_URL": params.synthesizer_url,
|
| 103 |
"SYNTHESIZER_MODEL": params.synthesizer_model,
|
| 104 |
"TRAINEE_BASE_URL": params.trainee_url,
|
|
@@ -128,19 +138,18 @@ def run_graphgen(params: WebuiParams, progress=gr.Progress()):
|
|
| 128 |
|
| 129 |
try:
|
| 130 |
# Process the data
|
| 131 |
-
graph_gen.insert()
|
| 132 |
|
| 133 |
if config["if_trainee_model"]:
|
| 134 |
-
#
|
| 135 |
-
graph_gen.
|
| 136 |
-
|
| 137 |
-
# Judge statements
|
| 138 |
-
graph_gen.judge()
|
| 139 |
else:
|
| 140 |
-
|
| 141 |
|
| 142 |
-
|
| 143 |
-
|
|
|
|
|
|
|
| 144 |
|
| 145 |
# Save output
|
| 146 |
output_data = graph_gen.qa_storage.data
|
|
@@ -249,6 +258,9 @@ with gr.Blocks(title="GraphGen Demo", theme=gr.themes.Glass(), css=css) as demo:
|
|
| 249 |
)
|
| 250 |
|
| 251 |
with gr.Accordion(label=_("Model Config"), open=False):
|
|
|
|
|
|
|
|
|
|
| 252 |
synthesizer_url = gr.Textbox(
|
| 253 |
label="Synthesizer URL",
|
| 254 |
value="https://api.siliconflow.cn/v1",
|
|
@@ -300,9 +312,6 @@ with gr.Blocks(title="GraphGen Demo", theme=gr.themes.Glass(), css=css) as demo:
|
|
| 300 |
step=100,
|
| 301 |
interactive=True,
|
| 302 |
)
|
| 303 |
-
tokenizer = gr.Textbox(
|
| 304 |
-
label="Tokenizer", value="cl100k_base", interactive=True
|
| 305 |
-
)
|
| 306 |
output_data_type = gr.Radio(
|
| 307 |
choices=["atomic", "multi_hop", "aggregated"],
|
| 308 |
label="Output Data Type",
|
|
|
|
| 39 |
set_logger(log_file, if_stream=True)
|
| 40 |
os.environ.update({k: str(v) for k, v in env.items()})
|
| 41 |
|
| 42 |
+
tokenizer_instance = Tokenizer(config.get("tokenizer", "cl100k_base"))
|
| 43 |
+
synthesizer_llm_client = OpenAIClient(
|
|
|
|
| 44 |
model_name=env.get("SYNTHESIZER_MODEL", ""),
|
| 45 |
base_url=env.get("SYNTHESIZER_BASE_URL", ""),
|
| 46 |
api_key=env.get("SYNTHESIZER_API_KEY", ""),
|
| 47 |
request_limit=True,
|
| 48 |
rpm=RPM(env.get("RPM", 1000)),
|
| 49 |
tpm=TPM(env.get("TPM", 50000)),
|
| 50 |
+
tokenizer=tokenizer_instance,
|
| 51 |
)
|
| 52 |
+
trainee_llm_client = OpenAIClient(
|
|
|
|
| 53 |
model_name=env.get("TRAINEE_MODEL", ""),
|
| 54 |
base_url=env.get("TRAINEE_BASE_URL", ""),
|
| 55 |
api_key=env.get("TRAINEE_API_KEY", ""),
|
| 56 |
request_limit=True,
|
| 57 |
rpm=RPM(env.get("RPM", 1000)),
|
| 58 |
tpm=TPM(env.get("TPM", 50000)),
|
| 59 |
+
tokenizer=tokenizer_instance,
|
| 60 |
)
|
| 61 |
|
| 62 |
+
graph_gen = GraphGen(
|
| 63 |
+
working_dir=working_dir,
|
| 64 |
+
tokenizer_instance=tokenizer_instance,
|
| 65 |
+
synthesizer_llm_client=synthesizer_llm_client,
|
| 66 |
+
trainee_llm_client=trainee_llm_client,
|
| 67 |
+
)
|
| 68 |
|
| 69 |
return graph_gen
|
| 70 |
|
|
|
|
| 83 |
"chunk_size": params.chunk_size,
|
| 84 |
"chunk_overlap": params.chunk_overlap,
|
| 85 |
},
|
|
|
|
|
|
|
|
|
|
| 86 |
"search": {"enabled": False},
|
| 87 |
+
"quiz_and_judge": {
|
| 88 |
"enabled": params.if_trainee_model,
|
| 89 |
"quiz_samples": params.quiz_samples,
|
| 90 |
},
|
| 91 |
+
"partition": {
|
| 92 |
+
"method": "ece",
|
| 93 |
+
"method_params": {
|
| 94 |
+
"bidirectional": params.bidirectional,
|
| 95 |
+
"expand_method": params.expand_method,
|
| 96 |
+
"max_extra_edges": params.max_extra_edges,
|
| 97 |
+
"max_tokens": params.max_tokens,
|
| 98 |
+
"max_depth": params.max_depth,
|
| 99 |
+
"edge_sampling": params.edge_sampling,
|
| 100 |
+
"isolated_node_strategy": params.isolated_node_strategy,
|
| 101 |
+
"loss_strategy": params.loss_strategy,
|
| 102 |
+
},
|
| 103 |
+
},
|
| 104 |
+
"generate": {
|
| 105 |
+
"mode": params.output_data_type,
|
| 106 |
+
"data_format": params.output_data_format,
|
| 107 |
},
|
| 108 |
}
|
| 109 |
|
| 110 |
env = {
|
| 111 |
+
"TOKENIZER_MODEL": params.tokenizer,
|
| 112 |
"SYNTHESIZER_BASE_URL": params.synthesizer_url,
|
| 113 |
"SYNTHESIZER_MODEL": params.synthesizer_model,
|
| 114 |
"TRAINEE_BASE_URL": params.trainee_url,
|
|
|
|
| 138 |
|
| 139 |
try:
|
| 140 |
# Process the data
|
| 141 |
+
graph_gen.insert(read_config=config["read"], split_config=config["split"])
|
| 142 |
|
| 143 |
if config["if_trainee_model"]:
|
| 144 |
+
# Quiz and Judge
|
| 145 |
+
graph_gen.quiz_and_judge(quiz_and_judge_config=config["quiz_and_judge"])
|
|
|
|
|
|
|
|
|
|
| 146 |
else:
|
| 147 |
+
config["partition"]["method_params"]["edge_sampling"] = "random"
|
| 148 |
|
| 149 |
+
graph_gen.generate(
|
| 150 |
+
partition_config=config["partition"],
|
| 151 |
+
generate_config=config["generate"],
|
| 152 |
+
)
|
| 153 |
|
| 154 |
# Save output
|
| 155 |
output_data = graph_gen.qa_storage.data
|
|
|
|
| 258 |
)
|
| 259 |
|
| 260 |
with gr.Accordion(label=_("Model Config"), open=False):
|
| 261 |
+
tokenizer = gr.Textbox(
|
| 262 |
+
label="Tokenizer", value="cl100k_base", interactive=True
|
| 263 |
+
)
|
| 264 |
synthesizer_url = gr.Textbox(
|
| 265 |
label="Synthesizer URL",
|
| 266 |
value="https://api.siliconflow.cn/v1",
|
|
|
|
| 312 |
step=100,
|
| 313 |
interactive=True,
|
| 314 |
)
|
|
|
|
|
|
|
|
|
|
| 315 |
output_data_type = gr.Radio(
|
| 316 |
choices=["atomic", "multi_hop", "aggregated"],
|
| 317 |
label="Output Data Type",
|