Global -> Predictor Handlers
- class tardis_em.utils.predictor.GeneralPredictor(predict: str, dir_s: str | tuple[ndarray] | ndarray, binary_mask: bool, output_format: str, patch_size: int, convolution_nn: str, cnn_threshold: str, dist_threshold: float, points_in_patch: int, predict_with_rotation: bool, instances: bool, device_s: str, debug: bool, checkpoint: list | None = None, model_version: int | None = None, correct_px: float | None = None, normalize_px: float | None = None, amira_prefix: str | None = None, filter_by_length: int | None = None, connect_splines: int | None = None, connect_cylinder: int | None = None, amira_compare_distance: int | None = None, amira_inter_probability: float | None = None, tardis_logo: bool = True, continue_b: bool = False)
Summary of what the class does.
The GeneralPredictor class is designed for handling image-based predictions using various neural network architectures. It provides functionalities to set up the environment, preprocess data, run predictions with or without rotation, and post-process the results for further usage or analysis. This class also includes methods to handle different file formats, apply various normalization techniques, and output results in specified formats. The class is modular enough to work with instances, semantic predictions, and custom settings based on the input parameters and configurations.
- semantic_header
Initial Setup
- instance_header
Initial Setup
- log_prediction
Initial Setup
- omit_format
Build handler’s
- create_headers()
Creates ASCII headers and initializes logging information for a spatial graph prediction process. The headers include project details, directory setup, neural network configurations for semantic and instance segmentation, and other metadata. It verifies paths, logs prediction states appropriately, and dynamically determines the model version depending on the predictive needs and available checkpoints.
- Raises:
None
- Parameters:
self.dir – Directory path for file output or processing. Default is retrieved current working directory if unset.
self.output_format – Format of the output data produced during prediction, e.g. as,send LabelSetsivet nn
- init_check()
Perform initialization and validation checks for a segmentation prediction task.
This method performs several checks and initializations before executing a segmentation task. It first validates whether the requested prediction type is supported. If support is enabled for a TARDIS logo, additional checks are performed for user configurations, and error messages are displayed if any invalid settings are detected. Subsequently, the log output initialization begins with a defined title according to the segmentation type and configuration.
This function also validates the chosen output format to ensure its compatibility with the system (e.g., checking for unsupported machine types like ARM64 or ensuring that at least one valid output format is selected). Invalid configurations or unsupported settings result in terminating the execution with proper error messaging.
- Parameters:
self – Instance of the class to which the initialization belongs.
- Raises:
AssertionError – If the requested segmentation type is unsupported.
TardisError – If any invalid configuration is detected based on TARDIS-specific rules (e.g., invalid output format, machine-type dependencies).
- build_NN(NN: str)
Builds the neural network and distance prediction modules based on the specified neural network (NN) type. Depending on the NN type, the appropriate configurations and pre-trained weights are loaded for CNN and DIST networks. This method supports multiple NN types, including Actin, Microtubule, Membrane, and General models. Additionally, configurations for 2D, 3D, and other specialized models are supported.
- Parameters:
NN (str) – A string denoting the neural network type. Supported types include “Actin”, “Microtubule”, “Microtubule_tirf”, “Membrane2D”, “Membrane”, or any type starting with “General”.
- load_data(id_name: str | ndarray)
Loads and processes image data or point cloud data from a specified file or array. Depending on the input type, the function determines if the data is an AmiraMesh 3D ASCII file, a general image file, or a preloaded array. It performs normalization, sanity checks, and prepares the data for further processing.
- Parameters:
id_name – Specifies either the path to the file to be loaded or an already loaded numpy array. Can be a string filename or a numpy array.
- Raises:
AssertionError –
If the Amira Spatial Graph has dimensions other than 4.
If the loaded image’s dtype is not float32 after normalization.
If the processed binary mask dtype is not int8 or uint8.
If tardis_logo is True and any of the aforementioned conditions fail.
SystemExit –
If certain errors occur during processing and tardis_logo is True.
- Returns:
None
- predict_cnn(id_i: int, id_name: str, dataloader)
Predict images using a Convolutional Neural Network (CNN) with options for image rotation and progress tracking integrated with the Tardis progress bar interface.
This method iterates over a dataloader to retrieve images, predicts their output using a CNN model (optionally with four 90° rotations), and writes the output to .tif files. The method supports progress tracking with Tardis interface updates and dynamically optimizes the progress bar refresh rate based on initial iteration timing.
- Parameters:
id_i (int) – Integer representing the ID of the image being processed.
id_name (str) – The name of the image being processed.
dataloader – An iterable object that provides access to image data and corresponding names.
- Returns:
None
- predict_cnn_napari(input_t: Tensor, name: str)
Predicts an output using the CNN model on the provided input tensor, saves the result in TIFF format, and returns the output tensor.
This function performs a prediction using the Convolutional Neural Network (CNN) model on the given input tensor. The result is saved as a TIFF file using the provided file name in the specified output directory.
- Parameters:
input_t – Input tensor on which prediction needs to be performed, should follow the required input format for the CNN model.
name – Name of the output file to save the predicted result.
- Returns:
The output tensor resulting from the CNN prediction, after processing with the input tensor.
- postprocess_CNN(id_name: str)
Post-processes the CNN prediction by stitching predicted image patches, restoring the original pixel size, applying a threshold, and optionally saving the results in the specified format. This function also performs clean-up of temporary directories after processing.
- Parameters:
id_name (str) – Identifier of the input data used to track and log the processed output.
- Returns:
None
- preprocess_DIST(id_name: str)
Preprocesses a given dataset identifier (id_name) to produce and manipulate high-density and low-density point clouds, typically used for structural or image data analysis. Depending on the prediction type and the presence of an Amira image, this function either post-processes predicted image patches to construct point clouds using provided processing utilities or applies optimization methods like voxel down-sampling for refining existing point clouds.
- Parameters:
id_name (str) – The unique dataset identifier used in debugging and processing.
- Returns:
None
- predict_DIST(id_i: int, id_name: str)
Predicts DIST graphs for the given coordinates using the provided DIST prediction model. The method processes coordinate data in chunks and updates the progress bar if visual feedback is enabled. The progress bar reflects the current task, the percentage of completion, and relevant details of the segmentation process. The function ensures predictive modeling for the total images with a controlled iteration mechanism.
- Parameters:
id_i – An integer representing the identifier of the image to be processed.
id_name – A string denoting the name of the image corresponding to the ID.
- Returns:
A list of predicted graph representations for each coordinate dataset.
- postprocess_DIST(id_i, i)
Processes and postprocesses data based on given inputs.
This function adjusts the pixel data, logs information based on specific prediction types, and handles the transformation of graphs to segments. Additionally, updates the Tardis progress bar to provide task-specific updates.
- Parameters:
id_i (int) – Identification number for the current image being processed.
i (int) – Index of the current image.
- Returns:
None, modifies instance attributes based on the processing steps.
- Return type:
None
- get_file_list()
Retrieves and processes a list of files to be used for prediction based on the directory or input provided. Filters files according to specified formats, handles input as either single directories or lists/tuples, and logs the processed files. Additionally, performs setup tasks for the prediction workflow, including generating paths for output directories and checking prediction readiness.
- Parameters:
self (Object containing attributes: dir, available_format, omit_format, continue_, tardis_logo, tardis_progress, output, am_output, predict_list, device, title.) – The instance of the object that contains attributes such as directory paths, filtering formats, continuation settings, and progress handlers for processing.
- Raises:
AssertionError – Raised when no recognizable files exist in the provided directory structure, based on specified formats, and appropriate progress handling or error logging is not enabled.
Exception – Additional exceptions may occur if environmental setup or file reading operations fail, depending on external utilities used and provided paths.
- Returns:
None
- log_tardis(id_i: int, i: str | ndarray, log_id: float)
Logs various states and processing stages of the TARDIS application based on the provided log_id and input data. Depending on the log ID and input type, it generates log messages showcasing the progress of various computational tasks and updates a progress bar accordingly.
- Parameters:
id_i (int) – Identifier for the current image being processed in the list of input images.
i (Union[str, numpy.ndarray]) – Input data for logging, representing either a string description or a numpy array. If a numpy array is passed, it is converted into a string representation.
log_id (float) – Numeric identifier specifying the current task or processing stage. Determines the type of logging information generated, and can optionally include different subtasks.
- Returns:
None
- save_semantic_mask(i)
Saves a semantic mask prediction in a specified format and logs the prediction details. Supported formats include MRC, TIF, AM, and NPY. The function also updates a log file with prediction details and writes the semantic mask output to the appropriate directory in the chosen format.
- Parameters:
i (str) – The input file name used to derive the output file name.
- Raises:
IOError – If there are issues writing the output files or logs.
ValueError – If the specified output format is unsupported.
- save_instance_PC(i, overwrite_save=False)
Save processed prediction instance data to disk in various formats. This method handles logging, filtering, and outputting of prediction data based on the specified output format or other input parameters. It supports multiple output types like CSV, MRC, TIF, AM, STL, and NPY for different prediction types such as “Actin”, “Microtubule”, “Membrane”, and general filaments or objects. Depending on the output format, it can further refine data through filtering, save semantic masks, or interface with Amira for spatial graph comparison and exportation.
- Parameters:
i (str) – The identifier for the instance being saved.
overwrite_save (bool) – A flag to denote whether an existing file should be overwritten. Defaults to False.
- Returns:
None
- class tardis_em.utils.predictor.Predictor(device: device, network: str | None = None, checkpoint: str | None = None, subtype: str | None = None, model_version: int | None = None, img_size: int | None = None, model_type: str | None = None, sigma: float | None = None, sigmoid=True, _2d=False, logo=True)
Handles model prediction workflows for neural networks, including loading pretrained weights, configuring model architectures dynamically, and predicting data. The purpose of this class is to abstract away the complexities of network setup and enhance user focus on utilizing pretrained networks, streamlining predictions.
The class ensures compatibility with various deep learning frameworks, supports CNNs and distance-based networks dynamically, and provides inference-time adjustments like rotations for robustness.
- predict(x: Tensor, y: Tensor | None = None, rotate=False) ndarray
Predicts an output based on given input data using a trained model. The method supports various modes of operation, including computing outputs for specific network types or applying rotations to the input for models with two-dimensional or three-dimensional spatial components. This function can handle data provided as PyTorch tensors or convert NumPy arrays into tensors internally. Outputs are generated either in a transformed or direct form according to the network type and additional parameters provided.
- Parameters:
x (torch.Tensor) – Input tensor containing the primary data for prediction. Expected to have shapes conforming to the model’s requirements.
y (Optional[torch.Tensor]) – Optional secondary input tensor containing additional data or node features. Defaults to None. Expected to match the compatible input feature dimensions of the model if provided.
rotate (bool) – Boolean flag indicating whether rotations should be applied to the input tensor to generate averaged transformed outputs. Defaults to False.
- Returns:
The processed output from the model, structured as a NumPy array. Its dimensions and content correspond to the defined task of the provided trained model. Output is adjusted according to whether rotations were applied or if specific computations are required by the network type.
- Return type:
np.ndarray