DIST

DIST Model Class

class tardis_em.dist_pytorch.dist.BasicDIST(n_out=1, node_input=0, node_dim=None, edge_dim=128, num_layers=6, num_heads=8, num_cls=None, rgb_embed_sigma=1.0, coord_embed_sigma=1.0, dropout_rate=0, structure='full', predict=False, edge_angles=False)

General DIST FOR DIMENSIONLESS INSTANCE SEGMENTATION TRANSFORMER

Parameters:
  • n_out (int) – Number of channels in the output layer.

  • node_input (int) – Length of the flattened image file.

  • node_dim (int, None) – In features of image for linear transformation.

  • edge_dim (int) – In feature of coord for linear transformation.

  • num_layers (int) – Number of DIST layers to initialize.

  • num_heads (int) – Number of heads for MHA.

  • num_cls (int, None) – Number of predicted classes.

  • coord_embed_sigma (float) – Sigma value used to embed coordinate distance features.

  • dropout_rate (float) – Dropout factor used in MHA dropout layer.

  • structure (str) – DIST network structure. (full, triang, dualtriang, quad, attn)

  • predict (bool) – If True sigmoid output.

embed_input(coords: Tensor, node_features: Tensor | None = None)

Embedding features

Parameters:
  • coords (torch.Tensor) – Coordinate features.

  • node_features (torch.Tensor, None) – Optional Node features.

Returns:

Embedded features for prediction.

Return type:

torch.tensor

forward(coords: Tensor, node_features=None) Tuple[Tensor, Tensor] | Tensor

Forward DIST model.

Parameters:
  • coords (torch.Tensor) – Coordinates input of a shape [Batch x Length x Channels].

  • node_features (torch.Tensor, None) – Image patch input of a shape [Batch x Length x Channels].

class tardis_em.dist_pytorch.dist.DIST(**kwargs)

MAIN DIST FOR DIMENSIONLESS INSTANCE SEGMENTATION TRANSFORMER

This transformer taking into the account the positional encoding of each coordinate point and attend them with patch image to which this coordinate is corresponding. This attention is aiming to training the transformer in outputting a graph from which point cloud can be segmented.

Returns:

DIST prediction after sigmoid (prediction) or last

linear layer (training).

Return type:

torch.Tensor

class tardis_em.dist_pytorch.dist.CDIST(**kwargs)

MAIN DIST FOR CLASSIFYING DIMENSIONLESS INSTANCE SEGMENTATION TRANSFORMER

This transformer taking into the account the positional encoding of each coordinate point and attend them with patch image to which this coordinate is corresponding. This attention is aiming to training the transformer in outputting a graph from which point cloud can be segmented.

Returns:

DIST prediction as well as DIST class prediction, after

sigmoid (prediction) or last linear layer (training).

Return type:

torch.Tensor

tardis_em.dist_pytorch.dist.build_dist_network(network_type: str, structure: dict, prediction: bool)

Wrapper for building DIST model

Wrapper take DIST parameter and predefined network type (e.g. DIST, C_DIST), and build DIST model.

Parameters:
  • network_type (str) – Network type name.

  • structure (dict) – Dictionary with all network setting.

  • prediction (bool) – If True, build network in prediction path.

Returns:

DIST network structure.

Return type:

DIST

DIST Train Module

tardis_em.dist_pytorch.train.train_dist(dataset_type: str, edge_angles: bool, train_dataloader, test_dataloader, model_structure: dict, checkpoint: str | None = None, loss_function='bce', learning_rate=0.001, lr_scheduler=False, early_stop_rate=10, device='gpu', epochs=1000)

Wrapper for DIST or C_DIST models.

Parameters:
  • dataset_type (str) – Type of input dataset.

  • train_dataloader (torch.DataLoader) – DataLoader with train dataset.

  • test_dataloader (torch.DataLoader) – DataLoader with test dataset.

  • model_structure (dict) – Dictionary with model setting.

  • checkpoint (None, optional) – Optional, DIST model checkpoint.

  • loss_function (str) – Type of loss function.

  • learning_rate (float) – Learning rate.

  • lr_scheduler (bool) – If True, LR_scheduler is used with training.

  • early_stop_rate (int) – Define max. number of epoch’s without improvements

  • stopped. (after which training is)

  • device (torch.device) – Device on which model is trained.

  • epochs (int) – Max number of epoch’s.

DIST Trainer Wrapper

class tardis_em.dist_pytorch.trainer.SparseDistTrainer(**kwargs)

DIST MODEL TRAINER

class tardis_em.dist_pytorch.trainer.DistTrainer(**kwargs)

DIST MODEL TRAINER

class tardis_em.dist_pytorch.trainer.CDistTrainer(**kwargs)

C_DIST MODEL TRAINER