Spaces:
Build error
Build error
| # coding=utf-8 | |
| # Copyright 2021 The Google Research Authors. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| # Lint as: python3 | |
| """Evaluation script for Nerf.""" | |
| import functools | |
| from os import path | |
| from absl import app | |
| from absl import flags | |
| import flax | |
| from flax.metrics import tensorboard | |
| from flax.training import checkpoints | |
| import jax | |
| from jax import random | |
| import numpy as np | |
| import tensorflow as tf | |
| import tensorflow_hub as tf_hub | |
| #import wandb | |
| import glob | |
| import cv2 | |
| import os | |
| from jaxnerf.nerf import datasets | |
| from jaxnerf.nerf import models | |
| from jaxnerf.nerf import utils | |
| FLAGS = flags.FLAGS | |
| utils.define_flags() | |
| #LPIPS_TFHUB_PATH = "@neural-rendering/lpips/distance/1" | |
| def compute_lpips(image1, image2, model): | |
| """Compute the LPIPS metric.""" | |
| # The LPIPS model expects a batch dimension. | |
| return model( | |
| tf.convert_to_tensor(image1[None, Ellipsis]), | |
| tf.convert_to_tensor(image2[None, Ellipsis]))[0] | |
| def main(unused_argv): | |
| # Hide the GPUs and TPUs from TF so it does not reserve memory on them for | |
| # LPIPS computation or dataset loading. | |
| tf.config.experimental.set_visible_devices([], "GPU") | |
| tf.config.experimental.set_visible_devices([], "TPU") | |
| #wandb.init(project="hf-flax-clip-nerf", entity="wandb", sync_tensorboard=True) | |
| rng = random.PRNGKey(20200823) | |
| if FLAGS.config is not None: | |
| utils.update_flags(FLAGS) | |
| if FLAGS.train_dir is None: | |
| raise ValueError("train_dir must be set. None set now.") | |
| if FLAGS.data_dir is None: | |
| raise ValueError("data_dir must be set. None set now.") | |
| dataset = datasets.get_dataset("test", FLAGS) | |
| rng, key = random.split(rng) | |
| model, init_variables = models.get_model(key, dataset.peek(), FLAGS) | |
| optimizer = flax.optim.Adam(FLAGS.lr_init).create(init_variables) | |
| state = utils.TrainState(optimizer=optimizer) | |
| del optimizer, init_variables | |
| #lpips_model = tf_hub.load(LPIPS_TFHUB_PATH) | |
| # Rendering is forced to be deterministic even if training was randomized, as | |
| # this eliminates "speckle" artifacts. | |
| def render_fn(variables, key_0, key_1, rays): | |
| return jax.lax.all_gather( | |
| model.apply(variables, key_0, key_1, rays, False), axis_name="batch") | |
| # pmap over only the data input. | |
| render_pfn = jax.pmap( | |
| render_fn, | |
| in_axes=(None, None, None, 0), | |
| donate_argnums=3, | |
| axis_name="batch", | |
| ) | |
| # Compiling to the CPU because it's faster and more accurate. | |
| ssim_fn = jax.jit( | |
| functools.partial(utils.compute_ssim, max_val=1.), backend="cpu") | |
| last_step = 0 | |
| out_dir = path.join(FLAGS.train_dir, | |
| "path_renders" if FLAGS.render_path else "test_preds") | |
| if not FLAGS.eval_once: | |
| summary_writer = tensorboard.SummaryWriter( | |
| path.join(FLAGS.train_dir, "eval")) | |
| while True: | |
| state = checkpoints.restore_checkpoint(FLAGS.train_dir, state) | |
| step = int(state.optimizer.state.step) | |
| if step <= last_step: | |
| continue | |
| if FLAGS.save_output and (not utils.isdir(out_dir)): | |
| utils.makedirs(out_dir) | |
| psnr_values = [] | |
| ssim_values = [] | |
| #lpips_values = [] | |
| if not FLAGS.eval_once: | |
| showcase_index = np.random.randint(0, dataset.size) | |
| for idx in range(dataset.sizerender_image): | |
| print(f"Evaluating {idx + 1}/{dataset.size}") | |
| batch = next(dataset) | |
| pred_color, pred_disp, pred_acc = utils.render_image( | |
| functools.partial(render_pfn, state.optimizer.target), | |
| batch["rays"], | |
| rng, | |
| FLAGS.dataset == "llff", | |
| chunk=FLAGS.chunk) | |
| if jax.host_id() != 0: # Only record via host 0. | |
| continue | |
| if not FLAGS.eval_once and idx == showcase_index: | |
| showcase_color = pred_color | |
| showcase_disp = pred_disp | |
| showcase_acc = pred_acc | |
| if not FLAGS.render_path: | |
| showcase_gt = batch["pixels"] | |
| if not FLAGS.render_path: | |
| psnr = utils.compute_psnr(((pred_color - batch["pixels"]) ** 2).mean()) | |
| ssim = ssim_fn(pred_color, batch["pixels"]) | |
| #lpips = compute_lpips(pred_color, batch["pixels"], lpips_model) | |
| print(f"PSNR = {psnr:.4f}, SSIM = {ssim:.4f}") | |
| psnr_values.append(float(psnr)) | |
| ssim_values.append(float(ssim)) | |
| #lpips_values.append(float(lpips)) | |
| if FLAGS.save_output: | |
| utils.save_img(pred_color, path.join(out_dir, "{:03d}.png".format(idx))) | |
| utils.save_img(pred_disp[Ellipsis, 0], | |
| path.join(out_dir, "disp_{:03d}.png".format(idx))) | |
| if (not FLAGS.eval_once) and (jax.host_id() == 0): | |
| summary_writer.image("pred_color", showcase_color, step) | |
| summary_writer.image("pred_disp", showcase_disp, step) | |
| summary_writer.image("pred_acc", showcase_acc, step) | |
| if not FLAGS.render_path: | |
| summary_writer.scalar("psnr", np.mean(np.array(psnr_values)), step) | |
| summary_writer.scalar("ssim", np.mean(np.array(ssim_values)), step) | |
| #summary_writer.scalar("lpips", np.mean(np.array(lpips_values)), step) | |
| summary_writer.image("target", showcase_gt, step) | |
| if FLAGS.save_output and (not FLAGS.render_path) and (jax.host_id() == 0): | |
| with utils.open_file(path.join(out_dir, f"psnrs_{step}.txt"), "w") as f: | |
| f.write(" ".join([str(v) for v in psnr_values])) | |
| with utils.open_file(path.join(out_dir, f"ssims_{step}.txt"), "w") as f: | |
| f.write(" ".join([str(v) for v in ssim_values])) | |
| #with utils.open_file(path.join(out_dir, f"lpips_{step}.txt"), "w") as f: | |
| #f.write(" ".join([str(v) for v in lpips_values])) | |
| with utils.open_file(path.join(out_dir, "psnr.txt"), "w") as f: | |
| f.write("{}".format(np.mean(np.array(psnr_values)))) | |
| with utils.open_file(path.join(out_dir, "ssim.txt"), "w") as f: | |
| f.write("{}".format(np.mean(np.array(ssim_values)))) | |
| #with utils.open_file(path.join(out_dir, "lpips.txt"), "w") as f: | |
| #f.write("{}".format(np.mean(np.array(lpips_values)))) | |
| imglist = glob.glob(os.path.join(out_dir, "[0-9][0-9][0-9].png")) | |
| sorted_files = sorted(imglist, key=lambda x: int(x.split('/')[-1].split('.')[0])) | |
| imglist2 = glob.glob(os.path.join(out_dir, "disp_[0-9][0-9][0-9].png")) | |
| sorted_files2 = sorted(imglist2, key=lambda x: int(x.split('/')[-1].split('.')[0].split('_')[-1])) | |
| fourcc = cv2.VideoWriter_fourcc(*'MP4V') | |
| fps = 10.0 | |
| out = cv2.VideoWriter(os.path.join(out_dir, "rendering_video.mp4"), fourcc, fps, | |
| (2 * img.shape[1], img.shape[0])) | |
| for i in range(len(imglist)): | |
| img = cv2.imread(imglist[i], cv2.IMREAD_COLOR) | |
| img2 = cv2.imread(imglist2[i], cv2.IMREAD_COLOR) | |
| catimg = np.concatenate((img, img2), axis=1) | |
| out.write(catimg) | |
| out.release() | |
| if FLAGS.eval_once: | |
| break | |
| if int(step) >= FLAGS.max_steps: | |
| break | |
| last_step = step | |
| if __name__ == "__main__": | |
| app.run(main) | |