DIST
DIST Model Class
- class tardis_em.dist_pytorch.dist.BasicDIST(n_out=1, node_input=0, node_dim=None, edge_dim=128, num_layers=6, num_heads=8, num_cls=None, rgb_embed_sigma=1.0, coord_embed_sigma=1.0, dropout_rate=0, structure='full', predict=False, edge_angles=False)
This class implements the BasicDIST model, a graph-based transformer designed for the prediction of graph edges from input node and edge features. The model can handle node and edge embedding, layer stacking, and decoding mechanisms to provide predictions.
The main purpose of this class is to act as a flexible and modular framework for processing graph-like structures with transformer-based operations. It supports different configurations including embeddings, number of layers, heads, and the ability to predict outputs using a sigmoid activation.
- embed_input(coords: Tensor, node_features: Tensor | None = None)
Embeds input coordinates and optional node features using separate embedding mechanisms.
- Parameters:
coords – A tensor representing the input coordinates of shape [Batch x Length x Coordinate_Dim].
node_features – An optional tensor representing the input node features of shape [Batch x Length x Feature_Dim]. If None, no node features are used in embedding.
- Returns:
A tuple containing the embedded node features and the embedded coordinates. The first element represents the embedded node features of shape [Batch x Length x Embedded_Dim], or None if node_features is not provided. The second element represents the embedded coordinates of shape [Batch x Length x Length x Channels].
- forward(coords: Tensor, node_features=None) Tuple[Tensor, Tensor] | Tensor
Processes input node and edge features through a transformer-based architecture to predict the graph edges.
The method takes as input the coordinates and optional node features of a graph and embeds them. It applies transformer layers for encoding, followed by decoding to predict the graph edges. The predictions are based on the transformed edge features, optionally applying a sigmoid function for binary prediction.
- Parameters:
coords – Coordinates of the nodes in the graph.
node_features (torch.Tensor or None) – Optional features of the nodes in the graph.
- Returns:
Predicted logits for the graph edges. If predict is True, returns logits processed through a sigmoid function for binary prediction.
- Return type:
Union[Tuple[torch.Tensor, torch.Tensor], torch.Tensor]
- class tardis_em.dist_pytorch.dist.DIST(**kwargs)
DIST class, a specialized subclass of BasicDIST.
The DIST class inherits from the BasicDIST class and provides additional functionality or modifications as per its design. It is a flexible implementation that accepts various keyword arguments during instantiation to customize its behavior or configuration. This class is intended as part of a larger system and builds upon the foundational functionality provided by its superclass, BasicDIST. The implementation utilizes keyword arguments to allow dynamic initialization of class attributes or properties in a flexible manner.
- class tardis_em.dist_pytorch.dist.CDIST(**kwargs)
CDIST class inherits from BasicDIST and is used for constructing a classification neural network. It dynamically creates layers based on input arguments and handles functionalities related to logits and predictions. This class serves as a foundational component for classification tasks, ensuring that proper configurations are enforced.
- tardis_em.dist_pytorch.dist.build_dist_network(network_type: str, structure: dict, prediction: bool)
Builds a DISTRIBUTED instance or semantic neural network based on the specified network type, structure parameters, and prediction mode.
This function creates a network object, either DIST for instance segmentation tasks or CDIST for semantic segmentation tasks, depending on the provided network_type. The network is configured dynamically using the provided structure dictionary and the prediction flag, which determines whether the network will operate in prediction mode or not.
- Parameters:
network_type (str) – Specifies the type of the network to be built. Accepted values are “instance” or “semantic”. An error is raised if an unsupported value is provided.
structure (dict) – Dictionary containing all the configuration parameters required for creating the network. These parameters are passed as arguments when instantiating the DIST or CDIST objects.
prediction (bool) – Boolean flag indicating whether the network should be set up in prediction mode.
- Returns:
An instantiated network object of type DIST or CDIST. If an invalid network_type is provided, None is returned.
- Return type:
DIST Train Module
- tardis_em.dist_pytorch.train.train_dist(dataset_type: str, edge_angles: bool, train_dataloader, test_dataloader, model_structure: dict, checkpoint: str | None = None, loss_function='bce', learning_rate=0.001, lr_scheduler=False, early_stop_rate=10, device='gpu', epochs=1000)
Train a DIST (Dimensionless Instance Segmentation via Transformers) model using the provided configurations and hyperparameters. This function supports multiple types of DIST models, including instance, instance-sparse, and semantic segmentation. The training process involves setting up the model, optimizer, loss function, and optionally loading checkpoints before initiating the trainer for the specified number of epochs. The function allows for flexibility in dataset types, device selection, learning rate schedulers, and custom loss functions.
- Parameters:
dataset_type – The type of dataset to be trained on; determines coverage calculation strategy
edge_angles – A boolean indicating whether edge angles are considered in the model
train_dataloader – DataLoader object for the training dataset
test_dataloader – DataLoader object for the testing/validation dataset
model_structure – Dictionary containing the structural specifications for the DIST model
checkpoint – Optional string specifying the path to a checkpoint file for resuming training
loss_function – String specifying the name of the loss function to be used, defaults to “bce”
learning_rate – Float value for the learning rate of the model optimizer, default is 0.001
lr_scheduler – Boolean indicating whether to use a learning rate scheduler, defaults to False
early_stop_rate – Integer value for the patience of early stopping mechanism, defaults to 10
device – String or torch.device indicating which device (e.g., “gpu” or “cpu”) to use for training
epochs – Integer specifying the number of epochs to train the model, default is 1000
- Returns:
None
DIST Trainer Wrapper
- class tardis_em.dist_pytorch.trainer.SparseDistTrainer(**kwargs)
SparseDistTrainer Class
The SparseDistTrainer class is designed for training models with a focus on graph-based data structures and various optimization pipelines. It extends the BasicTrainer class and introduces specific functionalities, such as graph-based thresholding and metrics tracking. The class includes utilities to save metrics, train the model, and validate its performance. It is tailored for models requiring graph propagation and metric-driven checkpoints.
Initialization involves setting up key components and thresholds for various graph configurations. Training and validation methods are provided to execute these processes on loaded datasets. The class incorporates functions to save metrics, display progress updates, and manage checkpoints based on improvements in tracked metrics.
- class tardis_em.dist_pytorch.trainer.DistTrainer(**kwargs)
Handles the training process for a distributed graph-based model leveraging different segmentation thresholds. This class is designed to optimize the training and checkpointing process for models that utilize graph-based representations.
The class uses various greedy graph cut algorithms with different thresholds to generate segmentations. Metrics and model states are saved during training to ensure progress tracking and checkpointing for better model reproducibility.
- class tardis_em.dist_pytorch.trainer.CDistTrainer(**kwargs)
Trainer class for implementing a custom distance-based training strategy.
The CDistTrainer class extends the BasicTrainer and provides functionality for training and validating a model with a specific distance-based loss criterion and evaluation metrics. This trainer is designed for node and edge-based input, handling mid-training evaluations, early stopping based on F1-score, and updating performance metrics such as loss, accuracy, precision, recall, F1 score, and threshold. It supports both training and validation phases using DataLoader objects for respective datasets.