DIST -> Model

DIST layer wrapper

class tardis_em.dist_pytorch.model.layers.DistStack(pairs_dim: int, node_dim: int | None = None, num_layers=1, dropout=0, ff_factor=4, num_heads=8, structure='full')

WRAPPER FOR DIST LAYER

This wrapper defines a number of layer for the DIST.

Parameters:
  • node_dim (int) – Number of input dimensions for node features.

  • pairs_dim (int, optional) – Number of input dimensions for pairs features.

  • num_layers (int) – Number of GraphFormer layers. Min. 1.

  • dropout (float) – Dropout rate.

  • ff_factor (int) – Feed forward factor.

  • num_heads – Number of heads in multi-head attention.

  • structure (str) – Define DIST structure.

forward(edge_features: Tensor, node_features: Tensor | None = None, src_mask=None, src_key_padding_mask=None) Tuple[Tensor, Tensor]

Forward throw individual DIST layer.

Parameters:
  • edge_features (torch.Tensor) – Edge features as a tensor of shape [Batch x Length x Length x Channels].

  • node_features (torch.Tensor, optional) – Optional node features as a tensor of shape [Batch x Length x Channels].

  • src_mask (torch.Tensor, optional) – Optional source mask for masking over batches.

  • src_key_padding_mask (torch.Tensor, optional) – Optional mask use for feature padding.

Returns:

Updated graph representation.

Return type:

Tuple[torch.Tensor, torch.Tensor]

class tardis_em.dist_pytorch.model.layers.DistLayer(pairs_dim: int, node_dim: int | None = None, dropout=0, ff_factor=4, num_heads=8, structure='full')

MAIN DIST LAYER

DistLayer takes an embedded input and performs the paired bias self-attention (modified multi-head attention), followed by GeLu feed-forward normalization to update node-embedded information. Then update from the GeLu is summed with the edge feature map. As an output, DIST outputs an attention vector for given input that encodes attention between nodes and pairs (edges).

Parameters:
  • pairs_dim (int) – Output feature for pairs and nodes representation.

  • node_dim (int) – Input feature for pairs and nodes representation.

  • dropout (float) – Dropout rate.

  • ff_factor (int) – Feedforward factor used for GeLuFFN.

  • num_heads (int) – Number of heads in self-attention

  • structure (str) – Structure of layer [‘full’, ‘full_af’, ‘self_attn’, ‘triang’, ‘dualtriang’, ‘quad’].

update_nodes(h_pairs: Tensor, h_nodes: Tensor | None = None, src_mask=None, src_key_padding_mask=None) Tensor

Transformer on the input weighted by the pair representations

Input:

h_paris -> Batch x Length x Length x Channels h_nodes -> Length x Batch x Channels

Output:

h_nodes -> Length x Batch x Channels

Parameters:
  • h_pairs (torch.Tensor) – Edge features.

  • h_nodes (torch.Tensor) – Node features.

  • src_mask (torch.Tensor) – Attention mask used for mask over batch.

  • src_key_padding_mask (torch.Tensor) – Attention key padding mask.

Returns:

Updated node features.

Return type:

torch.Tensor

update_edges(h_pairs: Tensor, h_nodes: Tensor | None = None, mask: Tensor | None = None, src_key_padding_mask=None) Tensor

Update the edge representations based on nodes.

Input:

h_pairs -> Batch x Length x Length x Channels h_nodes -> Length x Batch x Channels

Output:

h_pairs -> Batch x Length x Length x Channels

Parameters:
  • h_pairs – Edge features.

  • h_nodes – Node features.

  • mask – Attention mask.

  • src_key_padding_mask – Attention key padding mask.

Returns:

Updated edge features.

Return type:

torch.Tensor

forward(h_pairs: Tensor, h_nodes: Tensor | None = None, src_mask=None, src_key_padding_mask=None) Tuple[Tensor, Tensor]

Wrapped forward throw all DIST layers.

Parameters:
  • h_pairs – Pairs representation.

  • h_nodes – Node feature representation.

  • src_mask – Optional attention mask.

  • src_key_padding_mask – Optional padding mask for attention.

Return type:

Tuple[torch,Tensor, torch.Tensor]

DIST graph update-modules

Collection of all modules wrapped around ‘torch.nn.Module’ used in the DIST model.

class tardis_em.dist_pytorch.model.modules.PairBiasSelfAttention(embed_dim: int, pairs_dim: int, num_heads: int, init_scaling=0.7071067811865475)

SELF-ATTENTION WITH EDGE FEATURE-BASED BIAS AND PRE-LAYER NORMALIZATION

Self-attention block that attends coordinate and image patches or RGB.

forward(query: Tensor, pairs: Tensor, attn_mask: Tensor | None = None, key_padding_mask: Tensor | None = None, need_weights: bool = False, need_head_weights: bool = False) Tuple[Tensor, Tensor] | Tensor

Forward attention over node features.

Parameters:
  • query (torch.Tensor) – Nodes features [Length x Batch x Channel].

  • pairs (torch.Tensor) – Edges features [Batch x Length x Length x Channel].

  • attn_mask (torch.Tensor) – Typically used to implement causal attention, where the mask prevents the attention from looking forward in time.

  • key_padding_mask (torch.Tensor) – Mask to exclude keys that are pads, of shape [Batch, src_len]

  • need_weights (bool) – If True, return the attention weights, and averaged overheads.

  • need_head_weights (bool) – If True, return the attention weights for each head.

Returns:

Attention tensor for node features.

Return type:

Union[Tuple[torch.Tensor, torch.Tensor], torch.Tensor]

class tardis_em.dist_pytorch.model.modules.ComparisonLayer(input_dim: int, output_dim: int, channel_dim=128)

COMPARISON MODULE BETWEEN PAIRS AND NODE INPUTS

This module converts pairs representation of dim (Length x Batch x Channels) into (Batch x Length x Length x Channels) representation that can be compared with node representation.

Parameters:
  • input_dim (int) – Input dimension as in pairs features.

  • output_dim (int) – Output dimension as in node features.

  • channel_dim (int) – Number of output channels.

forward(x: Tensor) Tensor

Forward node compatible layer.

Parameters:

x (torch.Tensor) – Node features after attention layer.

Returns:

Converted Node features to [Batch x Length x Length x Out_Channels] shape.

Return type:

torch.Tensor

class tardis_em.dist_pytorch.model.modules.TriangularEdgeUpdate(input_dim, channel_dim=128, axis=1)

TRIANGULAR UPDATE MODEL FOR NODES FEATURES

This module takes node feature representation and performs triangular attention for each point. Similar to in Alphafold2 approach.

Parameters:
  • input_dim (int) – Number of input channels.

  • channel_dim (int) – Number of output channels.

  • axis (int) – Indicate the axis around which the attention is given.

forward(z: Tensor, mask: Tensor | None = None) Tensor

Forward Triangular edge update.

Parameters:
  • z (torch.Tensor) – Edge features.

  • mask (torch.Tensor, optional) – Optional mask torch.Tensor layer.

Returns:

Updated edge features.

Return type:

torch.Tensor

class tardis_em.dist_pytorch.model.modules.QuadraticEdgeUpdate(input_dim, channel_dim=128, axis=1)

QUADRATIC UPDATE MODEL FOR NODES FEATURES

This module takes node feature representation and performs quadratic attention for each point. This is a modified Alphafold2 solution.

Parameters:
  • input_dim (int) – Number of input channels.

  • channel_dim (int) – Number of output channels.

  • axis (int) – Indicate the axis around which the attention is given.

forward(z: Tensor, mask: Tensor | None = None) Tensor

Forward Quadratic edge update.

Parameters:
  • z (torch.Tensor) – Edge features.

  • mask (torch.Tensor, optional) – Optional mask torch.Tensor layer.

Returns:

Updated edge features.

Return type:

torch.Tensor

class tardis_em.dist_pytorch.model.modules.MultiHeadAttention(embed_dim: int, num_heads: int, kdim=None, vdim=None, dropout=0.0, bias=True, add_bias_kv=False, add_zero_attn=False, self_attention=False, encoder_decoder_attention=False, init_scaling=0.7071067811865475)

MULTI-HEADED ATTENTION

See “Attention Is All You Need” for more details. Modified from ‘fairseq’.

Parameters:
  • embed_dim (int) – Number of embedded dimensions for node features.

  • num_heads (int) – Number of heads for multi-head attention.

  • kdim – Key dimensions.

  • vdim – Values dimensions.

  • dropout (float) – Dropout probability.

  • bias (bool) – If True add bias.

  • add_bias_kv (bool) – If True add bias for keys and values.

  • add_zero_attn (bool) – If True replace attention with a zero-out mask.

  • self_attention (bool) – If True self-attention is used.

  • encoder_decoder_attention (bool) – If True self-attention over encode/decoder is used.

  • init_scaling (float) – The initial scaling factor used for reset parameters.

reset_parameters()

Initial parameter and bias scaling.

forward(query: Tensor, key: Tensor | None = None, value: Tensor | None = None, key_padding_mask: Tensor | None = None, need_weights: bool = False, attn_mask: Tensor | None = None, before_softmax: bool = False, need_head_weights: bool = False) Tuple[Tensor, Tensor] | Tensor

Forward for MHA.

Parameters:
  • query (torch.Tensor) – Query input.

  • key (torch.Tensor, optional) – Key input.

  • value (torch.Tensor, optional) – Value input.

  • key_padding_mask (torch.Tensor, optional) – Mask to exclude keys that are pads, of shape (batch, src_len).

  • need_weights (bool, optional) – Return the attention weights, averaged overheads.

  • attn_mask (torch.Tensor, optional) – Typically used to implement causal attention.

  • before_softmax (bool, optional) – Return the raw attention weights and values before the attention softmax.

  • need_head_weights (bool, optional) – Return the attention weights for each head. Implies need_weights.

class tardis_em.dist_pytorch.model.modules.SelfAttention2D(embed_dim: int, num_heads: int, axis=None, dropout=0.0, max_size=4194304)

COMPUTE SELF-ATTENTION OVER 2D INPUT

Perform self-attention over 2D input for node features using multi-head attention.

Parameters:
  • embed_dim (int) – Number of embedded dimensions for node features.

  • num_heads (int) – Number of heads for multi-head attention.

  • axis (int) – Indicate the axis over which the attention is performed.

  • dropout (float) – Dropout probability.

  • max_size (int) – Maximum size of the batch.

forward(x: Tensor, padding_mask=None) Tensor

Forward self-attention over 2D-edge features.

Reshape X depending on the axis attention mode! flatten over rows and cols for full N*M*N*M attention.

Parameters:
  • x (torch.Tensor) – Edge feature self-attention update. [num_rows X num_cols X batch_size X embed_dim].

  • padding_mask (torch.Tensor) – Optional padding mask. [batch_size X num_rows X num_cols].

class tardis_em.dist_pytorch.model.modules.GeluFeedForward(input_dim: int, ff_dim: int)

FEED-FORWARD TRANSFORMER MODULE USING GELU

Input: Batch x … x Dim Output: Batch x … x Dim

Parameters:
  • input_dim (int) – Number of input dimensions for linear transformation.

  • ff_dim (int) – Number of feed-forward dimensions in linear transformation.

forward(x: Tensor) Tensor

Forward Gelu normalized tensor.

Parameters:

x (torch.Tensor) – Any Tensor of shape [B x … x D].

Returns:

Gelu normalized tensor.

Return type:

torch.Tensor

Feature embedding

Collection of classes used for Node and Edge embedding.

  • Node embedding is composed of RGB value or optionally flattened image patches.

    The node embedding use only ‘nn.Linear’ to embedding (n) dimensional feature object. And output [Batch x Feature Length x Channels]

  • Edge embedding is composed directly from the (n)D coordinate values, where n

    is av dimension. The edge embedding computes ‘cdist’ operation on coordinate features and produces a distance matrix for all points in the given patch. The distance matrix is then normalized with an exponential function optimized with the sigma parameter. This exponential function normalize distance matrix by putting higher weight on the lower distance value (threshold with sigma). This allows the network to embed distance preserving SO(n) invariance for translation and rotation.

class tardis_em.dist_pytorch.model.embedding.NodeEmbedding(n_in: int, n_out: int, sigma=1)

NODE FEATURE EMBEDDING

Parameters:
  • n_in (int) – Number of input features.

  • n_out (int) – Number of output features.

forward(input_node: Tensor | None = None) Tensor | None

Forward node feature embedding.

Input: Batch x Length x Dim Output: Batch x Length x Dim

Parameters:

input_node (torch.Tensor) – Node features (RGB or image patches).

Returns:

Embedded features.

Return type:

torch.Tensor

class tardis_em.dist_pytorch.model.embedding.EdgeEmbedding(n_out: int, sigma: int | float | list)

COORDINATE EMBEDDING INTO GRAPH

Set of coordinates is used to build distance matrix which is then normalized using negative parabolic function.

Input: Batch x Length x Dim Output: Batch x Length x Length x Dim

Parameters:
  • n_out (int) – Number of features to output.

  • sigma (int, optional tuple) – Sigma value for an exponential function is used to normalize distances.

forward(input_coord: Tensor) Tensor

Forward node feature embedding.

Parameters:

input_coord (torch.Tensor) – Edge features ([N, 2] or [N, 3] coordinates array).

Returns:

Embedded features.

Return type:

torch.Tensor