Spaces:
Sleeping
Sleeping
Add CONCURRENCY_LIMIT; Graph config change -> directed
Browse files
interactive/pyvis_graph.py
CHANGED
|
@@ -3,12 +3,12 @@ import json
|
|
| 3 |
import torch
|
| 4 |
import gradio as gr
|
| 5 |
from pyvis.network import Network
|
| 6 |
-
|
| 7 |
sys.path.append(".")
|
|
|
|
| 8 |
from src.benchmarks import get_semistructured_data
|
| 9 |
|
| 10 |
-
|
| 11 |
-
TITLE = "STaRK Knowledge Base Explorer"
|
| 12 |
BRAND_NAME = {
|
| 13 |
"amazon": "STaRK-Amazon",
|
| 14 |
"mag": "STaRK-MAG",
|
|
@@ -22,20 +22,16 @@ NODE_COLORS = [
|
|
| 22 |
"#00796B", # Teal
|
| 23 |
"#03A9F4", # Light Blue
|
| 24 |
"#CDDC39", # Lime
|
| 25 |
-
"#E91E63", # Pink
|
| 26 |
"#3F51B5", # Indigo
|
| 27 |
"#00BCD4", # Cyan
|
| 28 |
"#FFC107", # Amber
|
| 29 |
"#8BC34A", # Light Green
|
| 30 |
-
"#795548", # Brown
|
| 31 |
"#9E9E9E", # Grey
|
| 32 |
"#607D8B", # Blue Grey
|
| 33 |
"#FFEB3B", # Bright Yellow
|
| 34 |
"#E1F5FE", # Light Blue 50
|
| 35 |
"#F1F8E9", # Light Green 50
|
| 36 |
"#FFF3E0", # Orange 50
|
| 37 |
-
"#FCE4EC", # Pink 50
|
| 38 |
-
"#F3E5F5", # Purple 50
|
| 39 |
"#FFFDE7", # Yellow 50
|
| 40 |
"#E0F7FA", # Cyan 50
|
| 41 |
"#E8F5E9", # Green 50
|
|
@@ -90,10 +86,14 @@ def relabel(x, edge_index, batch, pos=None):
|
|
| 90 |
return x, edge_index, batch, pos
|
| 91 |
|
| 92 |
|
| 93 |
-
def generate_network(kb, node_id, max_nodes=10, num_hops=
|
| 94 |
max_nodes = int(max_nodes)
|
| 95 |
-
|
| 96 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 97 |
|
| 98 |
def get_one_hop(kb, node_id, max_nodes):
|
| 99 |
edge_index = kb.edge_index
|
|
@@ -137,7 +137,6 @@ def generate_network(kb, node_id, max_nodes=10, num_hops="1"):
|
|
| 137 |
node_ids, relabel_edge_index, _, _ = relabel(
|
| 138 |
torch.arange(kb.num_nodes()), edge_index, batch=torch.zeros(kb.num_nodes())
|
| 139 |
)
|
| 140 |
-
|
| 141 |
for idx, n_id in enumerate(node_ids):
|
| 142 |
if node_id == n_id:
|
| 143 |
net.add_node(
|
|
@@ -158,31 +157,45 @@ def generate_network(kb, node_id, max_nodes=10, num_hops="1"):
|
|
| 158 |
font={"align": "middle", "size": 10},
|
| 159 |
)
|
| 160 |
for idx in range(relabel_edge_index.size(-1)):
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 172 |
return net.get_network_data()
|
| 173 |
|
| 174 |
|
| 175 |
def get_text_html(kb, node_id):
|
| 176 |
text = kb.get_doc_info(node_id, add_rel=False, compact=False)
|
| 177 |
-
# need a text box, figure left, text right
|
| 178 |
-
text = text.replace("\n", "<br>").replace(" ", " ")
|
| 179 |
# add a title
|
|
|
|
| 180 |
text = f"<h3>Textual Info of Entity {node_id}:</h3>{text}"
|
|
|
|
| 181 |
# show the text as what it is with empty space and can be scrolled
|
| 182 |
-
return f"""<
|
|
|
|
|
|
|
|
|
|
| 183 |
|
| 184 |
|
| 185 |
-
def get_subgraph_html(kb, kb_name, node_id, max_nodes=10, num_hops=
|
| 186 |
network = generate_network(kb, node_id, max_nodes, num_hops)
|
| 187 |
|
| 188 |
nodes = network[0]
|
|
@@ -200,7 +213,7 @@ def get_subgraph_html(kb, kb_name, node_id, max_nodes=10, num_hops="1"):
|
|
| 200 |
|
| 201 |
def main():
|
| 202 |
# kb = get_semistructured_data(DATASET_NAME)
|
| 203 |
-
kbs = {k: get_semistructured_data(k) for k in BRAND_NAME.keys()}
|
| 204 |
|
| 205 |
with gr.Blocks(head=VISJS_HEAD, title=TITLE) as demo:
|
| 206 |
gr.Markdown(f"# {TITLE}")
|
|
@@ -208,11 +221,14 @@ def main():
|
|
| 208 |
with gr.Tab(BRAND_NAME[name]):
|
| 209 |
with gr.Row():
|
| 210 |
entity_id = gr.Number(
|
| 211 |
-
label="Entity ID",
|
|
|
|
|
|
|
|
|
|
|
|
|
| 212 |
)
|
| 213 |
-
max_paths = gr.Slider(1, 200, 10, step=1, label="Max Paths")
|
| 214 |
num_hops = gr.Dropdown(
|
| 215 |
-
["1", "2", "inf"], value="
|
| 216 |
)
|
| 217 |
query_btn = gr.Button(
|
| 218 |
value="Show Graph",
|
|
@@ -232,7 +248,7 @@ def main():
|
|
| 232 |
),
|
| 233 |
inputs=[entity_id, max_paths, num_hops],
|
| 234 |
outputs=[graph_area, text_area],
|
| 235 |
-
api_name=f"{name}-fetch-graph"
|
| 236 |
)
|
| 237 |
|
| 238 |
# Hidden inputs for fetch just text
|
|
@@ -248,11 +264,12 @@ def main():
|
|
| 248 |
lambda e, kb=kb: get_text_html(kb, e),
|
| 249 |
inputs=[entity_for_text],
|
| 250 |
outputs=text_area,
|
| 251 |
-
api_name=f"{name}-fetch-text"
|
| 252 |
)
|
| 253 |
-
|
| 254 |
demo.launch(share=True)
|
| 255 |
|
| 256 |
|
| 257 |
if __name__ == "__main__":
|
| 258 |
-
|
|
|
|
|
|
| 3 |
import torch
|
| 4 |
import gradio as gr
|
| 5 |
from pyvis.network import Network
|
|
|
|
| 6 |
sys.path.append(".")
|
| 7 |
+
import re
|
| 8 |
from src.benchmarks import get_semistructured_data
|
| 9 |
|
| 10 |
+
CONCURRENCY_LIMIT = 1000
|
| 11 |
+
TITLE = "STaRK Semistructure Knowledge Base Explorer"
|
| 12 |
BRAND_NAME = {
|
| 13 |
"amazon": "STaRK-Amazon",
|
| 14 |
"mag": "STaRK-MAG",
|
|
|
|
| 22 |
"#00796B", # Teal
|
| 23 |
"#03A9F4", # Light Blue
|
| 24 |
"#CDDC39", # Lime
|
|
|
|
| 25 |
"#3F51B5", # Indigo
|
| 26 |
"#00BCD4", # Cyan
|
| 27 |
"#FFC107", # Amber
|
| 28 |
"#8BC34A", # Light Green
|
|
|
|
| 29 |
"#9E9E9E", # Grey
|
| 30 |
"#607D8B", # Blue Grey
|
| 31 |
"#FFEB3B", # Bright Yellow
|
| 32 |
"#E1F5FE", # Light Blue 50
|
| 33 |
"#F1F8E9", # Light Green 50
|
| 34 |
"#FFF3E0", # Orange 50
|
|
|
|
|
|
|
| 35 |
"#FFFDE7", # Yellow 50
|
| 36 |
"#E0F7FA", # Cyan 50
|
| 37 |
"#E8F5E9", # Green 50
|
|
|
|
| 86 |
return x, edge_index, batch, pos
|
| 87 |
|
| 88 |
|
| 89 |
+
def generate_network(kb, node_id, max_nodes=10, num_hops='2'):
|
| 90 |
max_nodes = int(max_nodes)
|
| 91 |
+
if 'gene/protein' in kb.node_type_dict.values():
|
| 92 |
+
indirected = True
|
| 93 |
+
net = Network(directed=False)
|
| 94 |
+
else:
|
| 95 |
+
indirected = False
|
| 96 |
+
net = Network()
|
| 97 |
|
| 98 |
def get_one_hop(kb, node_id, max_nodes):
|
| 99 |
edge_index = kb.edge_index
|
|
|
|
| 137 |
node_ids, relabel_edge_index, _, _ = relabel(
|
| 138 |
torch.arange(kb.num_nodes()), edge_index, batch=torch.zeros(kb.num_nodes())
|
| 139 |
)
|
|
|
|
| 140 |
for idx, n_id in enumerate(node_ids):
|
| 141 |
if node_id == n_id:
|
| 142 |
net.add_node(
|
|
|
|
| 157 |
font={"align": "middle", "size": 10},
|
| 158 |
)
|
| 159 |
for idx in range(relabel_edge_index.size(-1)):
|
| 160 |
+
if indirected:
|
| 161 |
+
net.add_edge(
|
| 162 |
+
relabel_edge_index[0][idx].item(),
|
| 163 |
+
relabel_edge_index[1][idx].item(),
|
| 164 |
+
color=EDGE_COLORS[edge_types[idx].item()],
|
| 165 |
+
label=kb.edge_type_dict[edge_types[idx].item()]
|
| 166 |
+
.replace('___', " ")
|
| 167 |
+
.replace('_', " "),
|
| 168 |
+
width=1,
|
| 169 |
+
font={"align": "middle", "size": 10})
|
| 170 |
+
else:
|
| 171 |
+
net.add_edge(
|
| 172 |
+
relabel_edge_index[0][idx].item(),
|
| 173 |
+
relabel_edge_index[1][idx].item(),
|
| 174 |
+
color=EDGE_COLORS[edge_types[idx].item()],
|
| 175 |
+
label=kb.edge_type_dict[edge_types[idx].item()]
|
| 176 |
+
.replace('___', " ")
|
| 177 |
+
.replace('_', " "),
|
| 178 |
+
width=1,
|
| 179 |
+
font={"align": "middle", "size": 10},
|
| 180 |
+
arrows="to",
|
| 181 |
+
arrowStrikethrough=False)
|
| 182 |
return net.get_network_data()
|
| 183 |
|
| 184 |
|
| 185 |
def get_text_html(kb, node_id):
|
| 186 |
text = kb.get_doc_info(node_id, add_rel=False, compact=False)
|
|
|
|
|
|
|
| 187 |
# add a title
|
| 188 |
+
text = text.replace("\n", "<br>").replace(" ", " ")
|
| 189 |
text = f"<h3>Textual Info of Entity {node_id}:</h3>{text}"
|
| 190 |
+
text = re.sub(r"\$([^$]+)\$", r"\\(\1\\)", text)
|
| 191 |
# show the text as what it is with empty space and can be scrolled
|
| 192 |
+
return f"""<script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
|
| 193 |
+
<script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
|
| 194 |
+
<div style="width: 100%; height: 600px; overflow-x: hidden; overflow-y: scroll; overflow-wrap: break-word; hyphens: auto; padding: 10px; margin: 0 auto; border: 1px solid #ccc; line-height: 1.5;
|
| 195 |
+
font-family: SF Pro Text, SF Pro Icons, Helvetica Neue, Helvetica, Arial, sans-serif;">{text}</div>"""
|
| 196 |
|
| 197 |
|
| 198 |
+
def get_subgraph_html(kb, kb_name, node_id, max_nodes=10, num_hops=1):
|
| 199 |
network = generate_network(kb, node_id, max_nodes, num_hops)
|
| 200 |
|
| 201 |
nodes = network[0]
|
|
|
|
| 213 |
|
| 214 |
def main():
|
| 215 |
# kb = get_semistructured_data(DATASET_NAME)
|
| 216 |
+
kbs = {k: get_semistructured_data(k, indirected=False) for k in BRAND_NAME.keys()}
|
| 217 |
|
| 218 |
with gr.Blocks(head=VISJS_HEAD, title=TITLE) as demo:
|
| 219 |
gr.Markdown(f"# {TITLE}")
|
|
|
|
| 221 |
with gr.Tab(BRAND_NAME[name]):
|
| 222 |
with gr.Row():
|
| 223 |
entity_id = gr.Number(
|
| 224 |
+
label="Entity ID",
|
| 225 |
+
elem_id=f"{name}-entity-id-input"
|
| 226 |
+
)
|
| 227 |
+
max_paths = gr.Slider(
|
| 228 |
+
1, 200, 10, step=1, label="Max Number of Paths"
|
| 229 |
)
|
|
|
|
| 230 |
num_hops = gr.Dropdown(
|
| 231 |
+
["1", "2", "inf"], value="2", label="Number of Hops"
|
| 232 |
)
|
| 233 |
query_btn = gr.Button(
|
| 234 |
value="Show Graph",
|
|
|
|
| 248 |
),
|
| 249 |
inputs=[entity_id, max_paths, num_hops],
|
| 250 |
outputs=[graph_area, text_area],
|
| 251 |
+
api_name=f"{name}-fetch-graph"
|
| 252 |
)
|
| 253 |
|
| 254 |
# Hidden inputs for fetch just text
|
|
|
|
| 264 |
lambda e, kb=kb: get_text_html(kb, e),
|
| 265 |
inputs=[entity_for_text],
|
| 266 |
outputs=text_area,
|
| 267 |
+
api_name=f"{name}-fetch-text"
|
| 268 |
)
|
| 269 |
+
demo.queue(max_size=2*CONCURRENCY_LIMIT, default_concurrency_limit=CONCURRENCY_LIMIT)
|
| 270 |
demo.launch(share=True)
|
| 271 |
|
| 272 |
|
| 273 |
if __name__ == "__main__":
|
| 274 |
+
|
| 275 |
+
main()
|
src/benchmarks/get_semistruct.py
CHANGED
|
@@ -2,21 +2,22 @@ import os.path as osp
|
|
| 2 |
from src.benchmarks.semistruct import AmazonSemiStruct, PrimeKGSemiStruct, MagSemiStruct
|
| 3 |
|
| 4 |
|
| 5 |
-
def get_semistructured_data(name, root='data/', download_processed=True):
|
| 6 |
data_root = osp.join(root, name)
|
| 7 |
if name == 'amazon':
|
| 8 |
categories = ['Sports_and_Outdoors']
|
| 9 |
kb = AmazonSemiStruct(root=data_root,
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
| 15 |
if name == 'primekg':
|
| 16 |
kb = PrimeKGSemiStruct(root=data_root,
|
| 17 |
-
|
|
|
|
| 18 |
|
| 19 |
if name == 'mag':
|
| 20 |
kb = MagSemiStruct(root=data_root,
|
| 21 |
-
|
| 22 |
return kb
|
|
|
|
| 2 |
from src.benchmarks.semistruct import AmazonSemiStruct, PrimeKGSemiStruct, MagSemiStruct
|
| 3 |
|
| 4 |
|
| 5 |
+
def get_semistructured_data(name, root='data/', download_processed=True, **kwargs):
|
| 6 |
data_root = osp.join(root, name)
|
| 7 |
if name == 'amazon':
|
| 8 |
categories = ['Sports_and_Outdoors']
|
| 9 |
kb = AmazonSemiStruct(root=data_root,
|
| 10 |
+
categories=categories,
|
| 11 |
+
meta_link_types=['brand'],
|
| 12 |
+
download_processed=download_processed,
|
| 13 |
+
**kwargs
|
| 14 |
+
)
|
| 15 |
if name == 'primekg':
|
| 16 |
kb = PrimeKGSemiStruct(root=data_root,
|
| 17 |
+
download_processed=download_processed,
|
| 18 |
+
**kwargs)
|
| 19 |
|
| 20 |
if name == 'mag':
|
| 21 |
kb = MagSemiStruct(root=data_root,
|
| 22 |
+
download_processed=download_processed)
|
| 23 |
return kb
|
src/benchmarks/semistruct/amazon.py
CHANGED
|
@@ -63,8 +63,8 @@ class AmazonSemiStruct(SemiStructureKB):
|
|
| 63 |
categories: list,
|
| 64 |
meta_link_types=['brand'],
|
| 65 |
max_entries=25,
|
| 66 |
-
|
| 67 |
-
|
| 68 |
'''
|
| 69 |
Args:
|
| 70 |
root (str): root directory to store the data
|
|
@@ -108,7 +108,7 @@ class AmazonSemiStruct(SemiStructureKB):
|
|
| 108 |
if meta_link_types:
|
| 109 |
# customize the graph by adding meta links
|
| 110 |
processed_data = self.post_process(processed_data, meta_link_types=meta_link_types, cache_path=cache_path)
|
| 111 |
-
super(AmazonSemiStruct, self).__init__(**processed_data,
|
| 112 |
|
| 113 |
def __getitem__(self, idx):
|
| 114 |
idx = int(idx)
|
|
|
|
| 63 |
categories: list,
|
| 64 |
meta_link_types=['brand'],
|
| 65 |
max_entries=25,
|
| 66 |
+
download_processed=True,
|
| 67 |
+
**kwargs):
|
| 68 |
'''
|
| 69 |
Args:
|
| 70 |
root (str): root directory to store the data
|
|
|
|
| 108 |
if meta_link_types:
|
| 109 |
# customize the graph by adding meta links
|
| 110 |
processed_data = self.post_process(processed_data, meta_link_types=meta_link_types, cache_path=cache_path)
|
| 111 |
+
super(AmazonSemiStruct, self).__init__(**processed_data, **kwargs)
|
| 112 |
|
| 113 |
def __getitem__(self, idx):
|
| 114 |
idx = int(idx)
|
src/benchmarks/semistruct/mag.py
CHANGED
|
@@ -40,7 +40,7 @@ class MagSemiStruct(SemiStructureKB):
|
|
| 40 |
ogbn_papers100M_url = 'https://snap.stanford.edu/ogb/data/misc/ogbn_papers100M/paperinfo.zip'
|
| 41 |
mag_mapping_url = 'https://zenodo.org/records/2628216/files'
|
| 42 |
|
| 43 |
-
def __init__(self, root, download_processed=True):
|
| 44 |
'''
|
| 45 |
Args:
|
| 46 |
root (str): root directory to store the dataset folder
|
|
@@ -88,7 +88,7 @@ class MagSemiStruct(SemiStructureKB):
|
|
| 88 |
processed_data = self._process_raw()
|
| 89 |
processed_data.update({'node_type_dict': self.node_type_dict,
|
| 90 |
'edge_type_dict': self.edge_type_dict})
|
| 91 |
-
super(MagSemiStruct, self).__init__(**processed_data)
|
| 92 |
|
| 93 |
def load_edge(self, edge_type):
|
| 94 |
edge_dir = osp.join(self.graph_data_root, f"raw/relations/{edge_type}/edge.csv.gz")
|
|
|
|
| 40 |
ogbn_papers100M_url = 'https://snap.stanford.edu/ogb/data/misc/ogbn_papers100M/paperinfo.zip'
|
| 41 |
mag_mapping_url = 'https://zenodo.org/records/2628216/files'
|
| 42 |
|
| 43 |
+
def __init__(self, root, download_processed=True, **kwargs):
|
| 44 |
'''
|
| 45 |
Args:
|
| 46 |
root (str): root directory to store the dataset folder
|
|
|
|
| 88 |
processed_data = self._process_raw()
|
| 89 |
processed_data.update({'node_type_dict': self.node_type_dict,
|
| 90 |
'edge_type_dict': self.edge_type_dict})
|
| 91 |
+
super(MagSemiStruct, self).__init__(**processed_data, **kwargs)
|
| 92 |
|
| 93 |
def load_edge(self, edge_type):
|
| 94 |
edge_dir = osp.join(self.graph_data_root, f"raw/relations/{edge_type}/edge.csv.gz")
|
src/benchmarks/semistruct/primekg.py
CHANGED
|
@@ -30,7 +30,7 @@ class PrimeKGSemiStruct(SemiStructureKB):
|
|
| 30 |
candidate_types = NODE_TYPES
|
| 31 |
raw_data_url = 'https://drive.google.com/uc?id=1d__3yP6YZYjKWR2F9fGg-y1rW7-HJPpr'
|
| 32 |
|
| 33 |
-
def __init__(self, root, download_processed=True):
|
| 34 |
'''
|
| 35 |
Args:
|
| 36 |
root (str): root directory to store the dataset folder
|
|
@@ -61,7 +61,7 @@ class PrimeKGSemiStruct(SemiStructureKB):
|
|
| 61 |
print(f'Loaded from {self.processed_data_dir}!')
|
| 62 |
else:
|
| 63 |
processed_data = self._process_raw()
|
| 64 |
-
super(PrimeKGSemiStruct, self).__init__(**processed_data)
|
| 65 |
|
| 66 |
self.node_info = clean_dict(self.node_info)
|
| 67 |
self.node_attr_dict = {}
|
|
|
|
| 30 |
candidate_types = NODE_TYPES
|
| 31 |
raw_data_url = 'https://drive.google.com/uc?id=1d__3yP6YZYjKWR2F9fGg-y1rW7-HJPpr'
|
| 32 |
|
| 33 |
+
def __init__(self, root, download_processed=True, **kwargs):
|
| 34 |
'''
|
| 35 |
Args:
|
| 36 |
root (str): root directory to store the dataset folder
|
|
|
|
| 61 |
print(f'Loaded from {self.processed_data_dir}!')
|
| 62 |
else:
|
| 63 |
processed_data = self._process_raw()
|
| 64 |
+
super(PrimeKGSemiStruct, self).__init__(**processed_data, **kwargs)
|
| 65 |
|
| 66 |
self.node_info = clean_dict(self.node_info)
|
| 67 |
self.node_attr_dict = {}
|