Module pysimt.models.snmt_tf_waitk
Expand source code
import logging
from . import SimultaneousTFNMT
logger = logging.getLogger('pysimt')
"""This is the training-time wait-k model from:
Ma et al. (2018), STACL: Simultaneous Translation with Implicit Anticipation
and Controllable Latency using Prefix-to-Prefix Framework, arXiv:1810.08398
The only required parameter is the `k` argument for training. When decoding,
pass the `k` argument explicitly to `pysimt translate`. A large enough `k`
should produce the same results as the `snmt.py` model.
"""
class SimultaneousTFWaitKNMT(SimultaneousTFNMT):
def set_defaults(self):
super().set_defaults()
self.defaults.update({
# Decoding/training simultaneous NMT args
'translator_type': 'wk', # This model implements train-time wait-k
'translator_args': {'k': 1e4}, # k as in wait-k in training
'consecutive_warmup': 0, # consecutive training for this many epochs
})
def __init__(self, opts):
super().__init__(opts)
assert not self.opts.model['enc_bidirectional'], \
'Bidirectional TF encoder is not currently supported for simultaneous MT.'
def forward(self, batch, **kwargs):
"""
Performs a forward pass.
:param batch: The batch.
:param kwargs: Any extra arguments.
:return: The output from the forward pass.
"""
k = int(self.opts.model['translator_args']['k'])
if self.training:
epoch_count = kwargs['ectr']
if epoch_count <= self.opts.model['consecutive_warmup']:
# warming up, use full contexts
k = int(1e4)
# Pass 'k' to the model.
return super().forward(batch, k=k)
Global variables
var logger
-
This is the training-time wait-k model from: Ma et al. (2018), STACL: Simultaneous Translation with Implicit Anticipation and Controllable Latency using Prefix-to-Prefix Framework, arXiv:1810.08398
The only required parameter is the
k
argument for training. When decoding, pass thek
argument explicitly topysimt translate
. A large enoughk
should produce the same results as thesnmt.py
model.
Classes
class SimultaneousTFWaitKNMT (opts)
-
Base class for all neural network modules.
Your models should also subclass this class.
Modules can also contain other Modules, allowing to nest them in a tree structure. You can assign the submodules as regular attributes::
import torch.nn as nn import torch.nn.functional as F class Model(nn.Module): def __init__(self): super(Model, self).__init__() self.conv1 = nn.Conv2d(1, 20, 5) self.conv2 = nn.Conv2d(20, 20, 5) def forward(self, x): x = F.relu(self.conv1(x)) return F.relu(self.conv2(x))
Submodules assigned in this way will be registered, and will have their parameters converted too when you call :meth:
to
, etc.:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool
Creates a SimultaneousNMTTransformer. :param opts: The options.
Expand source code
class SimultaneousTFWaitKNMT(SimultaneousTFNMT): def set_defaults(self): super().set_defaults() self.defaults.update({ # Decoding/training simultaneous NMT args 'translator_type': 'wk', # This model implements train-time wait-k 'translator_args': {'k': 1e4}, # k as in wait-k in training 'consecutive_warmup': 0, # consecutive training for this many epochs }) def __init__(self, opts): super().__init__(opts) assert not self.opts.model['enc_bidirectional'], \ 'Bidirectional TF encoder is not currently supported for simultaneous MT.' def forward(self, batch, **kwargs): """ Performs a forward pass. :param batch: The batch. :param kwargs: Any extra arguments. :return: The output from the forward pass. """ k = int(self.opts.model['translator_args']['k']) if self.training: epoch_count = kwargs['ectr'] if epoch_count <= self.opts.model['consecutive_warmup']: # warming up, use full contexts k = int(1e4) # Pass 'k' to the model. return super().forward(batch, k=k)
Ancestors
- SimultaneousTFNMT
- SimultaneousNMT
- torch.nn.modules.module.Module
Class variables
var dump_patches : bool
var training : bool
Methods
def forward(self, batch, **kwargs) ‑> Callable[..., Any]
-
Performs a forward pass. :param batch: The batch. :param kwargs: Any extra arguments. :return: The output from the forward pass.
Expand source code
def forward(self, batch, **kwargs): """ Performs a forward pass. :param batch: The batch. :param kwargs: Any extra arguments. :return: The output from the forward pass. """ k = int(self.opts.model['translator_args']['k']) if self.training: epoch_count = kwargs['ectr'] if epoch_count <= self.opts.model['consecutive_warmup']: # warming up, use full contexts k = int(1e4) # Pass 'k' to the model. return super().forward(batch, k=k)
def set_defaults(self)
-
Expand source code
def set_defaults(self): super().set_defaults() self.defaults.update({ # Decoding/training simultaneous NMT args 'translator_type': 'wk', # This model implements train-time wait-k 'translator_args': {'k': 1e4}, # k as in wait-k in training 'consecutive_warmup': 0, # consecutive training for this many epochs })
Inherited members