fardeenKhadri commited on
Commit
f71a064
·
verified ·
1 Parent(s): 61eecce

Update inference.py

Browse files
Files changed (1) hide show
  1. inference.py +257 -261
inference.py CHANGED
@@ -1,261 +1,257 @@
1
- """
2
- @Date: 2021/09/19
3
- @description:
4
- """
5
- import json
6
- import os
7
- import argparse
8
- import cv2
9
- import numpy as np
10
- import torch
11
- import matplotlib.pyplot as plt
12
- import glob
13
-
14
- from tqdm import tqdm
15
- from PIL import Image
16
- from config.defaults import merge_from_file, get_config
17
- from dataset.mp3d_dataset import MP3DDataset
18
- from dataset.zind_dataset import ZindDataset
19
- from models.build import build_model
20
- from loss import GradLoss
21
- from postprocessing.post_process import post_process
22
- from preprocessing.pano_lsd_align import panoEdgeDetection, rotatePanorama
23
- from utils.boundary import corners2boundaries, layout2depth
24
- from utils.conversion import depth2xyz
25
- from utils.logger import get_logger
26
- from utils.misc import tensor2np_d, tensor2np
27
- from evaluation.accuracy import show_grad
28
- from models.lgt_net import LGT_Net
29
- from utils.writer import xyz2json
30
- from visualization.boundary import draw_boundaries
31
- from visualization.floorplan import draw_floorplan, draw_iou_floorplan
32
- from visualization.obj3d import create_3d_obj
33
-
34
-
35
- def parse_option():
36
- parser = argparse.ArgumentParser(description='Panorama Layout Transformer training and evaluation script')
37
- parser.add_argument('--img_glob',
38
- type=str,
39
- required=True,
40
- help='image glob path')
41
-
42
- parser.add_argument('--cfg',
43
- type=str,
44
- required=True,
45
- metavar='FILE',
46
- help='path of config file')
47
-
48
- parser.add_argument('--post_processing',
49
- type=str,
50
- default='manhattan',
51
- choices=['manhattan', 'atalanta', 'original'],
52
- help='post-processing type')
53
-
54
- parser.add_argument('--output_dir',
55
- type=str,
56
- default='src/output',
57
- help='path of output')
58
-
59
- parser.add_argument('--visualize_3d', action='store_true',
60
- help='visualize_3d')
61
-
62
- parser.add_argument('--output_3d', action='store_true',
63
- help='output_3d')
64
-
65
- parser.add_argument('--device',
66
- type=str,
67
- default='cuda',
68
- help='device')
69
-
70
- args = parser.parse_args()
71
- args.mode = 'test'
72
-
73
- print("arguments:")
74
- for arg in vars(args):
75
- print(arg, ":", getattr(args, arg))
76
- print("-" * 50)
77
- return args
78
-
79
-
80
- def visualize_2d(img, dt, show_depth=True, show_floorplan=True, show=False, save_path=None):
81
- dt_np = tensor2np_d(dt)
82
- dt_depth = dt_np['depth'][0]
83
- dt_xyz = depth2xyz(np.abs(dt_depth))
84
- dt_ratio = dt_np['ratio'][0][0]
85
- dt_boundaries = corners2boundaries(dt_ratio, corners_xyz=dt_xyz, step=None, visible=False, length=img.shape[1])
86
- vis_img = draw_boundaries(img, boundary_list=dt_boundaries, boundary_color=[0, 1, 0])
87
-
88
- if 'processed_xyz' in dt:
89
- dt_boundaries = corners2boundaries(dt_ratio, corners_xyz=dt['processed_xyz'][0], step=None, visible=False,
90
- length=img.shape[1])
91
- vis_img = draw_boundaries(vis_img, boundary_list=dt_boundaries, boundary_color=[1, 0, 0])
92
-
93
- if show_depth:
94
- dt_grad_img = show_depth_normal_grad(dt)
95
- grad_h = dt_grad_img.shape[0]
96
- vis_merge = [
97
- vis_img[0:-grad_h, :, :],
98
- dt_grad_img,
99
- ]
100
- vis_img = np.concatenate(vis_merge, axis=0)
101
- # vis_img = dt_grad_img.transpose(1, 2, 0)[100:]
102
-
103
- if show_floorplan:
104
- if 'processed_xyz' in dt:
105
- floorplan = draw_iou_floorplan(dt['processed_xyz'][0][..., ::2], dt_xyz[..., ::2],
106
- dt_board_color=[1, 0, 0, 1], gt_board_color=[0, 1, 0, 1])
107
- else:
108
- floorplan = show_alpha_floorplan(dt_xyz, border_color=[0, 1, 0, 1])
109
-
110
- vis_img = np.concatenate([vis_img, floorplan[:, 60:-60, :]], axis=1)
111
- if show:
112
- plt.imshow(vis_img)
113
- plt.show()
114
- if save_path:
115
- result = Image.fromarray((vis_img * 255).astype(np.uint8))
116
- result.save(save_path)
117
- return vis_img
118
-
119
-
120
- def preprocess(img_ori, q_error=0.7, refine_iter=3, vp_cache_path=None):
121
- # Align images with VP
122
- if os.path.exists(vp_cache_path):
123
- with open(vp_cache_path) as f:
124
- vp = [[float(v) for v in line.rstrip().split(' ')] for line in f.readlines()]
125
- vp = np.array(vp)
126
- else:
127
- # VP detection and line segment extraction
128
- _, vp, _, _, _, _, _ = panoEdgeDetection(img_ori,
129
- qError=q_error,
130
- refineIter=refine_iter)
131
- i_img = rotatePanorama(img_ori, vp[2::-1])
132
-
133
- if vp_cache_path is not None:
134
- with open(vp_cache_path, 'w') as f:
135
- for i in range(3):
136
- f.write('%.6f %.6f %.6f\n' % (vp[i, 0], vp[i, 1], vp[i, 2]))
137
-
138
- return i_img, vp
139
-
140
-
141
- def show_depth_normal_grad(dt):
142
- grad_conv = GradLoss().to(dt['depth'].device).grad_conv
143
- dt_grad_img = show_grad(dt['depth'][0], grad_conv, 50)
144
- dt_grad_img = cv2.resize(dt_grad_img, (1024, 60), interpolation=cv2.INTER_NEAREST)
145
- return dt_grad_img
146
-
147
-
148
- def show_alpha_floorplan(dt_xyz, side_l=512, border_color=None):
149
- if border_color is None:
150
- border_color = [1, 0, 0, 1]
151
- fill_color = [0.2, 0.2, 0.2, 0.2]
152
- dt_floorplan = draw_floorplan(xz=dt_xyz[..., ::2], fill_color=fill_color,
153
- border_color=border_color, side_l=side_l, show=False, center_color=[1, 0, 0, 1])
154
- dt_floorplan = Image.fromarray((dt_floorplan * 255).astype(np.uint8), mode='RGBA')
155
- back = np.zeros([side_l, side_l, len(fill_color)], dtype=np.float32)
156
- back[..., :] = [0.8, 0.8, 0.8, 1]
157
- back = Image.fromarray((back * 255).astype(np.uint8), mode='RGBA')
158
- iou_floorplan = Image.alpha_composite(back, dt_floorplan).convert("RGB")
159
- dt_floorplan = np.array(iou_floorplan) / 255.0
160
- return dt_floorplan
161
-
162
-
163
- def save_pred_json(xyz, ration, save_path):
164
- # xyz[..., -1] = -xyz[..., -1]
165
- json_data = xyz2json(xyz, ration)
166
- with open(save_path, 'w') as f:
167
- f.write(json.dumps(json_data, indent=4) + '\n')
168
- return json_data
169
-
170
-
171
- def inference():
172
- if len(img_paths) == 0:
173
- logger.error('No images found')
174
- return
175
-
176
- bar = tqdm(img_paths, ncols=100)
177
- for img_path in bar:
178
- if not os.path.isfile(img_path):
179
- logger.error(f'The {img_path} not is file')
180
- continue
181
- name = os.path.basename(img_path).split('.')[0]
182
- bar.set_description(name)
183
- img = np.array(Image.open(img_path).resize((1024, 512), Image.Resampling.BICUBIC))[..., :3]
184
- if args.post_processing is not None and 'manhattan' in args.post_processing:
185
- bar.set_description("Preprocessing")
186
- img, vp = preprocess(img, vp_cache_path=os.path.join(args.output_dir, f"{name}_vp.txt"))
187
-
188
- img = (img / 255.0).astype(np.float32)
189
- run_one_inference(img, model, args, name)
190
-
191
-
192
- def inference_dataset(dataset):
193
- bar = tqdm(dataset, ncols=100)
194
- for data in bar:
195
- bar.set_description(data['id'])
196
- run_one_inference(data['image'].transpose(1, 2, 0), model, args, name=data['id'], logger=logger)
197
-
198
-
199
- @torch.no_grad()
200
- def run_one_inference(img, model, args, name, logger, show=True, show_depth=True,
201
- show_floorplan=True, mesh_format='.gltf', mesh_resolution=512):
202
- model.eval()
203
- logger.info("model inference...")
204
- dt = model(torch.from_numpy(img.transpose(2, 0, 1)[None]).to(args.device))
205
- if args.post_processing != 'original':
206
- logger.info(f"post-processing, type:{args.post_processing}...")
207
- dt['processed_xyz'] = post_process(tensor2np(dt['depth']), type_name=args.post_processing)
208
-
209
- visualize_2d(img, dt,
210
- show_depth=show_depth,
211
- show_floorplan=show_floorplan,
212
- show=show,
213
- save_path=os.path.join(args.output_dir, f"{name}_pred.png"))
214
- output_xyz = dt['processed_xyz'][0] if 'processed_xyz' in dt else depth2xyz(tensor2np(dt['depth'][0]))
215
-
216
- logger.info(f"saving predicted layout json...")
217
- json_data = save_pred_json(output_xyz, tensor2np(dt['ratio'][0])[0],
218
- save_path=os.path.join(args.output_dir, f"{name}_pred.json"))
219
- # if args.visualize_3d:
220
- # from visualization.visualizer.visualizer import visualize_3d
221
- # visualize_3d(json_data, (img * 255).astype(np.uint8))
222
-
223
- if args.visualize_3d or args.output_3d:
224
- dt_boundaries = corners2boundaries(tensor2np(dt['ratio'][0])[0], corners_xyz=output_xyz, step=None,
225
- length=mesh_resolution if 'processed_xyz' in dt else None,
226
- visible=True if 'processed_xyz' in dt else False)
227
- dt_layout_depth = layout2depth(dt_boundaries, show=False)
228
-
229
- logger.info(f"creating 3d mesh ...")
230
- create_3d_obj(cv2.resize(img, dt_layout_depth.shape[::-1]), dt_layout_depth,
231
- save_path=os.path.join(args.output_dir, f"{name}_3d{mesh_format}") if args.output_3d else None,
232
- mesh=True, show=args.visualize_3d)
233
-
234
-
235
- if __name__ == '__main__':
236
- logger = get_logger()
237
- args = parse_option()
238
- config = get_config(args)
239
-
240
- if ('cuda' in args.device or 'cuda' in config.TRAIN.DEVICE) and not torch.cuda.is_available():
241
- logger.info(f'The {args.device} is not available, will use cpu ...')
242
- config.defrost()
243
- args.device = "cpu"
244
- config.TRAIN.DEVICE = "cpu"
245
- config.freeze()
246
-
247
- model, _, _, _ = build_model(config, logger)
248
- os.makedirs(args.output_dir, exist_ok=True)
249
- img_paths = sorted(glob.glob(args.img_glob))
250
-
251
- inference()
252
-
253
- # dataset = MP3DDataset(root_dir='./src/dataset/mp3d', mode='test', split_list=[
254
- # ['7y3sRwLe3Va', '155fac2d50764bf09feb6c8f33e8fb76'],
255
- # ['e9zR4mvMWw7', 'c904c55a5d0e420bbd6e4e030b9fe5b4'],
256
- # ])
257
- # dataset = ZindDataset(root_dir='./src/dataset/zind', mode='test', split_list=[
258
- # '1169_pano_21',
259
- # '0583_pano_59',
260
- # ], vp_align=True)
261
- # inference_dataset(dataset)
 
1
+ import json
2
+ import os
3
+ import argparse
4
+ import cv2
5
+ import numpy as np
6
+ import torch
7
+ import matplotlib.pyplot as plt
8
+ import glob
9
+
10
+ from tqdm import tqdm
11
+ from PIL import Image
12
+ from config.defaults import merge_from_file, get_config
13
+ from dataset.mp3d_dataset import MP3DDataset
14
+ from dataset.zind_dataset import ZindDataset
15
+ from models.build import build_model
16
+ from loss import GradLoss
17
+ from postprocessing.post_process import post_process
18
+ from preprocessing.pano_lsd_align import panoEdgeDetection, rotatePanorama
19
+ from utils.boundary import corners2boundaries, layout2depth
20
+ from utils.conversion import depth2xyz
21
+ from utils.logger import get_logger
22
+ from utils.misc import tensor2np_d, tensor2np
23
+ from evaluation.accuracy import show_grad
24
+ from models.lgt_net import LGT_Net
25
+ from utils.writer import xyz2json
26
+ from visualization.boundary import draw_boundaries
27
+ from visualization.floorplan import draw_floorplan, draw_iou_floorplan
28
+ from visualization.obj3d import create_3d_obj
29
+
30
+
31
+ def parse_option():
32
+ parser = argparse.ArgumentParser(description='Panorama Layout Transformer training and evaluation script')
33
+ parser.add_argument('--img_glob',
34
+ type=str,
35
+ required=True,
36
+ help='image glob path')
37
+
38
+ parser.add_argument('--cfg',
39
+ type=str,
40
+ required=True,
41
+ metavar='FILE',
42
+ help='path of config file')
43
+
44
+ parser.add_argument('--post_processing',
45
+ type=str,
46
+ default='manhattan',
47
+ choices=['manhattan', 'atalanta', 'original'],
48
+ help='post-processing type')
49
+
50
+ parser.add_argument('--output_dir',
51
+ type=str,
52
+ default='src/output',
53
+ help='path of output')
54
+
55
+ parser.add_argument('--visualize_3d', action='store_true',
56
+ help='visualize_3d')
57
+
58
+ parser.add_argument('--output_3d', action='store_true',
59
+ help='output_3d')
60
+
61
+ parser.add_argument('--device',
62
+ type=str,
63
+ default='cuda',
64
+ help='device')
65
+
66
+ args = parser.parse_args()
67
+ args.mode = 'test'
68
+
69
+ print("arguments:")
70
+ for arg in vars(args):
71
+ print(arg, ":", getattr(args, arg))
72
+ print("-" * 50)
73
+ return args
74
+
75
+
76
+ def visualize_2d(img, dt, show_depth=True, show_floorplan=True, show=False, save_path=None):
77
+ dt_np = tensor2np_d(dt)
78
+ dt_depth = dt_np['depth'][0]
79
+ dt_xyz = depth2xyz(np.abs(dt_depth))
80
+ dt_ratio = dt_np['ratio'][0][0]
81
+ dt_boundaries = corners2boundaries(dt_ratio, corners_xyz=dt_xyz, step=None, visible=False, length=img.shape[1])
82
+ vis_img = draw_boundaries(img, boundary_list=dt_boundaries, boundary_color=[0, 1, 0])
83
+
84
+ if 'processed_xyz' in dt:
85
+ dt_boundaries = corners2boundaries(dt_ratio, corners_xyz=dt['processed_xyz'][0], step=None, visible=False,
86
+ length=img.shape[1])
87
+ vis_img = draw_boundaries(vis_img, boundary_list=dt_boundaries, boundary_color=[1, 0, 0])
88
+
89
+ if show_depth:
90
+ dt_grad_img = show_depth_normal_grad(dt)
91
+ grad_h = dt_grad_img.shape[0]
92
+ vis_merge = [
93
+ vis_img[0:-grad_h, :, :],
94
+ dt_grad_img,
95
+ ]
96
+ vis_img = np.concatenate(vis_merge, axis=0)
97
+ # vis_img = dt_grad_img.transpose(1, 2, 0)[100:]
98
+
99
+ if show_floorplan:
100
+ if 'processed_xyz' in dt:
101
+ floorplan = draw_iou_floorplan(dt['processed_xyz'][0][..., ::2], dt_xyz[..., ::2],
102
+ dt_board_color=[1, 0, 0, 1], gt_board_color=[0, 1, 0, 1])
103
+ else:
104
+ floorplan = show_alpha_floorplan(dt_xyz, border_color=[0, 1, 0, 1])
105
+
106
+ vis_img = np.concatenate([vis_img, floorplan[:, 60:-60, :]], axis=1)
107
+ if show:
108
+ plt.imshow(vis_img)
109
+ plt.show()
110
+ if save_path:
111
+ result = Image.fromarray((vis_img * 255).astype(np.uint8))
112
+ result.save(save_path)
113
+ return vis_img
114
+
115
+
116
+ def preprocess(img_ori, q_error=0.7, refine_iter=3, vp_cache_path=None):
117
+ # Align images with VP
118
+ if os.path.exists(vp_cache_path):
119
+ with open(vp_cache_path) as f:
120
+ vp = [[float(v) for v in line.rstrip().split(' ')] for line in f.readlines()]
121
+ vp = np.array(vp)
122
+ else:
123
+ # VP detection and line segment extraction
124
+ _, vp, _, _, _, _, _ = panoEdgeDetection(img_ori,
125
+ qError=q_error,
126
+ refineIter=refine_iter)
127
+ i_img = rotatePanorama(img_ori, vp[2::-1])
128
+
129
+ if vp_cache_path is not None:
130
+ with open(vp_cache_path, 'w') as f:
131
+ for i in range(3):
132
+ f.write('%.6f %.6f %.6f\n' % (vp[i, 0], vp[i, 1], vp[i, 2]))
133
+
134
+ return i_img, vp
135
+
136
+
137
+ def show_depth_normal_grad(dt):
138
+ grad_conv = GradLoss().to(dt['depth'].device).grad_conv
139
+ dt_grad_img = show_grad(dt['depth'][0], grad_conv, 50)
140
+ dt_grad_img = cv2.resize(dt_grad_img, (1024, 60), interpolation=cv2.INTER_NEAREST)
141
+ return dt_grad_img
142
+
143
+
144
+ def show_alpha_floorplan(dt_xyz, side_l=512, border_color=None):
145
+ if border_color is None:
146
+ border_color = [1, 0, 0, 1]
147
+ fill_color = [0.2, 0.2, 0.2, 0.2]
148
+ dt_floorplan = draw_floorplan(xz=dt_xyz[..., ::2], fill_color=fill_color,
149
+ border_color=border_color, side_l=side_l, show=False, center_color=[1, 0, 0, 1])
150
+ dt_floorplan = Image.fromarray((dt_floorplan * 255).astype(np.uint8), mode='RGBA')
151
+ back = np.zeros([side_l, side_l, len(fill_color)], dtype=np.float32)
152
+ back[..., :] = [0.8, 0.8, 0.8, 1]
153
+ back = Image.fromarray((back * 255).astype(np.uint8), mode='RGBA')
154
+ iou_floorplan = Image.alpha_composite(back, dt_floorplan).convert("RGB")
155
+ dt_floorplan = np.array(iou_floorplan) / 255.0
156
+ return dt_floorplan
157
+
158
+
159
+ def save_pred_json(xyz, ration, save_path):
160
+ # xyz[..., -1] = -xyz[..., -1]
161
+ json_data = xyz2json(xyz, ration)
162
+ with open(save_path, 'w') as f:
163
+ f.write(json.dumps(json_data, indent=4) + '\n')
164
+ return json_data
165
+
166
+
167
+ def inference():
168
+ if len(img_paths) == 0:
169
+ logger.error('No images found')
170
+ return
171
+
172
+ bar = tqdm(img_paths, ncols=100)
173
+ for img_path in bar:
174
+ if not os.path.isfile(img_path):
175
+ logger.error(f'The {img_path} not is file')
176
+ continue
177
+ name = os.path.basename(img_path).split('.')[0]
178
+ bar.set_description(name)
179
+ img = np.array(Image.open(img_path).resize((1024, 512), Image.Resampling.BICUBIC))[..., :3]
180
+ if args.post_processing is not None and 'manhattan' in args.post_processing:
181
+ bar.set_description("Preprocessing")
182
+ img, vp = preprocess(img, vp_cache_path=os.path.join(args.output_dir, f"{name}_vp.txt"))
183
+
184
+ img = (img / 255.0).astype(np.float32)
185
+ run_one_inference(img, model, args, name)
186
+
187
+
188
+ def inference_dataset(dataset):
189
+ bar = tqdm(dataset, ncols=100)
190
+ for data in bar:
191
+ bar.set_description(data['id'])
192
+ run_one_inference(data['image'].transpose(1, 2, 0), model, args, name=data['id'], logger=logger)
193
+
194
+
195
+ @torch.no_grad()
196
+ def run_one_inference(img, model, args, name, logger, show=True, show_depth=True,
197
+ show_floorplan=True, mesh_format='.gltf', mesh_resolution=512):
198
+ model.eval()
199
+ logger.info("model inference...")
200
+ dt = model(torch.from_numpy(img.transpose(2, 0, 1)[None]).to(args.device))
201
+ if args.post_processing != 'original':
202
+ logger.info(f"post-processing, type:{args.post_processing}...")
203
+ dt['processed_xyz'] = post_process(tensor2np(dt['depth']), type_name=args.post_processing)
204
+
205
+ visualize_2d(img, dt,
206
+ show_depth=show_depth,
207
+ show_floorplan=show_floorplan,
208
+ show=show,
209
+ save_path=os.path.join(args.output_dir, f"{name}_pred.png"))
210
+ output_xyz = dt['processed_xyz'][0] if 'processed_xyz' in dt else depth2xyz(tensor2np(dt['depth'][0]))
211
+
212
+ logger.info(f"saving predicted layout json...")
213
+ json_data = save_pred_json(output_xyz, tensor2np(dt['ratio'][0])[0],
214
+ save_path=os.path.join(args.output_dir, f"{name}_pred.json"))
215
+ # if args.visualize_3d:
216
+ # from visualization.visualizer.visualizer import visualize_3d
217
+ # visualize_3d(json_data, (img * 255).astype(np.uint8))
218
+
219
+ if args.visualize_3d or args.output_3d:
220
+ dt_boundaries = corners2boundaries(tensor2np(dt['ratio'][0])[0], corners_xyz=output_xyz, step=None,
221
+ length=mesh_resolution if 'processed_xyz' in dt else None,
222
+ visible=True if 'processed_xyz' in dt else False)
223
+ dt_layout_depth = layout2depth(dt_boundaries, show=False)
224
+
225
+ logger.info(f"creating 3d mesh ...")
226
+ create_3d_obj(cv2.resize(img, dt_layout_depth.shape[::-1]), dt_layout_depth,
227
+ save_path=os.path.join(args.output_dir, f"{name}_3d{mesh_format}") if args.output_3d else None,
228
+ mesh=True, show=args.visualize_3d)
229
+
230
+
231
+ if __name__ == '__main__':
232
+ logger = get_logger()
233
+ args = parse_option()
234
+ config = get_config(args)
235
+
236
+ if ('cuda' in args.device or 'cuda' in config.TRAIN.DEVICE) and not torch.cuda.is_available():
237
+ logger.info(f'The {args.device} is not available, will use cpu ...')
238
+ config.defrost()
239
+ args.device = "cpu"
240
+ config.TRAIN.DEVICE = "cpu"
241
+ config.freeze()
242
+
243
+ model, _, _, _ = build_model(config, logger)
244
+ os.makedirs(args.output_dir, exist_ok=True)
245
+ img_paths = sorted(glob.glob(args.img_glob))
246
+
247
+ inference()
248
+
249
+ # dataset = MP3DDataset(root_dir='./src/dataset/mp3d', mode='test', split_list=[
250
+ # ['7y3sRwLe3Va', '155fac2d50764bf09feb6c8f33e8fb76'],
251
+ # ['e9zR4mvMWw7', 'c904c55a5d0e420bbd6e4e030b9fe5b4'],
252
+ # ])
253
+ # dataset = ZindDataset(root_dir='./src/dataset/zind', mode='test', split_list=[
254
+ # '1169_pano_21',
255
+ # '0583_pano_59',
256
+ # ], vp_align=True)
257
+ # inference_dataset(dataset)