Global -> Trainer Handlers

General NN trainer

class tardis_em.utils.trainer.ISR_LR(optimizer: Adam, lr_mul: float, warmup_steps=1000, scale=100)

Costume Inverse Square Root Learning Rate Schedular

load_state_dict(checkpoint: dict)

Wrapper for loading Optimizer state dictionary


checkpoint (dict) – Dictionary with optimizer state.


Wrapper for retrieving Optimizer state dictionary


Step with the inner optimize


Zero out the gradients with the inner optimizer


Compute scaler for LR

class tardis_em.utils.trainer.BasicTrainer(model, structure: dict, device: device, criterion, optimizer: ISR_LR | Adam, print_setting: tuple, training_DataLoader, validation_DataLoader=None, lr_scheduler=False, epochs=100, early_stop_rate=10, instance_cov=2, checkpoint_name='DIST', classification=False)


  • model (nn.Module) – ML model build with nn.Module or nn.sequential.

  • structure (dict) – Model structure as dictionary.

  • device (torch.device) – Device for training.

  • criterion (nn.loss) – Loss function type.

  • optimizer (optim.Adam, ISR_LR) – Optimizer type.

  • training_DataLoader (torch.DataLoader) – DataLoader with training dataset.

  • validation_DataLoader (torch.DataLoader, optional) – DataLoader with test dataset.

  • print_setting (tuple) – Model property to display in TARDIS progress bar.

  • lr_scheduler (bool) – Optional Learning rate schedular.

  • epochs (int) – Max number of epoch’s.

  • early_stop_rate (int) – Number of epoch’s without improvement after which Trainer stop training.

  • checkpoint_name (str) – Name of the checkpoint.


Main training loop.

Loss functions

class tardis_em.utils.losses.AbstractLoss(smooth=1e-16, reduction='mean', diagonal=False, sigmoid=True)
ignor_diagonal(logits, targets, mask=False)
initialize_tensors(logits, targets, mask)
abstract forward(logits: Tensor, targets: Tensor, mask=False)
  • logits (torch.Tensor) – The predicted logits. Shape: [Batch x Channels x …].

  • targets (torch.Tensor) – The target values. Shape: [Batch x Channels x …].

  • mask (bool) – If Ture, output mask diagonal axis.


Computed loss function.

Return type:


class tardis_em.utils.losses.AdaptiveDiceLoss(alpha=0.1, **kwargs)

Implements an adaptive Dice loss function, which gives more weight to false negatives.

The AdaptiveDiceLoss loss function is a variant of the standard Dice loss, with an additional adaptive term (1 - logits) ** self.alpha applied to logits. This term will give higher weight to false negatives (i.e., the cases where the prediction is low but the ground truth is high), which can be useful in cases where these are particularly costly.

forward(logits: Tensor, targets: Tensor, mask=False) Tensor

Computes the adaptive Dice loss between the logits and targets.

class tardis_em.utils.losses.BCELoss(**kwargs)

Implements the Binary Cross-Entropy loss function with an option to ignore the diagonal elements.

The BCELoss class can be used for training where pixel-level accuracy is important.

forward(logits: Tensor, targets: Tensor, mask=True) Tensor

Computes the BCE loss between the logits and targets.

class tardis_em.utils.losses.BCEGraphWiseLoss(**kwargs)

Implements the Binary Cross-Entropy loss function with an option to ignore the diagonal elements.

The BCELoss class can be used for training where pixel-level accuracy is important.

forward(logits: Tensor, targets: Tensor, mask=True) Tensor

Computes the BCE loss between the logits and targets.

class tardis_em.utils.losses.BCEDiceLoss(**kwargs)


forward(logits: Tensor, targets: Tensor, mask=False) Tensor

Forward loos function

class tardis_em.utils.losses.CELoss(**kwargs)


forward(logits: Tensor, targets: Tensor, mask=False) Tensor

Forward loos function

class tardis_em.utils.losses.DiceLoss(**kwargs)

Dice coefficient loss function.

Dice=2(A∩B)(A)+(B); where ‘A∩B’ represents the common elements between sets A and B ‘A’ ann ‘B’ represents the number of elements in set A ans set B

This loss effectively zero-out any pixels from our prediction which are not “activated” in the target mask.

forward(logits: Tensor, targets: Tensor, mask=False) Tensor

Forward loos function

class tardis_em.utils.losses.LaplacianEigenmapsLoss(**kwargs)

A loss function for deep learning models that computes the mean squared error between the first non-zero eigenvectors of the Laplacian matrices of the ground truth and predicted adjacency matrices.

static compute_laplacian(A: Tensor) Tensor

Computes the Laplacian matrix of an adjacency matrix.


A (torch.Tensor) – The adjacency matrix.


The Laplacian matrix.

forward(logits: Tensor, targets: Tensor, mask=False) Tensor

Computes the Laplacian-Eigenmaps loss between the true and predicted adjacency matrices.

class tardis_em.utils.losses.SoftSkeletonization(_iter=5, **kwargs)

General soft skeletonization with DICE loss function

soft_skel(binary_mask: Tensor, iter_: int) Tensor

Soft skeletonization

  • binary_mask – Binary target mask.

  • iter – Number of iterations for erosion.


Skeleton on-hot mask

Return type:


forward(logits: Tensor, targets: Tensor, mask=False) Tensor

Forward loos function

class tardis_em.utils.losses.ClBCELoss(**kwargs)

Soft skeletonization with BCE loss function

Implements a custom version of the Binary Cross Entropy (BCE) loss, where an additional term is added to the standard BCE loss. This additional term is a kind of F1 score calculated on the soft-skeletonized version of the predicted and target masks.

forward(logits: Tensor, targets: Tensor, mask=False) Tensor

Forward loss function

class tardis_em.utils.losses.ClDiceLoss(**kwargs)

Soft skeletonization with DICE loss function

Implements a custom version of the Dice loss, where an additional term is added to the standard BCE loss. This additional term is a kind of F1 score calculated on the soft-skeletonized version of the predicted and target masks.

forward(logits: Tensor, targets: Tensor, mask=False) Tensor

Forward loss function

class tardis_em.utils.losses.SigmoidFocalLoss(gamma=0.25, alpha=None, **kwargs)

Implements the Sigmoid Focal Loss function with an option to ignore the diagonal elements.

The SigmoidFocalLoss class implements the Focal Loss, which was proposed as a method for focusing the model on hard examples during the training of an object detector. It provides an option to ignore the diagonal elements of the input matrices.

References: 10.1088/1742-6596/1229/1/012045

forward(logits: Tensor, targets: Tensor, mask=False) Tensor

Computes the sigmoid focal loss between the logits and targets.

class tardis_em.utils.losses.WBCELoss(**kwargs)

Implements a weighted Binary Cross-Entropy loss function with an option to ignore the diagonal elements.

The WBCELoss class can help to balance the contribution of positive and negative samples in datasets where one class significantly outnumbers the other. It provides an option to ignore the diagonal elements of the input matrices, which could be useful for applications like graph prediction where self-connections might not be meaningful.

forward(logits: Tensor, targets: Tensor, mask=False, pos=1, neg=0.1) Tensor

Computes the weighted BCE loss between the logits and targets.

class tardis_em.utils.losses.BCEMSELoss(mse_weight=0.1, **kwargs)

Implements the Binary Cross-Entropy over MSE loss function with an option to ignore the diagonal elements.

The BCELoss class can be used for training where pixel-level accuracy is important. The MSE loos is used over continues Z slices to ensure smooth segmentation accuracy.

forward(logits: Tensor, targets: Tensor, mask=True) Tensor

Computes the BCE loss between the logits and targets.

Training metrics

tardis_em.utils.metrics.compare_dict_metrics(last_best_dict: dict, new_dict: dict) bool

Compares two metric dictionaries and returns the one with the highest average metric values.

  • last_best_dict (dict) – The previous best metric dictionary.

  • new_dict (dict) – The new metric dictionary to compare.


True if the new dictionary has a higher average metric value.

Return type:


tardis_em.utils.metrics.eval_graph_f1(logits: Tensor, targets: Tensor, threshold: float, soft=False)

Module used for calculating training metrics

Works with torch a numpy dataset.

  • logits (np.ndarray, torch.Tensor) – Prediction output from the model.

  • targets (np.ndarray, torch.Tensor) – Ground truth mask.

  • threshold (float)

  • soft

tardis_em.utils.metrics.calculate_f1(logits: ndarray | Tensor, targets: ndarray | Tensor, best_f1=True)

Module used for calculating training metrics

Works with torch a numpy dataset.

  • logits (np.ndarray, torch.Tensor) – Prediction output from the model.

  • targets (np.ndarray, torch.Tensor) – Ground truth mask.

  • best_f1 (bool) – If True an expected inputs is probability of classes and measured metrics is soft-f1.

tardis_em.utils.metrics.AP(logits: ndarray, targets: ndarray) float
tardis_em.utils.metrics.AP_instance(input_: ndarray, targets: ndarray) float
tardis_em.utils.metrics.AUC(logits: ndarray, targets: ndarray, diagonal=False) float
tardis_em.utils.metrics.IoU(input_: ndarray, targets: ndarray, diagonal=False)
tardis_em.utils.metrics.mcov(input_, targets)

Mean Coverage metric

  • input (np.ndarray, torch.Tensor) – _description_

  • targets (np.ndarray, torch.Tensor) – _description_



Return type:


tardis_em.utils.metrics.confusion_matrix(logits: ndarray | Tensor, targets: ndarray | Tensor)
tardis_em.utils.metrics.normalize_image(image: ndarray)

Simple image data normalizer between 0,1


image (np.ndarray) – Image data set.