Customize Training and Inference
This file contains mid-level information regarding various parameters that can be leveraged to customize the training/inference in GaNDLF.
Model¶
- Defined under the global key
model
in the config filearchitecture
: Defines the model architecture (aka "network topology") to be used for training. All options can be found here. Some examples are:- Segmentation:
- Standardized 4-layer UNet with (
resunet
) and without (unet
) residual connections, as described in this paper. - Multi-layer UNet with (
resunet_multilayer
) and without (unet_multilayer
) residual connections - this is a more general version of the standard UNet, where the number of layers can be specified by the user. - UNet with Inception Blocks (
uinc
) is a variant of UNet with inception blocks, as described in this paper. - UNetR (
unetr
) is a variant of UNet with transformers, as described in this paper. - TransUNet (
transunet
) is a variant of UNet with transformers, as described in this paper. - And many more.
- Standardized 4-layer UNet with (
- Classification/Regression:
- VGG configurations (
vgg11
,vgg13
,vgg16
,vgg19
), as described in this paper. Our implementation allows true 3D computations (as opposed to 2D+1D convolutions). - VGG configurations initialized with weights trained on ImageNet (
imagenet_vgg11
,imagenet_vgg13
,imagenet_vgg16
,imagenet_vgg19
), as described in this paper. - DenseNet configurations (
densenet121
,densenet161
,densenet169
,densenet201
,densenet264
), as described in this paper. Our implementation allows true 3D computations (as opposed to 2D+1D convolutions). - ResNet configurations (
resnet18
,resnet34
,resnet50
,resnet101
,resnet152
), as described in this paper. Our implementation allows true 3D computations (as opposed to 2D+1D convolutions). - And many more.
- VGG configurations (
- Segmentation:
dimension
: Defines the dimensionality of convolutions, this is usually the same dimension as the input image, unless specialized processing is done to convert images to a different dimensionality (usually not recommended). For example, 2D images can be stacked to form a "pseudo" 3D image, and 3D images can be processed as "slices" as 2D images.final_layer
: The final layer of model that will be used to generate the final prediction. Unless otherwise specified, it can be one ofsoftmax
orsigmoid
orlogits
ornone
(the latter 2 are only used for regression tasks).class_list
: The list of classes that will be used for training. This is expected to be a list of integers.- For example, for a segmentation task, this can be a list of integers
[0, 1, 2, 4]
for the BraTS training case for all labels (background, necrosis, edema, and enhancing tumor). Additionally, different labels can be combined to perform "combinatorial training", such as[0, 1||4, 1||2||4, 4]
, for the BraTS training to train on background, tumor core, whole tumor, and enhancing, respectively. - For a classification task, this can be a list of integers
[0, 1]
.
- For example, for a segmentation task, this can be a list of integers
ignore_label_validation
: This is the location of the label inclass_list
whose performance is to be ignored during metric calculation for validation/testing datanorm_type
: The type of normalization to be used. This can be eitherbatch
orinstance
ornone
.- Various other options specific to architectures, such as (but not limited to):
densenet
models:growth_rate
: how many filters to add each layer (k in paper)bn_size
: multiplicative factor for number of bottle neck layers # (i.e. bn_size * k features in the bottleneck layer)drop_rate
: dropout rate after each dense layer
unet_multilayer
and other networks that support multiple layers:depth
: the number of encoder/decoder (or other types of) layers
Loss function¶
- Defined in the
loss_function
parameter of the model configuration. - By passing
weighted_loss: True
, the loss function will be weighted by the inverse of the class frequency. - This parameter controls the function which the model is trained. All options can be found here. Some examples are:
- Segmentation: dice (
dice
ordc
), dice and cross entropy (dcce
), focal loss (focal
), dice and focal (dc_focal
), matthews (mcc
) - Classification/regression: mean squared error (
mse
) - And many more.
- Segmentation: dice (
Metrics¶
- Defined in the
metrics
parameter of the model configuration. - This parameter controls the metrics to be used for model evaluation for the training/validation/testing datasets. All options can be found here. Most of these metrics are calculated using TorchMetrics. Some examples are:
- Segmentation: dice (
dice
anddice_per_label
), hausdorff distances (hausdorff
orhausdorff100
andhausdorff100_per_label
), hausdorff distances including on the 95th percentile of distances (hausdorff95
andhausdorff95_per_label
) - - Classification/regression: mean squared error (
mse
) calculated per sample - Metrics calculated per cohort (these are automatically calculated for classification and regression):
- Classification: accuracy, precision, recall, f1, for the entire cohort ("global"), per classified class ("per_class"), per classified class averaged ("per_class_average"), per classified class weighted/balanced ("per_class_weighted")
- Regression: mean absolute error, pearson and spearman coefficients, calculated as mean, sum, or standard.
- Segmentation: dice (
Patching Strategy¶
patch_size
: The size of the patch to be used for training. This is expected to be a list of integers, with the length of the list being the same as the dimensionality of the input image. For example, for a 2D image, this can be[128, 128]
, and for a 3D image, this can be[128, 128, 128]
.patch_sampler
: The sampler to be used for patch sampling during training. This can be one ofuniform
(the entire input image has equal weight on contributing a valid patch) orlabel
(only the regions that have a valid ground truth segmentation label can contribute a patch).label
sampler usually requires padding of the image to ensure blank patches are not inadvertently sampled; this can be controlled by theenable_padding
parameter.inference_mechanism
grid_aggregator_overlap
: this option provides the option to strategize the grid aggregation output; should be eithercrop
oraverage
- https://torchio.readthedocs.io/patches/patch_inference.html#grid-aggregatorpatch_overlap
: the amount of overlap of patches during inference in terms of pixels, defaults to0
; see https://torchio.readthedocs.io/patches/patch_inference.html#gridsampler for details.
Data Preprocessing¶
- Defined in the
data_preprocessing
parameter of the model configuration. - This parameter controls the various preprocessing functions that are applied to the entire image before the patching strategy is applied.
- All options can be found here. Some of the most important examples are:
- Intensity harmonization: GaNDLF provides multiple normalization and rescaling options to ensure intensity-level harmonization of the entire cohort. Some examples include:
normalize
: simple Z-score normalizationnormalize_positive
: this performs z-score normalization only onpixels > 0
normalize_nonZero
: this performs z-score normalization only onpixels != 0
normalize_nonZero_masked
: this performs z-score normalization only on the region defined by the ground truth annotationrescale
: simple min-max rescaling, sub-parameters includein_min_max
,out_min_max
,percentiles
; this option is useful to discard outliers in the intensity distribution- Template-based normalization: These options take a target image as input (defined by the
target
sub-parameter) and perform different matching strategies to match input image(s) to this target.histogram_matching
: this performs histogram matching as defined by this paper.- If the
target
image is absent, this will perform global histogram equalization. - If
target
isadaptive
, this will perform adaptive histogram equalization.
- If the
stain_normalization
: these are normalization techniques specifically designed for histology images; the different options includevahadane
,macenko
, orruifrok
, under theextractor
sub-parameter. Always needs atarget
image to work.
- Resolution harmonization: GaNDLF provides multiple resampling options to ensure resolution-level harmonization of the entire cohort. Some examples include:
resample
: resamples the image to the specified by theresolution
sub-parameterresample_min
: resamples the image to the maximum spacing defined by theresolution
sub-parameter; this is useful in cohorts that have varying resolutions, but the user wants to resample to the minimum resolution for consistencyresize_image
: NOT RECOMMENDED; resizes the image to the specified sizeresize_patch
: NOT RECOMMENDED; resizes the extracted patch to the specified size
- And many more.
- Intensity harmonization: GaNDLF provides multiple normalization and rescaling options to ensure intensity-level harmonization of the entire cohort. Some examples include:
Data Augmentation¶
- Defined in the
data_augmentation
parameter of the model configuration. - This parameter controls the various augmentation functions that are applied to the entire image before the patching strategy is applied.
- These should be defined in cognition of the task at hand (for example, RGB augmentations will not work for MRI/CT and other similar radiology images).
- All options can contain a
probability
sub-parameter, which defines the probability of the augmentation being applied to the image. When present, this will supersede thedefault_probability
parameter. - All options can be found here. Some of the most important examples are:
- Radiology-specific augmentations
kspace
: one of eitherghosting
orspiking
is picked for augmentation.bias
: applies a random bias field artefact to the input image using this function.
- RGB-specific augmentations
colorjitter
: applies the ColorJitter transform from PyTorch, has sub-parametersbrightness
,contrast
,saturation
, andhue
.
- General-purpose augmentations
- Spatial transforms: they only change the resolution (and thereby, the shape) of the input image, and only apply interpolation to the intensities for consistency
affine
: applies a random affine transformation to the input image; for details, see this page; has sub-parametersscales
(defining the scaling ranges),degrees
(defining the rotation ranges), andtranslation
(defining the translation ranges in real-world coordinates, which is usually in mm)elastic
: applies a random elastic deformation to the input image; for details, see this page; has sub-parametersnum_control_points
(defining the number of control points),locked_borders
(defining the number of locked borders),max_displacement
(defining the maximum displacement of the control points),num_control_points
(defining the number of control points), andlocked_borders
(defining the number of locked borders).flip
: applies a random flip to the input image; for details, see this page; has sub-parameteraxes
(defining the axes to flip).rotate
: applies a random rotation by 90 degrees (rotate_90
) or 180 degrees (rotate_180
), has sub-parameteraxes
(defining the axes to rotate).swap
: applies a random swap , has sub-parameterpatch_size
(defining the patch size to swap), andnum_iterations
(number of iterations that 2 patches will be swapped).
- Intensity transforms: they change the intensity of the input image, but never the actual resolution or shape.
motion
: applies a random motion blur to the input image using this function.blur
: applies a random Gaussian blur to the input image using this functionl has sub-parameterstd
(defines the standard deviation range).noise
: applies a random noise to the input image using this function; has sub-parametersstd
(defines the standard deviation range) andmean
(defines the mean of the noise to be added).noise_var
: applies a random noise to the input image, however, the with defaultstd = [0, 0.015 * std(image)]
.anisotropic
: applies random anisotropic transform to input image using this function. This changes the resolution and brings it back to its original resolution, thus applying "real-world" interpolation to images.
- Spatial transforms: they only change the resolution (and thereby, the shape) of the input image, and only apply interpolation to the intensities for consistency
- Radiology-specific augmentations
Training Parameters¶
- These are various parameters that control the overall training process.
verbose
: generate verbose messages on console; generally used for debugging.batch_size
: batch size to be used for training.in_memory
: this is to enable or disable lazy loading. If set toTrue
, all data is loaded onto the RAM at once during the construction of the dataloader (either training/validation/testing), thus resulting in faster training. If set toFalse
, data gets read into RAM on-the-go when needed (also called "lazy loading"), which slows down training but lessens the memory load. The latter is recommended if the user's RAM has limited capacity.num_epochs
: number of epochs to train for.patience
: number of epochs to wait for improvement in the validation loss before early stopping.learning_rate
: learning rate to be used for training.scheduler
: learning rate scheduler to be used for training, more details are here; can take the following sub-parameters:type
:triangle
,triangle_modified
,exp
,step
,reduce-on-plateau
,cosineannealing
,triangular
,triangular2
,exp_range
min_lr
: minimum learning rate to be used for training.max_lr
: maximum learning rate to be used for training.
optimizer
: optimizer to be used for training, more details are here.nested_training
: number of folds to use nested training, takestesting
andvalidation
as sub-parameters, with integer values defining the number of folds to use.memory_save_mode
: if enabled, resize/resample operations indata_preprocessing
will save files to disk instead of directly getting read into memory as tensors- Queue configuration: this defines how the queue for the input to the model is to be designed after the patching strategy has been applied, and more details are here. This takes the following sub-parameters:
q_max_length
: his determines the maximum number of patches that can be stored in the queue. Using a large number means that the queue needs to be filled less often, but more CPU memory is needed to store the patches.q_samples_per_volume
: this determines the number of patches to extract from each volume. A small number of patches ensures a large variability in the queue, but training will be slower.q_num_workers
: this determines the number subprocesses to use for data loading; '0' means main process is used, scale this according to available CPU resources.q_verbose
: used to debug the queue
Differentially Private Training¶
GaNDLF supports training differentially private models using Opacus. Here are some resources using which one can train private models:
- TLDR on DP and private training: read this paper and this blog post.
- All options are present in a new key called
differential_privacy
in the config file. It has the following options: noise_multiplier
: The ratio of the standard deviation of the Gaussian noise to the L2-sensitivity of the function to which the noise is added.max_grad_norm
: The maximum norm of the per-sample gradients. Any gradient with norm higher than this will be clipped to this value.accountant
: Accounting mechanism. Currently supported:rdp
(RDPAccountant),gdp
(GaussianAccountant),prv
(PRVAccountant)secure_mode
: Set toTrue
if cryptographically strong DP guarantee is required.secure_mode=True
uses secure random number generator for noise and shuffling (as opposed topseudo-rng
in vanilla PyTorch) and prevents certain floating-point arithmetic-based attacks.allow_opacus_model_fix
: Enabled automated fixing of the model based on Opacus [ref]delta
: Target delta to be achieved. Probability of information being leaked. Use either this orepsilon
.epsilon
: Target epsilon to be achieved, a metric of privacy loss at differential changes in data. Use either this ordelta
.physical_batch_size
: The batch size to use for DP computation (it is usually set lower than the baseline or non-DP batch size). Defaults tobatch_size
.