CNN

CNN Model Class

tardis_em.cnn.cnn.build_cnn_network(network_type: str, structure: dict, img_size: int, prediction: bool)

Builds and returns an instance of a CNN-based network based on the provided network type and configuration structure. The function supports various types of networks such as UNet, ResUNet, UNet3Plus, FNet, and FNet with attention mechanism. Each network type is associated with specific architectural parameters outlined in the structure dictionary. The function validates the input network type and configures the network based on the given attributes.

Parameters:
  • network_type – The type of CNN network to build. Possible values are “unet”, “resunet”, “unet3plus”, “fnet”, or “fnet_attn”.

  • structure – A dictionary containing the structural parameters for the network, including attributes like in_channel, out_channel, dropout, conv_kernel, etc.

  • img_size – An integer specifying the image patch size used in the network. It defines the dimensions of the input image patches.

  • prediction – A boolean indicating whether the network is used in prediction mode or not. Affects specific layers or configurations in the network instance.

Returns:

An instance of the specified CNN network configured with the provided attributes. Returns None if the input network type is invalid.

CNN Train Module

tardis_em.cnn.train.train_cnn(train_dataloader, test_dataloader, model_structure: dict, checkpoint: str | None = None, loss_function='bce', learning_rate=1, learning_rate_scheduler=False, early_stop_rate=10, device='gpu', warmup=100, epochs=1000)

This function trains a Convolutional Neural Network (CNN) using the provided data loaders, model structure, and training parameters. It initializes the model, sets up the training pipeline, and manages checkpoints, learning rate schedulers, and optimization steps. The function supports retraining from checkpoints and includes a variety of loss functions for customization.

Parameters:
  • train_dataloader (torch.utils.data.DataLoader) – A DataLoader used for training the CNN, containing the training dataset.

  • test_dataloader (torch.utils.data.DataLoader) – A DataLoader used for validation/testing the CNN, containing the test dataset.

  • model_structure (dict) – A dictionary defining the model’s structure, including network type, input/output channels, and CNN configurations.

  • checkpoint (Optional[str]) – Optional path to a checkpoint file for retraining a pre-existing model.

  • loss_function (str) – Specifies the loss function to be used during training. Defaults to “bce” (Binary Cross-Entropy). Available options include AdaptiveDiceLoss, BCELoss, and others.

  • learning_rate (float) – The initial learning rate for the optimizer. Defaults to 1. Higher or lower values can impact model convergence.

  • learning_rate_scheduler (bool) – Boolean flag indicating whether to use a learning rate scheduler during training. Defaults to False.

  • early_stop_rate (int) – The number of consecutive epochs of non-improvement in validation metrics before early stopping occurs. Defaults to 10.

  • device (Union[str, torch.device]) – The computational device to run the training on. Can be “gpu”, “cpu”, or a torch.device instance. Defaults to “gpu”.

  • warmup (int) – The number of warmup steps for the optimizer, used when a learning rate scheduler is enabled. Defaults to 100.

  • epochs (int) – The maximum number of training epochs. Defaults to 1000.

Returns:

None

CNN Trainer Wrapper

class tardis_em.cnn.trainer.CNNTrainer(**kwargs)

Class for training and validation of a Convolutional Neural Network (CNN).

Handles the entire process of training and validating a CNN model, including data loading, forward and backward passes, evaluation, progress tracking, and parameter updates. Designed to work with various configurations such as classification tasks, learning rate scheduling, and early stopping based on evaluation metrics.