# This implements the workflow for applying the network to a directory of images and measuring network performance with metrics. imports: - $import os - $import json - $import torch - $import glob # pull out some constants from MONAI image: $monai.utils.CommonKeys.IMAGE label: $monai.utils.CommonKeys.LABEL pred: $monai.utils.CommonKeys.PRED # hyperparameters for you to modify on the command line batch_size: 1 # number of images per batch num_workers: 0 # number of workers to generate batches with num_classes: 4 # number of classes in training data which network should predict device: $torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') # define various paths bundle_root: . # root directory of the bundle ckpt_path: $@bundle_root + '/models/model.pt' # checkpoint to load before starting dataset_dir: $@bundle_root + '/data/test_data' # where data is coming from # network definition, this could be parameterised by pre-defined values or on the command line network_def: _target_: DenseNet121 spatial_dims: 2 in_channels: 1 out_channels: '@num_classes' network: $@network_def.to(@device) # list all niftis in the input directory test_json: "$@bundle_root+'/data/test_samples.json'" test_fp: "$open(@test_json,'r', encoding='utf8')" # load json file test_dict: "$json.load(@test_fp)" # these transforms are used for inference to load and regularise inputs transforms: - _target_: LoadImaged keys: '@image' - _target_: EnsureChannelFirstd keys: '@image' - _target_: ScaleIntensityd keys: '@image' preprocessing: _target_: Compose transforms: $@transforms dataset: _target_: Dataset data: '@test_dict' transform: '@preprocessing' dataloader: _target_: ThreadDataLoader # generate data ansynchronously from inference dataset: '@dataset' batch_size: '@batch_size' num_workers: '@num_workers' # should be replaced with other inferer types if training process is different for your network inferer: _target_: SimpleInferer # transform to apply to data from network to be suitable for validation postprocessing: _target_: Compose transforms: - _target_: Activationsd keys: '@pred' softmax: true - _target_: AsDiscreted keys: ['@pred', '@label'] argmax: [true, false] to_onehot: '@num_classes' - _target_: ToTensord keys: ['@pred', '@label'] device: '@device' # inference handlers to load checkpoint, gather statistics val_handlers: - _target_: CheckpointLoader _disabled_: $not os.path.exists(@ckpt_path) load_path: '@ckpt_path' load_dict: model: '@network' - _target_: StatsHandler name: null # use engine.logger as the Logger object to log to output_transform: '$lambda x: None' # engine for running inference, ties together objects defined above and has metric definitions evaluator: _target_: SupervisedEvaluator device: '@device' val_data_loader: '@dataloader' network: '@network' inferer: '@inferer' postprocessing: '@postprocessing' key_val_metric: val_accuracy: _target_: ignite.metrics.Accuracy output_transform: $monai.handlers.from_engine([@pred, @label]) additional_metrics: val_f1: # can have other metrics _target_: ConfusionMatrix metric_name: 'f1 score' output_transform: $monai.handlers.from_engine([@pred, @label]) val_handlers: '@val_handlers' initialize: - "$setattr(torch.backends.cudnn, 'benchmark', True)" run: - "$@evaluator.run()"