Module pysimt.layers.transformers.cross_attention_sublayer_mm_hier
Expand source code
import torch
from ..attention import ScaledDotAttention
from .base_sublayer import BaseSublayer
class HierarchicalMMCrossAttentionSublayer(BaseSublayer):
def __init__(self, model_dim, n_heads, dropout=0.1, attn_dropout=0.0, is_pre_norm=False, n_hier_heads=8):
"""
Creates a HierarchicalMMCrossAttentionSublayer.
:param model_dim: The model dimensions.
:param n_heads: The number of attention heads.
:param dropout: The dropout rate for the residual connection.
:param is_pre_norm: Whether the layer type is pre_norm. Default: True.
"""
super().__init__(model_dim, dropout, is_pre_norm)
self.attn_txt = ScaledDotAttention(model_dim, n_heads, attn_dropout)
self.attn_img = ScaledDotAttention(model_dim, n_heads, attn_dropout)
self.attn_hierarchical = ScaledDotAttention(model_dim, n_hier_heads, attn_dropout)
def forward(self, query, key_txt, value_txt, mask_txt, key_img, value_img, mask_img=None):
"""
Performs a forward pass over the HierarchicalMMCrossAttentionSublayer.
:param query: The query. For encoder-decoder attention, it is the output from the previous decoder layer.
:param key_txt: The key for the textual modality. For encoder-decoder attention, it is the output from the encoder.
:param value_txt: The value for the textual modality. For encoder-decoder attention, it is the output from the encoder.
:param mask_txt: The mask. For encoder-decoder attention, it is the output from the encoder.
:param key_img: The key for the visual modality.
:param value_img: The value for the visual modality.
:param mask_img: The visual features mask.
:return: The output of the CrossAttentionSublayer.
"""
residual = query
query = self.apply_pre_norm_if_needed(query)
attn_txt, attn_weights_txt = self.attn_txt((query, key_txt, value_txt, mask_txt))
attn_img, attn_weights_img = self.attn_img((query, key_img, value_img, mask_img))
attn_combined, combined_attn_weights = self._fuse_contexts(query, attn_img, attn_txt)
out = self.apply_residual(residual, attn_combined)
attn = self.apply_post_norm_if_needed(out)
return attn, {'txt': attn_weights_txt, 'img': attn_weights_img, 'hier': combined_attn_weights}
def _fuse_contexts(self, query, attn_img, attn_txt, combined_mask=None):
seq_len, batch_size, model_dim = query.shape
combined_key_value = torch.stack((attn_txt, attn_img), dim=0).view(2, -1, model_dim)
combined_attn, combined_attn_weights = self.attn_hierarchical((query.view(1, -1, model_dim), combined_key_value,
combined_key_value, combined_mask))
return combined_attn.view(seq_len, batch_size, model_dim), combined_attn_weights
Classes
class HierarchicalMMCrossAttentionSublayer (model_dim, n_heads, dropout=0.1, attn_dropout=0.0, is_pre_norm=False, n_hier_heads=8)
-
Base class for all neural network modules.
Your models should also subclass this class.
Modules can also contain other Modules, allowing to nest them in a tree structure. You can assign the submodules as regular attributes::
import torch.nn as nn import torch.nn.functional as F class Model(nn.Module): def __init__(self): super(Model, self).__init__() self.conv1 = nn.Conv2d(1, 20, 5) self.conv2 = nn.Conv2d(20, 20, 5) def forward(self, x): x = F.relu(self.conv1(x)) return F.relu(self.conv2(x))
Submodules assigned in this way will be registered, and will have their parameters converted too when you call :meth:
to
, etc.:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool
Creates a HierarchicalMMCrossAttentionSublayer. :param model_dim: The model dimensions. :param n_heads: The number of attention heads. :param dropout: The dropout rate for the residual connection. :param is_pre_norm: Whether the layer type is pre_norm. Default: True.
Expand source code
class HierarchicalMMCrossAttentionSublayer(BaseSublayer): def __init__(self, model_dim, n_heads, dropout=0.1, attn_dropout=0.0, is_pre_norm=False, n_hier_heads=8): """ Creates a HierarchicalMMCrossAttentionSublayer. :param model_dim: The model dimensions. :param n_heads: The number of attention heads. :param dropout: The dropout rate for the residual connection. :param is_pre_norm: Whether the layer type is pre_norm. Default: True. """ super().__init__(model_dim, dropout, is_pre_norm) self.attn_txt = ScaledDotAttention(model_dim, n_heads, attn_dropout) self.attn_img = ScaledDotAttention(model_dim, n_heads, attn_dropout) self.attn_hierarchical = ScaledDotAttention(model_dim, n_hier_heads, attn_dropout) def forward(self, query, key_txt, value_txt, mask_txt, key_img, value_img, mask_img=None): """ Performs a forward pass over the HierarchicalMMCrossAttentionSublayer. :param query: The query. For encoder-decoder attention, it is the output from the previous decoder layer. :param key_txt: The key for the textual modality. For encoder-decoder attention, it is the output from the encoder. :param value_txt: The value for the textual modality. For encoder-decoder attention, it is the output from the encoder. :param mask_txt: The mask. For encoder-decoder attention, it is the output from the encoder. :param key_img: The key for the visual modality. :param value_img: The value for the visual modality. :param mask_img: The visual features mask. :return: The output of the CrossAttentionSublayer. """ residual = query query = self.apply_pre_norm_if_needed(query) attn_txt, attn_weights_txt = self.attn_txt((query, key_txt, value_txt, mask_txt)) attn_img, attn_weights_img = self.attn_img((query, key_img, value_img, mask_img)) attn_combined, combined_attn_weights = self._fuse_contexts(query, attn_img, attn_txt) out = self.apply_residual(residual, attn_combined) attn = self.apply_post_norm_if_needed(out) return attn, {'txt': attn_weights_txt, 'img': attn_weights_img, 'hier': combined_attn_weights} def _fuse_contexts(self, query, attn_img, attn_txt, combined_mask=None): seq_len, batch_size, model_dim = query.shape combined_key_value = torch.stack((attn_txt, attn_img), dim=0).view(2, -1, model_dim) combined_attn, combined_attn_weights = self.attn_hierarchical((query.view(1, -1, model_dim), combined_key_value, combined_key_value, combined_mask)) return combined_attn.view(seq_len, batch_size, model_dim), combined_attn_weights
Ancestors
- BaseSublayer
- torch.nn.modules.module.Module
Class variables
var dump_patches : bool
var training : bool
Methods
def forward(self, query, key_txt, value_txt, mask_txt, key_img, value_img, mask_img=None) ‑> Callable[..., Any]
-
Performs a forward pass over the HierarchicalMMCrossAttentionSublayer. :param query: The query. For encoder-decoder attention, it is the output from the previous decoder layer. :param key_txt: The key for the textual modality. For encoder-decoder attention, it is the output from the encoder. :param value_txt: The value for the textual modality. For encoder-decoder attention, it is the output from the encoder. :param mask_txt: The mask. For encoder-decoder attention, it is the output from the encoder. :param key_img: The key for the visual modality. :param value_img: The value for the visual modality. :param mask_img: The visual features mask. :return: The output of the CrossAttentionSublayer.
Expand source code
def forward(self, query, key_txt, value_txt, mask_txt, key_img, value_img, mask_img=None): """ Performs a forward pass over the HierarchicalMMCrossAttentionSublayer. :param query: The query. For encoder-decoder attention, it is the output from the previous decoder layer. :param key_txt: The key for the textual modality. For encoder-decoder attention, it is the output from the encoder. :param value_txt: The value for the textual modality. For encoder-decoder attention, it is the output from the encoder. :param mask_txt: The mask. For encoder-decoder attention, it is the output from the encoder. :param key_img: The key for the visual modality. :param value_img: The value for the visual modality. :param mask_img: The visual features mask. :return: The output of the CrossAttentionSublayer. """ residual = query query = self.apply_pre_norm_if_needed(query) attn_txt, attn_weights_txt = self.attn_txt((query, key_txt, value_txt, mask_txt)) attn_img, attn_weights_img = self.attn_img((query, key_img, value_img, mask_img)) attn_combined, combined_attn_weights = self._fuse_contexts(query, attn_img, attn_txt) out = self.apply_residual(residual, attn_combined) attn = self.apply_post_norm_if_needed(out) return attn, {'txt': attn_weights_txt, 'img': attn_weights_img, 'hier': combined_attn_weights}
Inherited members