CNN

CNN Model Class

tardis_em.cnn.cnn.build_cnn_network(network_type: str, structure: dict, img_size: int, prediction: bool)

Wrapper for building CNN model

Wrapper take CNN parameter and predefined network type (e.g. unet, etc.), and build CNN model.

Parameters:
  • network_type (str) – Name of network [unet, resunet, unet3plus, fnet].

  • structure (dict) – Dictionary with all setting to build CNN.

  • img_size (int) – Image patch size used for CNN.

  • prediction (bool) – If true, build CNN in prediction patch.

CNN Train Module

tardis_em.cnn.train.train_cnn(train_dataloader, test_dataloader, model_structure: dict, checkpoint: str | None = None, loss_function='bce', learning_rate=1, learning_rate_scheduler=False, early_stop_rate=10, device='gpu', warmup=100, epochs=1000)

Wrapper for CNN models.

Parameters:
  • train_dataloader (torch.DataLoader) – DataLoader with train dataset.

  • test_dataloader (torch.DataLoader) – DataLoader with test dataset.

  • model_structure (dict) – Dictionary with model setting.

  • checkpoint (None, optional) – Optional, CNN model checkpoint.

  • loss_function (str) – Type of loss function.

  • learning_rate (float) – Learning rate.

  • learning_rate_scheduler (bool) – If True, LR_scheduler is used with training.

  • early_stop_rate (int) – Define max. number of epoch’s without improvements

  • stopped. (after which training is)

  • device (torch.device) – Device on which model is trained.

  • warmup (int) – Number of warm-up steps.

  • epochs (int) – Max number of epoch’s.

CNN Trainer Wrapper

class tardis_em.cnn.trainer.CNNTrainer(**kwargs)

GENERAL CNN TRAINER