Spaces:
Running
Running
github-actions[bot]
commited on
Commit
·
9e67c3b
1
Parent(s):
e83bd85
Auto-sync from demo at Tue Nov 25 11:19:13 UTC 2025
Browse files- graphgen/bases/base_partitioner.py +3 -3
- graphgen/bases/base_storage.py +28 -32
- graphgen/graphgen.py +34 -38
- graphgen/models/kg_builder/light_rag_kg_builder.py +6 -6
- graphgen/models/partitioner/anchor_bfs_partitioner.py +2 -2
- graphgen/models/partitioner/bfs_partitioner.py +2 -2
- graphgen/models/partitioner/dfs_partitioner.py +2 -2
- graphgen/models/partitioner/ece_partitioner.py +2 -2
- graphgen/models/partitioner/leiden_partitioner.py +2 -2
- graphgen/models/storage/json_storage.py +18 -18
- graphgen/models/storage/networkx_storage.py +19 -23
- graphgen/operators/partition/partition_kg.py +3 -3
- graphgen/operators/partition/pre_tokenize.py +3 -3
- graphgen/operators/quiz_and_judge/judge.py +9 -15
- graphgen/operators/quiz_and_judge/quiz.py +4 -4
- graphgen/operators/storage.py +59 -0
- requirements.txt +1 -1
graphgen/bases/base_partitioner.py
CHANGED
|
@@ -39,16 +39,16 @@ class BasePartitioner(ABC):
|
|
| 39 |
edges = comm.edges
|
| 40 |
nodes_data = []
|
| 41 |
for node in nodes:
|
| 42 |
-
node_data =
|
| 43 |
if node_data:
|
| 44 |
nodes_data.append((node, node_data))
|
| 45 |
edges_data = []
|
| 46 |
for u, v in edges:
|
| 47 |
-
edge_data =
|
| 48 |
if edge_data:
|
| 49 |
edges_data.append((u, v, edge_data))
|
| 50 |
else:
|
| 51 |
-
edge_data =
|
| 52 |
if edge_data:
|
| 53 |
edges_data.append((v, u, edge_data))
|
| 54 |
batches.append((nodes_data, edges_data))
|
|
|
|
| 39 |
edges = comm.edges
|
| 40 |
nodes_data = []
|
| 41 |
for node in nodes:
|
| 42 |
+
node_data = g.get_node(node)
|
| 43 |
if node_data:
|
| 44 |
nodes_data.append((node, node_data))
|
| 45 |
edges_data = []
|
| 46 |
for u, v in edges:
|
| 47 |
+
edge_data = g.get_edge(u, v)
|
| 48 |
if edge_data:
|
| 49 |
edges_data.append((u, v, edge_data))
|
| 50 |
else:
|
| 51 |
+
edge_data = g.get_edge(v, u)
|
| 52 |
if edge_data:
|
| 53 |
edges_data.append((v, u, edge_data))
|
| 54 |
batches.append((nodes_data, edges_data))
|
graphgen/bases/base_storage.py
CHANGED
|
@@ -9,103 +9,99 @@ class StorageNameSpace:
|
|
| 9 |
working_dir: str = None
|
| 10 |
namespace: str = None
|
| 11 |
|
| 12 |
-
|
| 13 |
"""commit the storage operations after indexing"""
|
| 14 |
|
| 15 |
-
|
| 16 |
"""commit the storage operations after querying"""
|
| 17 |
|
| 18 |
|
| 19 |
class BaseListStorage(Generic[T], StorageNameSpace):
|
| 20 |
-
|
| 21 |
raise NotImplementedError
|
| 22 |
|
| 23 |
-
|
| 24 |
raise NotImplementedError
|
| 25 |
|
| 26 |
-
|
| 27 |
raise NotImplementedError
|
| 28 |
|
| 29 |
-
|
| 30 |
raise NotImplementedError
|
| 31 |
|
| 32 |
-
|
| 33 |
raise NotImplementedError
|
| 34 |
|
| 35 |
|
| 36 |
class BaseKVStorage(Generic[T], StorageNameSpace):
|
| 37 |
-
|
| 38 |
raise NotImplementedError
|
| 39 |
|
| 40 |
-
|
| 41 |
raise NotImplementedError
|
| 42 |
|
| 43 |
-
|
| 44 |
self, ids: list[str], fields: Union[set[str], None] = None
|
| 45 |
) -> list[Union[T, None]]:
|
| 46 |
raise NotImplementedError
|
| 47 |
|
| 48 |
-
|
| 49 |
raise NotImplementedError
|
| 50 |
|
| 51 |
-
|
| 52 |
"""return un-exist keys"""
|
| 53 |
raise NotImplementedError
|
| 54 |
|
| 55 |
-
|
| 56 |
raise NotImplementedError
|
| 57 |
|
| 58 |
-
|
| 59 |
raise NotImplementedError
|
| 60 |
|
| 61 |
|
| 62 |
class BaseGraphStorage(StorageNameSpace):
|
| 63 |
-
|
| 64 |
raise NotImplementedError
|
| 65 |
|
| 66 |
-
|
| 67 |
raise NotImplementedError
|
| 68 |
|
| 69 |
-
|
| 70 |
raise NotImplementedError
|
| 71 |
|
| 72 |
-
|
| 73 |
raise NotImplementedError
|
| 74 |
|
| 75 |
-
|
| 76 |
raise NotImplementedError
|
| 77 |
|
| 78 |
-
|
| 79 |
raise NotImplementedError
|
| 80 |
|
| 81 |
-
|
| 82 |
raise NotImplementedError
|
| 83 |
|
| 84 |
-
|
| 85 |
-
self, source_node_id: str, target_node_id: str
|
| 86 |
-
) -> Union[dict, None]:
|
| 87 |
raise NotImplementedError
|
| 88 |
|
| 89 |
-
|
| 90 |
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
|
| 91 |
):
|
| 92 |
raise NotImplementedError
|
| 93 |
|
| 94 |
-
|
| 95 |
raise NotImplementedError
|
| 96 |
|
| 97 |
-
|
| 98 |
-
self, source_node_id: str
|
| 99 |
-
) -> Union[list[tuple[str, str]], None]:
|
| 100 |
raise NotImplementedError
|
| 101 |
|
| 102 |
-
|
| 103 |
raise NotImplementedError
|
| 104 |
|
| 105 |
-
|
| 106 |
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
|
| 107 |
):
|
| 108 |
raise NotImplementedError
|
| 109 |
|
| 110 |
-
|
| 111 |
raise NotImplementedError
|
|
|
|
| 9 |
working_dir: str = None
|
| 10 |
namespace: str = None
|
| 11 |
|
| 12 |
+
def index_done_callback(self):
|
| 13 |
"""commit the storage operations after indexing"""
|
| 14 |
|
| 15 |
+
def query_done_callback(self):
|
| 16 |
"""commit the storage operations after querying"""
|
| 17 |
|
| 18 |
|
| 19 |
class BaseListStorage(Generic[T], StorageNameSpace):
|
| 20 |
+
def all_items(self) -> list[T]:
|
| 21 |
raise NotImplementedError
|
| 22 |
|
| 23 |
+
def get_by_index(self, index: int) -> Union[T, None]:
|
| 24 |
raise NotImplementedError
|
| 25 |
|
| 26 |
+
def append(self, data: T):
|
| 27 |
raise NotImplementedError
|
| 28 |
|
| 29 |
+
def upsert(self, data: list[T]):
|
| 30 |
raise NotImplementedError
|
| 31 |
|
| 32 |
+
def drop(self):
|
| 33 |
raise NotImplementedError
|
| 34 |
|
| 35 |
|
| 36 |
class BaseKVStorage(Generic[T], StorageNameSpace):
|
| 37 |
+
def all_keys(self) -> list[str]:
|
| 38 |
raise NotImplementedError
|
| 39 |
|
| 40 |
+
def get_by_id(self, id: str) -> Union[T, None]:
|
| 41 |
raise NotImplementedError
|
| 42 |
|
| 43 |
+
def get_by_ids(
|
| 44 |
self, ids: list[str], fields: Union[set[str], None] = None
|
| 45 |
) -> list[Union[T, None]]:
|
| 46 |
raise NotImplementedError
|
| 47 |
|
| 48 |
+
def get_all(self) -> dict[str, T]:
|
| 49 |
raise NotImplementedError
|
| 50 |
|
| 51 |
+
def filter_keys(self, data: list[str]) -> set[str]:
|
| 52 |
"""return un-exist keys"""
|
| 53 |
raise NotImplementedError
|
| 54 |
|
| 55 |
+
def upsert(self, data: dict[str, T]):
|
| 56 |
raise NotImplementedError
|
| 57 |
|
| 58 |
+
def drop(self):
|
| 59 |
raise NotImplementedError
|
| 60 |
|
| 61 |
|
| 62 |
class BaseGraphStorage(StorageNameSpace):
|
| 63 |
+
def has_node(self, node_id: str) -> bool:
|
| 64 |
raise NotImplementedError
|
| 65 |
|
| 66 |
+
def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
|
| 67 |
raise NotImplementedError
|
| 68 |
|
| 69 |
+
def node_degree(self, node_id: str) -> int:
|
| 70 |
raise NotImplementedError
|
| 71 |
|
| 72 |
+
def edge_degree(self, src_id: str, tgt_id: str) -> int:
|
| 73 |
raise NotImplementedError
|
| 74 |
|
| 75 |
+
def get_node(self, node_id: str) -> Union[dict, None]:
|
| 76 |
raise NotImplementedError
|
| 77 |
|
| 78 |
+
def update_node(self, node_id: str, node_data: dict[str, str]):
|
| 79 |
raise NotImplementedError
|
| 80 |
|
| 81 |
+
def get_all_nodes(self) -> Union[list[tuple[str, dict]], None]:
|
| 82 |
raise NotImplementedError
|
| 83 |
|
| 84 |
+
def get_edge(self, source_node_id: str, target_node_id: str) -> Union[dict, None]:
|
|
|
|
|
|
|
| 85 |
raise NotImplementedError
|
| 86 |
|
| 87 |
+
def update_edge(
|
| 88 |
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
|
| 89 |
):
|
| 90 |
raise NotImplementedError
|
| 91 |
|
| 92 |
+
def get_all_edges(self) -> Union[list[tuple[str, str, dict]], None]:
|
| 93 |
raise NotImplementedError
|
| 94 |
|
| 95 |
+
def get_node_edges(self, source_node_id: str) -> Union[list[tuple[str, str]], None]:
|
|
|
|
|
|
|
| 96 |
raise NotImplementedError
|
| 97 |
|
| 98 |
+
def upsert_node(self, node_id: str, node_data: dict[str, str]):
|
| 99 |
raise NotImplementedError
|
| 100 |
|
| 101 |
+
def upsert_edge(
|
| 102 |
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
|
| 103 |
):
|
| 104 |
raise NotImplementedError
|
| 105 |
|
| 106 |
+
def delete_node(self, node_id: str):
|
| 107 |
raise NotImplementedError
|
graphgen/graphgen.py
CHANGED
|
@@ -104,15 +104,15 @@ class GraphGen:
|
|
| 104 |
# TODO: configurable whether to use coreference resolution
|
| 105 |
|
| 106 |
new_docs = {compute_mm_hash(doc, prefix="doc-"): doc for doc in data}
|
| 107 |
-
_add_doc_keys =
|
| 108 |
new_docs = {k: v for k, v in new_docs.items() if k in _add_doc_keys}
|
| 109 |
|
| 110 |
if len(new_docs) == 0:
|
| 111 |
logger.warning("All documents are already in the storage")
|
| 112 |
return
|
| 113 |
|
| 114 |
-
|
| 115 |
-
|
| 116 |
|
| 117 |
@op("chunk", deps=["read"])
|
| 118 |
@async_to_sync_method
|
|
@@ -121,7 +121,7 @@ class GraphGen:
|
|
| 121 |
chunk documents into smaller pieces from full_docs_storage if not already present
|
| 122 |
"""
|
| 123 |
|
| 124 |
-
new_docs =
|
| 125 |
if len(new_docs) == 0:
|
| 126 |
logger.warning("All documents are already in the storage")
|
| 127 |
return
|
|
@@ -133,9 +133,7 @@ class GraphGen:
|
|
| 133 |
**chunk_config,
|
| 134 |
)
|
| 135 |
|
| 136 |
-
_add_chunk_keys =
|
| 137 |
-
list(inserting_chunks.keys())
|
| 138 |
-
)
|
| 139 |
inserting_chunks = {
|
| 140 |
k: v for k, v in inserting_chunks.items() if k in _add_chunk_keys
|
| 141 |
}
|
|
@@ -144,10 +142,10 @@ class GraphGen:
|
|
| 144 |
logger.warning("All chunks are already in the storage")
|
| 145 |
return
|
| 146 |
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
|
| 151 |
|
| 152 |
@op("build_kg", deps=["chunk"])
|
| 153 |
@async_to_sync_method
|
|
@@ -156,7 +154,7 @@ class GraphGen:
|
|
| 156 |
build knowledge graph from text chunks
|
| 157 |
"""
|
| 158 |
# Step 1: get new chunks according to meta and chunks storage
|
| 159 |
-
inserting_chunks =
|
| 160 |
if len(inserting_chunks) == 0:
|
| 161 |
logger.warning("All chunks are already in the storage")
|
| 162 |
return
|
|
@@ -174,9 +172,9 @@ class GraphGen:
|
|
| 174 |
return
|
| 175 |
|
| 176 |
# Step 3: mark meta
|
| 177 |
-
|
| 178 |
-
|
| 179 |
-
|
| 180 |
|
| 181 |
return _add_entities_and_relations
|
| 182 |
|
|
@@ -185,7 +183,7 @@ class GraphGen:
|
|
| 185 |
async def search(self, search_config: Dict):
|
| 186 |
logger.info("[Search] %s ...", ", ".join(search_config["data_sources"]))
|
| 187 |
|
| 188 |
-
seeds =
|
| 189 |
if len(seeds) == 0:
|
| 190 |
logger.warning("All documents are already been searched")
|
| 191 |
return
|
|
@@ -194,19 +192,17 @@ class GraphGen:
|
|
| 194 |
search_config=search_config,
|
| 195 |
)
|
| 196 |
|
| 197 |
-
_add_search_keys =
|
| 198 |
-
list(search_results.keys())
|
| 199 |
-
)
|
| 200 |
search_results = {
|
| 201 |
k: v for k, v in search_results.items() if k in _add_search_keys
|
| 202 |
}
|
| 203 |
if len(search_results) == 0:
|
| 204 |
logger.warning("All search results are already in the storage")
|
| 205 |
return
|
| 206 |
-
|
| 207 |
-
|
| 208 |
-
|
| 209 |
-
|
| 210 |
|
| 211 |
@op("quiz_and_judge", deps=["build_kg"])
|
| 212 |
@async_to_sync_method
|
|
@@ -240,8 +236,8 @@ class GraphGen:
|
|
| 240 |
progress_bar=self.progress_bar,
|
| 241 |
)
|
| 242 |
|
| 243 |
-
|
| 244 |
-
|
| 245 |
|
| 246 |
logger.info("Shutting down trainee LLM client.")
|
| 247 |
self.trainee_llm_client.shutdown()
|
|
@@ -258,7 +254,7 @@ class GraphGen:
|
|
| 258 |
self.tokenizer_instance,
|
| 259 |
partition_config,
|
| 260 |
)
|
| 261 |
-
|
| 262 |
return batches
|
| 263 |
|
| 264 |
@op("extract", deps=["chunk"])
|
|
@@ -276,10 +272,10 @@ class GraphGen:
|
|
| 276 |
logger.warning("No information extracted")
|
| 277 |
return
|
| 278 |
|
| 279 |
-
|
| 280 |
-
|
| 281 |
-
|
| 282 |
-
|
| 283 |
|
| 284 |
@op("generate", deps=["partition"])
|
| 285 |
@async_to_sync_method
|
|
@@ -303,17 +299,17 @@ class GraphGen:
|
|
| 303 |
return
|
| 304 |
|
| 305 |
# Step 3: store the generated QA pairs
|
| 306 |
-
|
| 307 |
-
|
| 308 |
|
| 309 |
@async_to_sync_method
|
| 310 |
async def clear(self):
|
| 311 |
-
|
| 312 |
-
|
| 313 |
-
|
| 314 |
-
|
| 315 |
-
|
| 316 |
-
|
| 317 |
|
| 318 |
logger.info("All caches are cleared")
|
| 319 |
|
|
|
|
| 104 |
# TODO: configurable whether to use coreference resolution
|
| 105 |
|
| 106 |
new_docs = {compute_mm_hash(doc, prefix="doc-"): doc for doc in data}
|
| 107 |
+
_add_doc_keys = self.full_docs_storage.filter_keys(list(new_docs.keys()))
|
| 108 |
new_docs = {k: v for k, v in new_docs.items() if k in _add_doc_keys}
|
| 109 |
|
| 110 |
if len(new_docs) == 0:
|
| 111 |
logger.warning("All documents are already in the storage")
|
| 112 |
return
|
| 113 |
|
| 114 |
+
self.full_docs_storage.upsert(new_docs)
|
| 115 |
+
self.full_docs_storage.index_done_callback()
|
| 116 |
|
| 117 |
@op("chunk", deps=["read"])
|
| 118 |
@async_to_sync_method
|
|
|
|
| 121 |
chunk documents into smaller pieces from full_docs_storage if not already present
|
| 122 |
"""
|
| 123 |
|
| 124 |
+
new_docs = self.meta_storage.get_new_data(self.full_docs_storage)
|
| 125 |
if len(new_docs) == 0:
|
| 126 |
logger.warning("All documents are already in the storage")
|
| 127 |
return
|
|
|
|
| 133 |
**chunk_config,
|
| 134 |
)
|
| 135 |
|
| 136 |
+
_add_chunk_keys = self.chunks_storage.filter_keys(list(inserting_chunks.keys()))
|
|
|
|
|
|
|
| 137 |
inserting_chunks = {
|
| 138 |
k: v for k, v in inserting_chunks.items() if k in _add_chunk_keys
|
| 139 |
}
|
|
|
|
| 142 |
logger.warning("All chunks are already in the storage")
|
| 143 |
return
|
| 144 |
|
| 145 |
+
self.chunks_storage.upsert(inserting_chunks)
|
| 146 |
+
self.chunks_storage.index_done_callback()
|
| 147 |
+
self.meta_storage.mark_done(self.full_docs_storage)
|
| 148 |
+
self.meta_storage.index_done_callback()
|
| 149 |
|
| 150 |
@op("build_kg", deps=["chunk"])
|
| 151 |
@async_to_sync_method
|
|
|
|
| 154 |
build knowledge graph from text chunks
|
| 155 |
"""
|
| 156 |
# Step 1: get new chunks according to meta and chunks storage
|
| 157 |
+
inserting_chunks = self.meta_storage.get_new_data(self.chunks_storage)
|
| 158 |
if len(inserting_chunks) == 0:
|
| 159 |
logger.warning("All chunks are already in the storage")
|
| 160 |
return
|
|
|
|
| 172 |
return
|
| 173 |
|
| 174 |
# Step 3: mark meta
|
| 175 |
+
self.graph_storage.index_done_callback()
|
| 176 |
+
self.meta_storage.mark_done(self.chunks_storage)
|
| 177 |
+
self.meta_storage.index_done_callback()
|
| 178 |
|
| 179 |
return _add_entities_and_relations
|
| 180 |
|
|
|
|
| 183 |
async def search(self, search_config: Dict):
|
| 184 |
logger.info("[Search] %s ...", ", ".join(search_config["data_sources"]))
|
| 185 |
|
| 186 |
+
seeds = self.meta_storage.get_new_data(self.full_docs_storage)
|
| 187 |
if len(seeds) == 0:
|
| 188 |
logger.warning("All documents are already been searched")
|
| 189 |
return
|
|
|
|
| 192 |
search_config=search_config,
|
| 193 |
)
|
| 194 |
|
| 195 |
+
_add_search_keys = self.search_storage.filter_keys(list(search_results.keys()))
|
|
|
|
|
|
|
| 196 |
search_results = {
|
| 197 |
k: v for k, v in search_results.items() if k in _add_search_keys
|
| 198 |
}
|
| 199 |
if len(search_results) == 0:
|
| 200 |
logger.warning("All search results are already in the storage")
|
| 201 |
return
|
| 202 |
+
self.search_storage.upsert(search_results)
|
| 203 |
+
self.search_storage.index_done_callback()
|
| 204 |
+
self.meta_storage.mark_done(self.full_docs_storage)
|
| 205 |
+
self.meta_storage.index_done_callback()
|
| 206 |
|
| 207 |
@op("quiz_and_judge", deps=["build_kg"])
|
| 208 |
@async_to_sync_method
|
|
|
|
| 236 |
progress_bar=self.progress_bar,
|
| 237 |
)
|
| 238 |
|
| 239 |
+
self.rephrase_storage.index_done_callback()
|
| 240 |
+
_update_relations.index_done_callback()
|
| 241 |
|
| 242 |
logger.info("Shutting down trainee LLM client.")
|
| 243 |
self.trainee_llm_client.shutdown()
|
|
|
|
| 254 |
self.tokenizer_instance,
|
| 255 |
partition_config,
|
| 256 |
)
|
| 257 |
+
self.partition_storage.upsert(batches)
|
| 258 |
return batches
|
| 259 |
|
| 260 |
@op("extract", deps=["chunk"])
|
|
|
|
| 272 |
logger.warning("No information extracted")
|
| 273 |
return
|
| 274 |
|
| 275 |
+
self.extract_storage.upsert(results)
|
| 276 |
+
self.extract_storage.index_done_callback()
|
| 277 |
+
self.meta_storage.mark_done(self.chunks_storage)
|
| 278 |
+
self.meta_storage.index_done_callback()
|
| 279 |
|
| 280 |
@op("generate", deps=["partition"])
|
| 281 |
@async_to_sync_method
|
|
|
|
| 299 |
return
|
| 300 |
|
| 301 |
# Step 3: store the generated QA pairs
|
| 302 |
+
self.qa_storage.upsert(results)
|
| 303 |
+
self.qa_storage.index_done_callback()
|
| 304 |
|
| 305 |
@async_to_sync_method
|
| 306 |
async def clear(self):
|
| 307 |
+
self.full_docs_storage.drop()
|
| 308 |
+
self.chunks_storage.drop()
|
| 309 |
+
self.search_storage.drop()
|
| 310 |
+
self.graph_storage.clear()
|
| 311 |
+
self.rephrase_storage.drop()
|
| 312 |
+
self.qa_storage.drop()
|
| 313 |
|
| 314 |
logger.info("All caches are cleared")
|
| 315 |
|
graphgen/models/kg_builder/light_rag_kg_builder.py
CHANGED
|
@@ -105,7 +105,7 @@ class LightRAGKGBuilder(BaseKGBuilder):
|
|
| 105 |
source_ids = []
|
| 106 |
descriptions = []
|
| 107 |
|
| 108 |
-
node =
|
| 109 |
if node is not None:
|
| 110 |
entity_types.append(node["entity_type"])
|
| 111 |
source_ids.extend(
|
|
@@ -134,7 +134,7 @@ class LightRAGKGBuilder(BaseKGBuilder):
|
|
| 134 |
"description": description,
|
| 135 |
"source_id": source_id,
|
| 136 |
}
|
| 137 |
-
|
| 138 |
|
| 139 |
async def merge_edges(
|
| 140 |
self,
|
|
@@ -146,7 +146,7 @@ class LightRAGKGBuilder(BaseKGBuilder):
|
|
| 146 |
source_ids = []
|
| 147 |
descriptions = []
|
| 148 |
|
| 149 |
-
edge =
|
| 150 |
if edge is not None:
|
| 151 |
source_ids.extend(
|
| 152 |
split_string_by_multi_markers(edge["source_id"], ["<SEP>"])
|
|
@@ -161,8 +161,8 @@ class LightRAGKGBuilder(BaseKGBuilder):
|
|
| 161 |
)
|
| 162 |
|
| 163 |
for insert_id in [src_id, tgt_id]:
|
| 164 |
-
if not
|
| 165 |
-
|
| 166 |
insert_id,
|
| 167 |
node_data={
|
| 168 |
"source_id": source_id,
|
|
@@ -175,7 +175,7 @@ class LightRAGKGBuilder(BaseKGBuilder):
|
|
| 175 |
f"({src_id}, {tgt_id})", description
|
| 176 |
)
|
| 177 |
|
| 178 |
-
|
| 179 |
src_id,
|
| 180 |
tgt_id,
|
| 181 |
edge_data={"source_id": source_id, "description": description},
|
|
|
|
| 105 |
source_ids = []
|
| 106 |
descriptions = []
|
| 107 |
|
| 108 |
+
node = kg_instance.get_node(entity_name)
|
| 109 |
if node is not None:
|
| 110 |
entity_types.append(node["entity_type"])
|
| 111 |
source_ids.extend(
|
|
|
|
| 134 |
"description": description,
|
| 135 |
"source_id": source_id,
|
| 136 |
}
|
| 137 |
+
kg_instance.upsert_node(entity_name, node_data=node_data)
|
| 138 |
|
| 139 |
async def merge_edges(
|
| 140 |
self,
|
|
|
|
| 146 |
source_ids = []
|
| 147 |
descriptions = []
|
| 148 |
|
| 149 |
+
edge = kg_instance.get_edge(src_id, tgt_id)
|
| 150 |
if edge is not None:
|
| 151 |
source_ids.extend(
|
| 152 |
split_string_by_multi_markers(edge["source_id"], ["<SEP>"])
|
|
|
|
| 161 |
)
|
| 162 |
|
| 163 |
for insert_id in [src_id, tgt_id]:
|
| 164 |
+
if not kg_instance.has_node(insert_id):
|
| 165 |
+
kg_instance.upsert_node(
|
| 166 |
insert_id,
|
| 167 |
node_data={
|
| 168 |
"source_id": source_id,
|
|
|
|
| 175 |
f"({src_id}, {tgt_id})", description
|
| 176 |
)
|
| 177 |
|
| 178 |
+
kg_instance.upsert_edge(
|
| 179 |
src_id,
|
| 180 |
tgt_id,
|
| 181 |
edge_data={"source_id": source_id, "description": description},
|
graphgen/models/partitioner/anchor_bfs_partitioner.py
CHANGED
|
@@ -36,8 +36,8 @@ class AnchorBFSPartitioner(BFSPartitioner):
|
|
| 36 |
max_units_per_community: int = 1,
|
| 37 |
**kwargs: Any,
|
| 38 |
) -> List[Community]:
|
| 39 |
-
nodes =
|
| 40 |
-
edges =
|
| 41 |
|
| 42 |
adj, _ = self._build_adjacency_list(nodes, edges)
|
| 43 |
|
|
|
|
| 36 |
max_units_per_community: int = 1,
|
| 37 |
**kwargs: Any,
|
| 38 |
) -> List[Community]:
|
| 39 |
+
nodes = g.get_all_nodes() # List[tuple[id, meta]]
|
| 40 |
+
edges = g.get_all_edges() # List[tuple[u, v, meta]]
|
| 41 |
|
| 42 |
adj, _ = self._build_adjacency_list(nodes, edges)
|
| 43 |
|
graphgen/models/partitioner/bfs_partitioner.py
CHANGED
|
@@ -23,8 +23,8 @@ class BFSPartitioner(BasePartitioner):
|
|
| 23 |
max_units_per_community: int = 1,
|
| 24 |
**kwargs: Any,
|
| 25 |
) -> List[Community]:
|
| 26 |
-
nodes =
|
| 27 |
-
edges =
|
| 28 |
|
| 29 |
adj, _ = self._build_adjacency_list(nodes, edges)
|
| 30 |
|
|
|
|
| 23 |
max_units_per_community: int = 1,
|
| 24 |
**kwargs: Any,
|
| 25 |
) -> List[Community]:
|
| 26 |
+
nodes = g.get_all_nodes()
|
| 27 |
+
edges = g.get_all_edges()
|
| 28 |
|
| 29 |
adj, _ = self._build_adjacency_list(nodes, edges)
|
| 30 |
|
graphgen/models/partitioner/dfs_partitioner.py
CHANGED
|
@@ -22,8 +22,8 @@ class DFSPartitioner(BasePartitioner):
|
|
| 22 |
max_units_per_community: int = 1,
|
| 23 |
**kwargs: Any,
|
| 24 |
) -> List[Community]:
|
| 25 |
-
nodes =
|
| 26 |
-
edges =
|
| 27 |
|
| 28 |
adj, _ = self._build_adjacency_list(nodes, edges)
|
| 29 |
|
|
|
|
| 22 |
max_units_per_community: int = 1,
|
| 23 |
**kwargs: Any,
|
| 24 |
) -> List[Community]:
|
| 25 |
+
nodes = g.get_all_nodes()
|
| 26 |
+
edges = g.get_all_edges()
|
| 27 |
|
| 28 |
adj, _ = self._build_adjacency_list(nodes, edges)
|
| 29 |
|
graphgen/models/partitioner/ece_partitioner.py
CHANGED
|
@@ -60,8 +60,8 @@ class ECEPartitioner(BFSPartitioner):
|
|
| 60 |
unit_sampling: str = "random",
|
| 61 |
**kwargs: Any,
|
| 62 |
) -> List[Community]:
|
| 63 |
-
nodes: List[Tuple[str, dict]] =
|
| 64 |
-
edges: List[Tuple[str, str, dict]] =
|
| 65 |
|
| 66 |
adj, _ = self._build_adjacency_list(nodes, edges)
|
| 67 |
node_dict = dict(nodes)
|
|
|
|
| 60 |
unit_sampling: str = "random",
|
| 61 |
**kwargs: Any,
|
| 62 |
) -> List[Community]:
|
| 63 |
+
nodes: List[Tuple[str, dict]] = g.get_all_nodes()
|
| 64 |
+
edges: List[Tuple[str, str, dict]] = g.get_all_edges()
|
| 65 |
|
| 66 |
adj, _ = self._build_adjacency_list(nodes, edges)
|
| 67 |
node_dict = dict(nodes)
|
graphgen/models/partitioner/leiden_partitioner.py
CHANGED
|
@@ -34,8 +34,8 @@ class LeidenPartitioner(BasePartitioner):
|
|
| 34 |
:param kwargs: other parameters for the leiden algorithm
|
| 35 |
:return:
|
| 36 |
"""
|
| 37 |
-
nodes =
|
| 38 |
-
edges =
|
| 39 |
|
| 40 |
node2cid: Dict[str, int] = await self._run_leiden(
|
| 41 |
nodes, edges, use_lcc, random_seed
|
|
|
|
| 34 |
:param kwargs: other parameters for the leiden algorithm
|
| 35 |
:return:
|
| 36 |
"""
|
| 37 |
+
nodes = g.get_all_nodes() # List[Tuple[str, dict]]
|
| 38 |
+
edges = g.get_all_edges() # List[Tuple[str, str, dict]]
|
| 39 |
|
| 40 |
node2cid: Dict[str, int] = await self._run_leiden(
|
| 41 |
nodes, edges, use_lcc, random_seed
|
graphgen/models/storage/json_storage.py
CHANGED
|
@@ -7,7 +7,7 @@ from graphgen.utils import load_json, logger, write_json
|
|
| 7 |
|
| 8 |
@dataclass
|
| 9 |
class JsonKVStorage(BaseKVStorage):
|
| 10 |
-
_data: dict[str,
|
| 11 |
|
| 12 |
def __post_init__(self):
|
| 13 |
self._file_name = os.path.join(self.working_dir, f"{self.namespace}.json")
|
|
@@ -18,16 +18,16 @@ class JsonKVStorage(BaseKVStorage):
|
|
| 18 |
def data(self):
|
| 19 |
return self._data
|
| 20 |
|
| 21 |
-
|
| 22 |
return list(self._data.keys())
|
| 23 |
|
| 24 |
-
|
| 25 |
write_json(self._data, self._file_name)
|
| 26 |
|
| 27 |
-
|
| 28 |
return self._data.get(id, None)
|
| 29 |
|
| 30 |
-
|
| 31 |
if fields is None:
|
| 32 |
return [self._data.get(id, None) for id in ids]
|
| 33 |
return [
|
|
@@ -39,19 +39,19 @@ class JsonKVStorage(BaseKVStorage):
|
|
| 39 |
for id in ids
|
| 40 |
]
|
| 41 |
|
| 42 |
-
|
| 43 |
return self._data
|
| 44 |
|
| 45 |
-
|
| 46 |
return {s for s in data if s not in self._data}
|
| 47 |
|
| 48 |
-
|
| 49 |
left_data = {k: v for k, v in data.items() if k not in self._data}
|
| 50 |
if left_data:
|
| 51 |
self._data.update(left_data)
|
| 52 |
return left_data
|
| 53 |
|
| 54 |
-
|
| 55 |
if self._data:
|
| 56 |
self._data.clear()
|
| 57 |
|
|
@@ -71,26 +71,26 @@ class JsonListStorage(BaseListStorage):
|
|
| 71 |
def data(self):
|
| 72 |
return self._data
|
| 73 |
|
| 74 |
-
|
| 75 |
return self._data
|
| 76 |
|
| 77 |
-
|
| 78 |
write_json(self._data, self._file_name)
|
| 79 |
|
| 80 |
-
|
| 81 |
if index < 0 or index >= len(self._data):
|
| 82 |
return None
|
| 83 |
return self._data[index]
|
| 84 |
|
| 85 |
-
|
| 86 |
self._data.append(data)
|
| 87 |
|
| 88 |
-
|
| 89 |
left_data = [d for d in data if d not in self._data]
|
| 90 |
self._data.extend(left_data)
|
| 91 |
return left_data
|
| 92 |
|
| 93 |
-
|
| 94 |
self._data = []
|
| 95 |
|
| 96 |
|
|
@@ -101,14 +101,14 @@ class MetaJsonKVStorage(JsonKVStorage):
|
|
| 101 |
self._data = load_json(self._file_name) or {}
|
| 102 |
logger.info("Load KV %s with %d data", self.namespace, len(self._data))
|
| 103 |
|
| 104 |
-
|
| 105 |
new_data = {}
|
| 106 |
for k, v in storage_instance.data.items():
|
| 107 |
if k not in self._data:
|
| 108 |
new_data[k] = v
|
| 109 |
return new_data
|
| 110 |
|
| 111 |
-
|
| 112 |
-
new_data =
|
| 113 |
if new_data:
|
| 114 |
self._data.update(new_data)
|
|
|
|
| 7 |
|
| 8 |
@dataclass
|
| 9 |
class JsonKVStorage(BaseKVStorage):
|
| 10 |
+
_data: dict[str, dict] = None
|
| 11 |
|
| 12 |
def __post_init__(self):
|
| 13 |
self._file_name = os.path.join(self.working_dir, f"{self.namespace}.json")
|
|
|
|
| 18 |
def data(self):
|
| 19 |
return self._data
|
| 20 |
|
| 21 |
+
def all_keys(self) -> list[str]:
|
| 22 |
return list(self._data.keys())
|
| 23 |
|
| 24 |
+
def index_done_callback(self):
|
| 25 |
write_json(self._data, self._file_name)
|
| 26 |
|
| 27 |
+
def get_by_id(self, id):
|
| 28 |
return self._data.get(id, None)
|
| 29 |
|
| 30 |
+
def get_by_ids(self, ids, fields=None) -> list:
|
| 31 |
if fields is None:
|
| 32 |
return [self._data.get(id, None) for id in ids]
|
| 33 |
return [
|
|
|
|
| 39 |
for id in ids
|
| 40 |
]
|
| 41 |
|
| 42 |
+
def get_all(self) -> dict[str, dict]:
|
| 43 |
return self._data
|
| 44 |
|
| 45 |
+
def filter_keys(self, data: list[str]) -> set[str]:
|
| 46 |
return {s for s in data if s not in self._data}
|
| 47 |
|
| 48 |
+
def upsert(self, data: dict):
|
| 49 |
left_data = {k: v for k, v in data.items() if k not in self._data}
|
| 50 |
if left_data:
|
| 51 |
self._data.update(left_data)
|
| 52 |
return left_data
|
| 53 |
|
| 54 |
+
def drop(self):
|
| 55 |
if self._data:
|
| 56 |
self._data.clear()
|
| 57 |
|
|
|
|
| 71 |
def data(self):
|
| 72 |
return self._data
|
| 73 |
|
| 74 |
+
def all_items(self) -> list:
|
| 75 |
return self._data
|
| 76 |
|
| 77 |
+
def index_done_callback(self):
|
| 78 |
write_json(self._data, self._file_name)
|
| 79 |
|
| 80 |
+
def get_by_index(self, index: int):
|
| 81 |
if index < 0 or index >= len(self._data):
|
| 82 |
return None
|
| 83 |
return self._data[index]
|
| 84 |
|
| 85 |
+
def append(self, data):
|
| 86 |
self._data.append(data)
|
| 87 |
|
| 88 |
+
def upsert(self, data: list):
|
| 89 |
left_data = [d for d in data if d not in self._data]
|
| 90 |
self._data.extend(left_data)
|
| 91 |
return left_data
|
| 92 |
|
| 93 |
+
def drop(self):
|
| 94 |
self._data = []
|
| 95 |
|
| 96 |
|
|
|
|
| 101 |
self._data = load_json(self._file_name) or {}
|
| 102 |
logger.info("Load KV %s with %d data", self.namespace, len(self._data))
|
| 103 |
|
| 104 |
+
def get_new_data(self, storage_instance: "JsonKVStorage") -> dict:
|
| 105 |
new_data = {}
|
| 106 |
for k, v in storage_instance.data.items():
|
| 107 |
if k not in self._data:
|
| 108 |
new_data[k] = v
|
| 109 |
return new_data
|
| 110 |
|
| 111 |
+
def mark_done(self, storage_instance: "JsonKVStorage"):
|
| 112 |
+
new_data = self.get_new_data(storage_instance)
|
| 113 |
if new_data:
|
| 114 |
self._data.update(new_data)
|
graphgen/models/storage/networkx_storage.py
CHANGED
|
@@ -91,60 +91,56 @@ class NetworkXStorage(BaseGraphStorage):
|
|
| 91 |
)
|
| 92 |
self._graph = preloaded_graph or nx.Graph()
|
| 93 |
|
| 94 |
-
|
| 95 |
NetworkXStorage.write_nx_graph(self._graph, self._graphml_xml_file)
|
| 96 |
|
| 97 |
-
|
| 98 |
return self._graph.has_node(node_id)
|
| 99 |
|
| 100 |
-
|
| 101 |
return self._graph.has_edge(source_node_id, target_node_id)
|
| 102 |
|
| 103 |
-
|
| 104 |
return self._graph.nodes.get(node_id)
|
| 105 |
|
| 106 |
-
|
| 107 |
return list(self._graph.nodes(data=True))
|
| 108 |
|
| 109 |
-
|
| 110 |
-
return self._graph.degree
|
| 111 |
|
| 112 |
-
|
| 113 |
-
return self._graph.degree
|
| 114 |
|
| 115 |
-
|
| 116 |
-
self, source_node_id: str, target_node_id: str
|
| 117 |
-
) -> Union[dict, None]:
|
| 118 |
return self._graph.edges.get((source_node_id, target_node_id))
|
| 119 |
|
| 120 |
-
|
| 121 |
return list(self._graph.edges(data=True))
|
| 122 |
|
| 123 |
-
|
| 124 |
-
self, source_node_id: str
|
| 125 |
-
) -> Union[list[tuple[str, str]], None]:
|
| 126 |
if self._graph.has_node(source_node_id):
|
| 127 |
return list(self._graph.edges(source_node_id, data=True))
|
| 128 |
return None
|
| 129 |
|
| 130 |
-
|
| 131 |
return self._graph
|
| 132 |
|
| 133 |
-
|
| 134 |
self._graph.add_node(node_id, **node_data)
|
| 135 |
|
| 136 |
-
|
| 137 |
if self._graph.has_node(node_id):
|
| 138 |
self._graph.nodes[node_id].update(node_data)
|
| 139 |
else:
|
| 140 |
logger.warning("Node %s not found in the graph for update.", node_id)
|
| 141 |
|
| 142 |
-
|
| 143 |
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
|
| 144 |
):
|
| 145 |
self._graph.add_edge(source_node_id, target_node_id, **edge_data)
|
| 146 |
|
| 147 |
-
|
| 148 |
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
|
| 149 |
):
|
| 150 |
if self._graph.has_edge(source_node_id, target_node_id):
|
|
@@ -156,7 +152,7 @@ class NetworkXStorage(BaseGraphStorage):
|
|
| 156 |
target_node_id,
|
| 157 |
)
|
| 158 |
|
| 159 |
-
|
| 160 |
"""
|
| 161 |
Delete a node from the graph based on the specified node_id.
|
| 162 |
|
|
@@ -168,7 +164,7 @@ class NetworkXStorage(BaseGraphStorage):
|
|
| 168 |
else:
|
| 169 |
logger.warning("Node %s not found in the graph for deletion.", node_id)
|
| 170 |
|
| 171 |
-
|
| 172 |
"""
|
| 173 |
Clear the graph by removing all nodes and edges.
|
| 174 |
"""
|
|
|
|
| 91 |
)
|
| 92 |
self._graph = preloaded_graph or nx.Graph()
|
| 93 |
|
| 94 |
+
def index_done_callback(self):
|
| 95 |
NetworkXStorage.write_nx_graph(self._graph, self._graphml_xml_file)
|
| 96 |
|
| 97 |
+
def has_node(self, node_id: str) -> bool:
|
| 98 |
return self._graph.has_node(node_id)
|
| 99 |
|
| 100 |
+
def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
|
| 101 |
return self._graph.has_edge(source_node_id, target_node_id)
|
| 102 |
|
| 103 |
+
def get_node(self, node_id: str) -> Union[dict, None]:
|
| 104 |
return self._graph.nodes.get(node_id)
|
| 105 |
|
| 106 |
+
def get_all_nodes(self) -> Union[list[tuple[str, dict]], None]:
|
| 107 |
return list(self._graph.nodes(data=True))
|
| 108 |
|
| 109 |
+
def node_degree(self, node_id: str) -> int:
|
| 110 |
+
return int(self._graph.degree[node_id])
|
| 111 |
|
| 112 |
+
def edge_degree(self, src_id: str, tgt_id: str) -> int:
|
| 113 |
+
return int(self._graph.degree[src_id] + self._graph.degree[tgt_id])
|
| 114 |
|
| 115 |
+
def get_edge(self, source_node_id: str, target_node_id: str) -> Union[dict, None]:
|
|
|
|
|
|
|
| 116 |
return self._graph.edges.get((source_node_id, target_node_id))
|
| 117 |
|
| 118 |
+
def get_all_edges(self) -> Union[list[tuple[str, str, dict]], None]:
|
| 119 |
return list(self._graph.edges(data=True))
|
| 120 |
|
| 121 |
+
def get_node_edges(self, source_node_id: str) -> Union[list[tuple[str, str]], None]:
|
|
|
|
|
|
|
| 122 |
if self._graph.has_node(source_node_id):
|
| 123 |
return list(self._graph.edges(source_node_id, data=True))
|
| 124 |
return None
|
| 125 |
|
| 126 |
+
def get_graph(self) -> nx.Graph:
|
| 127 |
return self._graph
|
| 128 |
|
| 129 |
+
def upsert_node(self, node_id: str, node_data: dict[str, str]):
|
| 130 |
self._graph.add_node(node_id, **node_data)
|
| 131 |
|
| 132 |
+
def update_node(self, node_id: str, node_data: dict[str, str]):
|
| 133 |
if self._graph.has_node(node_id):
|
| 134 |
self._graph.nodes[node_id].update(node_data)
|
| 135 |
else:
|
| 136 |
logger.warning("Node %s not found in the graph for update.", node_id)
|
| 137 |
|
| 138 |
+
def upsert_edge(
|
| 139 |
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
|
| 140 |
):
|
| 141 |
self._graph.add_edge(source_node_id, target_node_id, **edge_data)
|
| 142 |
|
| 143 |
+
def update_edge(
|
| 144 |
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
|
| 145 |
):
|
| 146 |
if self._graph.has_edge(source_node_id, target_node_id):
|
|
|
|
| 152 |
target_node_id,
|
| 153 |
)
|
| 154 |
|
| 155 |
+
def delete_node(self, node_id: str):
|
| 156 |
"""
|
| 157 |
Delete a node from the graph based on the specified node_id.
|
| 158 |
|
|
|
|
| 164 |
else:
|
| 165 |
logger.warning("Node %s not found in the graph for deletion.", node_id)
|
| 166 |
|
| 167 |
+
def clear(self):
|
| 168 |
"""
|
| 169 |
Clear the graph by removing all nodes and edges.
|
| 170 |
"""
|
graphgen/operators/partition/partition_kg.py
CHANGED
|
@@ -34,8 +34,8 @@ async def partition_kg(
|
|
| 34 |
# TODO: before ECE partitioning, we need to:
|
| 35 |
# 1. 'quiz and judge' to get the comprehension loss if unit_sampling is not random
|
| 36 |
# 2. pre-tokenize nodes and edges to get the token length
|
| 37 |
-
edges =
|
| 38 |
-
nodes =
|
| 39 |
await pre_tokenize(kg_instance, tokenizer, edges, nodes)
|
| 40 |
partitioner = ECEPartitioner()
|
| 41 |
elif method == "leiden":
|
|
@@ -105,7 +105,7 @@ async def _attach_by_type(
|
|
| 105 |
image_chunks = [
|
| 106 |
data
|
| 107 |
for sid in source_ids
|
| 108 |
-
if "image" in sid.lower() and (data :=
|
| 109 |
]
|
| 110 |
if image_chunks:
|
| 111 |
# The generator expects a dictionary with an 'img_path' key, not a list of captions.
|
|
|
|
| 34 |
# TODO: before ECE partitioning, we need to:
|
| 35 |
# 1. 'quiz and judge' to get the comprehension loss if unit_sampling is not random
|
| 36 |
# 2. pre-tokenize nodes and edges to get the token length
|
| 37 |
+
edges = kg_instance.get_all_edges()
|
| 38 |
+
nodes = kg_instance.get_all_nodes()
|
| 39 |
await pre_tokenize(kg_instance, tokenizer, edges, nodes)
|
| 40 |
partitioner = ECEPartitioner()
|
| 41 |
elif method == "leiden":
|
|
|
|
| 105 |
image_chunks = [
|
| 106 |
data
|
| 107 |
for sid in source_ids
|
| 108 |
+
if "image" in sid.lower() and (data := chunk_storage.get_by_id(sid))
|
| 109 |
]
|
| 110 |
if image_chunks:
|
| 111 |
# The generator expects a dictionary with an 'img_path' key, not a list of captions.
|
graphgen/operators/partition/pre_tokenize.py
CHANGED
|
@@ -29,9 +29,9 @@ async def pre_tokenize(
|
|
| 29 |
)
|
| 30 |
)
|
| 31 |
if is_node:
|
| 32 |
-
|
| 33 |
else:
|
| 34 |
-
|
| 35 |
return obj
|
| 36 |
|
| 37 |
new_edges, new_nodes = await asyncio.gather(
|
|
@@ -51,5 +51,5 @@ async def pre_tokenize(
|
|
| 51 |
),
|
| 52 |
)
|
| 53 |
|
| 54 |
-
|
| 55 |
return new_edges, new_nodes
|
|
|
|
| 29 |
)
|
| 30 |
)
|
| 31 |
if is_node:
|
| 32 |
+
graph_storage.update_node(obj[0], obj[1])
|
| 33 |
else:
|
| 34 |
+
graph_storage.update_edge(obj[0], obj[1], obj[2])
|
| 35 |
return obj
|
| 36 |
|
| 37 |
new_edges, new_nodes = await asyncio.gather(
|
|
|
|
| 51 |
),
|
| 52 |
)
|
| 53 |
|
| 54 |
+
graph_storage.index_done_callback()
|
| 55 |
return new_edges, new_nodes
|
graphgen/operators/quiz_and_judge/judge.py
CHANGED
|
@@ -45,16 +45,14 @@ async def judge_statement( # pylint: disable=too-many-statements
|
|
| 45 |
description = edge_data["description"]
|
| 46 |
|
| 47 |
try:
|
| 48 |
-
descriptions =
|
| 49 |
assert descriptions is not None
|
| 50 |
|
| 51 |
judgements = []
|
| 52 |
gts = [gt for _, gt in descriptions]
|
| 53 |
for description, gt in descriptions:
|
| 54 |
judgement = await trainee_llm_client.generate_topk_per_token(
|
| 55 |
-
STATEMENT_JUDGEMENT_PROMPT["TEMPLATE"].format(
|
| 56 |
-
statement=description
|
| 57 |
-
)
|
| 58 |
)
|
| 59 |
judgements.append(judgement[0].top_candidates)
|
| 60 |
|
|
@@ -76,10 +74,10 @@ async def judge_statement( # pylint: disable=too-many-statements
|
|
| 76 |
logger.info("Use default loss 0.1")
|
| 77 |
edge_data["loss"] = -math.log(0.1)
|
| 78 |
|
| 79 |
-
|
| 80 |
return source_id, target_id, edge_data
|
| 81 |
|
| 82 |
-
edges =
|
| 83 |
|
| 84 |
await run_concurrent(
|
| 85 |
_judge_single_relation,
|
|
@@ -104,24 +102,20 @@ async def judge_statement( # pylint: disable=too-many-statements
|
|
| 104 |
description = node_data["description"]
|
| 105 |
|
| 106 |
try:
|
| 107 |
-
descriptions =
|
| 108 |
assert descriptions is not None
|
| 109 |
|
| 110 |
judgements = []
|
| 111 |
gts = [gt for _, gt in descriptions]
|
| 112 |
for description, gt in descriptions:
|
| 113 |
judgement = await trainee_llm_client.generate_topk_per_token(
|
| 114 |
-
STATEMENT_JUDGEMENT_PROMPT["TEMPLATE"].format(
|
| 115 |
-
statement=description
|
| 116 |
-
)
|
| 117 |
)
|
| 118 |
judgements.append(judgement[0].top_candidates)
|
| 119 |
|
| 120 |
loss = yes_no_loss_entropy(judgements, gts)
|
| 121 |
|
| 122 |
-
logger.debug(
|
| 123 |
-
"Node %s description: %s loss: %s", node_id, description, loss
|
| 124 |
-
)
|
| 125 |
|
| 126 |
node_data["loss"] = loss
|
| 127 |
except Exception as e: # pylint: disable=broad-except
|
|
@@ -129,10 +123,10 @@ async def judge_statement( # pylint: disable=too-many-statements
|
|
| 129 |
logger.error("Use default loss 0.1")
|
| 130 |
node_data["loss"] = -math.log(0.1)
|
| 131 |
|
| 132 |
-
|
| 133 |
return node_id, node_data
|
| 134 |
|
| 135 |
-
nodes =
|
| 136 |
|
| 137 |
await run_concurrent(
|
| 138 |
_judge_single_entity,
|
|
|
|
| 45 |
description = edge_data["description"]
|
| 46 |
|
| 47 |
try:
|
| 48 |
+
descriptions = rephrase_storage.get_by_id(description)
|
| 49 |
assert descriptions is not None
|
| 50 |
|
| 51 |
judgements = []
|
| 52 |
gts = [gt for _, gt in descriptions]
|
| 53 |
for description, gt in descriptions:
|
| 54 |
judgement = await trainee_llm_client.generate_topk_per_token(
|
| 55 |
+
STATEMENT_JUDGEMENT_PROMPT["TEMPLATE"].format(statement=description)
|
|
|
|
|
|
|
| 56 |
)
|
| 57 |
judgements.append(judgement[0].top_candidates)
|
| 58 |
|
|
|
|
| 74 |
logger.info("Use default loss 0.1")
|
| 75 |
edge_data["loss"] = -math.log(0.1)
|
| 76 |
|
| 77 |
+
graph_storage.update_edge(source_id, target_id, edge_data)
|
| 78 |
return source_id, target_id, edge_data
|
| 79 |
|
| 80 |
+
edges = graph_storage.get_all_edges()
|
| 81 |
|
| 82 |
await run_concurrent(
|
| 83 |
_judge_single_relation,
|
|
|
|
| 102 |
description = node_data["description"]
|
| 103 |
|
| 104 |
try:
|
| 105 |
+
descriptions = rephrase_storage.get_by_id(description)
|
| 106 |
assert descriptions is not None
|
| 107 |
|
| 108 |
judgements = []
|
| 109 |
gts = [gt for _, gt in descriptions]
|
| 110 |
for description, gt in descriptions:
|
| 111 |
judgement = await trainee_llm_client.generate_topk_per_token(
|
| 112 |
+
STATEMENT_JUDGEMENT_PROMPT["TEMPLATE"].format(statement=description)
|
|
|
|
|
|
|
| 113 |
)
|
| 114 |
judgements.append(judgement[0].top_candidates)
|
| 115 |
|
| 116 |
loss = yes_no_loss_entropy(judgements, gts)
|
| 117 |
|
| 118 |
+
logger.debug("Node %s description: %s loss: %s", node_id, description, loss)
|
|
|
|
|
|
|
| 119 |
|
| 120 |
node_data["loss"] = loss
|
| 121 |
except Exception as e: # pylint: disable=broad-except
|
|
|
|
| 123 |
logger.error("Use default loss 0.1")
|
| 124 |
node_data["loss"] = -math.log(0.1)
|
| 125 |
|
| 126 |
+
graph_storage.update_node(node_id, node_data)
|
| 127 |
return node_id, node_data
|
| 128 |
|
| 129 |
+
nodes = graph_storage.get_all_nodes()
|
| 130 |
|
| 131 |
await run_concurrent(
|
| 132 |
_judge_single_entity,
|
graphgen/operators/quiz_and_judge/quiz.py
CHANGED
|
@@ -31,7 +31,7 @@ async def quiz(
|
|
| 31 |
description, template_type, gt = item
|
| 32 |
try:
|
| 33 |
# if rephrase_storage exists already, directly get it
|
| 34 |
-
descriptions =
|
| 35 |
if descriptions:
|
| 36 |
return None
|
| 37 |
|
|
@@ -46,8 +46,8 @@ async def quiz(
|
|
| 46 |
logger.error("Error when quizzing description %s: %s", description, e)
|
| 47 |
return None
|
| 48 |
|
| 49 |
-
edges =
|
| 50 |
-
nodes =
|
| 51 |
|
| 52 |
results = defaultdict(list)
|
| 53 |
items = []
|
|
@@ -88,6 +88,6 @@ async def quiz(
|
|
| 88 |
|
| 89 |
for key, value in results.items():
|
| 90 |
results[key] = list(set(value))
|
| 91 |
-
|
| 92 |
|
| 93 |
return rephrase_storage
|
|
|
|
| 31 |
description, template_type, gt = item
|
| 32 |
try:
|
| 33 |
# if rephrase_storage exists already, directly get it
|
| 34 |
+
descriptions = rephrase_storage.get_by_id(description)
|
| 35 |
if descriptions:
|
| 36 |
return None
|
| 37 |
|
|
|
|
| 46 |
logger.error("Error when quizzing description %s: %s", description, e)
|
| 47 |
return None
|
| 48 |
|
| 49 |
+
edges = graph_storage.get_all_edges()
|
| 50 |
+
nodes = graph_storage.get_all_nodes()
|
| 51 |
|
| 52 |
results = defaultdict(list)
|
| 53 |
items = []
|
|
|
|
| 88 |
|
| 89 |
for key, value in results.items():
|
| 90 |
results[key] = list(set(value))
|
| 91 |
+
rephrase_storage.upsert({key: results[key]})
|
| 92 |
|
| 93 |
return rephrase_storage
|
graphgen/operators/storage.py
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from typing import Any
|
| 3 |
+
|
| 4 |
+
import ray
|
| 5 |
+
|
| 6 |
+
from graphgen.models import JsonKVStorage, JsonListStorage, NetworkXStorage
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
@ray.remote
|
| 10 |
+
class StorageManager:
|
| 11 |
+
"""
|
| 12 |
+
Centralized storage for all operators
|
| 13 |
+
|
| 14 |
+
Example Usage:
|
| 15 |
+
----------
|
| 16 |
+
# init
|
| 17 |
+
storage_manager = StorageManager.remote(working_dir="/path/to/dir", unique_id=123)
|
| 18 |
+
|
| 19 |
+
# visit storage in tasks
|
| 20 |
+
@ray.remote
|
| 21 |
+
def some_task(storage_manager):
|
| 22 |
+
full_docs_storage = ray.get(storage_manager.get_storage.remote("full_docs"))
|
| 23 |
+
|
| 24 |
+
# visit storage in other actors
|
| 25 |
+
@ray.remote
|
| 26 |
+
class SomeOperator:
|
| 27 |
+
def __init__(self, storage_manager):
|
| 28 |
+
self.storage_manager = storage_manager
|
| 29 |
+
def some_method(self):
|
| 30 |
+
full_docs_storage = ray.get(self.storage_manager.get_storage.remote("full_docs"))
|
| 31 |
+
"""
|
| 32 |
+
|
| 33 |
+
def __init__(self, working_dir: str, unique_id: int):
|
| 34 |
+
self.working_dir = working_dir
|
| 35 |
+
self.unique_id = unique_id
|
| 36 |
+
|
| 37 |
+
# Initialize all storage backends
|
| 38 |
+
self.storages = {
|
| 39 |
+
"full_docs": JsonKVStorage(working_dir, namespace="full_docs"),
|
| 40 |
+
"chunks": JsonKVStorage(working_dir, namespace="chunks"),
|
| 41 |
+
"graph": NetworkXStorage(working_dir, namespace="graph"),
|
| 42 |
+
"rephrase": JsonKVStorage(working_dir, namespace="rephrase"),
|
| 43 |
+
"partition": JsonListStorage(working_dir, namespace="partition"),
|
| 44 |
+
"search": JsonKVStorage(
|
| 45 |
+
os.path.join(working_dir, "data", "graphgen", f"{unique_id}"),
|
| 46 |
+
namespace="search",
|
| 47 |
+
),
|
| 48 |
+
"extraction": JsonKVStorage(
|
| 49 |
+
os.path.join(working_dir, "data", "graphgen", f"{unique_id}"),
|
| 50 |
+
namespace="extraction",
|
| 51 |
+
),
|
| 52 |
+
"qa": JsonListStorage(
|
| 53 |
+
os.path.join(working_dir, "data", "graphgen", f"{unique_id}"),
|
| 54 |
+
namespace="qa",
|
| 55 |
+
),
|
| 56 |
+
}
|
| 57 |
+
|
| 58 |
+
def get_storage(self, name: str) -> Any:
|
| 59 |
+
return self.storages.get(name)
|
requirements.txt
CHANGED
|
@@ -12,7 +12,7 @@ nltk
|
|
| 12 |
jieba
|
| 13 |
plotly
|
| 14 |
pandas
|
| 15 |
-
gradio
|
| 16 |
kaleido
|
| 17 |
pyyaml
|
| 18 |
langcodes
|
|
|
|
| 12 |
jieba
|
| 13 |
plotly
|
| 14 |
pandas
|
| 15 |
+
gradio==5.44.1
|
| 16 |
kaleido
|
| 17 |
pyyaml
|
| 18 |
langcodes
|