Module pysimt.models.snmt_rnn_waitk
Expand source code
import logging
from . import SimultaneousNMT
logger = logging.getLogger('pysimt')
"""This is the training-time wait-k model from:
Ma et al. (2018), STACL: Simultaneous Translation with Implicit Anticipation
and Controllable Latency using Prefix-to-Prefix Framework, arXiv:1810.08398
The only required parameter is the `k` argument for training. When decoding,
pass the `k` argument explicitly to `pysimt translate`. A large enough `k`
should produce the same results as the `snmt.py` model.
"""
class SimultaneousWaitKNMT(SimultaneousNMT):
def set_defaults(self):
super().set_defaults()
self.defaults.update({
# Decoding/training simultaneous NMT args
'translator_type': 'wk', # This model implements train-time wait-k
'translator_args': {'k': 1e4}, # k as in wait-k in training
'consecutive_warmup': 0, # consecutive training for this many epochs
})
def __init__(self, opts):
super().__init__(opts)
def forward(self, batch, **kwargs):
"""Training forward-pass with explicit timestep-based loop."""
loss = 0.0
k = int(self.opts.model['translator_args']['k'])
if self.training:
epoch_count = kwargs['ectr']
if epoch_count <= self.opts.model['consecutive_warmup']:
# warming up, use full contexts
k = int(1e4)
# Cache encoder states first
self.cache_enc_states(batch)
# Initial state is None i.e. 0.
h = self.dec.f_init()
# Convert target token indices to embeddings -> T*B*E
y = batch[self.tl]
y_emb = self.dec.emb(y)
# -1: So that we skip the timestep where input is <eos>
for t in range(y_emb.size(0) - 1):
###########################################
# waitk: pass partial context incrementally
###########################################
state_dict = self.get_enc_state_dict(up_to=k + t)
log_p, h = self.dec.f_next(state_dict, y_emb[t], h)
loss += self.dec.nll_loss(log_p, y[t + 1])
return {
'loss': loss,
'n_items': y[1:].nonzero(as_tuple=False).size(0),
}
Global variables
var logger
-
This is the training-time wait-k model from: Ma et al. (2018), STACL: Simultaneous Translation with Implicit Anticipation and Controllable Latency using Prefix-to-Prefix Framework, arXiv:1810.08398
The only required parameter is the
k
argument for training. When decoding, pass thek
argument explicitly topysimt translate
. A large enoughk
should produce the same results as thesnmt.py
model.
Classes
class SimultaneousWaitKNMT (opts)
-
Base class for all neural network modules.
Your models should also subclass this class.
Modules can also contain other Modules, allowing to nest them in a tree structure. You can assign the submodules as regular attributes::
import torch.nn as nn import torch.nn.functional as F class Model(nn.Module): def __init__(self): super(Model, self).__init__() self.conv1 = nn.Conv2d(1, 20, 5) self.conv2 = nn.Conv2d(20, 20, 5) def forward(self, x): x = F.relu(self.conv1(x)) return F.relu(self.conv2(x))
Submodules assigned in this way will be registered, and will have their parameters converted too when you call :meth:
to
, etc.:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool
Initializes internal Module state, shared by both nn.Module and ScriptModule.
Expand source code
class SimultaneousWaitKNMT(SimultaneousNMT): def set_defaults(self): super().set_defaults() self.defaults.update({ # Decoding/training simultaneous NMT args 'translator_type': 'wk', # This model implements train-time wait-k 'translator_args': {'k': 1e4}, # k as in wait-k in training 'consecutive_warmup': 0, # consecutive training for this many epochs }) def __init__(self, opts): super().__init__(opts) def forward(self, batch, **kwargs): """Training forward-pass with explicit timestep-based loop.""" loss = 0.0 k = int(self.opts.model['translator_args']['k']) if self.training: epoch_count = kwargs['ectr'] if epoch_count <= self.opts.model['consecutive_warmup']: # warming up, use full contexts k = int(1e4) # Cache encoder states first self.cache_enc_states(batch) # Initial state is None i.e. 0. h = self.dec.f_init() # Convert target token indices to embeddings -> T*B*E y = batch[self.tl] y_emb = self.dec.emb(y) # -1: So that we skip the timestep where input is <eos> for t in range(y_emb.size(0) - 1): ########################################### # waitk: pass partial context incrementally ########################################### state_dict = self.get_enc_state_dict(up_to=k + t) log_p, h = self.dec.f_next(state_dict, y_emb[t], h) loss += self.dec.nll_loss(log_p, y[t + 1]) return { 'loss': loss, 'n_items': y[1:].nonzero(as_tuple=False).size(0), }
Ancestors
- SimultaneousNMT
- torch.nn.modules.module.Module
Class variables
var dump_patches : bool
var training : bool
Methods
def set_defaults(self)
-
Expand source code
def set_defaults(self): super().set_defaults() self.defaults.update({ # Decoding/training simultaneous NMT args 'translator_type': 'wk', # This model implements train-time wait-k 'translator_args': {'k': 1e4}, # k as in wait-k in training 'consecutive_warmup': 0, # consecutive training for this many epochs })
Inherited members