| | import json |
| | import networkx as nx |
| | import matplotlib.pyplot as plt |
| | import os |
| |
|
| | |
| | index_file_path = "graphs/index.json" |
| |
|
| | |
| | def load_index_data(file_path): |
| | """Load the index.json file and parse its contents.""" |
| | with open(file_path, "r") as file: |
| | data = json.load(file) |
| | return data |
| |
|
| | def load_entity_file(entity_info): |
| | """Load the entity-specific JSON file if file_path is provided.""" |
| | file_path = entity_info.get("file_path") |
| | if file_path and os.path.exists(file_path): |
| | try: |
| | with open(file_path, "r") as file: |
| | data = json.load(file) |
| | return data |
| | except json.JSONDecodeError as e: |
| | print(f"Error loading JSON file at {file_path}: {e}") |
| | return None |
| | elif file_path: |
| | print(f"File not found: {file_path}") |
| | return None |
| |
|
| | def build_graph(data): |
| | """Builds a directed graph based on entities and relationships.""" |
| | G = nx.DiGraph() |
| |
|
| | |
| | excluded_nodes = {"patient_protection._tmp", "phsa_sec_340b", "medicade_tmp"} |
| | for entity_id, entity_info in data["entities"].items(): |
| | if entity_id in excluded_nodes: |
| | continue |
| | label = entity_info.get("label", entity_id) |
| | G.add_node(entity_id, label=label, domain=entity_info.get("inherits_from", "Default")) |
| |
|
| | |
| | entity_data = load_entity_file(entity_info) |
| | if isinstance(entity_data, dict): |
| | for relationship in entity_data.get("relationships", []): |
| | source = relationship["source"] |
| | target = relationship["target"] |
| | relationship_label = relationship["attributes"].get("relationship", "related_to") |
| | G.add_edge(source, target, label=relationship_label) |
| | else: |
| | print(f"Skipping entity {entity_id} due to invalid data format.") |
| |
|
| | |
| | for relationship in data["relationships"]: |
| | source = relationship["source"] |
| | target = relationship["target"] |
| | relationship_label = relationship["attributes"].get("relationship", "related_to") |
| | G.add_edge(source, target, label=relationship_label) |
| |
|
| | return G |
| |
|
| | |
| | def visualize_graph(G, title="Inferred Contextual Relationships"): |
| | """Visualizes the graph with nodes and relationships, using domain colors and improved layout.""" |
| | |
| | color_map = { |
| | "Legislation": "lightcoral", |
| | "Healthcare Systems": "lightgreen", |
| | "Healthcare Policies": "lightblue", |
| | "Default": "lightgrey" |
| | } |
| | |
| | |
| | node_colors = [color_map.get(G.nodes[node].get("domain", "Default"), "lightgrey") for node in G.nodes] |
| |
|
| | |
| | pos = nx.kamada_kawai_layout(G) |
| |
|
| | |
| | plt.figure(figsize=(15, 10)) |
| | nx.draw_networkx_nodes(G, pos, node_size=3000, node_color=node_colors, alpha=0.8) |
| | nx.draw_networkx_labels(G, pos, font_size=9, font_color="black", font_weight="bold") |
| |
|
| | |
| | nx.draw_networkx_edges(G, pos, arrowstyle="->", arrowsize=15, edge_color="gray", connectionstyle="arc3,rad=0.1") |
| | edge_labels = {(u, v): d["label"] for u, v, d in G.edges(data=True)} |
| | nx.draw_networkx_edge_labels(G, pos, edge_labels=edge_labels, font_color="red", font_size=8) |
| |
|
| | |
| | plt.title(title, fontsize=14) |
| | plt.axis("off") |
| | plt.show() |
| |
|
| | |
| | if __name__ == "__main__": |
| | |
| | data = load_index_data(index_file_path) |
| |
|
| | |
| | G = build_graph(data) |
| |
|
| | |
| | visualize_graph(G) |