Module pysimt.translators.waitk_greedy

Expand source code
import logging
import time

from .sim_greedy import SimultaneousGreedySearch
from ..utils.data import sort_predictions
from ..utils.device import DEVICE
from ..utils.io import progress_bar

logger = logging.getLogger('pysimt')


class SimultaneousWaitKGreedySearch(SimultaneousGreedySearch):
    def run_all(self):
        """Do a grid search over the given list of parameters."""

        # Let's pretend that `s_0` is `k`
        for k in self.list_of_s_0:
            # Run the decoding
            hyps, actions, up_time = self.run(int(k))

            # Dumps two files one with segmentations preserved, another
            # with post-processing filters applied
            self.dump_results(hyps, suffix=f'wait{k}')

            # Dump actions
            self.dump_lines(actions, suffix=f'wait{k}.acts')

    def run(self, k):
        # R/W actions generated for the whole test set
        actions = []

        # Final translations
        translations = []

        start = time.time()
        for batch in progress_bar(self.data_loader, unit='batch'):
            self.clear_states()

            batch.device(DEVICE)

            # Compute all at once
            self.cache_encoder_states(batch)

            # Read some words and get trimmed states
            state_dict = self.read_more(k)
            self.update_s(k)

            # Initial state is None i.e. 0. state_dict is not used
            self.prev_h = self.decoder_init(state_dict)

            # last batch could be smaller than the requested batch size
            cur_batch_size = batch.size

            # Start all sentences with <s>
            self.set_first_word_to_bos(cur_batch_size)

            # We will start by writing
            next_action = self.ACT_WRITE

            while not self.eos_written and self.t_ptr < self.max_len:
                if next_action == self.ACT_WRITE or self.is_src_read():
                    logp, new_h, new_word = self.decoder_step(
                        state_dict, self.prev_word, self.prev_h, self.tf_decoder_input)
                    self.write(new_word, new_h)
                else:
                    # READ
                    state_dict = self.read_more(1)
                    self.update_s(1)

                # Invert the last committed action for interleaved decoding
                next_action = 1 - self.actions[-1]

            # All finished, convert translations to python lists on CPU
            idxs = self.buffer[self.buffer.ne(0)].tolist()
            if idxs[-1] != self.eos:
                # In cases where <eos> not produced and the above loop
                # went on until max_len, add an explicit <eos> for correctness
                idxs.append(self.eos)

            # compute action sequence from which metrics will be computed
            actions.append(' '.join(map(lambda i: str(i), self.actions)))
            translations.append(self.vocab.idxs_to_sent(idxs))

        up_time = time.time() - start

        hyps = sort_predictions(self.data_loader, translations)
        actions = sort_predictions(self.data_loader, actions)
        return (hyps, actions, up_time)

Classes

class SimultaneousWaitKGreedySearch (model, data_loader, out_prefix, batch_size, filter_chain=None, max_len=100, **kwargs)
Expand source code
class SimultaneousWaitKGreedySearch(SimultaneousGreedySearch):
    def run_all(self):
        """Do a grid search over the given list of parameters."""

        # Let's pretend that `s_0` is `k`
        for k in self.list_of_s_0:
            # Run the decoding
            hyps, actions, up_time = self.run(int(k))

            # Dumps two files one with segmentations preserved, another
            # with post-processing filters applied
            self.dump_results(hyps, suffix=f'wait{k}')

            # Dump actions
            self.dump_lines(actions, suffix=f'wait{k}.acts')

    def run(self, k):
        # R/W actions generated for the whole test set
        actions = []

        # Final translations
        translations = []

        start = time.time()
        for batch in progress_bar(self.data_loader, unit='batch'):
            self.clear_states()

            batch.device(DEVICE)

            # Compute all at once
            self.cache_encoder_states(batch)

            # Read some words and get trimmed states
            state_dict = self.read_more(k)
            self.update_s(k)

            # Initial state is None i.e. 0. state_dict is not used
            self.prev_h = self.decoder_init(state_dict)

            # last batch could be smaller than the requested batch size
            cur_batch_size = batch.size

            # Start all sentences with <s>
            self.set_first_word_to_bos(cur_batch_size)

            # We will start by writing
            next_action = self.ACT_WRITE

            while not self.eos_written and self.t_ptr < self.max_len:
                if next_action == self.ACT_WRITE or self.is_src_read():
                    logp, new_h, new_word = self.decoder_step(
                        state_dict, self.prev_word, self.prev_h, self.tf_decoder_input)
                    self.write(new_word, new_h)
                else:
                    # READ
                    state_dict = self.read_more(1)
                    self.update_s(1)

                # Invert the last committed action for interleaved decoding
                next_action = 1 - self.actions[-1]

            # All finished, convert translations to python lists on CPU
            idxs = self.buffer[self.buffer.ne(0)].tolist()
            if idxs[-1] != self.eos:
                # In cases where <eos> not produced and the above loop
                # went on until max_len, add an explicit <eos> for correctness
                idxs.append(self.eos)

            # compute action sequence from which metrics will be computed
            actions.append(' '.join(map(lambda i: str(i), self.actions)))
            translations.append(self.vocab.idxs_to_sent(idxs))

        up_time = time.time() - start

        hyps = sort_predictions(self.data_loader, translations)
        actions = sort_predictions(self.data_loader, actions)
        return (hyps, actions, up_time)

Ancestors

Methods

def run(self, k)
Expand source code
def run(self, k):
    # R/W actions generated for the whole test set
    actions = []

    # Final translations
    translations = []

    start = time.time()
    for batch in progress_bar(self.data_loader, unit='batch'):
        self.clear_states()

        batch.device(DEVICE)

        # Compute all at once
        self.cache_encoder_states(batch)

        # Read some words and get trimmed states
        state_dict = self.read_more(k)
        self.update_s(k)

        # Initial state is None i.e. 0. state_dict is not used
        self.prev_h = self.decoder_init(state_dict)

        # last batch could be smaller than the requested batch size
        cur_batch_size = batch.size

        # Start all sentences with <s>
        self.set_first_word_to_bos(cur_batch_size)

        # We will start by writing
        next_action = self.ACT_WRITE

        while not self.eos_written and self.t_ptr < self.max_len:
            if next_action == self.ACT_WRITE or self.is_src_read():
                logp, new_h, new_word = self.decoder_step(
                    state_dict, self.prev_word, self.prev_h, self.tf_decoder_input)
                self.write(new_word, new_h)
            else:
                # READ
                state_dict = self.read_more(1)
                self.update_s(1)

            # Invert the last committed action for interleaved decoding
            next_action = 1 - self.actions[-1]

        # All finished, convert translations to python lists on CPU
        idxs = self.buffer[self.buffer.ne(0)].tolist()
        if idxs[-1] != self.eos:
            # In cases where <eos> not produced and the above loop
            # went on until max_len, add an explicit <eos> for correctness
            idxs.append(self.eos)

        # compute action sequence from which metrics will be computed
        actions.append(' '.join(map(lambda i: str(i), self.actions)))
        translations.append(self.vocab.idxs_to_sent(idxs))

    up_time = time.time() - start

    hyps = sort_predictions(self.data_loader, translations)
    actions = sort_predictions(self.data_loader, actions)
    return (hyps, actions, up_time)

Inherited members