Spaces:
Sleeping
Sleeping
| import sys | |
| import json | |
| import torch | |
| import gradio as gr | |
| from pyvis.network import Network | |
| sys.path.append(".") | |
| import re | |
| from src.benchmarks import get_semistructured_data | |
| CONCURRENCY_LIMIT = 1000 | |
| TITLE = "STaRK Semi-structured Knowledge Base Explorer" | |
| BRAND_NAME = { | |
| "amazon": "STaRK-Amazon", | |
| "mag": "STaRK-MAG", | |
| "primekg": "STaRK-Prime", | |
| } | |
| NODE_COLORS = [ | |
| "#4285F4", # Blue | |
| "#F4B400", # Yellow | |
| "#0F9D58", # Green | |
| "#00796B", # Teal | |
| "#03A9F4", # Light Blue | |
| "#CDDC39", # Lime | |
| "#3F51B5", # Indigo | |
| "#00BCD4", # Cyan | |
| "#FFC107", # Amber | |
| "#8BC34A", # Light Green | |
| "#9E9E9E", # Grey | |
| "#607D8B", # Blue Grey | |
| "#FFEB3B", # Bright Yellow | |
| "#E1F5FE", # Light Blue 50 | |
| "#F1F8E9", # Light Green 50 | |
| "#FFF3E0", # Orange 50 | |
| "#FFFDE7", # Yellow 50 | |
| "#E0F7FA", # Cyan 50 | |
| "#E8F5E9", # Green 50 | |
| "#E3F2FD", # Blue 50 | |
| "#FFF8E1", # Amber 50 | |
| "#E0F2F1", # Teal 50 | |
| "#F9FBE7", # Lime 50 | |
| ] | |
| EDGE_COLORS = [ | |
| "#1B5E20", # Green 900 | |
| "#004D40", # Teal 900 | |
| "#1A237E", # Indigo 900 | |
| "#3E2723", # Brown 900 | |
| "#880E4F", # Pink 900 | |
| "#01579B", # Light Blue 900 | |
| "#F57F17", # Yellow 900 | |
| "#FF6F00", # Amber 900 | |
| "#4A148C", # Purple 900 | |
| "#0D47A1", # Blue 900 | |
| "#006064", # Cyan 900 | |
| "#827717", # Lime 900 | |
| "#E8EAF6", # Indigo 50 | |
| "#ECEFF1", # Blue Grey 50 | |
| "#9C27B0", # Purple | |
| "#311B92", # Deep Purple 900 | |
| "#673AB7", # Deep Purple | |
| "#EDE7F6", # Deep Purple 50 | |
| ] | |
| VISJS_HEAD = """ | |
| <script src="https://cdnjs.cloudflare.com/ajax/libs/vis-network/9.1.9/dist/vis-network.min.js" integrity="sha512-4/EGWWWj7LIr/e+CvsslZkRk0fHDpf04dydJHoHOH32Mpw8jYU28GNI6mruO7fh/1kq15kSvwhKJftMSlgm0FA==" crossorigin="anonymous" referrerpolicy="no-referrer"></script> | |
| <link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/vis-network/9.1.9/dist/dist/vis-network.min.css" integrity="sha512-WgxfT5LWjfszlPHXRmBWHkV2eceiWTOBvrKCNbdgDYTHrT2AeLCGbF4sZlZw3UMN3WtL0tGUoIAKsu8mllg/XA==" crossorigin="anonymous" referrerpolicy="no-referrer" /> | |
| <style type="text/css"> .graph-area { flex-basis: 30% !important; } .network-graph { width: 100%; height: 600px; background-color: #ffffff; border: 1px solid lightgray; position: relative; float: left; } </style> | |
| """ | |
| with open("interactive/draw_graph.js", "r") as f: | |
| VISJS_HEAD += f"<script>{f.read()}</script>" | |
| def relabel(x, edge_index, batch, pos=None): | |
| num_nodes = x.size(0) | |
| sub_nodes = torch.unique(edge_index) | |
| x = x[sub_nodes] | |
| batch = batch[sub_nodes] | |
| row, col = edge_index | |
| # remapping the nodes in the explanatory subgraph to new ids. | |
| node_idx = row.new_full((num_nodes,), -1) | |
| node_idx[sub_nodes] = torch.arange(sub_nodes.size(0), device=row.device) | |
| edge_index = node_idx[edge_index] | |
| if pos is not None: | |
| pos = pos[sub_nodes] | |
| return x, edge_index, batch, pos | |
| def generate_network(kb, node_id, max_nodes=10, num_hops='2'): | |
| max_nodes = int(max_nodes) | |
| if 'gene/protein' in kb.node_type_dict.values(): | |
| indirected = True | |
| net = Network(directed=False) | |
| else: | |
| indirected = False | |
| net = Network() | |
| def get_one_hop(kb, node_id, max_nodes): | |
| edge_index = kb.edge_index | |
| mask = ( | |
| torch.Tensor(edge_index[0] == node_id).float() | |
| + torch.Tensor(edge_index[1] == node_id).float() | |
| ) > 0 | |
| edge_index_with_node_id = edge_index[:, mask] | |
| edge_types = kb.edge_types[mask] | |
| # take the edge index with | |
| # ramdomly sample max_nodes edges | |
| if edge_index_with_node_id.size(1) > max_nodes: | |
| perm = torch.randperm(edge_index_with_node_id.size(1)) | |
| edge_index_with_node_id = edge_index_with_node_id[:, perm[:max_nodes]] | |
| edge_types = edge_types[perm[:max_nodes]] | |
| return edge_index_with_node_id, edge_types | |
| if num_hops == "1": | |
| edge_index, edge_types = get_one_hop(kb, node_id, max_nodes) | |
| if num_hops == "2": | |
| edge_index, edge_types = get_one_hop(kb, node_id, max_nodes) | |
| neighbor_nodes = torch.unique(edge_index).tolist() | |
| if node_id in neighbor_nodes: | |
| neighbor_nodes.remove(node_id) | |
| for neighbor_node in neighbor_nodes: | |
| e_index, e_type = get_one_hop(kb, neighbor_node, max_nodes=1) | |
| edge_index = torch.cat([edge_index, e_index], dim=1) | |
| edge_types = torch.cat([edge_types, e_type], dim=0) | |
| if num_hops == "inf": | |
| edge_index, edge_types = kb.edge_index, kb.edge_types | |
| # sample max_nodes edges | |
| if edge_index.size(1) > max_nodes: | |
| perm = torch.randperm(edge_index.size(1)) | |
| edge_index = edge_index[:, perm[:max_nodes]] | |
| edge_types = edge_types[perm[:max_nodes]] | |
| add_edge_index, add_edge_types = get_one_hop(kb, node_id, max_nodes=1) | |
| edge_index = torch.cat([edge_index, add_edge_index], dim=1) | |
| edge_types = torch.cat([edge_types, add_edge_types], dim=0) | |
| # add a self-loop for node_id to avoid isolated node | |
| edge_index = torch.concat([edge_index, torch.LongTensor([[node_id], [node_id]])], dim=1) | |
| node_ids, relabel_edge_index, _, _ = relabel( | |
| torch.arange(kb.num_nodes()), edge_index, batch=torch.zeros(kb.num_nodes()) | |
| ) | |
| for idx, n_id in enumerate(node_ids): | |
| if node_id == n_id: | |
| net.add_node( | |
| idx, | |
| node_id=n_id.item(), | |
| color="#DB4437", | |
| size=20, | |
| label=f"{kb.node_type_dict[kb.node_types[n_id].item()]}<{n_id}>", | |
| font={"align": "middle", "size": 10}, | |
| ) | |
| else: | |
| net.add_node( | |
| idx, | |
| node_id=n_id.item(), | |
| size=15, | |
| color=NODE_COLORS[kb.node_types[n_id].item()], | |
| label=f"{kb.node_type_dict[kb.node_types[n_id].item()]}", | |
| font={"align": "middle", "size": 10}, | |
| ) | |
| for idx in range(relabel_edge_index.size(-1)): | |
| if relabel_edge_index[0][idx].item() == relabel_edge_index[1][idx].item(): | |
| continue | |
| if indirected: | |
| net.add_edge( | |
| relabel_edge_index[0][idx].item(), | |
| relabel_edge_index[1][idx].item(), | |
| color=EDGE_COLORS[edge_types[idx].item()], | |
| label=kb.edge_type_dict[edge_types[idx].item()] | |
| .replace('___', " ") | |
| .replace('_', " "), | |
| width=1, | |
| font={"align": "middle", "size": 10}) | |
| else: | |
| net.add_edge( | |
| relabel_edge_index[0][idx].item(), | |
| relabel_edge_index[1][idx].item(), | |
| color=EDGE_COLORS[edge_types[idx].item()], | |
| label=kb.edge_type_dict[edge_types[idx].item()] | |
| .replace('___', " ") | |
| .replace('_', " "), | |
| width=1, | |
| font={"align": "middle", "size": 10}, | |
| arrows="to", | |
| arrowStrikethrough=False) | |
| return net.get_network_data() | |
| def get_text_html(kb, node_id): | |
| text = kb.get_doc_info(node_id, add_rel=False, compact=False) | |
| # add a title | |
| text = text.replace("\n", "<br>").replace(" ", " ") | |
| text = f"<h3>Textual Info of Entity {node_id}:</h3>{text}" | |
| text = re.sub(r"\$([^$]+)\$", r"\\(\1\\)", text) | |
| # show the text as what it is with empty space and can be scrolled | |
| return f"""<script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script> | |
| <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script> | |
| <div style="width: 100%; height: 600px; overflow-x: hidden; overflow-y: scroll; overflow-wrap: break-word; hyphens: auto; padding: 10px; margin: 0 auto; border: 1px solid #ccc; line-height: 1.5; | |
| font-family: SF Pro Text, SF Pro Icons, Helvetica Neue, Helvetica, Arial, sans-serif;">{text}</div>""" | |
| def get_subgraph_html(kb, kb_name, node_id, max_nodes=10, num_hops='1'): | |
| network = generate_network(kb, node_id, max_nodes, num_hops) | |
| nodes = network[0] | |
| edges = network[1] | |
| # A dirty hack to trigger the drawGraph function ;) | |
| # Have to do it this way because of the way Gradio handles HTML updates | |
| figure_html = f""" | |
| <div id="{kb_name}-network" class="network-graph"></div> | |
| <img src="/dummy.img" style="display: none;" onerror='drawGraph({json.dumps({"nodes": nodes, "edges": edges, "dataset": kb_name})});'> | |
| """ | |
| return figure_html | |
| def main(): | |
| # kb = get_semistructured_data(DATASET_NAME) | |
| kbs = {k: get_semistructured_data(k, indirected=False) for k in BRAND_NAME.keys()} | |
| with gr.Blocks(head=VISJS_HEAD, title=TITLE) as demo: | |
| gr.Markdown(f"# {TITLE}") | |
| for name, kb in kbs.items(): | |
| with gr.Tab(BRAND_NAME[name]): | |
| with gr.Row(): | |
| entity_id = gr.Number( | |
| label="Entity ID", | |
| elem_id=f"{name}-entity-id-input" | |
| ) | |
| max_paths = gr.Slider( | |
| 1, 200, 10, step=1, label="Max Number of Paths" | |
| ) | |
| num_hops = gr.Dropdown( | |
| ["1", "2", "inf"], value="2", label="Number of Hops" | |
| ) | |
| query_btn = gr.Button( | |
| value="Display Semi-structured Data", | |
| variant="primary", | |
| elem_id=f"{name}-fetch-btn" | |
| ) | |
| with gr.Row(): | |
| graph_area = gr.HTML(elem_classes="graph-area") | |
| text_area = gr.HTML(elem_classes="text-area") | |
| query_btn.click( | |
| # copy capture current kb and name | |
| lambda e, n, h, kb=kb, name=name: ( | |
| get_subgraph_html(kb, name, e, n, h), | |
| get_text_html(kb, e), | |
| ), | |
| inputs=[entity_id, max_paths, num_hops], | |
| outputs=[graph_area, text_area], | |
| api_name=f"{name}-fetch-graph" | |
| ) | |
| # Hidden inputs for fetch just text | |
| with gr.Row(visible=False): | |
| entity_for_text = gr.Number( | |
| label="Text Entity ID", elem_id=f"{name}-entity-id-text-input" | |
| ) | |
| query_text_btn = gr.Button( | |
| value="Show Text", elem_id=f"{name}-fetch-text-btn" | |
| ) | |
| query_text_btn.click( | |
| lambda e, kb=kb: get_text_html(kb, e), | |
| inputs=[entity_for_text], | |
| outputs=text_area, | |
| api_name=f"{name}-fetch-text" | |
| ) | |
| demo.queue(max_size=2*CONCURRENCY_LIMIT, default_concurrency_limit=CONCURRENCY_LIMIT) | |
| demo.launch(share=True) | |
| if __name__ == "__main__": | |
| main() |