DIST -> Model

DIST layer wrapper

class tardis_em.dist_pytorch.model.layers.DistStack(pairs_dim: int, node_dim: int | None = None, num_layers=1, dropout=0, ff_factor=4, num_heads=8, structure='full')

A neural network module for applying a stack of DistLayer layers to process edge and optional node features, typically for graph-based tasks.

The DistStack class is a part of a Transformer-like architecture designed for graphs, where the stack is composed of multiple DistLayer layers. It provides an easy way to apply the stack sequentially on graph-related data, with the ability to handle optional node features and customizable edge masks for input feature attention.

forward(edge_features: Tensor, node_features: Tensor | None = None, src_mask=None, src_key_padding_mask=None) Tuple[Tensor, Tensor]

Processes input edge features and optionally node features through multiple layers, applying transformations to generate updated tensors for nodes and edges.

The forward method iterates through all available layers, transforming the input features. Each layer processes the given feature sets (nodes and edges), along with optional masks (src_mask and src_key_padding_mask) as part of the computation. The output is updated node features and edge features after the transformations.

Parameters:
  • edge_features – Input edge features as a tensor.

  • node_features – Optional input node features, provided as a tensor. Defaults to None.

  • src_mask – Optional source mask for attention-based computations.

  • src_key_padding_mask – Optional key padding mask for attention-based computations.

Returns:

A tuple consisting of two tensors: - Updated node features - Updated edge features

class tardis_em.dist_pytorch.model.layers.DistLayer(pairs_dim: int, node_dim: int | None = None, dropout=0, ff_factor=4, num_heads=8, structure='full')

DistLayer class is designed for hierarchical processing of node and pair representations using multi-head attention and feed-forward mechanisms. The class supports various structures for interaction layers such as triangular, quadratic, dual triangular updates, or full attention-based architectures. It extends PyTorch’s nn.Module and incorporates mechanisms to handle input features, as well as dropout for regularization.

DistLayer allows for versatile interaction between node and pair features through specific update routines that depend on the chosen structure.

update_nodes(h_pairs: Tensor, h_nodes: Tensor | None = None, src_mask=None, src_key_padding_mask=None) Tensor

Updates the node representations based on pair embeddings and self-attention mechanism. This function combines the provided node embeddings and pair embeddings through an attention mechanism and applies a feed-forward network for further transformation. The updated node representation is returned.

Parameters:
  • h_pairs – Pairwise embeddings. A tensor that provides information about pair dependencies between nodes.

  • h_nodes – Optional initial node embeddings. If provided, these will be updated using the attention mechanism and feed-forward network.

  • src_mask – Attention mask used during the attention computation to indicate valid positions. This allows selective attention and prevents unwanted information flow.

  • src_key_padding_mask – Key padding mask used to indicate valid and invalid tokens or nodes for each sample in a batch. Useful during variable-length sequence handling.

Returns:

Updated node representations after attention computation and feed-forward network application.

Return type:

torch.Tensor

update_edges(h_pairs: Tensor, h_nodes: Tensor | None = None, mask: Tensor | None = None, src_key_padding_mask=None) Tensor

Updates the edge features in a graph based on the chosen structural configuration and optionally includes node features or masking conditions. The method modifies the input edge features by applying a variety of attentions, feature updates, and feedforward transformations, depending on the structure type.

Parameters:
  • h_pairs – Tensor containing initial edge features, of shape (batch_size, num_nodes, num_nodes, feature_dim).

  • h_nodes – Optional tensor containing node features, of shape (batch_size, num_nodes, feature_dim). If provided, the function incorporates these features into the edge features during the update process.

  • mask – Optional tensor of shape (batch_size, num_nodes, num_nodes). Acts as an attention mask or structural constraint for the feature update process.

  • src_key_padding_mask – Optional tensor of shape (batch_size, num_nodes). If provided, it is used to generate a mask to ignore certain nodes by expanding it along the necessary dimensions.

Returns:

A tensor of the same shape as h_pairs, representing the updated edge features after applying the selected transformations.

forward(h_pairs: Tensor, h_nodes: Tensor | None = None, src_mask=None, src_key_padding_mask=None) Tuple[Tensor, Tensor]

Processes and updates node and edge features using provided input tensors. The function advances the transformation of node and edge features by applying update operations on input tensors, including optional masking of source inputs.

Parameters:
  • h_pairs – Tensor representing the edge features in the graph.

  • h_nodes – Tensor representing the node features in the graph. It can be optionally None, in which case no node updates will be performed.

  • src_mask – Optional mask applied at the source level during node update.

  • src_key_padding_mask – Optional mask to specify which elements should be ignored in the computation, typically used for padded sequences.

Returns:

Tuple of two tensors, the updated node features and the updated edge features in the graph.

DIST graph update-modules

Collection of all modules wrapped around ‘torch.nn.Module’ used in the DIST model.

class tardis_em.dist_pytorch.model.modules.PairBiasSelfAttention(embed_dim: int, pairs_dim: int, num_heads: int, init_scaling=0.7071067811865475)

Implements self-attention mechanism that incorporates pairwise features for multi-head attention. This class is designed for scenarios where attention bias is calculated based on edge features alongside the node embeddings. The module allows adjustable parameters including the number of heads, embedding dimensions, and initial scaling factors.

The attention mechanism applies normalization and linear projections to both node and edge features, calculates attention weights, and uses those weights to compute weighted combinations of feature representations.

forward(query: Tensor, pairs: Tensor, attn_mask: Tensor | None = None, key_padding_mask: Tensor | None = None, need_weights: bool = False, need_head_weights: bool = False) Tuple[Tensor, Tensor] | Tensor

Computes the forward pass of a multi-head attention module with additional pairwise positional weighting. This function first processes the query tensor through normalization and projection, prepares the attention weights using pair and masked components, applies the attention and then combines attended outputs with pairwise positional contributions. It optionally provides attention weights for inference or analysis.

This method supports key padding masks for excluding irrelevant tokens from attention calculations, as well as optional attention masks for customizing attention strength in multi-head contexts. The function includes the necessary transformations to prepare query, key, and value tensors for batched multi-head operations.

Parameters:
  • query – The input tensor of shape (target length, batch size, embedding dimension).

  • pairs – Pairwise positional tensor of shape (batch size, sequence length, sequence length, number of heads).

  • attn_mask – Optional mask tensor of shape (target length, source length) to apply additional additive masking to the attention weights.

  • key_padding_mask – Optional binary tensor of shape (batch size, source length) indicating padded positions.

  • need_weights – Boolean flag to indicate whether to return the attention weights alongside the output tensor.

  • need_head_weights – Boolean flag to indicate whether individual head weights should be returned. Overrides need_weights when True.

Returns:

The output tensor of shape (target length, batch size, embedding dimension). If need_weights is True, also returns attention weights; the shape of weights depends on the need_head_weights parameter.

class tardis_em.dist_pytorch.model.modules.ComparisonLayer(input_dim: int, output_dim: int, channel_dim=128)

Defines the ComparisonLayer class, which is used to process and transform node features into a specific tensor shape. The class includes normalization and linear transformations for enhanced data manipulation.

This layer takes as input node features and performs a series of transformations to generate a tensor that is compatible with specific downstream tasks in deep learning models. It leverages PyTorch’s nn.Module for implementing customized neural network layers.

forward(x: Tensor) Tensor

Performs a forward pass through the computational graph built for the operation. The method applies specific transformations on the input tensor such as transposition, normalization, and linear transformations to generate the final output. The computation includes element-wise multiplication and subtraction of transformed tensors, followed by additional linear transformations.

Parameters:

x (torch.Tensor) – Input tensor with shape (Batch, Length, Feature_Dimensions)

Returns:

Transformed tensor with shape (Batch, Length, Length, Out_Channels)

Return type:

torch.Tensor

class tardis_em.dist_pytorch.model.modules.TriangularEdgeUpdate(input_dim, channel_dim=128, axis=1)

The TriangularEdgeUpdate class implements a neural network module that performs triangular edge updates for edge feature tensors. This is primarily designed for processing relational or structural data, where edge updates between nodes in a triangular relationship must be computed.

The class takes input and processes it using linear layers, layer normalization, and gating mechanisms. It supports optional masking and performs updates based on a defined axis. The resulting features are computed via einsum operations, allowing for flexible interaction across specified dimensions.

forward(z: Tensor, mask: Tensor | None = None) Tensor

Processes the input tensor z using gated mechanisms and performs tensor operations to produce the output. It applies normalization, gating, and optionally masks certain elements based on the provided mask tensor. The computation involves einsum operations and a final gating mechanism for the output tensor.

Parameters:
  • z – Input tensor with dimensions suitable for processing by the forward method.

  • mask – Optional tensor used to mask specific elements of the input tensor during processing. If provided, elements are masked where the mask tensor indicates.

Returns:

Processed tensor after applying the normalization, gating, masking (if applicable), and tensor manipulation operations.

class tardis_em.dist_pytorch.model.modules.QuadraticEdgeUpdate(input_dim, channel_dim=128, axis=1)

A neural network module for quadratic edge updates with gated linear units and layer normalization.

This module processes edge features through a series of transformations, including gated linear layer calculations, layer normalization, and tensor manipulation using the einsum operation. It supports masking for optional selective computation over input tensors and provides flexible dimensional configurations.

The input is normalized first and then transformed via multiple linear and gating operations. The outputs are combined through einsum-based operations based on the configured axis, enabling contextual computations for edge features. The final results are subjected to normalization and linear transformations to produce the output tensor.

forward(z: Tensor, mask: Tensor | None = None) Tensor

Computes the forward pass of the layer by applying gated transformations to the input tensor, optionally considering an input mask. It applies a series of linear transformations and element-wise sigmoid activations to generate intermediate tensors, followed by a tensor contraction using Einstein summation notation for specific axes. The output is further transformed by gating and normalization operations.

Parameters:
  • z – The input tensor with shape (B x L x D) where B is the batch size, L is the sequence length, and D is the input dimension.

  • mask – An optional binary mask tensor with shape (B x L x L) where masked positions are indicated with 1 and unmasked positions with 0. If provided, it will nullify specific computations in the output.

Returns:

A tensor with shape (B x L x L x O), representing the transformed outputs after gating, linear transformation, normalization, and contraction operations.

class tardis_em.dist_pytorch.model.modules.MultiHeadAttention(embed_dim: int, num_heads: int, kdim=None, vdim=None, dropout=0.0, bias=True, add_bias_kv=False, add_zero_attn=False, self_attention=False, encoder_decoder_attention=False, init_scaling=0.7071067811865475)

Represents a Multi-Head Attention (MHA) mechanism that enables self-attention or cross-attention in neural networks, primarily used in transformers. MHA facilitates attention over multiple heads, allowing the model to focus on different parts of the sequence simultaneously. This module supports various options such as encoder-decoder attention, dropout, bias customization, and scaling initialization.

Multi-Head Attention is a key building block for many Natural Language Processing (NLP) and computer vision tasks, enabling the model to capture contextual dependencies efficiently.

The module expects inputs in the form of query, key, and value tensors and provides attention outputs for further processing in the neural network.

reset_parameters()

Initializes model parameters with specific initialization methods.

This method resets the weights and biases of the key, value, query, and output projection layers in the neural network to ensure consistent training results and proper initialization of the model. It uses Xavier uniform initialization for the weights and constant initialization for specific components such as biases. If optional biases bias_k or bias_v exist, they are also reset to constant values.

Raises:

TypeError – An exception is raised if the model components are not properly initialized or invalid attribute references to occur.

forward(query: Tensor, key: Tensor | None = None, value: Tensor | None = None, key_padding_mask: Tensor | None = None, need_weights: bool = False, attn_mask: Tensor | None = None, before_softmax: bool = False, need_head_weights: bool = False) Tuple[Tensor, Tensor] | Tensor

Computes the forward pass through a multi-head attention mechanism with support for self-attention, encoder-decoder attention, and custom input masks. The function supports options to include or exclude weights, apply dropout, and handle special cases such as biases and padding.

Parameters:
  • query – Tensor representing the input sequence to compute attention for, with dimensions (target length, batch size, embedding dimension).

  • key – Optional tensor representing the key in the attention mechanism, with dimensions (source length, batch size, embedding dimension). If None, it assumes self-attention or other specialized attention types.

  • value – Optional tensor representing the value in the attention mechanism, with dimensions (source length, batch size, embedding dimension). If None, it assumes self-attention or other specialized attention types.

  • key_padding_mask – Optional boolean tensor used to specify padding on certain input positions, with dimensions (batch size, source length). Non-zero values denote positions to be masked.

  • need_weights – Boolean flag indicating whether the function should return attention weights along with the computed output.

  • attn_mask – Optional tensor representing a mask to restrict attention to specific positions, with dimensions (target length, source length). Typical for causal masking in transformer models.

  • before_softmax – Boolean flag to determine if attention weights are returned before or after the softmax operation is applied.

  • need_head_weights – Boolean flag indicating whether attention weights per head are needed (as opposed to aggregated weights across heads).

Returns:

A tuple consisting of the attention output (tensor with dimensions (target length, batch size, embedding dimension)) and, if requested, the attention weights (as a tensor with details depending on the head weights selection). If weights are not requested, only the attention output is returned.

class tardis_em.dist_pytorch.model.modules.SelfAttention2D(embed_dim: int, num_heads: int, axis=None, dropout=0.0, max_size=4194304)

Implements 2D self-attention mechanism.

This class extends the MultiHeadAttention module to perform self-attention specifically over 2D edge features. It provides functionality to reshape the input features depending on the axis mode (rows or columns) and enables efficient computation of attention by considering memory constraints via batching.

forward(x: Tensor, padding_mask=None) Tensor

Processes input tensor through a 2D self-attention mechanism and adjusts its shape based on the specified axis. Handles padding masks if provided, allowing optional batching for memory-efficient computation when the attention matrix size exceeds a specified maximum.

Parameters:
  • x (torch.Tensor) – Input tensor containing the features to be processed using 2D self-attention.

  • padding_mask (torch.Tensor or None) – Optional mask used to ignore certain positions during the attention computation.

Returns:

Transformed tensor with the same spatial dimensions as the input but adjusted for the attention weights applied.

Return type:

torch.Tensor

class tardis_em.dist_pytorch.model.modules.GeluFeedForward(input_dim: int, ff_dim: int)

Applies a GELU-based feedforward transformation to the input tensor.

GeluFeedForward is a neural network module that normalizes the input tensor and applies a two-layer feedforward network with GELU activation. This is commonly used in transformer architectures or other deep learning models to enhance the representational power of the model.

forward(x: Tensor) Tensor

Applies a forward pass through the sequence of operations which includes normalization, two linear transformations, and the application of GELU activation function.

This method processes the input tensor by first normalizing it using the self.norm function. It then applies the first linear transformation (self.linear1), followed by the GELU activation, and finally the second linear transformation (self.linear2). The output is returned as a transformed tensor.

Parameters:

x (torch.Tensor) – Input tensor to the forward pass.

Returns:

Transformed tensor after normalization, linear transformations, and activation function.

Return type:

torch.Tensor

Feature embedding

Collection of classes used for Node and Edge embedding.

  • Node embedding is composed of RGB value or optionally flattened image patches.

    The node embedding use only ‘nn.Linear’ to embedding (n) dimensional feature object. And output [Batch x Feature Length x Channels]

  • Edge embedding is composed directly from the (n)D coordinate values, where n

    is av dimension. The edge embedding computes ‘cdist’ operation on coordinate features and produces a distance matrix for all points in the given patch. The distance matrix is then normalized with an exponential function optimized with the sigma parameter. This exponential function normalize distance matrix by putting higher weight on the lower distance value (threshold with sigma). This allows the network to embed distance preserving SO(n) invariance for translation and rotation.

class tardis_em.dist_pytorch.model.embedding.NodeEmbedding(n_in: int, n_out: int, sigma=1)

NodeEmbedding class for transforming node features using either a learned linear mapping or a randomized cosine transformation, depending on the value of sigma.

This class is used to embed input node features (e.g., RGB values or image patches) into a desired output dimension. If sigma is 0, a trainable linear layer is applied. Otherwise, a fixed random projection is utilized with a cosine activation.

forward(input_node: Tensor | None = None) Tensor | None

Performs the forward pass of the module. If an input tensor is provided, it processes it through the defined linear layer if available, otherwise applies a specific transformation involving the cosine function and linear operation.

Parameters:

input_node (Optional[torch.Tensor]) – Input tensor to process. If None, returns None.

Returns:

A tensor where the processed input is transformed and scaled into the range [0, 1], or None if no input was provided.

Return type:

Optional[torch.Tensor]

class tardis_em.dist_pytorch.model.embedding.EdgeEmbedding(n_out: int, sigma: int | float | list)

EdgeEmbedding layer encapsulates the functionality of computing edge-based representations within a graph, leveraging Gaussian radial basis functions (RBF) as embedding mechanisms. This module enables the transformation of edge distances into either fixed-dimensional encodings or dynamically adjustable encodings based on input configurations.

EdgeEmbedding computes Gaussian kernel representations of edge distances and optionally applies a linear transformation to produce embeddings of a specified dimensionality. It supports both fixed sigma values (as single or iterable ranges) and learns to dynamically adjust their dimensions via linear layers.

forward(input_coord: Tensor) Tensor
Computes a transformation of pairwise distances between input coordinates

and applies an optional linear transformation. The specific transformation depends on the configuration, such as whether a fixed sigma value or varying _range values are provided. Handles missing values robustly by replacing NaN distance values with zeros.

Parameters:

input_coord (torch.Tensor) – Tensor of shape (…, L, D) where L is the number of coordinate points and D is the dimensionality of each point. It represents the input feature coordinates for which pairwise distances will be computed.

Returns:

Tensor of shape (…, L, L, K) where K is either 1 or the length of the _range attribute. Represents the transformed pairwise distances. If self.linear is provided, the returned tensor is further transformed using the linear layer.

Return type:

torch.Tensor