File size: 3,216 Bytes
ce34030
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import os, sys
sys.path.append(os.path.join(os.path.dirname(__file__), '..', '..'))
import matplotlib
matplotlib.use('Agg')
import numpy as np
import networkx as nx
from io import BytesIO
from PIL import Image, ImageDraw
from matplotlib import pyplot as plt
from sklearn.decomposition import PCA
from singapo_utils.refs import graph_color_ref

def add_text(text, imgarr):
    '''
    Function to add text to image

    Args:
    - text (str): text to add
    - imgarr (np.array): image array

    Returns:
    - img (np.array): image array with text
    '''
    img = Image.fromarray(imgarr)
    I = ImageDraw.Draw(img)
    I.text((10, 10), text, fill='black')
    return np.asarray(img)

def get_color(ref, n_nodes):
    '''
    Function to color the nodes

    Args:
    - ref (list): list of color reference
    - n_nodes (int): number of nodes

    Returns:
    - colors (list): list of colors
    '''
    N = len(ref)
    colors = []
    for i in range(n_nodes):
        colors.append(np.array([[int(i) for i in ref[i%N][4:-1].split(',')]]) / 255.)
    return colors


def make_grid(images, cols=5):
    """
    Arrange list of images into a N x cols grid.
    
    Args:
    - images (list): List of Numpy arrays representing the images.
    - cols (int): Number of columns for the grid.
    
    Returns:
    - grid (numpy array): Numpy array representing the image grid.
    """
    # Determine the dimensions of each image
    img_h, img_w, _ = images[0].shape
    rows = len(images) // cols
    
    # Initialize a blank canvas
    grid = np.zeros((rows * img_h, cols * img_w, 3), dtype=images[0].dtype)
    
    # Place each image onto the grid
    for idx, img in enumerate(images):
        y = (idx // cols) * img_h
        x = (idx % cols) * img_w
        grid[y: y + img_h, x: x + img_w] = img
    
    return grid

def viz_graph(info_dict, res=256):
    '''
    Function to plot the directed graph

    Args:
    - info_dict (dict): output json containing the graph information
    - res (int): resolution of the image

    Returns:
    - img_arr (np.array): image array
    '''
    # build tree
    tree = info_dict['diffuse_tree']
    edges = []
    for node in tree:
        edges += [(node['id'], child) for child in node['children']]
    G = nx.DiGraph()
    G.add_edges_from(edges)

    # plot tree
    plt.figure(figsize=(res/100, res/100))

    colors = get_color(graph_color_ref, len(tree))
    pos = nx.nx_agraph.graphviz_layout(G, prog="twopi", args="")
    node_order = sorted(G.nodes())
    nx.draw(G, pos, node_color=colors, nodelist=node_order, edge_color='k', with_labels=False)
    
    buf = BytesIO()
    plt.savefig(buf, format="png", dpi=100)
    buf.seek(0)
    img = Image.open(buf)
    img_arr = np.asarray(img)
    buf.close()
    plt.clf()
    plt.close()
    return img_arr[:, :, :3]

def viz_patch_feat_pca(feat):
    pca = PCA(n_components=3)
    pca.fit(feat)
    feat_pca = pca.transform(feat)
    
    t = np.array(feat_pca)
    t_min = t.min(axis=0, keepdims=True)
    t_max = t.max(axis=0, keepdims=True)
    normalized_t = (t - t_min) / (t_max - t_min)

    array = (normalized_t * 255).astype(np.uint8)
    img_array = array.reshape(16, 16, 3)
    return img_array