Global -> Trainer Handlers
General NN trainer
- class tardis_em.utils.trainer.ISR_LR(optimizer: Adam, lr_mul: float, warmup_steps=1000, scale=100)
Costume Inverse Square Root Learning Rate Schedular
- load_state_dict(checkpoint: dict)
Wrapper for loading Optimizer state dictionary
- Parameters:
checkpoint (dict) – Dictionary with optimizer state.
- state_dict()
Wrapper for retrieving Optimizer state dictionary
- step()
Step with the inner optimize
- zero_grad()
Zero out the gradients with the inner optimizer
- get_lr_scale()
Compute scaler for LR
- class tardis_em.utils.trainer.BasicTrainer(model, structure: dict, device: device, criterion, optimizer: ISR_LR | Adam, print_setting: tuple, training_DataLoader, validation_DataLoader=None, lr_scheduler=False, epochs=100, early_stop_rate=10, instance_cov=2, checkpoint_name='DIST', classification=False)
BASIC MODEL TRAINER FOR DIST AND CNN
- Parameters:
model (nn.Module) – ML model build with nn.Module or nn.sequential.
structure (dict) – Model structure as dictionary.
device (torch.device) – Device for training.
criterion (nn.loss) – Loss function type.
optimizer (optim.Adam, ISR_LR) – Optimizer type.
training_DataLoader (torch.DataLoader) – DataLoader with training dataset.
validation_DataLoader (torch.DataLoader, optional) – DataLoader with test dataset.
print_setting (tuple) – Model property to display in TARDIS progress bar.
lr_scheduler (bool) – Optional Learning rate schedular.
epochs (int) – Max number of epoch’s.
early_stop_rate (int) – Number of epoch’s without improvement after which Trainer stop training.
checkpoint_name (str) – Name of the checkpoint.
- run_trainer()
Main training loop.
Loss functions
- class tardis_em.utils.losses.AbstractLoss(smooth=1e-16, reduction='mean', diagonal=False, sigmoid=True)
- activate(logits)
- ignor_diagonal(logits, targets, mask=False)
- initialize_tensors(logits, targets, mask)
- abstract forward(logits: Tensor, targets: Tensor, mask=False)
- Parameters:
logits (torch.Tensor) – The predicted logits. Shape: [Batch x Channels x …].
targets (torch.Tensor) – The target values. Shape: [Batch x Channels x …].
mask (bool) – If Ture, output mask diagonal axis.
- Returns:
Computed loss function.
- Return type:
torch.Tensor
- class tardis_em.utils.losses.AdaptiveDiceLoss(alpha=0.1, **kwargs)
Implements an adaptive Dice loss function, which gives more weight to false negatives.
The AdaptiveDiceLoss loss function is a variant of the standard Dice loss, with an additional adaptive term (1 - logits) ** self.alpha applied to logits. This term will give higher weight to false negatives (i.e., the cases where the prediction is low but the ground truth is high), which can be useful in cases where these are particularly costly.
- forward(logits: Tensor, targets: Tensor, mask=False) Tensor
Computes the adaptive Dice loss between the logits and targets.
- class tardis_em.utils.losses.BCELoss(**kwargs)
Implements the Binary Cross-Entropy loss function with an option to ignore the diagonal elements.
The BCELoss class can be used for training where pixel-level accuracy is important.
- forward(logits: Tensor, targets: Tensor, mask=True) Tensor
Computes the BCE loss between the logits and targets.
- class tardis_em.utils.losses.BCEGraphWiseLoss(**kwargs)
Implements the Binary Cross-Entropy loss function with an option to ignore the diagonal elements.
The BCELoss class can be used for training where pixel-level accuracy is important.
- forward(logits: Tensor, targets: Tensor, mask=True) Tensor
Computes the BCE loss between the logits and targets.
- class tardis_em.utils.losses.BCEDiceLoss(**kwargs)
DICE + BCE LOSS FUNCTION
- forward(logits: Tensor, targets: Tensor, mask=False) Tensor
Forward loos function
- class tardis_em.utils.losses.CELoss(**kwargs)
STANDARD CROSS-ENTROPY LOSS FUNCTION
- forward(logits: Tensor, targets: Tensor, mask=False) Tensor
Forward loos function
- class tardis_em.utils.losses.DiceLoss(**kwargs)
Dice coefficient loss function.
Dice=2(A∩B)(A)+(B); where ‘A∩B’ represents the common elements between sets A and B ‘A’ ann ‘B’ represents the number of elements in set A ans set B
This loss effectively zero-out any pixels from our prediction which are not “activated” in the target mask.
- forward(logits: Tensor, targets: Tensor, mask=False) Tensor
Forward loos function
- class tardis_em.utils.losses.LaplacianEigenmapsLoss(**kwargs)
A loss function for deep learning models that computes the mean squared error between the first non-zero eigenvectors of the Laplacian matrices of the ground truth and predicted adjacency matrices.
- static compute_laplacian(A: Tensor) Tensor
Computes the Laplacian matrix of an adjacency matrix.
- Parameters:
A (torch.Tensor) – The adjacency matrix.
- Returns:
The Laplacian matrix.
- forward(logits: Tensor, targets: Tensor, mask=False) Tensor
Computes the Laplacian-Eigenmaps loss between the true and predicted adjacency matrices.
- class tardis_em.utils.losses.SoftSkeletonization(_iter=5, **kwargs)
General soft skeletonization with DICE loss function
- soft_skel(binary_mask: Tensor, iter_: int) Tensor
Soft skeletonization
- Parameters:
binary_mask – Binary target mask.
iter – Number of iterations for erosion.
- Returns:
Skeleton on-hot mask
- Return type:
torch.Tensor
- forward(logits: Tensor, targets: Tensor, mask=False) Tensor
Forward loos function
- class tardis_em.utils.losses.ClBCELoss(**kwargs)
Soft skeletonization with BCE loss function
Implements a custom version of the Binary Cross Entropy (BCE) loss, where an additional term is added to the standard BCE loss. This additional term is a kind of F1 score calculated on the soft-skeletonized version of the predicted and target masks.
- forward(logits: Tensor, targets: Tensor, mask=False) Tensor
Forward loss function
- class tardis_em.utils.losses.ClDiceLoss(**kwargs)
Soft skeletonization with DICE loss function
Implements a custom version of the Dice loss, where an additional term is added to the standard BCE loss. This additional term is a kind of F1 score calculated on the soft-skeletonized version of the predicted and target masks.
- forward(logits: Tensor, targets: Tensor, mask=False) Tensor
Forward loss function
- class tardis_em.utils.losses.SigmoidFocalLoss(gamma=0.25, alpha=None, **kwargs)
Implements the Sigmoid Focal Loss function with an option to ignore the diagonal elements.
The SigmoidFocalLoss class implements the Focal Loss, which was proposed as a method for focusing the model on hard examples during the training of an object detector. It provides an option to ignore the diagonal elements of the input matrices.
References: 10.1088/1742-6596/1229/1/012045
- forward(logits: Tensor, targets: Tensor, mask=False) Tensor
Computes the sigmoid focal loss between the logits and targets.
- class tardis_em.utils.losses.WBCELoss(**kwargs)
Implements a weighted Binary Cross-Entropy loss function with an option to ignore the diagonal elements.
The WBCELoss class can help to balance the contribution of positive and negative samples in datasets where one class significantly outnumbers the other. It provides an option to ignore the diagonal elements of the input matrices, which could be useful for applications like graph prediction where self-connections might not be meaningful.
- forward(logits: Tensor, targets: Tensor, mask=False, pos=1, neg=0.1) Tensor
Computes the weighted BCE loss between the logits and targets.
- class tardis_em.utils.losses.BCEMSELoss(mse_weight=0.1, **kwargs)
Implements the Binary Cross-Entropy over MSE loss function with an option to ignore the diagonal elements.
The BCELoss class can be used for training where pixel-level accuracy is important. The MSE loos is used over continues Z slices to ensure smooth segmentation accuracy.
- forward(logits: Tensor, targets: Tensor, mask=True) Tensor
Computes the BCE loss between the logits and targets.
Training metrics
- tardis_em.utils.metrics.compare_dict_metrics(last_best_dict: dict, new_dict: dict) bool
Compares two metric dictionaries and returns the one with the highest average metric values.
- Parameters:
last_best_dict (dict) – The previous best metric dictionary.
new_dict (dict) – The new metric dictionary to compare.
- Returns:
True if the new dictionary has a higher average metric value.
- Return type:
bool
- tardis_em.utils.metrics.eval_graph_f1(logits: Tensor, targets: Tensor, threshold: float, soft=False)
Module used for calculating training metrics
Works with torch a numpy dataset.
- Parameters:
logits (np.ndarray, torch.Tensor) – Prediction output from the model.
targets (np.ndarray, torch.Tensor) – Ground truth mask.
threshold (float)
soft
- tardis_em.utils.metrics.calculate_f1(logits: ndarray | Tensor, targets: ndarray | Tensor, best_f1=True)
Module used for calculating training metrics
Works with torch a numpy dataset.
- Parameters:
logits (np.ndarray, torch.Tensor) – Prediction output from the model.
targets (np.ndarray, torch.Tensor) – Ground truth mask.
best_f1 (bool) – If True an expected inputs is probability of classes and measured metrics is soft-f1.
- tardis_em.utils.metrics.AP(logits: ndarray, targets: ndarray) float
- tardis_em.utils.metrics.AP_instance(input_: ndarray, targets: ndarray) float
- tardis_em.utils.metrics.AUC(logits: ndarray, targets: ndarray, diagonal=False) float
- tardis_em.utils.metrics.IoU(input_: ndarray, targets: ndarray, diagonal=False)
- tardis_em.utils.metrics.mcov(input_, targets)
Mean Coverage metric
- Parameters:
input (np.ndarray, torch.Tensor) – _description_
targets (np.ndarray, torch.Tensor) – _description_
- Returns:
_description_
- Return type:
float
- tardis_em.utils.metrics.confusion_matrix(logits: ndarray | Tensor, targets: ndarray | Tensor)
- tardis_em.utils.metrics.normalize_image(image: ndarray)
Simple image data normalizer between 0,1
- Parameters:
image (np.ndarray) – Image data set.