Module pysimt.models.snmt_rnn_waitk

Expand source code
import logging

from . import SimultaneousNMT

logger = logging.getLogger('pysimt')


"""This is the training-time wait-k model from:
    Ma et al. (2018), STACL: Simultaneous Translation with Implicit Anticipation
   and Controllable Latency using Prefix-to-Prefix Framework, arXiv:1810.08398

The only required parameter is the `k` argument for training. When decoding,
pass the `k` argument explicitly to `pysimt translate`. A large enough `k`
should produce the same results as the `snmt.py` model.
"""


class SimultaneousWaitKNMT(SimultaneousNMT):

    def set_defaults(self):
        super().set_defaults()
        self.defaults.update({
            # Decoding/training simultaneous NMT args
            'translator_type': 'wk',        # This model implements train-time wait-k
            'translator_args': {'k': 1e4},  # k as in wait-k in training
            'consecutive_warmup': 0,        # consecutive training for this many epochs
        })

    def __init__(self, opts):
        super().__init__(opts)

    def forward(self, batch, **kwargs):
        """Training forward-pass with explicit timestep-based loop."""
        loss = 0.0

        k = int(self.opts.model['translator_args']['k'])
        if self.training:
            epoch_count = kwargs['ectr']
            if epoch_count <= self.opts.model['consecutive_warmup']:
                # warming up, use full contexts
                k = int(1e4)

        # Cache encoder states first
        self.cache_enc_states(batch)

        # Initial state is None i.e. 0.
        h = self.dec.f_init()

        # Convert target token indices to embeddings -> T*B*E
        y = batch[self.tl]
        y_emb = self.dec.emb(y)

        # -1: So that we skip the timestep where input is <eos>
        for t in range(y_emb.size(0) - 1):
            ###########################################
            # waitk: pass partial context incrementally
            ###########################################
            state_dict = self.get_enc_state_dict(up_to=k + t)
            log_p, h = self.dec.f_next(state_dict, y_emb[t], h)
            loss += self.dec.nll_loss(log_p, y[t + 1])

        return {
            'loss': loss,
            'n_items': y[1:].nonzero(as_tuple=False).size(0),
        }

Global variables

var logger

This is the training-time wait-k model from: Ma et al. (2018), STACL: Simultaneous Translation with Implicit Anticipation and Controllable Latency using Prefix-to-Prefix Framework, arXiv:1810.08398

The only required parameter is the k argument for training. When decoding, pass the k argument explicitly to pysimt translate. A large enough k should produce the same results as the snmt.py model.

Classes

class SimultaneousWaitKNMT (opts)

Base class for all neural network modules.

Your models should also subclass this class.

Modules can also contain other Modules, allowing to nest them in a tree structure. You can assign the submodules as regular attributes::

import torch.nn as nn
import torch.nn.functional as F

class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

Submodules assigned in this way will be registered, and will have their parameters converted too when you call :meth:to, etc.

:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool

Initializes internal Module state, shared by both nn.Module and ScriptModule.

Expand source code
class SimultaneousWaitKNMT(SimultaneousNMT):

    def set_defaults(self):
        super().set_defaults()
        self.defaults.update({
            # Decoding/training simultaneous NMT args
            'translator_type': 'wk',        # This model implements train-time wait-k
            'translator_args': {'k': 1e4},  # k as in wait-k in training
            'consecutive_warmup': 0,        # consecutive training for this many epochs
        })

    def __init__(self, opts):
        super().__init__(opts)

    def forward(self, batch, **kwargs):
        """Training forward-pass with explicit timestep-based loop."""
        loss = 0.0

        k = int(self.opts.model['translator_args']['k'])
        if self.training:
            epoch_count = kwargs['ectr']
            if epoch_count <= self.opts.model['consecutive_warmup']:
                # warming up, use full contexts
                k = int(1e4)

        # Cache encoder states first
        self.cache_enc_states(batch)

        # Initial state is None i.e. 0.
        h = self.dec.f_init()

        # Convert target token indices to embeddings -> T*B*E
        y = batch[self.tl]
        y_emb = self.dec.emb(y)

        # -1: So that we skip the timestep where input is <eos>
        for t in range(y_emb.size(0) - 1):
            ###########################################
            # waitk: pass partial context incrementally
            ###########################################
            state_dict = self.get_enc_state_dict(up_to=k + t)
            log_p, h = self.dec.f_next(state_dict, y_emb[t], h)
            loss += self.dec.nll_loss(log_p, y[t + 1])

        return {
            'loss': loss,
            'n_items': y[1:].nonzero(as_tuple=False).size(0),
        }

Ancestors

Class variables

var dump_patches : bool
var training : bool

Methods

def set_defaults(self)
Expand source code
def set_defaults(self):
    super().set_defaults()
    self.defaults.update({
        # Decoding/training simultaneous NMT args
        'translator_type': 'wk',        # This model implements train-time wait-k
        'translator_args': {'k': 1e4},  # k as in wait-k in training
        'consecutive_warmup': 0,        # consecutive training for this many epochs
    })

Inherited members