Module pysimt.models.snmt_tf_waitk

Expand source code
import logging

from . import SimultaneousTFNMT

logger = logging.getLogger('pysimt')

"""This is the training-time wait-k model from:
    Ma et al. (2018), STACL: Simultaneous Translation with Implicit Anticipation
   and Controllable Latency using Prefix-to-Prefix Framework, arXiv:1810.08398

The only required parameter is the `k` argument for training. When decoding,
pass the `k` argument explicitly to `pysimt translate`. A large enough `k`
should produce the same results as the `snmt.py` model.
"""


class SimultaneousTFWaitKNMT(SimultaneousTFNMT):

    def set_defaults(self):
        super().set_defaults()
        self.defaults.update({
            # Decoding/training simultaneous NMT args
            'translator_type': 'wk',  # This model implements train-time wait-k
            'translator_args': {'k': 1e4},  # k as in wait-k in training
            'consecutive_warmup': 0,  # consecutive training for this many epochs
        })

    def __init__(self, opts):
        super().__init__(opts)
        assert not self.opts.model['enc_bidirectional'], \
            'Bidirectional TF encoder is not currently supported for simultaneous MT.'

    def forward(self, batch, **kwargs):
        """
        Performs a forward pass.
        :param batch: The batch.
        :param kwargs: Any extra arguments.
        :return: The output from the forward pass.
        """
        k = int(self.opts.model['translator_args']['k'])
        if self.training:
            epoch_count = kwargs['ectr']
            if epoch_count <= self.opts.model['consecutive_warmup']:
                # warming up, use full contexts
                k = int(1e4)

        # Pass 'k' to the model.
        return super().forward(batch, k=k)

Global variables

var logger

This is the training-time wait-k model from: Ma et al. (2018), STACL: Simultaneous Translation with Implicit Anticipation and Controllable Latency using Prefix-to-Prefix Framework, arXiv:1810.08398

The only required parameter is the k argument for training. When decoding, pass the k argument explicitly to pysimt translate. A large enough k should produce the same results as the snmt.py model.

Classes

class SimultaneousTFWaitKNMT (opts)

Base class for all neural network modules.

Your models should also subclass this class.

Modules can also contain other Modules, allowing to nest them in a tree structure. You can assign the submodules as regular attributes::

import torch.nn as nn
import torch.nn.functional as F

class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

Submodules assigned in this way will be registered, and will have their parameters converted too when you call :meth:to, etc.

:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool

Creates a SimultaneousNMTTransformer. :param opts: The options.

Expand source code
class SimultaneousTFWaitKNMT(SimultaneousTFNMT):

    def set_defaults(self):
        super().set_defaults()
        self.defaults.update({
            # Decoding/training simultaneous NMT args
            'translator_type': 'wk',  # This model implements train-time wait-k
            'translator_args': {'k': 1e4},  # k as in wait-k in training
            'consecutive_warmup': 0,  # consecutive training for this many epochs
        })

    def __init__(self, opts):
        super().__init__(opts)
        assert not self.opts.model['enc_bidirectional'], \
            'Bidirectional TF encoder is not currently supported for simultaneous MT.'

    def forward(self, batch, **kwargs):
        """
        Performs a forward pass.
        :param batch: The batch.
        :param kwargs: Any extra arguments.
        :return: The output from the forward pass.
        """
        k = int(self.opts.model['translator_args']['k'])
        if self.training:
            epoch_count = kwargs['ectr']
            if epoch_count <= self.opts.model['consecutive_warmup']:
                # warming up, use full contexts
                k = int(1e4)

        # Pass 'k' to the model.
        return super().forward(batch, k=k)

Ancestors

Class variables

var dump_patches : bool
var training : bool

Methods

def forward(self, batch, **kwargs) ‑> Callable[..., Any]

Performs a forward pass. :param batch: The batch. :param kwargs: Any extra arguments. :return: The output from the forward pass.

Expand source code
def forward(self, batch, **kwargs):
    """
    Performs a forward pass.
    :param batch: The batch.
    :param kwargs: Any extra arguments.
    :return: The output from the forward pass.
    """
    k = int(self.opts.model['translator_args']['k'])
    if self.training:
        epoch_count = kwargs['ectr']
        if epoch_count <= self.opts.model['consecutive_warmup']:
            # warming up, use full contexts
            k = int(1e4)

    # Pass 'k' to the model.
    return super().forward(batch, k=k)
def set_defaults(self)
Expand source code
def set_defaults(self):
    super().set_defaults()
    self.defaults.update({
        # Decoding/training simultaneous NMT args
        'translator_type': 'wk',  # This model implements train-time wait-k
        'translator_args': {'k': 1e4},  # k as in wait-k in training
        'consecutive_warmup': 0,  # consecutive training for this many epochs
    })

Inherited members