Module pysimt.layers.attention.scaled_dot

Expand source code
# -*- coding: utf-8 -*-
import math
import torch


class ScaledDotAttention(torch.nn.Module):

    def __init__(self, model_dim, n_heads, dropout=0.0):
        """
        Creates a ScaledDotAttention.
        :param model_dim: The model dimensions.
        :param n_heads: The number of heads.
        :param dropout: The dropout value. Default 0.0.
        """
        super().__init__()
        self.model_dim = model_dim
        self.n_heads = n_heads

        self.lin_k = torch.nn.Linear(self.model_dim, self.model_dim, bias=False)
        self.lin_q = torch.nn.Linear(self.model_dim, self.model_dim, bias=False)
        self.lin_v = torch.nn.Linear(self.model_dim, self.model_dim, bias=False)
        self.dropout = torch.nn.Dropout(dropout)
        self.lin_o = torch.nn.Linear(self.model_dim, self.model_dim, bias=False)

        self.head_dim = self.model_dim // self.n_heads
        self.scale = math.sqrt(self.head_dim)

    def forward(self, inputs):
        """Scaled dot-product attention forward-pass

        :param inputs: dictionary with query, key, value and mask tensors
            the shape of the tensors are (tstep, bsize, dim) except for the
            mask which is (bsize, query_len, key_len)

        :return: the output from the forward pass, the attention weights
        """
        q, k, v, mask = inputs
        q_len, _, _ = q.shape
        query, key, value = self._project_and_reshape(q, k, v)

        attn_weights = self._compute_attn_weights(query, key, mask)
        attn_probs = self.dropout(attn_weights)
        scores = torch.matmul(attn_probs, value)
        out = self.lin_o(self._view_as_concat(scores, q_len))

        return out, attn_weights

    def _project_and_reshape(self, q, k, v):
        """
        Projects the q, k and v and reshapes it into size (bsize, n_heads, q|k|v_len, head_dim).
        :param q: q of shape (q_len, b_size, model_dim)
        :param k: k of shape (k_len, b_size, model_dim)
        :param v: v of shape (v_len, b_size, model_dim)
        :return: The query, key, value of shape (b_size, n_heads, q|k|v_len, head_dim).
        """
        query = self._view_as_headed(self.lin_q(q))
        key = self._view_as_headed(self.lin_k(k))
        value = self._view_as_headed(self.lin_v(v))
        return query, key, value

    def _compute_attn_weights(self, query, key, mask):
        """
        Computes the normalized attention scores.
        :param query: The query of shape (b_size, n_heads, q_len, head_dim).
        :param key: The key of shape (b_size, n_heads, k_len, head_dim).
        :param mask: The value of shape (b_size, _, k_len).
        :return: The normalized attention scores of shape (b_size, n_heads, q_len, k_len).
        """
        attn = torch.matmul(query.div(self.scale), key.transpose(-2, -1))
        attn = self._apply_mask(mask, attn)
        return attn.softmax(dim=-1)

    def _view_as_headed(self, data):
        """
        Reshapes the data into a head format.
        :param data: (seq_len, b_size, model_dim)
        :return: (b_size, n_heads, seq_len, head_dim).
        """
        return data.view(data.shape[0], data.shape[1], self.n_heads, -1).permute(1, 2, 0, 3)

    def _view_as_concat(self, data, q_len):
        return data.permute(2, 0, 1, 3).contiguous().view(q_len, -1, self.model_dim)

    @staticmethod
    def _apply_mask(mask, attn):
        if mask is not None:
            mask = mask.unsqueeze(1)
            attn.masked_fill_(mask, -1e8)
        return attn

Classes

class ScaledDotAttention (model_dim, n_heads, dropout=0.0)

Base class for all neural network modules.

Your models should also subclass this class.

Modules can also contain other Modules, allowing to nest them in a tree structure. You can assign the submodules as regular attributes::

import torch.nn as nn
import torch.nn.functional as F

class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

Submodules assigned in this way will be registered, and will have their parameters converted too when you call :meth:to, etc.

:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool

Creates a ScaledDotAttention. :param model_dim: The model dimensions. :param n_heads: The number of heads. :param dropout: The dropout value. Default 0.0.

Expand source code
class ScaledDotAttention(torch.nn.Module):

    def __init__(self, model_dim, n_heads, dropout=0.0):
        """
        Creates a ScaledDotAttention.
        :param model_dim: The model dimensions.
        :param n_heads: The number of heads.
        :param dropout: The dropout value. Default 0.0.
        """
        super().__init__()
        self.model_dim = model_dim
        self.n_heads = n_heads

        self.lin_k = torch.nn.Linear(self.model_dim, self.model_dim, bias=False)
        self.lin_q = torch.nn.Linear(self.model_dim, self.model_dim, bias=False)
        self.lin_v = torch.nn.Linear(self.model_dim, self.model_dim, bias=False)
        self.dropout = torch.nn.Dropout(dropout)
        self.lin_o = torch.nn.Linear(self.model_dim, self.model_dim, bias=False)

        self.head_dim = self.model_dim // self.n_heads
        self.scale = math.sqrt(self.head_dim)

    def forward(self, inputs):
        """Scaled dot-product attention forward-pass

        :param inputs: dictionary with query, key, value and mask tensors
            the shape of the tensors are (tstep, bsize, dim) except for the
            mask which is (bsize, query_len, key_len)

        :return: the output from the forward pass, the attention weights
        """
        q, k, v, mask = inputs
        q_len, _, _ = q.shape
        query, key, value = self._project_and_reshape(q, k, v)

        attn_weights = self._compute_attn_weights(query, key, mask)
        attn_probs = self.dropout(attn_weights)
        scores = torch.matmul(attn_probs, value)
        out = self.lin_o(self._view_as_concat(scores, q_len))

        return out, attn_weights

    def _project_and_reshape(self, q, k, v):
        """
        Projects the q, k and v and reshapes it into size (bsize, n_heads, q|k|v_len, head_dim).
        :param q: q of shape (q_len, b_size, model_dim)
        :param k: k of shape (k_len, b_size, model_dim)
        :param v: v of shape (v_len, b_size, model_dim)
        :return: The query, key, value of shape (b_size, n_heads, q|k|v_len, head_dim).
        """
        query = self._view_as_headed(self.lin_q(q))
        key = self._view_as_headed(self.lin_k(k))
        value = self._view_as_headed(self.lin_v(v))
        return query, key, value

    def _compute_attn_weights(self, query, key, mask):
        """
        Computes the normalized attention scores.
        :param query: The query of shape (b_size, n_heads, q_len, head_dim).
        :param key: The key of shape (b_size, n_heads, k_len, head_dim).
        :param mask: The value of shape (b_size, _, k_len).
        :return: The normalized attention scores of shape (b_size, n_heads, q_len, k_len).
        """
        attn = torch.matmul(query.div(self.scale), key.transpose(-2, -1))
        attn = self._apply_mask(mask, attn)
        return attn.softmax(dim=-1)

    def _view_as_headed(self, data):
        """
        Reshapes the data into a head format.
        :param data: (seq_len, b_size, model_dim)
        :return: (b_size, n_heads, seq_len, head_dim).
        """
        return data.view(data.shape[0], data.shape[1], self.n_heads, -1).permute(1, 2, 0, 3)

    def _view_as_concat(self, data, q_len):
        return data.permute(2, 0, 1, 3).contiguous().view(q_len, -1, self.model_dim)

    @staticmethod
    def _apply_mask(mask, attn):
        if mask is not None:
            mask = mask.unsqueeze(1)
            attn.masked_fill_(mask, -1e8)
        return attn

Ancestors

  • torch.nn.modules.module.Module

Class variables

var dump_patches : bool
var training : bool

Methods

def forward(self, inputs) ‑> Callable[..., Any]

Scaled dot-product attention forward-pass

:param inputs: dictionary with query, key, value and mask tensors the shape of the tensors are (tstep, bsize, dim) except for the mask which is (bsize, query_len, key_len)

:return: the output from the forward pass, the attention weights

Expand source code
def forward(self, inputs):
    """Scaled dot-product attention forward-pass

    :param inputs: dictionary with query, key, value and mask tensors
        the shape of the tensors are (tstep, bsize, dim) except for the
        mask which is (bsize, query_len, key_len)

    :return: the output from the forward pass, the attention weights
    """
    q, k, v, mask = inputs
    q_len, _, _ = q.shape
    query, key, value = self._project_and_reshape(q, k, v)

    attn_weights = self._compute_attn_weights(query, key, mask)
    attn_probs = self.dropout(attn_weights)
    scores = torch.matmul(attn_probs, value)
    out = self.lin_o(self._view_as_concat(scores, q_len))

    return out, attn_weights