Global -> Predictor Handlers
- class tardis_em.utils.predictor.GeneralPredictor(predict: str, dir_: str | tuple[ndarray] | ndarray, binary_mask: bool, output_format: str, patch_size: int, convolution_nn: str, cnn_threshold: str, dist_threshold: float, points_in_patch: int, predict_with_rotation: bool, instances: bool, device_: str, debug: bool, checkpoint: list | None = None, model_version: int | None = None, correct_px: float | None = None, normalize_px: float | None = None, amira_prefix: str | None = None, filter_by_length: int | None = None, connect_splines: int | None = None, connect_cylinder: int | None = None, amira_compare_distance: int | None = None, amira_inter_probability: float | None = None, tardis_logo: bool = True)
MAIN WRAPPER FOR PREDICTION MT/MEM WITH TARDIS-PYTORCH
- Parameters:
predict (str) – Dataset type name.
dir (str, np.ndarray) – Dataset directory.
output_format (str) – Two output format for semantic and instance prediction.
patch_size (int) – Image 3D crop size.
cnn_threshold (str) – Threshold for CNN model.
dist_threshold (float) – Threshold for DIST model.
points_in_patch (int) – Maximum number of points per patched point cloud.
predict_with_rotation (bool) – If True, CNN predict with 4 90* rotations.
amira_prefix (str) – Optional, Amira file prefix used for spatial graph comparison.
filter_by_length (float) – Optional, filter setting for filtering short splines.
connect_splines (int) – Optional, filter setting for connecting near splines.
connect_cylinder (int) – Optional, filter setting for connecting splines withing cylinder radius.
amira_compare_distance (int) – Optional, compare setting, max distance between two splines
same. (to consider them as the)
amira_inter_probability (float) – Optional, compare setting, portability threshold
class. (to define comparison)
instances (bool) – If True, run instance segmentation after semantic.
device (str) – Define a computation device.
debug (bool) – If True, run in debugging mode.
- semantic_header
Initial Setup
- instance_header
Initial Setup
- log_prediction
Initial Setup
- omit_format
Build handler’s
- create_headers()
- init_check()
All sanities check before TARDIS initialize prediction
- build_NN(NN: str)
- load_data(id_name: str | ndarray)
- predict_cnn(id_: int, id_name: str, dataloader)
- postprocess_CNN(id_name: str)
- preprocess_DIST(id_name: str)
- predict_DIST(id_: int, id_name: str)
- postprocess_DIST(id_, i)
- get_file_list()
- log_tardis(id_: int, i: str | ndarray, log_id: float)
- save_semantic_mask(i)
- save_instance_PC(i)
- class tardis_em.utils.predictor.Predictor(device: device, network: str | None = None, checkpoint: str | None = None, subtype: str | None = None, model_version: int | None = None, img_size: int | None = None, model_type: str | None = None, sigma: float | None = None, sigmoid=True, _2d=False, logo=True)
WRAPPER FOR PREDICTION
- Args:
device (torch.device): Device on which to predict. checkpoint (str, Optional): Local weights files. network (str, Optional): Optional network type name. subtype (str, Optional): Optional model subtype name. model_type (str, Optional): Optional model type name. model_version (int, Optional): Optional model version. img_size (int, Optional): Optional image patch size. sigmoid (bool): Predict output with sigmoid.
- predict(x: Tensor, y: Tensor | None = None, rotate=False) ndarray
General predictor.
- Parameters:
x (torch.Tensor) – Main feature used for prediction.
y (torch.Tensor, None) – Optional feature used for prediction.
rotate (bool) – Optional flag for CNN to output avg. From 4x 90* rotation
- Returns:
Predicted features.
- Return type:
np.ndarray