Global -> Predictor Handlers

class tardis_em.utils.predictor.GeneralPredictor(predict: str, dir_: str | tuple[ndarray] | ndarray, binary_mask: bool, output_format: str, patch_size: int, convolution_nn: str, cnn_threshold: str, dist_threshold: float, points_in_patch: int, predict_with_rotation: bool, instances: bool, device_: str, debug: bool, checkpoint: list | None = None, model_version: int | None = None, correct_px: float | None = None, normalize_px: float | None = None, amira_prefix: str | None = None, filter_by_length: int | None = None, connect_splines: int | None = None, connect_cylinder: int | None = None, amira_compare_distance: int | None = None, amira_inter_probability: float | None = None, tardis_logo: bool = True)

MAIN WRAPPER FOR PREDICTION MT/MEM WITH TARDIS-PYTORCH

Parameters:
  • predict (str) – Dataset type name.

  • dir (str, np.ndarray) – Dataset directory.

  • output_format (str) – Two output format for semantic and instance prediction.

  • patch_size (int) – Image 3D crop size.

  • cnn_threshold (str) – Threshold for CNN model.

  • dist_threshold (float) – Threshold for DIST model.

  • points_in_patch (int) – Maximum number of points per patched point cloud.

  • predict_with_rotation (bool) – If True, CNN predict with 4 90* rotations.

  • amira_prefix (str) – Optional, Amira file prefix used for spatial graph comparison.

  • filter_by_length (float) – Optional, filter setting for filtering short splines.

  • connect_splines (int) – Optional, filter setting for connecting near splines.

  • connect_cylinder (int) – Optional, filter setting for connecting splines withing cylinder radius.

  • amira_compare_distance (int) – Optional, compare setting, max distance between two splines

  • same. (to consider them as the)

  • amira_inter_probability (float) – Optional, compare setting, portability threshold

  • class. (to define comparison)

  • instances (bool) – If True, run instance segmentation after semantic.

  • device (str) – Define a computation device.

  • debug (bool) – If True, run in debugging mode.

semantic_header

Initial Setup

instance_header

Initial Setup

log_prediction

Initial Setup

omit_format

Build handler’s

create_headers()
init_check()

All sanities check before TARDIS initialize prediction

build_NN(NN: str)
load_data(id_name: str | ndarray)
predict_cnn(id_: int, id_name: str, dataloader)
postprocess_CNN(id_name: str)
preprocess_DIST(id_name: str)
predict_DIST(id_: int, id_name: str)
postprocess_DIST(id_, i)
get_file_list()
log_tardis(id_: int, i: str | ndarray, log_id: float)
save_semantic_mask(i)
save_instance_PC(i)
class tardis_em.utils.predictor.Predictor(device: device, network: str | None = None, checkpoint: str | None = None, subtype: str | None = None, model_version: int | None = None, img_size: int | None = None, model_type: str | None = None, sigma: float | None = None, sigmoid=True, _2d=False, logo=True)

WRAPPER FOR PREDICTION

Args:

device (torch.device): Device on which to predict. checkpoint (str, Optional): Local weights files. network (str, Optional): Optional network type name. subtype (str, Optional): Optional model subtype name. model_type (str, Optional): Optional model type name. model_version (int, Optional): Optional model version. img_size (int, Optional): Optional image patch size. sigmoid (bool): Predict output with sigmoid.

predict(x: Tensor, y: Tensor | None = None, rotate=False) ndarray

General predictor.

Parameters:
  • x (torch.Tensor) – Main feature used for prediction.

  • y (torch.Tensor, None) – Optional feature used for prediction.

  • rotate (bool) – Optional flag for CNN to output avg. From 4x 90* rotation

Returns:

Predicted features.

Return type:

np.ndarray