Module pysimt.translators.greedy

Expand source code
import logging
import math

import torch

from ..utils.device import DEVICE
from ..utils.io import progress_bar
from ..utils.data import sort_predictions

logger = logging.getLogger('pysimt')


"""Batched vanilla greedy search without any simultaneous translation
features."""


class GreedySearch:
    def __init__(self, model, data_loader, out_prefix, batch_size, filter_chain=None,
                 max_len=100, **kwargs):
        self.model = model
        self.data_loader = data_loader
        self.batch_size = batch_size
        self.filter_chain = filter_chain
        self.out_prefix = out_prefix

        self.vocab = self.model.trg_vocab
        self.n_vocab = len(self.vocab)
        self.unk = self.vocab['<unk>']
        self.eos = self.vocab['<eos>']
        self.bos = self.vocab['<bos>']
        self.pad = self.vocab['<pad>']

        self.max_len = max_len
        self.do_dump = out_prefix != ''

    def dump_results(self, hyps, suffix=''):
        suffix = 'gs' if not suffix else f'{suffix}.gs'

        # Dump raw ones (BPE/SPM etc.)
        self.dump_lines(hyps, suffix + '.raw')
        if self.filter_chain is not None:
            self.dump_lines(self.filter_chain.apply(hyps), suffix)

    def dump_lines(self, lines, suffix):
        fname = f'{self.out_prefix}.{suffix}'
        with open(fname, 'w') as f:
            for line in lines:
                f.write(f'{line}\n')

    def decoder_step(self, state_dict, next_word_idxs, h, hypothesis=None):
        logp, h = self.model.dec.f_next(
            state_dict, self.model.dec.get_emb(next_word_idxs), h, hypothesis)

        # Similar to the logic in fairseq https://bit.ly/3agXAa7
        # Never select the pad token or the bos token
        logp[:, self.pad] = -math.inf
        logp[:, self.bos] = -math.inf

        # Compute most likely word idxs
        next_word_idxs = logp.argmax(dim=-1)
        return logp, h, next_word_idxs

    def decoder_init(self, state_dict=None):
        return self.model.dec.f_init(state_dict)

    def run_all(self):
        return self.run()

    def run(self, **kwargs):
        # effective batch size may be different
        max_batch_size = self.data_loader.batch_sampler.batch_size

        translations = []
        hyps = torch.zeros(
            (self.max_len, max_batch_size), dtype=torch.long, device=DEVICE)

        for batch in progress_bar(self.data_loader, unit='batch'):
            batch.device(DEVICE)

            # Reset hypotheses
            hyps.zero_()

            # Cache encoder states
            self.model.cache_enc_states(batch)

            # Get encoder hidden states
            state_dict = self.model.get_enc_state_dict()

            # Initial state is None i.e. 0. state_dict is not used
            h = self.decoder_init(state_dict)

            # last batch could be smaller than the requested batch size
            cur_batch_size = batch.size

            # Track sentences who already produced </s>
            track_fini = torch.zeros((cur_batch_size, ), device=DEVICE).bool()

            # Start all sentences with <s>
            next_word_idxs = self.model.get_bos(cur_batch_size).to(DEVICE)

            # The Transformer decoder require the <bos> to be passed alongside all hypothesis objects for prediction
            tf_decoder_input = next_word_idxs.unsqueeze(0)

            # A maximum of `max_len` decoding steps
            for t in range(self.max_len):
                if track_fini.all():
                    # All hypotheses produced </s>, early stop!
                    break

                logp, h, next_word_idxs = self.decoder_step(
                    state_dict, next_word_idxs, h, tf_decoder_input)

                # Update finished sentence tracker
                track_fini.add_(next_word_idxs.eq(self.eos))

                # Insert most probable words for timestep `t` into tensor
                hyps[t, :cur_batch_size] = next_word_idxs

                # Add the predicted word to the decoder's input. Used for the transformer models.
                tf_decoder_input = torch.cat((tf_decoder_input, next_word_idxs.unsqueeze(0)), dim=0)

            # All finished, convert translations to python lists on CPU
            sent_idxs = hyps[:, :cur_batch_size].t().cpu().tolist()
            translations.extend(self.vocab.list_of_idxs_to_sents(sent_idxs))

        hyps = sort_predictions(self.data_loader, translations)

        if self.do_dump:
            self.dump_results(hyps)

        return (hyps,)

Global variables

var logger

Batched vanilla greedy search without any simultaneous translation features.

Classes

class GreedySearch (model, data_loader, out_prefix, batch_size, filter_chain=None, max_len=100, **kwargs)
Expand source code
class GreedySearch:
    def __init__(self, model, data_loader, out_prefix, batch_size, filter_chain=None,
                 max_len=100, **kwargs):
        self.model = model
        self.data_loader = data_loader
        self.batch_size = batch_size
        self.filter_chain = filter_chain
        self.out_prefix = out_prefix

        self.vocab = self.model.trg_vocab
        self.n_vocab = len(self.vocab)
        self.unk = self.vocab['<unk>']
        self.eos = self.vocab['<eos>']
        self.bos = self.vocab['<bos>']
        self.pad = self.vocab['<pad>']

        self.max_len = max_len
        self.do_dump = out_prefix != ''

    def dump_results(self, hyps, suffix=''):
        suffix = 'gs' if not suffix else f'{suffix}.gs'

        # Dump raw ones (BPE/SPM etc.)
        self.dump_lines(hyps, suffix + '.raw')
        if self.filter_chain is not None:
            self.dump_lines(self.filter_chain.apply(hyps), suffix)

    def dump_lines(self, lines, suffix):
        fname = f'{self.out_prefix}.{suffix}'
        with open(fname, 'w') as f:
            for line in lines:
                f.write(f'{line}\n')

    def decoder_step(self, state_dict, next_word_idxs, h, hypothesis=None):
        logp, h = self.model.dec.f_next(
            state_dict, self.model.dec.get_emb(next_word_idxs), h, hypothesis)

        # Similar to the logic in fairseq https://bit.ly/3agXAa7
        # Never select the pad token or the bos token
        logp[:, self.pad] = -math.inf
        logp[:, self.bos] = -math.inf

        # Compute most likely word idxs
        next_word_idxs = logp.argmax(dim=-1)
        return logp, h, next_word_idxs

    def decoder_init(self, state_dict=None):
        return self.model.dec.f_init(state_dict)

    def run_all(self):
        return self.run()

    def run(self, **kwargs):
        # effective batch size may be different
        max_batch_size = self.data_loader.batch_sampler.batch_size

        translations = []
        hyps = torch.zeros(
            (self.max_len, max_batch_size), dtype=torch.long, device=DEVICE)

        for batch in progress_bar(self.data_loader, unit='batch'):
            batch.device(DEVICE)

            # Reset hypotheses
            hyps.zero_()

            # Cache encoder states
            self.model.cache_enc_states(batch)

            # Get encoder hidden states
            state_dict = self.model.get_enc_state_dict()

            # Initial state is None i.e. 0. state_dict is not used
            h = self.decoder_init(state_dict)

            # last batch could be smaller than the requested batch size
            cur_batch_size = batch.size

            # Track sentences who already produced </s>
            track_fini = torch.zeros((cur_batch_size, ), device=DEVICE).bool()

            # Start all sentences with <s>
            next_word_idxs = self.model.get_bos(cur_batch_size).to(DEVICE)

            # The Transformer decoder require the <bos> to be passed alongside all hypothesis objects for prediction
            tf_decoder_input = next_word_idxs.unsqueeze(0)

            # A maximum of `max_len` decoding steps
            for t in range(self.max_len):
                if track_fini.all():
                    # All hypotheses produced </s>, early stop!
                    break

                logp, h, next_word_idxs = self.decoder_step(
                    state_dict, next_word_idxs, h, tf_decoder_input)

                # Update finished sentence tracker
                track_fini.add_(next_word_idxs.eq(self.eos))

                # Insert most probable words for timestep `t` into tensor
                hyps[t, :cur_batch_size] = next_word_idxs

                # Add the predicted word to the decoder's input. Used for the transformer models.
                tf_decoder_input = torch.cat((tf_decoder_input, next_word_idxs.unsqueeze(0)), dim=0)

            # All finished, convert translations to python lists on CPU
            sent_idxs = hyps[:, :cur_batch_size].t().cpu().tolist()
            translations.extend(self.vocab.list_of_idxs_to_sents(sent_idxs))

        hyps = sort_predictions(self.data_loader, translations)

        if self.do_dump:
            self.dump_results(hyps)

        return (hyps,)

Subclasses

Methods

def decoder_init(self, state_dict=None)
Expand source code
def decoder_init(self, state_dict=None):
    return self.model.dec.f_init(state_dict)
def decoder_step(self, state_dict, next_word_idxs, h, hypothesis=None)
Expand source code
def decoder_step(self, state_dict, next_word_idxs, h, hypothesis=None):
    logp, h = self.model.dec.f_next(
        state_dict, self.model.dec.get_emb(next_word_idxs), h, hypothesis)

    # Similar to the logic in fairseq https://bit.ly/3agXAa7
    # Never select the pad token or the bos token
    logp[:, self.pad] = -math.inf
    logp[:, self.bos] = -math.inf

    # Compute most likely word idxs
    next_word_idxs = logp.argmax(dim=-1)
    return logp, h, next_word_idxs
def dump_lines(self, lines, suffix)
Expand source code
def dump_lines(self, lines, suffix):
    fname = f'{self.out_prefix}.{suffix}'
    with open(fname, 'w') as f:
        for line in lines:
            f.write(f'{line}\n')
def dump_results(self, hyps, suffix='')
Expand source code
def dump_results(self, hyps, suffix=''):
    suffix = 'gs' if not suffix else f'{suffix}.gs'

    # Dump raw ones (BPE/SPM etc.)
    self.dump_lines(hyps, suffix + '.raw')
    if self.filter_chain is not None:
        self.dump_lines(self.filter_chain.apply(hyps), suffix)
def run(self, **kwargs)
Expand source code
def run(self, **kwargs):
    # effective batch size may be different
    max_batch_size = self.data_loader.batch_sampler.batch_size

    translations = []
    hyps = torch.zeros(
        (self.max_len, max_batch_size), dtype=torch.long, device=DEVICE)

    for batch in progress_bar(self.data_loader, unit='batch'):
        batch.device(DEVICE)

        # Reset hypotheses
        hyps.zero_()

        # Cache encoder states
        self.model.cache_enc_states(batch)

        # Get encoder hidden states
        state_dict = self.model.get_enc_state_dict()

        # Initial state is None i.e. 0. state_dict is not used
        h = self.decoder_init(state_dict)

        # last batch could be smaller than the requested batch size
        cur_batch_size = batch.size

        # Track sentences who already produced </s>
        track_fini = torch.zeros((cur_batch_size, ), device=DEVICE).bool()

        # Start all sentences with <s>
        next_word_idxs = self.model.get_bos(cur_batch_size).to(DEVICE)

        # The Transformer decoder require the <bos> to be passed alongside all hypothesis objects for prediction
        tf_decoder_input = next_word_idxs.unsqueeze(0)

        # A maximum of `max_len` decoding steps
        for t in range(self.max_len):
            if track_fini.all():
                # All hypotheses produced </s>, early stop!
                break

            logp, h, next_word_idxs = self.decoder_step(
                state_dict, next_word_idxs, h, tf_decoder_input)

            # Update finished sentence tracker
            track_fini.add_(next_word_idxs.eq(self.eos))

            # Insert most probable words for timestep `t` into tensor
            hyps[t, :cur_batch_size] = next_word_idxs

            # Add the predicted word to the decoder's input. Used for the transformer models.
            tf_decoder_input = torch.cat((tf_decoder_input, next_word_idxs.unsqueeze(0)), dim=0)

        # All finished, convert translations to python lists on CPU
        sent_idxs = hyps[:, :cur_batch_size].t().cpu().tolist()
        translations.extend(self.vocab.list_of_idxs_to_sents(sent_idxs))

    hyps = sort_predictions(self.data_loader, translations)

    if self.do_dump:
        self.dump_results(hyps)

    return (hyps,)
def run_all(self)
Expand source code
def run_all(self):
    return self.run()