Module pysimt.layers.decoders.conditional
Expand source code
import torch
from torch import nn
import torch.nn.functional as F
from collections import defaultdict
from ...utils.nn import get_activation_fn
from .. import FF, Fusion
from ..attention import get_attention
class ConditionalGRUDecoder(nn.Module):
"""A conditional decoder with attention à la dl4mt-tutorial. It supports
multimodal attention if more than one source modality is available. The
initial state of the decoder RNN is set to `zero` and can not be modified
for the sake of simplicity for simultaneous MT."""
def __init__(self, input_size, hidden_size, n_vocab, encoders,
rnn_type='gru', tied_emb=False, att_type='mlp',
att_activ='tanh', att_bottleneck='ctx', att_temp=1.0,
dropout_out=0, out_logic='simple', dec_inp_activ=None,
mm_fusion_op=None, mm_fusion_dropout=0.0):
super().__init__()
# Normalize case
self.rnn_type = rnn_type.upper()
# Safety checks
assert self.rnn_type in ('GRU',), f"{rnn_type!r} unknown"
assert mm_fusion_op in ('sum', 'concat', None), "mm_fusion_op unknown"
RNN = getattr(nn, '{}Cell'.format(self.rnn_type))
# Other arguments
self.input_size = input_size
self.hidden_size = hidden_size
self.n_vocab = n_vocab
self.dec_inp_activ_fn = get_activation_fn(dec_inp_activ)
# Create target embeddings
self.emb = nn.Embedding(self.n_vocab, self.input_size, padding_idx=0)
# Create attention layer(s)
self.att = nn.ModuleDict()
for key, enc in encoders.items():
Attention = get_attention(att_type)
self.att[str(key)] = Attention(
enc.ctx_size,
self.hidden_size,
transform_ctx=True,
ctx2hid=True,
mlp_bias=False,
att_activ=att_activ,
att_bottleneck=att_bottleneck,
temp=att_temp)
# return the only c_t from the list
self.fusion = lambda x: x[0]
if len(encoders) > 1:
# Multiple inputs (multimodal NMT)
ctx_sizes = [ll.ctx_size for ll in encoders.values()]
if mm_fusion_op == 'concat':
mm_inp_size = sum(ctx_sizes)
else:
assert len(set(ctx_sizes)) == 1, \
"Context sizes are not compatible with mm_fusion_op!"
mm_inp_size = ctx_sizes[0]
fusion = [Fusion(mm_fusion_op, input_size=mm_inp_size, output_size=self.hidden_size)]
if mm_fusion_dropout > 0:
fusion.append(nn.Dropout(mm_fusion_dropout))
self.fusion = nn.Sequential(*fusion)
# Create decoders
self.dec0 = RNN(self.input_size, self.hidden_size)
self.dec1 = RNN(self.hidden_size, self.hidden_size)
# Output dropout
if dropout_out > 0:
self.do_out = nn.Dropout(p=dropout_out)
else:
self.do_out = lambda x: x
# Output bottleneck: maps hidden states to target emb dim
# simple: tanh(W*h)
# deep: tanh(W*h + U*emb + V*ctx)
out_inp_size = self.hidden_size
# Dummy op to return back the hidden state for simple output
self.out_merge_fn = lambda h, e, c: h
if out_logic == 'deep':
out_inp_size += (self.input_size + self.hidden_size)
self.out_merge_fn = lambda h, e, c: torch.cat((h, e, c), dim=1)
# Final transformation that receives concatenated outputs or only h
self.hid2out = FF(out_inp_size, self.input_size,
bias_zero=True, activ='tanh')
# Final softmax
self.out2prob = FF(self.input_size, self.n_vocab)
# Tie input embedding matrix and output embedding matrix
if tied_emb:
self.out2prob.weight = self.emb.weight
self.nll_loss = nn.NLLLoss(reduction="sum", ignore_index=0)
def get_emb(self, idxs):
"""Returns time-step based embeddings."""
return self.emb(idxs)
def f_init(self, state_dict=None):
"""Returns the initial h_0 for the decoder."""
self.history = defaultdict(list)
return None
def f_next(self, state_dict, y, h, hypothesis=None):
"""Applies one timestep of recurrence. `state_dict` may contain
partial source information depending on how the model constructs it."""
# Get hidden states from the first decoder (purely cond. on LM)
h1 = self.dec0(y, h)
query = h1.unsqueeze(0)
# Obtain attention for each different input context in encoder
atts = []
for k, (s, m) in state_dict.items():
alpha, ctx = self.att[k](query, s, m)
atts.append(ctx)
if not self.training:
self.history[f'alpha_{k}'].append(alpha.cpu())
# Fuse input contexts
c_t = self.fusion(atts)
# Run second decoder (h1 is compatible now as it was returned by GRU)
# Additional optional transformation is to make the comparison
# fair with the MMT model.
h2 = self.dec1(self.dec_inp_activ_fn(c_t), h1)
# Output logic: dropout -> proj(o_t)
# transform logit to T*B*V (V: vocab_size)
logit = self.out2prob(
self.do_out(self.hid2out(self.out_merge_fn(h2, y, c_t))))
# Compute log_softmax over token dim
log_p = F.log_softmax(logit, dim=-1)
# Return log probs and new hidden states
return log_p, h2
Classes
class ConditionalGRUDecoder (input_size, hidden_size, n_vocab, encoders, rnn_type='gru', tied_emb=False, att_type='mlp', att_activ='tanh', att_bottleneck='ctx', att_temp=1.0, dropout_out=0, out_logic='simple', dec_inp_activ=None, mm_fusion_op=None, mm_fusion_dropout=0.0)
-
A conditional decoder with attention à la dl4mt-tutorial. It supports multimodal attention if more than one source modality is available. The initial state of the decoder RNN is set to
zero
and can not be modified for the sake of simplicity for simultaneous MT.Initializes internal Module state, shared by both nn.Module and ScriptModule.
Expand source code
class ConditionalGRUDecoder(nn.Module): """A conditional decoder with attention à la dl4mt-tutorial. It supports multimodal attention if more than one source modality is available. The initial state of the decoder RNN is set to `zero` and can not be modified for the sake of simplicity for simultaneous MT.""" def __init__(self, input_size, hidden_size, n_vocab, encoders, rnn_type='gru', tied_emb=False, att_type='mlp', att_activ='tanh', att_bottleneck='ctx', att_temp=1.0, dropout_out=0, out_logic='simple', dec_inp_activ=None, mm_fusion_op=None, mm_fusion_dropout=0.0): super().__init__() # Normalize case self.rnn_type = rnn_type.upper() # Safety checks assert self.rnn_type in ('GRU',), f"{rnn_type!r} unknown" assert mm_fusion_op in ('sum', 'concat', None), "mm_fusion_op unknown" RNN = getattr(nn, '{}Cell'.format(self.rnn_type)) # Other arguments self.input_size = input_size self.hidden_size = hidden_size self.n_vocab = n_vocab self.dec_inp_activ_fn = get_activation_fn(dec_inp_activ) # Create target embeddings self.emb = nn.Embedding(self.n_vocab, self.input_size, padding_idx=0) # Create attention layer(s) self.att = nn.ModuleDict() for key, enc in encoders.items(): Attention = get_attention(att_type) self.att[str(key)] = Attention( enc.ctx_size, self.hidden_size, transform_ctx=True, ctx2hid=True, mlp_bias=False, att_activ=att_activ, att_bottleneck=att_bottleneck, temp=att_temp) # return the only c_t from the list self.fusion = lambda x: x[0] if len(encoders) > 1: # Multiple inputs (multimodal NMT) ctx_sizes = [ll.ctx_size for ll in encoders.values()] if mm_fusion_op == 'concat': mm_inp_size = sum(ctx_sizes) else: assert len(set(ctx_sizes)) == 1, \ "Context sizes are not compatible with mm_fusion_op!" mm_inp_size = ctx_sizes[0] fusion = [Fusion(mm_fusion_op, input_size=mm_inp_size, output_size=self.hidden_size)] if mm_fusion_dropout > 0: fusion.append(nn.Dropout(mm_fusion_dropout)) self.fusion = nn.Sequential(*fusion) # Create decoders self.dec0 = RNN(self.input_size, self.hidden_size) self.dec1 = RNN(self.hidden_size, self.hidden_size) # Output dropout if dropout_out > 0: self.do_out = nn.Dropout(p=dropout_out) else: self.do_out = lambda x: x # Output bottleneck: maps hidden states to target emb dim # simple: tanh(W*h) # deep: tanh(W*h + U*emb + V*ctx) out_inp_size = self.hidden_size # Dummy op to return back the hidden state for simple output self.out_merge_fn = lambda h, e, c: h if out_logic == 'deep': out_inp_size += (self.input_size + self.hidden_size) self.out_merge_fn = lambda h, e, c: torch.cat((h, e, c), dim=1) # Final transformation that receives concatenated outputs or only h self.hid2out = FF(out_inp_size, self.input_size, bias_zero=True, activ='tanh') # Final softmax self.out2prob = FF(self.input_size, self.n_vocab) # Tie input embedding matrix and output embedding matrix if tied_emb: self.out2prob.weight = self.emb.weight self.nll_loss = nn.NLLLoss(reduction="sum", ignore_index=0) def get_emb(self, idxs): """Returns time-step based embeddings.""" return self.emb(idxs) def f_init(self, state_dict=None): """Returns the initial h_0 for the decoder.""" self.history = defaultdict(list) return None def f_next(self, state_dict, y, h, hypothesis=None): """Applies one timestep of recurrence. `state_dict` may contain partial source information depending on how the model constructs it.""" # Get hidden states from the first decoder (purely cond. on LM) h1 = self.dec0(y, h) query = h1.unsqueeze(0) # Obtain attention for each different input context in encoder atts = [] for k, (s, m) in state_dict.items(): alpha, ctx = self.att[k](query, s, m) atts.append(ctx) if not self.training: self.history[f'alpha_{k}'].append(alpha.cpu()) # Fuse input contexts c_t = self.fusion(atts) # Run second decoder (h1 is compatible now as it was returned by GRU) # Additional optional transformation is to make the comparison # fair with the MMT model. h2 = self.dec1(self.dec_inp_activ_fn(c_t), h1) # Output logic: dropout -> proj(o_t) # transform logit to T*B*V (V: vocab_size) logit = self.out2prob( self.do_out(self.hid2out(self.out_merge_fn(h2, y, c_t)))) # Compute log_softmax over token dim log_p = F.log_softmax(logit, dim=-1) # Return log probs and new hidden states return log_p, h2
Ancestors
- torch.nn.modules.module.Module
Class variables
var dump_patches : bool
var training : bool
Methods
def f_init(self, state_dict=None)
-
Returns the initial h_0 for the decoder.
Expand source code
def f_init(self, state_dict=None): """Returns the initial h_0 for the decoder.""" self.history = defaultdict(list) return None
def f_next(self, state_dict, y, h, hypothesis=None)
-
Applies one timestep of recurrence.
state_dict
may contain partial source information depending on how the model constructs it.Expand source code
def f_next(self, state_dict, y, h, hypothesis=None): """Applies one timestep of recurrence. `state_dict` may contain partial source information depending on how the model constructs it.""" # Get hidden states from the first decoder (purely cond. on LM) h1 = self.dec0(y, h) query = h1.unsqueeze(0) # Obtain attention for each different input context in encoder atts = [] for k, (s, m) in state_dict.items(): alpha, ctx = self.att[k](query, s, m) atts.append(ctx) if not self.training: self.history[f'alpha_{k}'].append(alpha.cpu()) # Fuse input contexts c_t = self.fusion(atts) # Run second decoder (h1 is compatible now as it was returned by GRU) # Additional optional transformation is to make the comparison # fair with the MMT model. h2 = self.dec1(self.dec_inp_activ_fn(c_t), h1) # Output logic: dropout -> proj(o_t) # transform logit to T*B*V (V: vocab_size) logit = self.out2prob( self.do_out(self.hid2out(self.out_merge_fn(h2, y, c_t)))) # Compute log_softmax over token dim log_p = F.log_softmax(logit, dim=-1) # Return log probs and new hidden states return log_p, h2
def forward(self, *input: Any) ‑> NoneType
-
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the :class:
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.Expand source code
def _forward_unimplemented(self, *input: Any) -> None: r"""Defines the computation performed at every call. Should be overridden by all subclasses. .. note:: Although the recipe for forward pass needs to be defined within this function, one should call the :class:`Module` instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them. """ raise NotImplementedError
def get_emb(self, idxs)
-
Returns time-step based embeddings.
Expand source code
def get_emb(self, idxs): """Returns time-step based embeddings.""" return self.emb(idxs)