| | import requests |
| | import networkx as nx |
| | import matplotlib.pyplot as plt |
| |
|
| | |
| | base_url = "http://localhost:5000" |
| |
|
| | def fetch_relationships(node_id, direction="down"): |
| | """Fetch relationships for the specified node in the given direction (up or down).""" |
| | response = requests.get(f"{base_url}/traverse_node?node_id={node_id}&direction={direction}") |
| | return response.json().get("traversal_path", {}) |
| |
|
| | def build_graph_from_relationships(node_id): |
| | """Builds a NetworkX graph based on recursive relationship traversal.""" |
| | |
| | G = nx.DiGraph() |
| |
|
| | |
| | descendants_data = fetch_relationships(node_id, direction="down") |
| | ancestors_data = fetch_relationships(node_id, direction="up") |
| |
|
| | |
| | add_nodes_and_edges(G, descendants_data) |
| | add_nodes_and_edges(G, ancestors_data) |
| |
|
| | return G |
| |
|
| | def add_nodes_and_edges(G, node, visited=None): |
| | """Recursive function to add nodes and edges from a traversal hierarchy to a NetworkX graph.""" |
| | if visited is None: |
| | visited = set() |
| |
|
| | node_id = node.get("node_id") |
| | if not node_id or node_id in visited: |
| | return |
| | visited.add(node_id) |
| |
|
| | |
| | G.add_node(node_id, label=node_id) |
| |
|
| | |
| | for child in node.get("descendants", []): |
| | child_id = child.get("node_id") |
| | relationship = child.get("relationship", "related_to") |
| | G.add_edge(node_id, child_id, label=relationship) |
| | add_nodes_and_edges(G, child, visited) |
| |
|
| | |
| | for ancestor in node.get("ancestors", []): |
| | ancestor_id = ancestor.get("node_id") |
| | relationship = ancestor.get("relationship", "related_to") |
| | G.add_edge(ancestor_id, node_id, label=relationship) |
| | add_nodes_and_edges(G, ancestor, visited) |
| |
|
| | def visualize_graph(G, title="Graph Structure and Relationships"): |
| | """Visualize the graph using matplotlib and networkx.""" |
| | plt.figure(figsize=(12, 8)) |
| | pos = nx.spring_layout(G) |
| |
|
| | |
| | nx.draw_networkx_nodes(G, pos, node_size=3000, node_color="skyblue", alpha=0.8) |
| | nx.draw_networkx_labels(G, pos, font_size=10, font_color="black") |
| |
|
| | |
| | nx.draw_networkx_edges(G, pos, edge_color="gray", arrows=True) |
| | edge_labels = {(u, v): d["label"] for u, v, d in G.edges(data=True)} |
| | nx.draw_networkx_edge_labels(G, pos, edge_labels=edge_labels, font_color="red") |
| |
|
| | |
| | plt.title(title) |
| | plt.axis("off") |
| | plt.show() |
| |
|
| | |
| | print("\n--- Loading Graph ---") |
| | graph_data = {"graph_file": "graphs/PHSA/phsa_sec_340b.json"} |
| | response = requests.post(f"{base_url}/load_graph", json=graph_data) |
| | print("Load Graph Response:", response.json()) |
| |
|
| | |
| | print("\n--- Building Graph for Visualization ---") |
| | G = build_graph_from_relationships("340B Program") |
| | visualize_graph(G, title="340B Program - Inferred Contextual Relationships") |