DIST
DIST Model Class
- class tardis_em.dist_pytorch.dist.BasicDIST(n_out=1, node_input=0, node_dim=None, edge_dim=128, num_layers=6, num_heads=8, num_cls=None, rgb_embed_sigma=1.0, coord_embed_sigma=1.0, dropout_rate=0, structure='full', predict=False, edge_angles=False)
General DIST FOR DIMENSIONLESS INSTANCE SEGMENTATION TRANSFORMER
- Parameters:
n_out (int) – Number of channels in the output layer.
node_input (int) – Length of the flattened image file.
node_dim (int, None) – In features of image for linear transformation.
edge_dim (int) – In feature of coord for linear transformation.
num_layers (int) – Number of DIST layers to initialize.
num_heads (int) – Number of heads for MHA.
num_cls (int, None) – Number of predicted classes.
coord_embed_sigma (float) – Sigma value used to embed coordinate distance features.
dropout_rate (float) – Dropout factor used in MHA dropout layer.
structure (str) – DIST network structure. (full, triang, dualtriang, quad, attn)
predict (bool) – If True sigmoid output.
- embed_input(coords: Tensor, node_features: Tensor | None = None)
Embedding features
- Parameters:
coords (torch.Tensor) – Coordinate features.
node_features (torch.Tensor, None) – Optional Node features.
- Returns:
Embedded features for prediction.
- Return type:
torch.tensor
- forward(coords: Tensor, node_features=None) Tuple[Tensor, Tensor] | Tensor
Forward DIST model.
- Parameters:
coords (torch.Tensor) – Coordinates input of a shape [Batch x Length x Channels].
node_features (torch.Tensor, None) – Image patch input of a shape [Batch x Length x Channels].
- class tardis_em.dist_pytorch.dist.DIST(**kwargs)
MAIN DIST FOR DIMENSIONLESS INSTANCE SEGMENTATION TRANSFORMER
This transformer taking into the account the positional encoding of each coordinate point and attend them with patch image to which this coordinate is corresponding. This attention is aiming to training the transformer in outputting a graph from which point cloud can be segmented.
- Returns:
- DIST prediction after sigmoid (prediction) or last
linear layer (training).
- Return type:
torch.Tensor
- class tardis_em.dist_pytorch.dist.CDIST(**kwargs)
MAIN DIST FOR CLASSIFYING DIMENSIONLESS INSTANCE SEGMENTATION TRANSFORMER
This transformer taking into the account the positional encoding of each coordinate point and attend them with patch image to which this coordinate is corresponding. This attention is aiming to training the transformer in outputting a graph from which point cloud can be segmented.
- Returns:
- DIST prediction as well as DIST class prediction, after
sigmoid (prediction) or last linear layer (training).
- Return type:
torch.Tensor
- tardis_em.dist_pytorch.dist.build_dist_network(network_type: str, structure: dict, prediction: bool)
Wrapper for building DIST model
Wrapper take DIST parameter and predefined network type (e.g. DIST, C_DIST), and build DIST model.
- Parameters:
network_type (str) – Network type name.
structure (dict) – Dictionary with all network setting.
prediction (bool) – If True, build network in prediction path.
- Returns:
DIST network structure.
- Return type:
DIST Train Module
- tardis_em.dist_pytorch.train.train_dist(dataset_type: str, edge_angles: bool, train_dataloader, test_dataloader, model_structure: dict, checkpoint: str | None = None, loss_function='bce', learning_rate=0.001, lr_scheduler=False, early_stop_rate=10, device='gpu', epochs=1000)
Wrapper for DIST or C_DIST models.
- Parameters:
dataset_type (str) – Type of input dataset.
edge_angles (bool) – If True, use an angle for embedding.
train_dataloader (torch.DataLoader) – DataLoader with train dataset.
test_dataloader (torch.DataLoader) – DataLoader with test dataset.
model_structure (dict) – Dictionary with model setting.
checkpoint (None, optional) – Optional, DIST model checkpoint.
loss_function (str) – Type of loss function.
learning_rate (float) – Learning rate.
lr_scheduler (bool) – If True, LR_scheduler is used with training.
early_stop_rate (int) – Define max. number of epoch’s without improvements
stopped. (after which training is)
device (torch.device) – Device on which model is trained.
epochs (int) – Max number of epoch’s.
DIST Trainer Wrapper
- class tardis_em.dist_pytorch.trainer.SparseDistTrainer(**kwargs)
DIST MODEL TRAINER
- class tardis_em.dist_pytorch.trainer.DistTrainer(**kwargs)
DIST MODEL TRAINER
- class tardis_em.dist_pytorch.trainer.CDistTrainer(**kwargs)
C_DIST MODEL TRAINER