Module pysimt.layers.transformers.base_sublayer

Expand source code
from torch import nn


class BaseSublayer(nn.Module):

    def __init__(self, model_dim, dropout=0.1, is_pre_norm=False):
        """
        Creates a BaseSublayer.
        :param model_dim: The model dimension.
        :param dropout: The dropout layer.
        :param is_pre_norm: Whether it should use pre_norm transformer layers. Default: False.
        """
        super().__init__()
        self.is_pre_norm = is_pre_norm
        self.layer_norm = nn.LayerNorm(model_dim, eps=1e-6)
        self.dropout = nn.Dropout(dropout)

    def forward(self, **kwargs):
        raise NotImplementedError("BaseSublayer does not implement forward.")

    def apply_pre_norm_if_needed(self, x):
        """
        Applies pre_norm to the input if needed. If pre_norm is false, the input remains unchanged.
        :param x: The input.
        :return: The output.
        """
        if self.is_pre_norm:
            x = self.layer_norm(x)
        return x

    def apply_post_norm_if_needed(self, x):
        """
        Applies post_norm to the input if needed. If pre_norm is true, the input remains unchanged.
        :param x: The input.
        :return: The output.
        """
        if not self.is_pre_norm:
            x = self.layer_norm(x)
        return x

    def apply_residual(self, residual, x):
        """
        Applies the residual connection.
        :param residual: The residual.
        :param x: The input x.
        :return: The output of the residual connection.
        """
        return residual + self.dropout(x)

Classes

class BaseSublayer (model_dim, dropout=0.1, is_pre_norm=False)

Base class for all neural network modules.

Your models should also subclass this class.

Modules can also contain other Modules, allowing to nest them in a tree structure. You can assign the submodules as regular attributes::

import torch.nn as nn
import torch.nn.functional as F

class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

Submodules assigned in this way will be registered, and will have their parameters converted too when you call :meth:to, etc.

:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool

Creates a BaseSublayer. :param model_dim: The model dimension. :param dropout: The dropout layer. :param is_pre_norm: Whether it should use pre_norm transformer layers. Default: False.

Expand source code
class BaseSublayer(nn.Module):

    def __init__(self, model_dim, dropout=0.1, is_pre_norm=False):
        """
        Creates a BaseSublayer.
        :param model_dim: The model dimension.
        :param dropout: The dropout layer.
        :param is_pre_norm: Whether it should use pre_norm transformer layers. Default: False.
        """
        super().__init__()
        self.is_pre_norm = is_pre_norm
        self.layer_norm = nn.LayerNorm(model_dim, eps=1e-6)
        self.dropout = nn.Dropout(dropout)

    def forward(self, **kwargs):
        raise NotImplementedError("BaseSublayer does not implement forward.")

    def apply_pre_norm_if_needed(self, x):
        """
        Applies pre_norm to the input if needed. If pre_norm is false, the input remains unchanged.
        :param x: The input.
        :return: The output.
        """
        if self.is_pre_norm:
            x = self.layer_norm(x)
        return x

    def apply_post_norm_if_needed(self, x):
        """
        Applies post_norm to the input if needed. If pre_norm is true, the input remains unchanged.
        :param x: The input.
        :return: The output.
        """
        if not self.is_pre_norm:
            x = self.layer_norm(x)
        return x

    def apply_residual(self, residual, x):
        """
        Applies the residual connection.
        :param residual: The residual.
        :param x: The input x.
        :return: The output of the residual connection.
        """
        return residual + self.dropout(x)

Ancestors

  • torch.nn.modules.module.Module

Subclasses

Class variables

var dump_patches : bool
var training : bool

Methods

def apply_post_norm_if_needed(self, x)

Applies post_norm to the input if needed. If pre_norm is true, the input remains unchanged. :param x: The input. :return: The output.

Expand source code
def apply_post_norm_if_needed(self, x):
    """
    Applies post_norm to the input if needed. If pre_norm is true, the input remains unchanged.
    :param x: The input.
    :return: The output.
    """
    if not self.is_pre_norm:
        x = self.layer_norm(x)
    return x
def apply_pre_norm_if_needed(self, x)

Applies pre_norm to the input if needed. If pre_norm is false, the input remains unchanged. :param x: The input. :return: The output.

Expand source code
def apply_pre_norm_if_needed(self, x):
    """
    Applies pre_norm to the input if needed. If pre_norm is false, the input remains unchanged.
    :param x: The input.
    :return: The output.
    """
    if self.is_pre_norm:
        x = self.layer_norm(x)
    return x
def apply_residual(self, residual, x)

Applies the residual connection. :param residual: The residual. :param x: The input x. :return: The output of the residual connection.

Expand source code
def apply_residual(self, residual, x):
    """
    Applies the residual connection.
    :param residual: The residual.
    :param x: The input x.
    :return: The output of the residual connection.
    """
    return residual + self.dropout(x)
def forward(self, **kwargs) ‑> Callable[..., Any]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the :class:Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Expand source code
def forward(self, **kwargs):
    raise NotImplementedError("BaseSublayer does not implement forward.")