Customize Training and Inference
This file contains mid-level information regarding various parameters that can be leveraged to customize the training/inference in GaNDLF.
Model¶
- Defined under the global key
modelin the config filearchitecture: Defines the model architecture (aka "network topology") to be used for training. All options can be found here. Some examples are:- Segmentation:
- Standardized 4-layer UNet with (
resunet) and without (unet) residual connections, as described in this paper. - Multi-layer UNet with (
resunet_multilayer) and without (unet_multilayer) residual connections - this is a more general version of the standard UNet, where the number of layers can be specified by the user. - UNet with Inception Blocks (
uinc) is a variant of UNet with inception blocks, as described in this paper. - UNetR (
unetr) is a variant of UNet with transformers, as described in this paper. - TransUNet (
transunet) is a variant of UNet with transformers, as described in this paper. - And many more.
- Standardized 4-layer UNet with (
- Classification/Regression:
- VGG configurations (
vgg11,vgg13,vgg16,vgg19), as described in this paper. Our implementation allows true 3D computations (as opposed to 2D+1D convolutions). - VGG configurations initialized with weights trained on ImageNet (
imagenet_vgg11,imagenet_vgg13,imagenet_vgg16,imagenet_vgg19), as described in this paper. - DenseNet configurations (
densenet121,densenet161,densenet169,densenet201,densenet264), as described in this paper. Our implementation allows true 3D computations (as opposed to 2D+1D convolutions). - ResNet configurations (
resnet18,resnet34,resnet50,resnet101,resnet152), as described in this paper. Our implementation allows true 3D computations (as opposed to 2D+1D convolutions). - And many more.
- VGG configurations (
- Segmentation:
dimension: Defines the dimensionality of convolutions, this is usually the same dimension as the input image, unless specialized processing is done to convert images to a different dimensionality (usually not recommended). For example, 2D images can be stacked to form a "pseudo" 3D image, and 3D images can be processed as "slices" as 2D images.final_layer: The final layer of model that will be used to generate the final prediction. Unless otherwise specified, it can be one ofsoftmaxorsigmoidorlogitsornone(the latter 2 are only used for regression tasks).class_list: The list of classes that will be used for training. This is expected to be a list of integers.- For example, for a segmentation task, this can be a list of integers
[0, 1, 2, 4]for the BraTS training case for all labels (background, necrosis, edema, and enhancing tumor). Additionally, different labels can be combined to perform "combinatorial training", such as[0, 1||4, 1||2||4, 4], for the BraTS training to train on background, tumor core, whole tumor, and enhancing, respectively. - For a classification task, this can be a list of integers
[0, 1].
- For example, for a segmentation task, this can be a list of integers
ignore_label_validation: This is the location of the label inclass_listwhose performance is to be ignored during metric calculation for validation/testing datanorm_type: The type of normalization to be used. This can be eitherbatchorinstanceornone.- Various other options specific to architectures, such as (but not limited to):
densenetmodels:growth_rate: how many filters to add each layer (k in paper)bn_size: multiplicative factor for number of bottle neck layers # (i.e. bn_size * k features in the bottleneck layer)drop_rate: dropout rate after each dense layer
unet_multilayerand other networks that support multiple layers:depth: the number of encoder/decoder (or other types of) layers
Loss function¶
- Defined in the
loss_functionparameter of the model configuration. - By passing
weighted_loss: True, the loss function will be weighted by the inverse of the class frequency. - This parameter controls the function which the model is trained. All options can be found here. Some examples are:
- Segmentation: dice (
diceordc), dice and cross entropy (dcce), focal loss (focal), dice and focal (dc_focal), matthews (mcc) - Classification/regression: mean squared error (
mse) - And many more.
- Segmentation: dice (
Metrics¶
- Defined in the
metricsparameter of the model configuration. - This parameter controls the metrics to be used for model evaluation for the training/validation/testing datasets. All options can be found here. Most of these metrics are calculated using TorchMetrics. Some examples are:
- Segmentation: dice (
diceanddice_per_label), hausdorff distances (hausdorfforhausdorff100andhausdorff100_per_label), hausdorff distances including on the 95th percentile of distances (hausdorff95andhausdorff95_per_label) - - Classification/regression: mean squared error (
mse) calculated per sample - Metrics calculated per cohort (these are automatically calculated for classification and regression and cannot be disabled):
- Classification: accuracy, precision, recall, f1, for the entire cohort ("global"), per classified class ("per_class"), per classified class averaged ("per_class_average"), per classified class weighted/balanced ("per_class_weighted")
- Regression: mean absolute error, pearson and spearman coefficients, calculated as mean, sum, or standard.
- Segmentation: dice (
Patching Strategy¶
patch_size: The size of the patch to be used for training. This is expected to be a list of integers, with the length of the list being the same as the dimensionality of the input image. For example, for a 2D image, this can be[128, 128], and for a 3D image, this can be[128, 128, 128].patch_sampler: The sampler to be used for patch sampling during training. This can be one ofuniform(the entire input image has equal weight on contributing a valid patch) orlabel(only the regions that have a valid ground truth segmentation label can contribute a patch).labelsampler usually requires padding of the image to ensure blank patches are not inadvertently sampled; this can be controlled by theenable_paddingparameter.inference_mechanismgrid_aggregator_overlap: this option provides the option to strategize the grid aggregation output; should be eithercroporaverage- https://torchio.readthedocs.io/patches/patch_inference.html#grid-aggregatorpatch_overlap: the amount of overlap of patches during inference in terms of pixels, defaults to0; see https://torchio.readthedocs.io/patches/patch_inference.html#gridsampler for details.
Data Preprocessing¶
- Defined in the
data_preprocessingparameter of the model configuration. - This parameter controls the various preprocessing functions that are applied to the entire image before the patching strategy is applied.
- All options can be found here. Some of the most important examples are:
- Intensity harmonization: GaNDLF provides multiple normalization and rescaling options to ensure intensity-level harmonization of the entire cohort. Some examples include:
normalize: simple Z-score normalizationnormalize_positive: this performs z-score normalization only onpixels > 0normalize_nonZero: this performs z-score normalization only onpixels != 0normalize_nonZero_masked: this performs z-score normalization only on the region defined by the ground truth annotationrescale: simple min-max rescaling, sub-parameters includein_min_max,out_min_max,percentiles; this option is useful to discard outliers in the intensity distribution- Template-based normalization: These options take a target image as input (defined by the
targetsub-parameter) and perform different matching strategies to match input image(s) to this target.histogram_matching: this performs histogram matching as defined by this paper.- If the
targetimage is absent, this will perform global histogram equalization. - If
targetisadaptive, this will perform adaptive histogram equalization.
- If the
stain_normalization: these are normalization techniques specifically designed for histology images; the different options includevahadane,macenko, orruifrok, under theextractorsub-parameter. Always needs atargetimage to work.
- Resolution harmonization: GaNDLF provides multiple resampling options to ensure resolution-level harmonization of the entire cohort. Some examples include:
resample: resamples the image to the specified by theresolutionsub-parameterresample_min: resamples the image to the maximum spacing defined by theresolutionsub-parameter; this is useful in cohorts that have varying resolutions, but the user wants to resample to the minimum resolution for consistencyresize_image: NOT RECOMMENDED; resizes the image to the specified sizeresize_patch: NOT RECOMMENDED; resizes the extracted patch to the specified size
- And many more.
- Intensity harmonization: GaNDLF provides multiple normalization and rescaling options to ensure intensity-level harmonization of the entire cohort. Some examples include:
Data Augmentation¶
- Defined in the
data_augmentationparameter of the model configuration. - This parameter controls the various augmentation functions that are applied to the entire image before the patching strategy is applied.
- These should be defined in cognition of the task at hand (for example, RGB augmentations will not work for MRI/CT and other similar radiology images).
- All options can contain a
probabilitysub-parameter, which defines the probability of the augmentation being applied to the image. When present, this will supersede thedefault_probabilityparameter. - All options can be found here. Some of the most important examples are:
- Radiology-specific augmentations
kspace: one of eitherghostingorspikingis picked for augmentation.bias: applies a random bias field artefact to the input image using this function.
- RGB-specific augmentations
colorjitter: applies the ColorJitter transform from PyTorch, has sub-parametersbrightness,contrast,saturation, andhue.
- General-purpose augmentations
- Spatial transforms: they only change the resolution (and thereby, the shape) of the input image, and only apply interpolation to the intensities for consistency
affine: applies a random affine transformation to the input image; for details, see this page; has sub-parametersscales(defining the scaling ranges),degrees(defining the rotation ranges), andtranslation(defining the translation ranges in real-world coordinates, which is usually in mm)elastic: applies a random elastic deformation to the input image; for details, see this page; has sub-parametersnum_control_points(defining the number of control points),locked_borders(defining the number of locked borders),max_displacement(defining the maximum displacement of the control points),num_control_points(defining the number of control points), andlocked_borders(defining the number of locked borders).flip: applies a random flip to the input image; for details, see this page; has sub-parameteraxes(defining the axes to flip).rotate: applies a random rotation by 90 degrees (rotate_90) or 180 degrees (rotate_180), has sub-parameteraxes(defining the axes to rotate).swap: applies a random swap , has sub-parameterpatch_size(defining the patch size to swap), andnum_iterations(number of iterations that 2 patches will be swapped).
- Intensity transforms: they change the intensity of the input image, but never the actual resolution or shape.
motion: applies a random motion blur to the input image using this function.blur: applies a random Gaussian blur to the input image using this functionl has sub-parameterstd(defines the standard deviation range).noise: applies a random noise to the input image using this function; has sub-parametersstd(defines the standard deviation range) andmean(defines the mean of the noise to be added).noise_var: applies a random noise to the input image, however, the with defaultstd = [0, 0.015 * std(image)].anisotropic: applies random anisotropic transform to input image using this function. This changes the resolution and brings it back to its original resolution, thus applying "real-world" interpolation to images.
- Spatial transforms: they only change the resolution (and thereby, the shape) of the input image, and only apply interpolation to the intensities for consistency
- Radiology-specific augmentations
Training Parameters¶
- These are various parameters that control the overall training process.
verbose: generate verbose messages on console; generally used for debugging.batch_size: batch size to be used for training.in_memory: this is to enable or disable lazy loading. If set toTrue, all data is loaded onto the RAM at once during the construction of the dataloader (either training/validation/testing), thus resulting in faster training. If set toFalse, data gets read into RAM on-the-go when needed (also called "lazy loading"), which slows down training but lessens the memory load. The latter is recommended if the user's RAM has limited capacity.num_epochs: number of epochs to train for.patience: number of epochs to wait for improvement in the validation loss before early stopping.learning_rate: learning rate to be used for training.scheduler: learning rate scheduler to be used for training, more details are here; can take the following sub-parameters:type:triangle,triangle_modified,exp,step,reduce-on-plateau,cosineannealing,triangular,triangular2,exp_rangemin_lr: minimum learning rate to be used for training.max_lr: maximum learning rate to be used for training.
optimizer: optimizer to be used for training, more details are here.nested_training: number of folds to use nested training, takestestingandvalidationas sub-parameters, with integer values defining the number of folds to use.memory_save_mode: if enabled, resize/resample operations indata_preprocessingwill save files to disk instead of directly getting read into memory as tensors- Queue configuration: this defines how the queue for the input to the model is to be designed after the patching strategy has been applied, and more details are here. This takes the following sub-parameters:
q_max_length: his determines the maximum number of patches that can be stored in the queue. Using a large number means that the queue needs to be filled less often, but more CPU memory is needed to store the patches.q_samples_per_volume: this determines the number of patches to extract from each volume. A small number of patches ensures a large variability in the queue, but training will be slower.q_num_workers: this determines the number subprocesses to use for data loading; '0' means main process is used, scale this according to available CPU resources.q_verbose: used to debug the queue
Differentially Private Training¶
GaNDLF supports training differentially private models using Opacus. Here are some resources using which one can train private models:
- TLDR on DP and private training: read this paper and this blog post.
- All options are present in a new key called
differential_privacyin the config file. It has the following options: noise_multiplier: The ratio of the standard deviation of the Gaussian noise to the L2-sensitivity of the function to which the noise is added.max_grad_norm: The maximum norm of the per-sample gradients. Any gradient with norm higher than this will be clipped to this value.accountant: Accounting mechanism. Currently supported:rdp(RDPAccountant),gdp(GaussianAccountant),prv(PRVAccountant)secure_mode: Set toTrueif cryptographically strong DP guarantee is required.secure_mode=Trueuses secure random number generator for noise and shuffling (as opposed topseudo-rngin vanilla PyTorch) and prevents certain floating-point arithmetic-based attacks.allow_opacus_model_fix: Enabled automated fixing of the model based on Opacus [ref]delta: Target delta to be achieved. Probability of information being leaked. Use either this orepsilon.epsilon: Target epsilon to be achieved, a metric of privacy loss at differential changes in data. Use either this ordelta.physical_batch_size: The batch size to use for DP computation (it is usually set lower than the baseline or non-DP batch size). Defaults tobatch_size.