Global -> Trainer Handlers

General NN trainer

class tardis_em.utils.trainer.ISR_LR(optimizer: Adam, lr_mul: float, warmup_steps=1000, scale=100)

Handles learning rate scheduling with warmup and scaling for an inner optimizer.

This class provides a custom learning rate scheduling strategy, blending a warmup period followed by scaled, decreasing learning rates. It also wraps core optimizer functionalities like stepping, zeroing gradients, saving, and loading the optimizer’s state. The scheduling ensures a dynamic learning rate control to improve model training stability and efficiency.

load_state_dict(checkpoint: dict)

Loads the state_dict of the optimizer from a given checkpoint.

Parameters:

checkpoint – A dictionary containing the state_dict of the optimizer.

Returns:

None

state_dict()

Wrapper for retrieving Optimizer state dictionary

step()

Step with the inner optimize

zero_grad()

Zero out the gradients with the inner optimizer

get_lr_scale()

Compute scaler for LR

class tardis_em.utils.trainer.BasicTrainer(model, structure: dict, device: device, criterion, optimizer: ISR_LR | Adam, print_setting: tuple, training_DataLoader, validation_DataLoader=None, lr_scheduler=False, epochs=100, early_stop_rate=10, instance_cov=2, checkpoint_name='DIST', classification=False)

This class initializes a trainer for machine learning models, with the capability to handle various neural network structures and configurations. It includes train and validation data loaders, optional learning rate scheduling, early stopping, and metric tracking. It is targeted for both classification and distributed computation tasks.

run_trainer()

Executes the training process for a machine learning model. This method handles the initialization of necessary components such as progress bars, early stopping, and training directories. It iteratively trains and validates the model over a specified number of epochs, updating progress and metrics. The process supports early stopping to terminate training if a certain condition is met.

Raises:

FileExistsError – If errors occur while handling directory setup.

Returns:

None

Loss functions

class tardis_em.utils.losses.AbstractLoss(smooth=1e-16, reduction='mean', diagonal=False, sigmoid=True)
activate(logits)
ignor_diagonal(logits, targets, mask=False)

Applies processing to logits and targets to ignore the diagonal entries when the diagonal condition is set. The behavior changes based on the value of the mask parameter. If masking is enabled, the diagonal entries of logits and targets are nullified. Without masking, the diagonal entries of logits and targets are set to 1.

Parameters:
  • logits – A tensor representing unnormalized log probabilities of predicted classes.

  • targets – A tensor representing the actual target classes.

  • mask – A boolean indicating whether to apply zero masking on the diagonal elements of logits and targets.

Returns:

A tuple containing the processed logits and targets tensors, with the diagonal entries adjusted based on the mask parameter.

initialize_tensors(logits, targets, mask)

Initializes tensors by applying activation to the logits and optionally ignoring diagonal in graph structures.

Parameters:
  • logits (Tensor) – The input tensor representing logits which undergo activation.

  • targets (Tensor) – A tensor representing target values.

  • mask (Tensor) – A boolean tensor to apply masking for ignoring specific elements, such as diagonal values in graphs.

Returns:

Activated logits tensor with potential diagonal values ignored based on the provided mask.

Return type:

Tensor

abstract forward(logits: Tensor, targets: Tensor, mask=False)

Computes the forward pass of the layer.

Parameters:
  • logits – The predicted logits represented as a PyTorch Tensor.

  • targets – The ground-truth values represented as a PyTorch Tensor.

  • mask – A boolean flag. If set to True, a mask will be applied during computation. Defaults to False.

Returns:

The computed result as a PyTorch Tensor.

class tardis_em.utils.losses.AdaptiveDiceLoss(alpha=0.1, **kwargs)

Computes the Adaptive Dice Loss, a metric commonly used in image segmentation that combines the concept of Dice loss with an adaptive exponent to control the weight given to false negatives. The method calculates the similarity between predicted outputs (logits) and target values while considering class imbalance.

This loss function is particularly useful in medical image analysis or other scenarios where the regions of interest can be very small compared to the overall image size, making it difficult for traditional loss functions to perform well.

forward(logits: Tensor, targets: Tensor, mask=False) Tensor

Computes the Soft Weighted Dice Loss between a predicted tensor (logits) and a target tensor.

class tardis_em.utils.losses.BCELoss(**kwargs)

Binary Cross-Entropy Loss (BCELoss) class.

This class is responsible for computing the binary cross-entropy loss between predicted logits and target values. It inherits from the AbstractLoss class and uses PyTorch’s nn.BCELoss.

The purpose of this class is to provide a specialized loss computation for binary classification tasks or similar problems where binary labels are involved. It supports masking functionality to handle specific use cases where part of the data needs to be ignored during loss calculation.

forward(logits: Tensor, targets: Tensor, mask=True) Tensor

Computes the BCE loss between the logits and targets.

class tardis_em.utils.losses.BCEGraphWiseLoss(**kwargs)

Handles binary cross-entropy loss computation on a graph-wise level.

This class extends AbstractLoss and is designed to compute binary cross-entropy (BCE) loss with specific functionality for handling graph-based input and output. It applies BCE loss separately to positive and negative targets, allowing for different treatment of these cases. The loss is calculated using a mask to focus only on specific parts of the logits and targets.

forward(logits: Tensor, targets: Tensor, mask=True) Tensor

Computes and returns the combined loss for positive and negative examples.

class tardis_em.utils.losses.BCEDiceLoss(**kwargs)

The BCEDiceLoss class combines Binary Cross-Entropy (BCE) loss and Dice loss to support both pixel-wise classification and segmentation tasks. It is designed to accommodate weighted or masked losses and is particularly useful in applications such as medical image segmentation where overlapping metric-based losses are advantageous.

forward(logits: Tensor, targets: Tensor, mask=False) Tensor

Computes the combined loss by summing Binary Cross-Entropy (BCE) loss and Dice loss.

class tardis_em.utils.losses.CELoss(**kwargs)

Implements a Cross-Entropy Loss (CELoss) for neural networks.

Cross-Entropy loss is commonly used in classification problems as a measure of how well the predicted probability distribution aligns with the actual distribution. This implementation supports optional masking for specific elements during the computation of the loss.

forward(logits: Tensor, targets: Tensor, mask=False) Tensor

Forward loos function

class tardis_em.utils.losses.DiceLoss(**kwargs)

Computes the Dice Loss, primarily used for image segmentation tasks.

This loss function is frequently used to evaluate the overlap between predicted segmentations and ground truth segmentations. It is designed to penalize predictions that deviate from the ground truth, especially in situations where the classes are imbalanced. Dice Loss is defined as 1 - Dice Coefficient and is differentiable, making it useful for optimization in deep learning models.

forward(logits: Tensor, targets: Tensor, mask=False) Tensor

Forward loos function

class tardis_em.utils.losses.LaplacianEigenmapsLoss(**kwargs)

Implements a loss function based on Laplacian Eigenmaps. This loss function compares the smallest non-zero eigenvectors of Laplacian matrices derived from true and predicted adjacency matrices. It uses mean squared error to quantify the discrepancy.

This class leverages nn.MSELoss internally to compute the error between eigenvectors. The Laplacian matrix is computed as the degree matrix minus the adjacency matrix. The loss is primarily designed for graph-structured data.

static compute_laplacian(A: Tensor) Tensor

Computes the Laplacian matrix of a given adjacency matrix.

Parameters:

A (torch.Tensor) – The adjacency matrix represented as a tensor. Each element in the matrix defines the connection weights between nodes in a graph.

Returns:

The Laplacian matrix computed from the given adjacency matrix.

Return type:

torch.Tensor

forward(logits: Tensor, targets: Tensor, mask=False) Tensor

Computes the Laplacian-Eigenmaps loss between the true and predicted adjacency matrices.

class tardis_em.utils.losses.SoftSkeletonization(_iter=5, **kwargs)

Implements a loss function based on soft skeletonization.

This class provides a mechanism to compute a specific type of loss by applying soft morphological operations (erosion, dilation, opening) iteratively to extract soft skeletal representations of binary target masks. Primarily used for tasks requiring representation of skeletonized structures in binary masks.

soft_skel(binary_mask: Tensor, iter_i: int) Tensor

Extracts a soft skeleton from a binary mask using iterative erosion and morphological opening techniques. The algorithm starts by computing a partial skeleton and refines it iteratively over a user-defined number of iterations. This operation is used in image processing tasks to extract a skeletonized representation of binary masks.

Parameters:
  • binary_mask – Input binary mask as a tensor on which the soft skeleton operation will be performed.

  • iter_i – Number of iterations for which the skeletonization process will be performed. This controls the refinement of the resulting skeleton.

Returns:

The resulting soft skeleton as a tensor after processing the input binary mask and applying iterative operations.

forward(logits: Tensor, targets: Tensor, mask=False) Tensor

Forward loos function

class tardis_em.utils.losses.ClBCELoss(**kwargs)

Computes a combined Binary Cross-Entropy (BCE) loss and soft skeletonization- based class-sensitive loss. This class extends SoftSkeletonization and aims to address challenges in pixel-based confidence predictions for segmentation tasks.

The main purpose of this loss function is to enhance performance by integrating standard BCE loss with a sensitivity and precision balance derived from skeletonized inputs. It calculates a harmonized ClBCE loss by prioritizing critical regions of the predictions while maintaining global and structural integrity.

forward(logits: Tensor, targets: Tensor, mask=False) Tensor

Forward loss function

class tardis_em.utils.losses.ClDiceLoss(**kwargs)

The ClDiceLoss class combines Soft Dice Loss with a topology-preserving loss based on soft skeletonization, referred to as clDice loss.

This class is designed to compute the clDice loss, which aligns with the Dice Loss for pixel-wise accuracy while incorporating additional terms to preserve the topology of structures in binary segmentation tasks. The clDice loss employs soft skeletonization techniques to evaluate the similarity between skeletonized predictions and ground truth.

forward(logits: Tensor, targets: Tensor, mask=False) Tensor

Forward loss function

class tardis_em.utils.losses.SigmoidFocalLoss(gamma=0.25, alpha=None, **kwargs)

Implements the Sigmoid Focal Loss function, a type of loss function often used for addressing class imbalance problems in classification tasks.

The function applies a modulating factor to the standard cross-entropy criterion to focus learning more on hard-to-classify examples. It includes parameters like gamma for focusing and an optional alpha for handling class imbalance.

forward(logits: Tensor, targets: Tensor, mask=False) Tensor

Computes the sigmoid focal loss between the logits and targets.

class tardis_em.utils.losses.WBCELoss(**kwargs)

Defines the WBCELoss class which calculates a Weighted Binary Cross-Entropy (BCE) loss. This loss function is useful in scenarios where there is a significant class imbalance, as it allows for custom weighting of positive and negative classes. The loss calculation takes into account the logits (predicted probabilities), the target values, and optional parameters for masking and adjusting class weights.

Its primary purpose is to compute a more flexible BCE loss that accommodates customizable weight scaling for positive and negative classes, especially in datasets with imbalanced distributions.

forward(logits: Tensor, targets: Tensor, mask=False, pos=1, neg=0.1) Tensor

Computes the weighted BCE loss between the logits and targets.

class tardis_em.utils.losses.BCEMSELoss(mse_weight=0.1, **kwargs)

Combines Binary Cross Entropy (BCE) loss and Mean Squared Error (MSE) loss for enhanced predictive modeling.

This class is designed to compute a loss function that combines BCE for binary classification tasks and MSE for additional temporal consistency by penalizing differences between adjacent frames. The mse_weight parameter controls the relative contribution of the MSE to the final loss value.

forward(logits: Tensor, targets: Tensor, mask=True) Tensor

Computes the BCE loss between the logits and targets.

Training metrics

tardis_em.utils.metrics.compare_dict_metrics(last_best_dict: dict, new_dict: dict) bool

Compare the average metric values of two dictionaries and return whether the new dictionary has a higher average than the last best dictionary. The metric comparison is performed by calculating the average of the dictionary values.

Parameters:
  • last_best_dict – A dictionary representing the last best metrics. Values must be numeric and will be used to compute their average.

  • new_dict – A dictionary representing the new metrics. Values must be numeric and will be used to compute their average.

Returns:

A boolean indicating whether the average value of metrics in the new dictionary is greater than the average value of the last best dictionary.

tardis_em.utils.metrics.eval_graph_f1(logits: Tensor, targets: Tensor, threshold: float, soft=False)

Evaluates the Graph-F1 metric for given logits and targets using a threshold-based approach with an option to calculate the soft variant.

The function supports both soft approximation and threshold-based calculation to determine precision, recall, accuracy, and F1-score of predictions. It uses specific configurations for masking the diagonal elements and applies varied thresholds to optimize performance metrics. This is particularly useful in evaluating graph-based predictions.

Parameters:
  • logits – Prediction scores or probabilities from the model (e.g., output of a neural network).

  • targets – True binary labels corresponding to the predictions.

  • threshold – A threshold value for converting logits into binary predictions. Determines decision boundaries.

  • soft – If True, computes a soft approximation of metrics; otherwise, uses threshold-based evaluation.

Returns:

If soft is True, returns precision cost, recall cost, and F1 cost as tensors. Otherwise, returns averaged accuracy score, precision score, recall score, F1 score, and selected threshold.

tardis_em.utils.metrics.calculate_f1(logits: ndarray | Tensor, targets: ndarray | Tensor, best_f1=True)

Calculates evaluation metrics for binary classification tasks, such as F1 score, precision, recall, and accuracy. Depending on the best_f1 flag, the function either computes metrics for a specific threshold or iteratively finds the threshold yielding the best F1 score.

Parameters:
  • logits (Union[numpy.ndarray, torch.Tensor]) – Predictions or model outputs, either as probabilities or logits.

  • targets (Union[numpy.ndarray, torch.Tensor]) – Ground-truth binary targets for the classification task.

  • best_f1 (bool) – Flag to enable iterative calculation to find the threshold for the highest F1 score. If False, metrics are calculated directly.

Returns:

A tuple containing metrics - accuracy, precision, recall, F1 score, and (if applicable) the best threshold.

Return type:

Tuple[float, float, float, float, Optional[float]]

tardis_em.utils.metrics.AP(logits: ndarray, targets: ndarray) float
tardis_em.utils.metrics.AP_instance(input_n: ndarray, targets: ndarray) float

Compute the average precision (AP) for the given input and target instances. The function compares the input predictions with ground truth instances and calculates precision values based on the best matches. Precision is evaluated as the ratio of true positives to the total number of positive detections. The final AP is normalized by the total number of unique target instances.

Parameters:
  • input_n – A 2D numpy array representing predicted instances, where the first column corresponds to instance labels and the remaining columns represent associated features.

  • targets – A 2D numpy array representing ground truth (GT) instances, where the first column corresponds to instance labels and the remaining columns represent associated features.

Returns:

The computed average precision (AP) as a float, normalized across all unique GT instances in the targets dataset.

tardis_em.utils.metrics.AUC(logits: ndarray, targets: ndarray, diagonal=False) float

Computes the Area Under the Curve (AUC) for given logits and targets.

This function calculates the AUC score, which is a measure of the performance of a classification model. It uses logits and targets as input and computes the Receiver Operating Characteristic (ROC) curve for evaluation. If the diagonal parameter is set to True, it modifies the diagonal of the input logits and targets matrices to ensure specific conditions before calculating the AUC.

Parameters:
  • logits – The predicted scores or probabilities as a numpy array. Can be 2D or 3D.

  • targets – Ground truth binary labels as a numpy array. Must match the shape of logits.

  • diagonal – A boolean indicating if the diagonal elements of logits/targets matrices should be altered before computation. Default is False.

Returns:

The computed AUC metric as a float.

tardis_em.utils.metrics.IoU(input_n: ndarray, targets: ndarray, diagonal=False)

Compute the Intersection Over Union (IoU) metric for given input and targets.

The IoU is a performance metric commonly used for evaluating segmentation models in computer vision. It measures the overlap between the ground-truth target and predicted values. The optional ‘diagonal’ parameter allows modifications by excluding diagonal elements in the computation, particularly useful in multi-class datasets.

Parameters:
  • input_n – Input numpy array, typically model predictions. Can be a 2D or 3D array.

  • targets – Target numpy array, representing the ground-truth labels. Must match the shape of the input.

  • diagonal – Flag to enforce diagonal elements to be 1. If True, modifies the diagonal elements in the input and targets arrays.

Returns:

Computed IoU value as a floating-point number.

tardis_em.utils.metrics.mcov(input_n, targets)

Calculate the mean Coverage (mCov) and weighted mean Coverage (mwCov) for given input and target data.

This function computes the coverage metrics evaluating how well given input instances match with ground truth (GT) instances based on Intersection over Union (IoU). It considers both the ratio of the instance size (w_g) relative to the total size and overall mean matching.

Parameters:
  • input_n (numpy.ndarray) – Input point cloud data where the first column identifies instance labels and the remaining columns represent corresponding coordinates or features.

  • targets (numpy.ndarray) – Ground truth (GT) point cloud data with the first column identifying instance labels and the remaining columns representing corresponding coordinates or features.

Returns:

A tuple containing: - mCov (float): Mean Coverage across all GT instances. - mwCov (float): Weighted Mean Coverage based on instance ratio.

Return type:

tuple[float, float]

tardis_em.utils.metrics.confusion_matrix(logits: ndarray | Tensor, targets: ndarray | Tensor)

Calculates the confusion matrix components for the provided logits and targets. The confusion matrix consists of True Positives (TP), False Positives (FP), True Negatives (TN), and False Negatives (FN). This function handles both PyTorch tensors and NumPy arrays as input for logits and targets.

Parameters:
  • logits – The predicted values, supporting either PyTorch tensors or NumPy arrays.

  • targets – The ground-truth values, supporting either PyTorch tensors or NumPy arrays.

Returns:

A tuple containing four integer values representing True Positives (TP), False Positives (FP), True Negatives (TN), and False Negatives (FN), computed based on the provided logits and targets.

Return type:

Tuple[int, int, int, int]

tardis_em.utils.metrics.normalize_image(image: ndarray)

Normalizes a given image represented as a NumPy array.

This function ensures that the input image array falls within the defined binary normalization context. If the image has minimum and maximum values already set to 0 and 1, respectively, it is returned unchanged. Otherwise, based on the minimum and maximum values, the image is normalized to binary values (0 or 1).

Parameters:

image – The input image to be normalized, represented as a NumPy array.

Returns:

The normalized image array with binary values, where pixel values are either 0 or 1.