Spaces:
Build error
Build error
| # coding=utf-8 | |
| # Copyright 2021 The Google Research Authors. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| # Lint as: python3 | |
| """Training script for Nerf.""" | |
| import functools | |
| import gc | |
| import time | |
| from absl import app | |
| from absl import flags | |
| import flax | |
| from flax.metrics import tensorboard | |
| from flax.training import checkpoints | |
| import jax | |
| from jax import config | |
| from jax import random | |
| import jax.numpy as jnp | |
| import numpy as np | |
| # import wandb | |
| from tqdm import tqdm | |
| from jaxnerf.nerf import datasets | |
| from jaxnerf.nerf import models | |
| from jaxnerf.nerf import utils | |
| from jaxnerf.nerf import clip_utils | |
| FLAGS = flags.FLAGS | |
| utils.define_flags() | |
| config.parse_flags_with_absl() | |
| # set up TPU for colab | |
| import os | |
| if "COLAB_TPU_ADDR" in os.environ: | |
| import jax.tools.colab_tpu | |
| jax.tools.colab_tpu.setup_tpu() | |
| print(f"detected device: {jax.local_devices()}") | |
| def train_step(model, clip_model, rng, state, batch, lr, step, K):#, clip_grad): | |
| # TODO make clip_grad input enable | |
| """One optimization step. | |
| Args: | |
| model: The linen model. | |
| rng: jnp.ndarray, random number generator. | |
| state: utils.TrainState, state of the model/optimizer. | |
| batch: dict, a mini-batch of data for training. | |
| lr: float, real-time learning rate. | |
| Returns: | |
| new_state: utils.TrainState, new training state. | |
| stats: list. [(loss, psnr), (loss_coarse, psnr_coarse)]. | |
| rng: jnp.ndarray, updated random number generator. | |
| """ | |
| rng, key_0, key_1 = random.split(rng, 3) | |
| def loss_fn(variables): | |
| rays = batch["rays"] | |
| ret = model.apply(variables, key_0, key_1, rays, FLAGS.randomized) | |
| if len(ret) not in (1, 2): | |
| raise ValueError( | |
| "ret should contain either 1 set of output (coarse only), or 2 sets" | |
| "of output (coarse as ret[0] and fine as ret[1]).") | |
| # The main prediction is always at the end of the ret list. | |
| rgb, unused_disp, unused_acc = ret[-1] | |
| loss = ((rgb - batch["pixels"][Ellipsis, :3]) ** 2).mean() | |
| psnr = utils.compute_psnr(loss) | |
| if len(ret) > 1: | |
| # If there are both coarse and fine predictions, we compute the loss for | |
| # the coarse prediction (ret[0]) as well. | |
| rgb_c, unused_disp_c, unused_acc_c = ret[0] | |
| loss_c = ((rgb_c - batch["pixels"][Ellipsis, :3]) ** 2).mean() | |
| psnr_c = utils.compute_psnr(loss_c) | |
| else: | |
| loss_c = 0. | |
| psnr_c = 0. | |
| def tree_sum_fn(fn): | |
| return jax.tree_util.tree_reduce(lambda x, y: x + fn(y), | |
| variables, initializer=0) | |
| weight_l2 = (tree_sum_fn(lambda z: jnp.sum(z ** 2)) / | |
| tree_sum_fn(lambda z: jnp.prod(jnp.array(z.shape)))) | |
| total_loss = loss + loss_c + FLAGS.weight_decay_mult * weight_l2 | |
| stats = utils.Stats(loss=loss, psnr=psnr, loss_c=loss_c, | |
| psnr_c=psnr_c, weight_l2=weight_l2) | |
| return total_loss, stats | |
| (_, stats), grad = ( | |
| jax.value_and_grad(loss_fn, has_aux=True)(state.optimizer.target)) | |
| grad = jax.lax.pmean(grad, axis_name="batch") | |
| stats = jax.lax.pmean(stats, axis_name="batch") | |
| # Clip the gradient by value. | |
| if FLAGS.grad_max_val > 0: | |
| clip_fn = lambda z: jnp.clip(z, -FLAGS.grad_max_val, FLAGS.grad_max_val) | |
| grad = jax.tree_util.tree_map(clip_fn, grad) | |
| # Clip the (possibly value-clipped) gradient by norm. | |
| if FLAGS.grad_max_norm > 0: | |
| grad_norm = jnp.sqrt( | |
| jax.tree_util.tree_reduce( | |
| lambda x, y: x + jnp.sum(y ** 2), grad, initializer=0)) | |
| mult = jnp.minimum(1, FLAGS.grad_max_norm / (1e-7 + grad_norm)) | |
| grad = jax.tree_util.tree_map(lambda z: mult * z, grad) | |
| #return grad, state, rng | |
| new_optimizer = state.optimizer.apply_gradient(grad, learning_rate =lr) | |
| new_state = state.replace(optimizer=new_optimizer) | |
| return new_state, stats, rng | |
| def update_step(state, grad, lr): | |
| new_optimizer = state.optimizer.apply_gradient(grad, learning_rate=lr) | |
| new_state = state.replace(optimizer=new_optimizer) | |
| return new_state | |
| def main(unused_argv): | |
| #wandb.init(project="hf-flax-clip-nerf", entity="wandb", sync_tensorboard=True) | |
| rng = random.PRNGKey(20200823) | |
| # Shift the numpy random seed by host_id() to shuffle data loaded by different | |
| # hosts. | |
| np.random.seed(20201473 + jax.host_id()) | |
| if FLAGS.config is not None: | |
| utils.update_flags(FLAGS) | |
| if FLAGS.batch_size % jax.device_count() != 0: | |
| raise ValueError("Batch size must be divisible by the number of devices.") | |
| if FLAGS.train_dir is None: | |
| raise ValueError("train_dir must be set. None set now.") | |
| if FLAGS.data_dir is None: | |
| raise ValueError("data_dir must be set. None set now.") | |
| # setup CLIP model | |
| if FLAGS.use_semantic_loss: | |
| clip_model = clip_utils.init_CLIP(FLAGS.clip_output_dtype, | |
| FLAGS.clip_model_name) | |
| print('semantic loss ACTIVATED, CLIP is set up') | |
| else: | |
| clip_model = None | |
| print('semantic loss DEACTIVATED, CLIP is set to None') | |
| dataset = datasets.get_dataset("train", FLAGS, clip_model) | |
| test_dataset = datasets.get_dataset("test", FLAGS, clip_model) | |
| # setup NeRF model | |
| rng, key = random.split(rng) | |
| model, variables = models.get_model(key, dataset.peek(), FLAGS) | |
| optimizer = flax.optim.Adam(FLAGS.lr_init).create(variables) | |
| state = utils.TrainState(optimizer=optimizer) | |
| del optimizer, variables | |
| learning_rate_fn = functools.partial( | |
| utils.learning_rate_decay, | |
| lr_init=FLAGS.lr_init, | |
| lr_final=FLAGS.lr_final, | |
| max_steps=FLAGS.max_steps, | |
| lr_delay_steps=FLAGS.lr_delay_steps, | |
| lr_delay_mult=FLAGS.lr_delay_mult) | |
| train_pstep = jax.pmap( | |
| functools.partial(train_step, model, clip_model), | |
| axis_name="batch", | |
| in_axes=(0, 0, 0, None, None, None), | |
| donate_argnums=(2,)) | |
| update_pstep = jax.pmap( | |
| functools.partial(update_step,), | |
| axis_name="batch", | |
| in_axes=(0, None, None), | |
| donate_argnums=(0,)) | |
| def render_fn(variables, key_0, key_1, rays): | |
| return jax.lax.all_gather( | |
| model.apply(variables, key_0, key_1, rays, FLAGS.randomized), | |
| axis_name="batch") | |
| render_pfn = jax.pmap( | |
| render_fn, | |
| in_axes=(None, None, None, 0), # Only distribute the data input. | |
| donate_argnums=(3,), | |
| axis_name="batch") | |
| # Compiling to the CPU because it's faster and more accurate. | |
| ssim_fn = jax.jit( | |
| functools.partial(utils.compute_ssim, max_val=1.), backend="cpu") | |
| if not utils.isdir(FLAGS.train_dir): | |
| utils.makedirs(FLAGS.train_dir) | |
| state = checkpoints.restore_checkpoint(FLAGS.train_dir, state) | |
| # Resume training a the step of the last checkpoint. | |
| init_step = state.optimizer.state.step + 1 | |
| # for distributive training | |
| state = flax.jax_utils.replicate(state) | |
| if jax.host_id() == 0: | |
| summary_writer = tensorboard.SummaryWriter(FLAGS.train_dir) | |
| # Prefetch_buffer_size = 3 x batch_size | |
| pdataset = flax.jax_utils.prefetch_to_device(dataset, 3) | |
| n_local_devices = jax.local_device_count() | |
| rng = rng + jax.host_id() # Make random seed separate across hosts. | |
| keys = random.split(rng, n_local_devices) # For pmapping RNG keys. | |
| gc.disable() # Disable automatic garbage collection for efficiency. | |
| stats_trace = [] | |
| reset_timer = True | |
| # for semantic loss update | |
| cnter = 1 | |
| trigger = int(FLAGS.sc_loss_every / n_local_devices) | |
| for step, batch in tqdm(zip(range(init_step, FLAGS.max_steps + 1), pdataset)): | |
| if reset_timer: | |
| t_loop_start = time.time() | |
| reset_timer = False | |
| lr = learning_rate_fn(step) | |
| if step%FLAGS.sc_loss_every == 0 and FLAGS.use_semantic_loss: | |
| # remove dimension for device coz its only run in host core | |
| sc_batch = dataset.get_clip_data() | |
| sc_loss, sc_grad = clip_utils.update_semantic_loss(model, clip_model, | |
| keys[0], state, sc_batch, lr) | |
| sc_grad = flax.jax_utils.replicate(sc_grad) | |
| sc_grad = jax.tree_map( lambda x: x[0], sc_grad) | |
| else: | |
| sc_loss = 0. | |
| state, stats, keys = train_pstep(keys, state, batch, lr, step, FLAGS.sc_loss_every)#, grad) | |
| if step%FLAGS.sc_loss_every == 0 and FLAGS.use_semantic_loss: | |
| state = update_pstep(state, sc_grad, lr) | |
| if jax.host_id() == 0: | |
| stats_trace.append(stats) | |
| if step % FLAGS.gc_every == 0: | |
| gc.collect() | |
| # Log training summaries. This is put behind a host_id check because in | |
| # multi-host evaluation, all hosts need to run inference even though we | |
| # only use host 0 to record results. | |
| if jax.host_id() == 0: | |
| if step % FLAGS.print_every == 0: | |
| summary_writer.scalar("train_loss", stats.loss[0], step) | |
| summary_writer.scalar("train_psnr", stats.psnr[0], step) | |
| summary_writer.scalar("train_loss_coarse", stats.loss_c[0], step) | |
| summary_writer.scalar("train_psnr_coarse", stats.psnr_c[0], step) | |
| summary_writer.scalar("weight_l2", stats.weight_l2[0], step) | |
| avg_loss = np.mean(np.concatenate([s.loss for s in stats_trace])) | |
| avg_psnr = np.mean(np.concatenate([s.psnr for s in stats_trace])) | |
| stats_trace = [] | |
| summary_writer.scalar("train_avg_loss", avg_loss, step) | |
| summary_writer.scalar("train_avg_psnr", avg_psnr, step) | |
| summary_writer.scalar("learning_rate", lr, step) | |
| steps_per_sec = FLAGS.print_every / (time.time() - t_loop_start) | |
| reset_timer = True | |
| rays_per_sec = FLAGS.batch_size * steps_per_sec | |
| summary_writer.scalar("train_steps_per_sec", steps_per_sec, step) | |
| summary_writer.scalar("train_rays_per_sec", rays_per_sec, step) | |
| precision = int(np.ceil(np.log10(FLAGS.max_steps))) + 1 | |
| print(("{:" + "{:d}".format(precision) + "d}").format(step) + | |
| f"/{FLAGS.max_steps:d}: " + f"i_loss={stats.loss[0]:0.4f}, " + | |
| f"avg_loss={avg_loss:0.4f}, " + | |
| f"weight_l2={stats.weight_l2[0]:0.2e}, " + | |
| # f"sc_loss={sc_loss:0.4f}, " + | |
| f"lr={lr:0.2e}, {rays_per_sec:0.0f} rays/sec") | |
| if step % FLAGS.save_every == 0: | |
| state_to_save = jax.device_get(jax.tree_map(lambda x: x[0], state)) | |
| checkpoints.save_checkpoint( | |
| FLAGS.train_dir, state_to_save, int(step), keep=100) | |
| # Test-set evaluation. | |
| if FLAGS.render_every > 0 and step % FLAGS.render_every == 0: | |
| # We reuse the same random number generator from the optimization step | |
| # here on purpose so that the visualization matches what happened in | |
| # training. | |
| t_eval_start = time.time() | |
| eval_variables = jax.device_get(jax.tree_map(lambda x: x[0], | |
| state)).optimizer.target | |
| test_case = next(test_dataset) | |
| pred_color, pred_disp, pred_acc = utils.render_image( | |
| functools.partial(render_pfn, eval_variables), | |
| test_case["rays"], | |
| keys[0], | |
| FLAGS.dataset == "llff", | |
| chunk=FLAGS.chunk) | |
| # Log eval summaries on host 0. | |
| if jax.host_id() == 0: | |
| psnr = utils.compute_psnr( | |
| ((pred_color - test_case["pixels"]) ** 2).mean()) | |
| ssim = ssim_fn(pred_color, test_case["pixels"]) | |
| eval_time = time.time() - t_eval_start | |
| num_rays = jnp.prod(jnp.array(test_case["rays"].directions.shape[:-1])) | |
| rays_per_sec = num_rays / eval_time | |
| summary_writer.scalar("test_rays_per_sec", rays_per_sec, step) | |
| print(f"Eval {step}: {eval_time:0.3f}s., {rays_per_sec:0.0f} rays/sec") | |
| summary_writer.scalar("test_psnr", psnr, step) | |
| summary_writer.scalar("test_ssim", ssim, step) | |
| summary_writer.image("test_pred_color", pred_color, step) | |
| summary_writer.image("test_pred_disp", pred_disp, step) | |
| summary_writer.image("test_pred_acc", pred_acc, step) | |
| summary_writer.image("test_target", test_case["pixels"], step) | |
| if FLAGS.max_steps % FLAGS.save_every != 0: | |
| state = jax.device_get(jax.tree_map(lambda x: x[0], state)) | |
| checkpoints.save_checkpoint( | |
| FLAGS.train_dir, state, int(FLAGS.max_steps), keep=100) | |
| if __name__ == "__main__": | |
| app.run(main) | |