{ "imports": [ "$import glob", "$import os", "$import scripts", "$import ignite" ], "bundle_root": ".", "ckpt_dir": "$@bundle_root + '/models'", "output_dir": "$@bundle_root + '/output'", "data_list_file_path": "$@bundle_root + '/datasets/C4KC-KiTS_subset.json'", "dataset_dir": "$@bundle_root + '/datasets/C4KC-KiTS_subset'", "trained_diffusion_path": "$@ckpt_dir + '/input_unet3d_data-all_steps1000size512ddpm_random_current_inputx_v1.pt'", "trained_controlnet_path": "$@ckpt_dir + '/controlnet-20datasets-e20wl100fold0bc_noi_dia_fsize_current.pt'", "use_tensorboard": true, "fold": 0, "device": "$torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')", "epochs": 100, "batch_size": 1, "val_at_start": false, "learning_rate": 0.0001, "weighted_loss_label": [ 129 ], "weighted_loss": 100, "amp": true, "train_datalist": "$scripts.utils.maisi_datafold_read(json_list=@data_list_file_path, data_base_dir=@dataset_dir, fold=@fold)[0]", "spatial_dims": 3, "image_channels": 1, "latent_channels": 4, "diffusion_unet_def": { "_target_": "monai.apps.generation.maisi.networks.diffusion_model_unet_maisi.DiffusionModelUNetMaisi", "spatial_dims": "@spatial_dims", "in_channels": "@latent_channels", "out_channels": "@latent_channels", "num_channels": [ 64, 128, 256, 512 ], "attention_levels": [ false, false, true, true ], "num_head_channels": [ 0, 0, 32, 32 ], "num_res_blocks": 2, "use_flash_attention": true, "include_top_region_index_input": true, "include_bottom_region_index_input": true, "include_spacing_input": true }, "controlnet_def": { "_target_": "monai.apps.generation.maisi.networks.controlnet_maisi.ControlNetMaisi", "spatial_dims": "@spatial_dims", "in_channels": "@latent_channels", "num_channels": [ 64, 128, 256, 512 ], "attention_levels": [ false, false, true, true ], "num_head_channels": [ 0, 0, 32, 32 ], "num_res_blocks": 2, "use_flash_attention": true, "conditioning_embedding_in_channels": 8, "conditioning_embedding_num_channels": [ 8, 32, 64 ] }, "noise_scheduler": { "_target_": "monai.networks.schedulers.ddpm.DDPMScheduler", "num_train_timesteps": 1000, "beta_start": 0.0015, "beta_end": 0.0195, "schedule": "scaled_linear_beta", "clip_sample": false }, "unzip_dataset": "$scripts.utils.unzip_dataset(@dataset_dir)", "diffusion_unet": "$@diffusion_unet_def.to(@device)", "checkpoint_diffusion_unet": "$torch.load(@trained_diffusion_path, weights_only=False)", "load_diffusion": "$@diffusion_unet.load_state_dict(@checkpoint_diffusion_unet['unet_state_dict'])", "controlnet": "$@controlnet_def.to(@device)", "copy_controlnet_state": "$monai.networks.utils.copy_model_state(@controlnet, @diffusion_unet.state_dict())", "checkpoint_controlnet": "$torch.load(@trained_controlnet_path, weights_only=False)", "load_controlnet": "$@controlnet.load_state_dict(@checkpoint_controlnet['controlnet_state_dict'], strict=True)", "scale_factor": "$@checkpoint_diffusion_unet['scale_factor'].to(@device)", "loss": { "_target_": "torch.nn.L1Loss", "reduction": "none" }, "optimizer": { "_target_": "torch.optim.AdamW", "params": "$@controlnet.parameters()", "lr": "@learning_rate", "weight_decay": 1e-05 }, "lr_schedule": { "activate": true, "lr_scheduler": { "_target_": "torch.optim.lr_scheduler.PolynomialLR", "optimizer": "@optimizer", "total_iters": "$(@epochs * len(@train#dataloader.dataset)) / @batch_size", "power": 2.0 } }, "train": { "deterministic_transforms": [ { "_target_": "LoadImaged", "keys": [ "image", "label" ], "image_only": true, "ensure_channel_first": true }, { "_target_": "Orientationd", "keys": [ "label" ], "axcodes": "RAS" }, { "_target_": "EnsureTyped", "keys": [ "label" ], "dtype": "$torch.uint8", "track_meta": true }, { "_target_": "Lambdad", "keys": "top_region_index", "func": "$lambda x: torch.FloatTensor(x)" }, { "_target_": "Lambdad", "keys": "bottom_region_index", "func": "$lambda x: torch.FloatTensor(x)" }, { "_target_": "Lambdad", "keys": "spacing", "func": "$lambda x: torch.FloatTensor(x)" }, { "_target_": "Lambdad", "keys": "top_region_index", "func": "$lambda x: x * 1e2" }, { "_target_": "Lambdad", "keys": "bottom_region_index", "func": "$lambda x: x * 1e2" }, { "_target_": "Lambdad", "keys": "spacing", "func": "$lambda x: x * 1e2" } ], "inferer": { "_target_": "SimpleInferer" }, "preprocessing": { "_target_": "Compose", "transforms": "$@train#deterministic_transforms" }, "dataset": { "_target_": "Dataset", "data": "@train_datalist", "transform": "@train#preprocessing" }, "dataloader": { "_target_": "DataLoader", "dataset": "@train#dataset", "batch_size": "@batch_size", "shuffle": true, "num_workers": 4, "pin_memory": true, "persistent_workers": true }, "handlers": [ { "_target_": "LrScheduleHandler", "_disabled_": "$not @lr_schedule#activate", "lr_scheduler": "@lr_schedule#lr_scheduler", "epoch_level": false, "print_lr": true }, { "_target_": "CheckpointSaver", "save_dir": "@ckpt_dir", "save_dict": { "controlnet_state_dict": "@controlnet", "optimizer": "@optimizer" }, "save_interval": 1, "n_saved": 5 }, { "_target_": "TensorBoardStatsHandler", "_disabled_": "$not @use_tensorboard", "log_dir": "@output_dir", "tag_name": "train_loss", "output_transform": "$monai.handlers.from_engine(['loss'], first=True)" }, { "_target_": "StatsHandler", "tag_name": "train_loss", "name": "StatsHandler", "output_transform": "$monai.handlers.from_engine(['loss'], first=True)" } ], "trainer": { "_target_": "scripts.trainer.MAISIControlNetTrainer", "_requires_": [ "@load_diffusion", "@copy_controlnet_state", "@load_controlnet", "@unzip_dataset" ], "max_epochs": "@epochs", "device": "@device", "train_data_loader": "@train#dataloader", "diffusion_unet": "@diffusion_unet", "controlnet": "@controlnet", "noise_scheduler": "@noise_scheduler", "loss_function": "@loss", "optimizer": "@optimizer", "inferer": "@train#inferer", "key_train_metric": null, "train_handlers": "@train#handlers", "amp": "@amp", "hyper_kwargs": { "weighted_loss": "@weighted_loss", "weighted_loss_label": "@weighted_loss_label", "scale_factor": "@scale_factor" } } }, "initialize": [ "$monai.utils.set_determinism(seed=0)" ], "run": [ "$@train#trainer.add_event_handler(ignite.engine.Events.ITERATION_COMPLETED, ignite.handlers.TerminateOnNan())", "$@train#trainer.run()" ] }