CNN
CNN Model Class
- tardis_em.cnn.cnn.build_cnn_network(network_type: str, structure: dict, img_size: int, prediction: bool)
Wrapper for building CNN model
Wrapper take CNN parameter and predefined network type (e.g. unet, etc.), and build CNN model.
- Parameters:
network_type (str) – Name of network [unet, resunet, unet3plus, fnet].
structure (dict) – Dictionary with all setting to build CNN.
img_size (int) – Image patch size used for CNN.
prediction (bool) – If true, build CNN in prediction patch.
CNN Train Module
- tardis_em.cnn.train.train_cnn(train_dataloader, test_dataloader, model_structure: dict, checkpoint: str | None = None, loss_function='bce', learning_rate=1, learning_rate_scheduler=False, early_stop_rate=10, device='gpu', warmup=100, epochs=1000)
Wrapper for CNN models.
- Parameters:
train_dataloader (torch.DataLoader) – DataLoader with train dataset.
test_dataloader (torch.DataLoader) – DataLoader with test dataset.
model_structure (dict) – Dictionary with model setting.
checkpoint (None, optional) – Optional, CNN model checkpoint.
loss_function (str) – Type of loss function.
learning_rate (float) – Learning rate.
learning_rate_scheduler (bool) – If True, LR_scheduler is used with training.
early_stop_rate (int) – Define max. number of epoch’s without improvements
stopped. (after which training is)
device (torch.device) – Device on which model is trained.
warmup (int) – Number of warm-up steps.
epochs (int) – Max number of epoch’s.
CNN Trainer Wrapper
- class tardis_em.cnn.trainer.CNNTrainer(**kwargs)
GENERAL CNN TRAINER