Module pysimt.layers.attention.multihead
Expand source code
import math
import numpy as np
import torch
from torch import nn
class MultiheadAttention(nn.Module):
"""General purpose multihead attention implementation."""
def __init__(self, input_dim, proj_dim, n_heads=1, dropout=0.0,
attn_type='cross', initializer='xavier_uniform'):
assert proj_dim % n_heads == 0, "proj_dim not divisible by n_heads."
super().__init__()
self.input_dim = input_dim
self.proj_dim = proj_dim
self.n_heads = n_heads
self.head_dim = self.proj_dim // self.n_heads
self.scale = math.sqrt(self.head_dim)
self.minus_inf = float('-inf')
self.attn_type = attn_type
self.initializer = initializer
self.p_dropout = dropout
self._apply_projections_and_reshape = getattr(
self, f'_apply_projections_and_reshape_{self.attn_type}')
# dropout over attention probability
self.dropout = nn.Dropout(dropout) if dropout > 0.0 else lambda x: x
self._create_layers()
self._reset_parameters(getattr(nn.init, f'{initializer}_'))
def __repr__(self):
s = f"MultiheadAttention({self.input_dim} -> {self.proj_dim}, {self.n_heads} heads, "
s += f"type={self.attn_type!r}, dropout={self.p_dropout})"
return s
def view_as_headed(self, x):
"""Returns a view of shape `[bsz, n_heads, seq_len, head_dim]`
from `[bsz, seq_len, head_dim * n_heads]`."""
return x.view(x.size(0), x.size(1), self.n_heads, -1).transpose(1, 2)
@staticmethod
def view_as_concat(x):
"""Returns a view of shape `[bsz, seq_len, head_dim * n_heads]`
from `[bsz, n_heads, seq_len, head_dim]`."""
return x.transpose(1, 2).contiguous().view(x.size(0), x.size(2), -1)
def _reset_parameters(self, init_fn):
"""Reinitializes layer weights."""
for param in self.parameters():
init_fn(param)
def _create_layers(self):
"""Create projection layer weights."""
self.lin_o = nn.Parameter(torch.Tensor(self.proj_dim, self.proj_dim))
if self.attn_type != 'self':
self.lin_k = nn.Parameter(torch.Tensor(self.input_dim, self.proj_dim))
self.lin_q = nn.Parameter(torch.Tensor(self.input_dim, self.proj_dim))
self.lin_v = nn.Parameter(torch.Tensor(self.input_dim, self.proj_dim))
else:
self.lin_k = nn.Parameter(torch.Tensor(self.input_dim, 3 * self.proj_dim))
def _apply_projections_and_reshape_self(self, k, v=None, q=None):
"""Projects key, value and queries and returns multi-head view
for self-attention variant.
Args:
k: Tensor of shape `[batch_size, v_len, dim]`.
v: `None` for self-attention. This is not used.
q: `None` for self-attention. This is not used.
Returns:
A tuple of 3 tensors for k,v,q projections, each with shape
`[batch_size, n_heads, v_len, head_dim]`.
"""
return (
self.view_as_headed(t) for t in k.matmul(self.lin_k).chunk(3, dim=-1))
def _apply_projections_and_reshape_cross(self, k, v, q):
"""Projects key, value and queries and returns multi-head view
for cross-attention variant.
Args:
k: Tensor of shape `[batch_size, v_len, dim]`.
v: Tensor of shape `[batch_size, v_len, dim]`.
q: Tensor of shape `[batch_size, q_len, dim]`.
Returns:
A tuple of 3 tensors for k,v,q projections, each with shape
`[batch_size, n_heads, (v|q)_len, head_dim]`.
"""
return (self.view_as_headed(k.matmul(self.lin_k)),
self.view_as_headed(v.matmul(self.lin_v)),
self.view_as_headed(q.matmul(self.lin_q)))
def _compute_scores(self, query, key, k_mask=None):
"""Computes normalized scaled dot-product scores between query and key.
Args:
query: Tensor of shape `[batch_size, n_heads, q_len, dim]`.
key: Tensor of shape `[batch_size, n_heads, v_len, dim]`.
k_mask: Tensor of shape `[batch_size, v_len]`.
Returns:
Tensor of shape `[batch_size, n_heads, q_len, v_len]` with
normalized attention weights.
"""
scores = torch.matmul(query.div(self.scale), key.transpose(-2, -1))
if k_mask is not None:
# mask <pad>'ded positions
scores.masked_fill_(k_mask[:, None, None, :], self.minus_inf)
return self.dropout(scores.softmax(dim=-1))
def _apply_scores(self, p, value, q_mask=None):
"""Applies normalized attention weights on `value`. `q_mask`
is used to zero padded positions afterwards.
Args:
p: Tensor of shape `[batch_size, n_heads, q_len, v_len]`.
value: Tensor of shape `[batch_size, n_heads, v_len, dim]`.
q_mask: Tensor of shape `[batch_size, q_len]`.
Returns:
Tensor of shape `[batch_size, n_heads, v_len, dim]`.
"""
ctx = torch.matmul(p, value)
if q_mask is not None:
# zero out <pad>'ded positions
ctx.mul_(q_mask[:, None, :, None].logical_not())
return ctx
def forward(self, k, v=None, q=None, k_mask=None, q_mask=None):
kp, vp, qp = self._apply_projections_and_reshape(k, v, q)
# Get normalized scores
alpha = self._compute_scores(qp, kp, k_mask)
# Get weighted contexts for each head -> concat -> project
return self.view_as_concat(
self._apply_scores(alpha, vp, q_mask)).matmul(self.lin_o)
def get_upstream_impl(dim, n_heads):
mha = nn.MultiheadAttention(dim, n_heads, bias=False)
nn.init.eye_(mha.out_proj.weight.data)
list(map(lambda i: nn.init.eye_(i), mha.in_proj_weight.data.chunk(3, dim=0)))
nn.init.eye_(mha.in_proj_weight.data[:dim])
nn.init.eye_(mha.in_proj_weight.data[dim:2 * dim])
nn.init.eye_(mha.in_proj_weight.data[-dim:])
return mha
def get_own_self_impl(i_dim, p_dim, n_heads):
self_att = MultiheadAttention(input_dim=i_dim, proj_dim=p_dim, n_heads=n_heads, attn_type='self')
print(self_att)
nn.init.eye_(self_att.lin_o.data)
list(map(lambda x: nn.init.eye_(x), self_att.lin_k.data.chunk(3, dim=-1)))
return self_att
def get_own_cross_impl(i_dim, p_dim, n_heads):
cross_att = MultiheadAttention(input_dim=i_dim, proj_dim=p_dim, n_heads=n_heads)
print(cross_att)
nn.init.eye_(cross_att.lin_o.data)
nn.init.eye_(cross_att.lin_k.data)
nn.init.eye_(cross_att.lin_q.data)
nn.init.eye_(cross_att.lin_v.data)
return cross_att
def main():
np.random.seed(2)
torch.manual_seed(3)
torch.cuda.manual_seed(4)
input_dim = 512
batch_size = 100
vocab_size = 1000
# Create the embeddings
embs = nn.Embedding(vocab_size, embedding_dim=input_dim, padding_idx=0)
# Sample sequence lengths
src_seq_lens = np.random.normal(6, 1, size=(batch_size,)).astype('int')
trg_seq_lens = np.random.normal(6, 1, size=(batch_size,)).astype('int')
# Sample random vocab IDs
src_idxs = torch.randint(
low=1, high=vocab_size, size=(batch_size, src_seq_lens.max()))
trg_idxs = torch.randint(
low=1, high=vocab_size, size=(batch_size, trg_seq_lens.max()))
# pad short sequences
for seq, seqlen in enumerate(src_seq_lens):
src_idxs[seq, seqlen:].fill_(0)
for seq, seqlen in enumerate(trg_seq_lens):
trg_idxs[seq, seqlen:].fill_(0)
# masks with `True` for padded positions
src_padding_mask = src_idxs.eq(0)
trg_padding_mask = trg_idxs.eq(0)
# Verify lengths
assert np.allclose(src_seq_lens, src_idxs.ne(0).sum(1))
assert np.allclose(trg_seq_lens, trg_idxs.ne(0).sum(1))
# get embeddings
x = embs(src_idxs)
y = embs(trg_idxs)
# Verify lengths using embeddings
assert np.allclose(src_seq_lens, x.sum(-1).ne(0.0).sum(1))
assert np.allclose(trg_seq_lens, y.sum(-1).ne(0.0).sum(1))
mha = get_upstream_impl(input_dim, 1)
xp = x.transpose(0, 1)
yp = y.transpose(0, 1)
h_mha_self, p_mha_self = mha(
query=xp, key=xp, value=xp, key_padding_mask=src_padding_mask)
h_mha_cross, p_mha_cross = mha(
query=yp, key=xp, value=xp, key_padding_mask=src_padding_mask)
h_mha_self.transpose_(0, 1)
h_mha_cross.transpose_(0, 1)
# self attention
# q_mask: src
self_att = get_own_self_impl(input_dim, input_dim, n_heads=1)
h_self = self_att(k=x, v=x, q=x, k_mask=src_padding_mask, q_mask=None)
assert torch.allclose(h_self, h_mha_self, atol=1e-1)
# self attention with identity projections should produce the query itself
assert torch.allclose(
self_att(x, x, x, src_padding_mask, src_padding_mask), x, atol=1e-1)
# cross attention
# q_mask: trg
cross_att = get_own_cross_impl(input_dim, input_dim, n_heads=1)
h_cross = cross_att(k=x, v=x, q=y, k_mask=src_padding_mask, q_mask=trg_padding_mask)
assert torch.allclose(
cross_att(x, x, y, src_padding_mask, None), h_mha_cross, atol=1e-1)
#################
# multi-head test
#################
for nh in (1, 2, 4, 8, 16, 32):
print(f'# heads: {nh}')
self_att = get_own_self_impl(input_dim, input_dim, n_heads=nh)
cross_att = get_own_cross_impl(input_dim, input_dim, n_heads=nh)
torc_att = get_upstream_impl(input_dim, nh)
h_torc, p_torc = torc_att(xp, xp, xp, key_padding_mask=src_padding_mask)
h_torc.transpose_(0, 1)
h_self = self_att(k=x, k_mask=src_padding_mask, q_mask=None)
h_cross = cross_att(x, x, x, k_mask=src_padding_mask, q_mask=None)
assert torch.allclose(h_self, h_torc, atol=1e-1)
assert torch.allclose(h_cross, h_torc, atol=1e-1)
self_att = get_own_self_impl(input_dim, 256, n_heads=2)
cross_att = get_own_cross_impl(input_dim, 256, n_heads=2)
h_self = self_att(k=x, k_mask=src_padding_mask, q_mask=None)
h_cross = cross_att(x, x, x, k_mask=src_padding_mask, q_mask=None)
if __name__ == '__main__':
main()
Functions
def get_own_cross_impl(i_dim, p_dim, n_heads)
-
Expand source code
def get_own_cross_impl(i_dim, p_dim, n_heads): cross_att = MultiheadAttention(input_dim=i_dim, proj_dim=p_dim, n_heads=n_heads) print(cross_att) nn.init.eye_(cross_att.lin_o.data) nn.init.eye_(cross_att.lin_k.data) nn.init.eye_(cross_att.lin_q.data) nn.init.eye_(cross_att.lin_v.data) return cross_att
def get_own_self_impl(i_dim, p_dim, n_heads)
-
Expand source code
def get_own_self_impl(i_dim, p_dim, n_heads): self_att = MultiheadAttention(input_dim=i_dim, proj_dim=p_dim, n_heads=n_heads, attn_type='self') print(self_att) nn.init.eye_(self_att.lin_o.data) list(map(lambda x: nn.init.eye_(x), self_att.lin_k.data.chunk(3, dim=-1))) return self_att
def get_upstream_impl(dim, n_heads)
-
Expand source code
def get_upstream_impl(dim, n_heads): mha = nn.MultiheadAttention(dim, n_heads, bias=False) nn.init.eye_(mha.out_proj.weight.data) list(map(lambda i: nn.init.eye_(i), mha.in_proj_weight.data.chunk(3, dim=0))) nn.init.eye_(mha.in_proj_weight.data[:dim]) nn.init.eye_(mha.in_proj_weight.data[dim:2 * dim]) nn.init.eye_(mha.in_proj_weight.data[-dim:]) return mha
def main()
-
Expand source code
def main(): np.random.seed(2) torch.manual_seed(3) torch.cuda.manual_seed(4) input_dim = 512 batch_size = 100 vocab_size = 1000 # Create the embeddings embs = nn.Embedding(vocab_size, embedding_dim=input_dim, padding_idx=0) # Sample sequence lengths src_seq_lens = np.random.normal(6, 1, size=(batch_size,)).astype('int') trg_seq_lens = np.random.normal(6, 1, size=(batch_size,)).astype('int') # Sample random vocab IDs src_idxs = torch.randint( low=1, high=vocab_size, size=(batch_size, src_seq_lens.max())) trg_idxs = torch.randint( low=1, high=vocab_size, size=(batch_size, trg_seq_lens.max())) # pad short sequences for seq, seqlen in enumerate(src_seq_lens): src_idxs[seq, seqlen:].fill_(0) for seq, seqlen in enumerate(trg_seq_lens): trg_idxs[seq, seqlen:].fill_(0) # masks with `True` for padded positions src_padding_mask = src_idxs.eq(0) trg_padding_mask = trg_idxs.eq(0) # Verify lengths assert np.allclose(src_seq_lens, src_idxs.ne(0).sum(1)) assert np.allclose(trg_seq_lens, trg_idxs.ne(0).sum(1)) # get embeddings x = embs(src_idxs) y = embs(trg_idxs) # Verify lengths using embeddings assert np.allclose(src_seq_lens, x.sum(-1).ne(0.0).sum(1)) assert np.allclose(trg_seq_lens, y.sum(-1).ne(0.0).sum(1)) mha = get_upstream_impl(input_dim, 1) xp = x.transpose(0, 1) yp = y.transpose(0, 1) h_mha_self, p_mha_self = mha( query=xp, key=xp, value=xp, key_padding_mask=src_padding_mask) h_mha_cross, p_mha_cross = mha( query=yp, key=xp, value=xp, key_padding_mask=src_padding_mask) h_mha_self.transpose_(0, 1) h_mha_cross.transpose_(0, 1) # self attention # q_mask: src self_att = get_own_self_impl(input_dim, input_dim, n_heads=1) h_self = self_att(k=x, v=x, q=x, k_mask=src_padding_mask, q_mask=None) assert torch.allclose(h_self, h_mha_self, atol=1e-1) # self attention with identity projections should produce the query itself assert torch.allclose( self_att(x, x, x, src_padding_mask, src_padding_mask), x, atol=1e-1) # cross attention # q_mask: trg cross_att = get_own_cross_impl(input_dim, input_dim, n_heads=1) h_cross = cross_att(k=x, v=x, q=y, k_mask=src_padding_mask, q_mask=trg_padding_mask) assert torch.allclose( cross_att(x, x, y, src_padding_mask, None), h_mha_cross, atol=1e-1) ################# # multi-head test ################# for nh in (1, 2, 4, 8, 16, 32): print(f'# heads: {nh}') self_att = get_own_self_impl(input_dim, input_dim, n_heads=nh) cross_att = get_own_cross_impl(input_dim, input_dim, n_heads=nh) torc_att = get_upstream_impl(input_dim, nh) h_torc, p_torc = torc_att(xp, xp, xp, key_padding_mask=src_padding_mask) h_torc.transpose_(0, 1) h_self = self_att(k=x, k_mask=src_padding_mask, q_mask=None) h_cross = cross_att(x, x, x, k_mask=src_padding_mask, q_mask=None) assert torch.allclose(h_self, h_torc, atol=1e-1) assert torch.allclose(h_cross, h_torc, atol=1e-1) self_att = get_own_self_impl(input_dim, 256, n_heads=2) cross_att = get_own_cross_impl(input_dim, 256, n_heads=2) h_self = self_att(k=x, k_mask=src_padding_mask, q_mask=None) h_cross = cross_att(x, x, x, k_mask=src_padding_mask, q_mask=None)
Classes
class MultiheadAttention (input_dim, proj_dim, n_heads=1, dropout=0.0, attn_type='cross', initializer='xavier_uniform')
-
General purpose multihead attention implementation.
Initializes internal Module state, shared by both nn.Module and ScriptModule.
Expand source code
class MultiheadAttention(nn.Module): """General purpose multihead attention implementation.""" def __init__(self, input_dim, proj_dim, n_heads=1, dropout=0.0, attn_type='cross', initializer='xavier_uniform'): assert proj_dim % n_heads == 0, "proj_dim not divisible by n_heads." super().__init__() self.input_dim = input_dim self.proj_dim = proj_dim self.n_heads = n_heads self.head_dim = self.proj_dim // self.n_heads self.scale = math.sqrt(self.head_dim) self.minus_inf = float('-inf') self.attn_type = attn_type self.initializer = initializer self.p_dropout = dropout self._apply_projections_and_reshape = getattr( self, f'_apply_projections_and_reshape_{self.attn_type}') # dropout over attention probability self.dropout = nn.Dropout(dropout) if dropout > 0.0 else lambda x: x self._create_layers() self._reset_parameters(getattr(nn.init, f'{initializer}_')) def __repr__(self): s = f"MultiheadAttention({self.input_dim} -> {self.proj_dim}, {self.n_heads} heads, " s += f"type={self.attn_type!r}, dropout={self.p_dropout})" return s def view_as_headed(self, x): """Returns a view of shape `[bsz, n_heads, seq_len, head_dim]` from `[bsz, seq_len, head_dim * n_heads]`.""" return x.view(x.size(0), x.size(1), self.n_heads, -1).transpose(1, 2) @staticmethod def view_as_concat(x): """Returns a view of shape `[bsz, seq_len, head_dim * n_heads]` from `[bsz, n_heads, seq_len, head_dim]`.""" return x.transpose(1, 2).contiguous().view(x.size(0), x.size(2), -1) def _reset_parameters(self, init_fn): """Reinitializes layer weights.""" for param in self.parameters(): init_fn(param) def _create_layers(self): """Create projection layer weights.""" self.lin_o = nn.Parameter(torch.Tensor(self.proj_dim, self.proj_dim)) if self.attn_type != 'self': self.lin_k = nn.Parameter(torch.Tensor(self.input_dim, self.proj_dim)) self.lin_q = nn.Parameter(torch.Tensor(self.input_dim, self.proj_dim)) self.lin_v = nn.Parameter(torch.Tensor(self.input_dim, self.proj_dim)) else: self.lin_k = nn.Parameter(torch.Tensor(self.input_dim, 3 * self.proj_dim)) def _apply_projections_and_reshape_self(self, k, v=None, q=None): """Projects key, value and queries and returns multi-head view for self-attention variant. Args: k: Tensor of shape `[batch_size, v_len, dim]`. v: `None` for self-attention. This is not used. q: `None` for self-attention. This is not used. Returns: A tuple of 3 tensors for k,v,q projections, each with shape `[batch_size, n_heads, v_len, head_dim]`. """ return ( self.view_as_headed(t) for t in k.matmul(self.lin_k).chunk(3, dim=-1)) def _apply_projections_and_reshape_cross(self, k, v, q): """Projects key, value and queries and returns multi-head view for cross-attention variant. Args: k: Tensor of shape `[batch_size, v_len, dim]`. v: Tensor of shape `[batch_size, v_len, dim]`. q: Tensor of shape `[batch_size, q_len, dim]`. Returns: A tuple of 3 tensors for k,v,q projections, each with shape `[batch_size, n_heads, (v|q)_len, head_dim]`. """ return (self.view_as_headed(k.matmul(self.lin_k)), self.view_as_headed(v.matmul(self.lin_v)), self.view_as_headed(q.matmul(self.lin_q))) def _compute_scores(self, query, key, k_mask=None): """Computes normalized scaled dot-product scores between query and key. Args: query: Tensor of shape `[batch_size, n_heads, q_len, dim]`. key: Tensor of shape `[batch_size, n_heads, v_len, dim]`. k_mask: Tensor of shape `[batch_size, v_len]`. Returns: Tensor of shape `[batch_size, n_heads, q_len, v_len]` with normalized attention weights. """ scores = torch.matmul(query.div(self.scale), key.transpose(-2, -1)) if k_mask is not None: # mask <pad>'ded positions scores.masked_fill_(k_mask[:, None, None, :], self.minus_inf) return self.dropout(scores.softmax(dim=-1)) def _apply_scores(self, p, value, q_mask=None): """Applies normalized attention weights on `value`. `q_mask` is used to zero padded positions afterwards. Args: p: Tensor of shape `[batch_size, n_heads, q_len, v_len]`. value: Tensor of shape `[batch_size, n_heads, v_len, dim]`. q_mask: Tensor of shape `[batch_size, q_len]`. Returns: Tensor of shape `[batch_size, n_heads, v_len, dim]`. """ ctx = torch.matmul(p, value) if q_mask is not None: # zero out <pad>'ded positions ctx.mul_(q_mask[:, None, :, None].logical_not()) return ctx def forward(self, k, v=None, q=None, k_mask=None, q_mask=None): kp, vp, qp = self._apply_projections_and_reshape(k, v, q) # Get normalized scores alpha = self._compute_scores(qp, kp, k_mask) # Get weighted contexts for each head -> concat -> project return self.view_as_concat( self._apply_scores(alpha, vp, q_mask)).matmul(self.lin_o)
Ancestors
- torch.nn.modules.module.Module
Class variables
var dump_patches : bool
var training : bool
Static methods
def view_as_concat(x)
-
Returns a view of shape
[bsz, seq_len, head_dim * n_heads]
from[bsz, n_heads, seq_len, head_dim]
.Expand source code
@staticmethod def view_as_concat(x): """Returns a view of shape `[bsz, seq_len, head_dim * n_heads]` from `[bsz, n_heads, seq_len, head_dim]`.""" return x.transpose(1, 2).contiguous().view(x.size(0), x.size(2), -1)
Methods
def forward(self, k, v=None, q=None, k_mask=None, q_mask=None) ‑> Callable[..., Any]
-
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the :class:
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.Expand source code
def forward(self, k, v=None, q=None, k_mask=None, q_mask=None): kp, vp, qp = self._apply_projections_and_reshape(k, v, q) # Get normalized scores alpha = self._compute_scores(qp, kp, k_mask) # Get weighted contexts for each head -> concat -> project return self.view_as_concat( self._apply_scores(alpha, vp, q_mask)).matmul(self.lin_o)
def view_as_headed(self, x)
-
Returns a view of shape
[bsz, n_heads, seq_len, head_dim]
from[bsz, seq_len, head_dim * n_heads]
.Expand source code
def view_as_headed(self, x): """Returns a view of shape `[bsz, n_heads, seq_len, head_dim]` from `[bsz, seq_len, head_dim * n_heads]`.""" return x.view(x.size(0), x.size(1), self.n_heads, -1).transpose(1, 2)