Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
|
@@ -1,18 +1,171 @@
|
|
| 1 |
import gradio
|
| 2 |
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 18 |
gradio_interface.launch()
|
|
|
|
|
|
| 1 |
import gradio
|
| 2 |
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
import torch
|
| 5 |
+
from torch_geometric.loader import DataLoader
|
| 6 |
+
|
| 7 |
+
import utils.clean_data as cd
|
| 8 |
+
import utils.shape_features as sf
|
| 9 |
+
import utils.node_features as nf
|
| 10 |
+
import utils.edge_features as ef
|
| 11 |
+
|
| 12 |
+
import json
|
| 13 |
+
|
| 14 |
+
node_model_path = 'utils/emb_model/Node_64.pt'
|
| 15 |
+
edge_model_path = 'utils/emb_model/Edge_64.pt'
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class InfoGraph(nn.Module):
|
| 19 |
+
def __init__(self, hidden_dim, num_gc_layers, alpha=0.5, beta=1., gamma=.1):
|
| 20 |
+
super(InfoGraph, self).__init__()
|
| 21 |
+
|
| 22 |
+
self.alpha = alpha
|
| 23 |
+
self.beta = beta
|
| 24 |
+
self.gamma = gamma
|
| 25 |
+
self.prior = False
|
| 26 |
+
|
| 27 |
+
self.embedding_dim = mi_units = hidden_dim * num_gc_layers
|
| 28 |
+
self.encoder = Encoder(dataset_num_features, hidden_dim, num_gc_layers)
|
| 29 |
+
|
| 30 |
+
self.local_d = FF(self.embedding_dim)
|
| 31 |
+
self.global_d = FF(self.embedding_dim)
|
| 32 |
+
# self.local_d = MI1x1ConvNet(self.embedding_dim, mi_units)
|
| 33 |
+
# self.global_d = MIFCNet(self.embedding_dim, mi_units)
|
| 34 |
+
|
| 35 |
+
if self.prior:
|
| 36 |
+
self.prior_d = PriorDiscriminator(self.embedding_dim)
|
| 37 |
+
|
| 38 |
+
self.init_emb()
|
| 39 |
+
|
| 40 |
+
def init_emb(self):
|
| 41 |
+
initrange = -1.5 / self.embedding_dim
|
| 42 |
+
for m in self.modules():
|
| 43 |
+
if isinstance(m, nn.Linear):
|
| 44 |
+
torch.nn.init.xavier_uniform_(m.weight.data)
|
| 45 |
+
if m.bias is not None:
|
| 46 |
+
m.bias.data.fill_(0.0)
|
| 47 |
+
|
| 48 |
+
def forward(self, x, edge_index, batch, num_graphs):
|
| 49 |
+
# batch_size = data.num_graphs
|
| 50 |
+
if x is None:
|
| 51 |
+
x = torch.ones(batch.shape[0]).to(device)
|
| 52 |
+
|
| 53 |
+
y, M = self.encoder(x, edge_index, batch)
|
| 54 |
+
|
| 55 |
+
g_enc = self.global_d(y)
|
| 56 |
+
l_enc = self.local_d(M)
|
| 57 |
+
|
| 58 |
+
mode='fd'
|
| 59 |
+
measure='JSD'
|
| 60 |
+
local_global_loss = local_global_loss_(l_enc, g_enc, edge_index, batch, measure)
|
| 61 |
+
|
| 62 |
+
if self.prior:
|
| 63 |
+
prior = torch.rand_like(y)
|
| 64 |
+
term_a = torch.log(self.prior_d(prior)).mean()
|
| 65 |
+
term_b = torch.log(1.0 - self.prior_d(y)).mean()
|
| 66 |
+
PRIOR = - (term_a + term_b) * self.gamma
|
| 67 |
+
else:
|
| 68 |
+
PRIOR = 0
|
| 69 |
+
|
| 70 |
+
return local_global_loss + PRIOR
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def outline_embedding(wkt, wall):
|
| 74 |
+
wall_f, wkt_f = cd.read_wall_wkt(wall, wkt)
|
| 75 |
+
apa_wall, apa_geo = cd.clean_geometry(wall_f, wkt_f)
|
| 76 |
+
|
| 77 |
+
apa_geo = apa_geo
|
| 78 |
+
apa_line = apa_geo.boundary
|
| 79 |
+
|
| 80 |
+
apa_wall_O = cd.exterior_wall(apa_line, apa_wall)
|
| 81 |
+
apa_coor = cd.geo_coor(apa_geo)
|
| 82 |
+
|
| 83 |
+
xarr4cv, yarr4cv = apa_geo.exterior.coords.xy
|
| 84 |
+
x4cv = xarr4cv.tolist()
|
| 85 |
+
y4cv = yarr4cv.tolist()
|
| 86 |
+
|
| 87 |
+
scale = 100000
|
| 88 |
+
xmin_abs = abs(min(x4cv))
|
| 89 |
+
ymin_abs = abs(min(y4cv))
|
| 90 |
+
|
| 91 |
+
p_4_cv = cd.points4cv(x4cv, y4cv, xmin_abs, ymin_abs, scale)
|
| 92 |
+
|
| 93 |
+
grid_points = cd.gridpoints(apa_geo, 1)
|
| 94 |
+
|
| 95 |
+
Dir_S_longestedge, Dir_N_longestedge, Dir_W_longestedge, Dir_E_longestedge, Dir_S_max, Dir_N_max, Dir_W_max, Dir_E_max, Facade_length, Facade_ratio = sf.wall_direction_ratio(apa_line, apa_wall)
|
| 96 |
+
Perimeter = sf.apartment_perimeter(apa_geo)
|
| 97 |
+
Area = sf.apartment_area(apa_geo)
|
| 98 |
+
BBox_width_x, BBox_height_y, Aspect_ratio, Extent, ULC_x, ULC_y, LRC_x, LRC_y = sf.boundingbox_features(apa_geo)
|
| 99 |
+
Max_diameter = sf.max_diameter(apa_geo)
|
| 100 |
+
Fractality = sf.fractality(apa_geo)
|
| 101 |
+
Circularity = sf.circularity(apa_geo)
|
| 102 |
+
Outer_radius = sf.outer_radius(p_4_cv, xmin_abs, ymin_abs, scale)
|
| 103 |
+
Inner_radius = sf.inner_radius(apa_geo, apa_line)
|
| 104 |
+
Dist_mean, Dist_sigma, Roundness = sf.roundness_features(apa_line)
|
| 105 |
+
Compactness = sf.compactness(apa_geo)
|
| 106 |
+
Equivalent_diameter = sf.equivalent_diameter(apa_geo)
|
| 107 |
+
Shape_membership_index = sf.shape_membership_index(apa_line)
|
| 108 |
+
Convexity, Hull_geo = sf.convexity(p_4_cv, apa_geo, xmin_abs, ymin_abs, scale)
|
| 109 |
+
Rectangularity, Rect_phi, Rect_width, Rect_height = sf.rectangle_features(p_4_cv, apa_geo, xmin_abs, ymin_abs, scale)
|
| 110 |
+
Squareness = sf.squareness(apa_geo)
|
| 111 |
+
Moment_index = sf.moment_index(apa_geo, Convexity, Compactness)
|
| 112 |
+
nDetour_index = sf.ndetour_index(apa_geo, Hull_geo)
|
| 113 |
+
nCohesion_index = sf.ncohesion_index(apa_geo, grid_points)
|
| 114 |
+
nProximity_index, nSpin_index = sf.nproximity_nspin_index(apa_geo, grid_points)
|
| 115 |
+
nExchange_index = sf.nexchange_index(apa_geo)
|
| 116 |
+
nPerimeter_index = sf.nperimeter_index(apa_geo)
|
| 117 |
+
nDepth_index = sf.ndepth_index(apa_geo, apa_line, grid_points)
|
| 118 |
+
nGirth_index = sf.ngirth_index(apa_geo, Inner_radius)
|
| 119 |
+
nRange_index = sf.nrange_index(apa_geo, Outer_radius)
|
| 120 |
+
nTraversal_index = sf.ntraversal_index(apa_geo, apa_line)
|
| 121 |
+
|
| 122 |
+
shape = [Dir_S_longestedge, Dir_N_longestedge, Dir_W_longestedge, Dir_E_longestedge, Dir_S_max, Dir_N_max, Dir_W_max, Dir_E_max, Facade_length, Facade_ratio,
|
| 123 |
+
Perimeter, Area,
|
| 124 |
+
BBox_width_x, BBox_height_y, Aspect_ratio, Extent, ULC_x, ULC_y, LRC_x, LRC_y,
|
| 125 |
+
Max_diameter, Fractality, Circularity, Outer_radius, Inner_radius,
|
| 126 |
+
Dist_mean, Dist_sigma, Roundness,
|
| 127 |
+
Compactness, Equivalent_diameter, Shape_membership_index, Convexity,
|
| 128 |
+
Rectangularity, Rect_phi, Rect_width, Rect_height,
|
| 129 |
+
Squareness, Moment_index, nDetour_index, nCohesion_index,
|
| 130 |
+
nProximity_index, nExchange_index, nSpin_index, nPerimeter_index,
|
| 131 |
+
nDepth_index, nGirth_index, nRange_index, nTraversal_index]
|
| 132 |
+
shape = [float(i) for i in shape]
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
node_graph = nf.node_graph(apa_coor, apa_geo)
|
| 136 |
+
node_model = torch.load(node_model_path)
|
| 137 |
+
node_model.eval()
|
| 138 |
+
|
| 139 |
+
node_dataloader = DataLoader(node_graph, batch_size=1)
|
| 140 |
+
node_emb = node_model.encoder.get_embeddings(node_dataloader)
|
| 141 |
+
node = node_emb[0].tolist()
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
edge_graph = ef.edge_graph(apa_line, apa_wall)
|
| 145 |
+
edge_model = torch.load(edge_model_path)
|
| 146 |
+
edge_model.eval()
|
| 147 |
+
|
| 148 |
+
edge_dataloader = DataLoader(edge_graph, batch_size=1)
|
| 149 |
+
edge_emb = edge_model.encoder.get_embeddings(edge_dataloader)
|
| 150 |
+
edge = edge_emb[0].tolist()
|
| 151 |
+
|
| 152 |
+
json = {"shape": shape,
|
| 153 |
+
"node": node,
|
| 154 |
+
"edge": edge}
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
return json
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
gradio_interface = gradio.Interface(fn=outline_embedding,
|
| 162 |
+
inputs = [gradio.Textbox(type="text", label="wkt", show_legend=True),
|
| 163 |
+
gradio.Textbox(type="text", label="wall", show_legend=True),
|
| 164 |
+
gradio.Textbox(type="text", label="highway", show_legend=True),
|
| 165 |
+
gradio.Textbox(type="text", label="primary", show_legend=True),
|
| 166 |
+
gradio.Textbox(type="text", label="railway", show_legend=True)],
|
| 167 |
+
outputs = "json",
|
| 168 |
+
title="outline embedding")
|
| 169 |
+
|
| 170 |
gradio_interface.launch()
|
| 171 |
+
|