Module pysimt.translators.sim_greedy
Expand source code
import time
import logging
import itertools
import torch
from ..utils.device import DEVICE
from ..utils.io import progress_bar
from ..utils.data import sort_predictions
from .greedy import GreedySearch
logger = logging.getLogger('pysimt')
class SimultaneousGreedySearch(GreedySearch):
ACT_READ, ACT_WRITE = 0, 1
def __init__(self, model, data_loader, out_prefix, batch_size, filter_chain=None,
max_len=100, **kwargs):
assert not model.opts.model['enc_bidirectional'], \
"Bidirectional models can not be used for simultaneous MT."
assert model.opts.model.get('dec_init', 'zero') == 'zero', \
"`dec_init` should be 'zero' for simplicity."
logger.info(f'Ignoring batch_size {batch_size} for simultaneous greedy search')
batch_size = 1
super().__init__(model, data_loader, out_prefix,
batch_size, filter_chain, max_len)
# Partial modality i.e. text
self._partial_key = str(model.sl)
self.buffer = None
self.list_of_s_0 = kwargs.pop('s_0', '').split(',')
self.list_of_delta = kwargs.pop('delta', '').split(',')
self.criteria = kwargs.pop('criteria', '').split(',')
self.tf_decoder_input = None
@staticmethod
def wait_if_diff(cur_log_p, cur_next_pred, cand_log_p, cand_next_pred):
"""If the candidate changes with more context, READ. Otherwise WRITE."""
return cand_next_pred.ne(cur_next_pred)
@staticmethod
def wait_if_worse(cur_log_p, cur_next_pred, cand_log_p, cand_next_pred):
"""If confidence for the candidate decreases WAIT/READ. Otherwise WRITE."""
return cand_log_p[0, cur_next_pred] < cur_log_p[0, cur_next_pred]
def write(self, new_word, new_h):
"""Write the new word, move the pointer and accept the hidden state."""
self.prev_word, self.buffer[self.t_ptr] = new_word, new_word
self.prev_h = new_h
self.actions.append(self.ACT_WRITE)
self.t_ptr += 1
self.eos_written = new_word.item() == self.eos
self.tf_decoder_input = torch.cat((self.tf_decoder_input, new_word.unsqueeze(0)), dim=0)
def update_s(self, increment):
"""Update read pointer."""
new_pos = min(self.s_len, self.s_ptr + increment)
n_reads = new_pos - self.s_ptr
self.actions.extend([self.ACT_READ] * n_reads)
self.s_ptr = new_pos
def clear_states(self):
self.s_ptr = 0
self.t_ptr = 0
self.prev_h = None
self._c_states = None
self.prev_word = None
self.eos_written = False
self.actions = []
if self.buffer is None:
# Write buffer
self.buffer = torch.zeros((self.max_len, ), dtype=torch.long, device=DEVICE)
else:
# Reset hypothesis buffer
self.buffer.zero_()
def is_src_read(self):
return self.s_ptr >= self.s_len
def cache_encoder_states(self, batch):
"""Encode full source sentence and cache the states."""
self.model.cache_enc_states(batch)
self.s_len = batch[self._partial_key].size(0)
def read_more(self, n):
"""Reads more source words and computes new states."""
return self.model.get_enc_state_dict(up_to=self.s_ptr + n)
def get_src_prefix_str(self, batch):
idxs = batch[self._partial_key][:self.s_ptr, 0].tolist()
return self.model.vocabs[self._partial_key].idxs_to_sent(idxs)
def get_trg_prefix_str(self):
idxs = self.buffer[:self.t_ptr].tolist()
return self.model.vocabs['trg'].idxs_to_sent(idxs)
def run_all(self):
"""Do a grid search over the given list of parameters."""
#############
# grid search
#############
settings = itertools.product(
self.list_of_s_0,
self.list_of_delta,
self.criteria,
)
for s_0, delta, crit in settings:
# Run the decoding
hyps, actions, up_time = self.run(int(s_0), int(delta), crit)
# Dumps two files one with segmentations preserved, another
# with post-processing filters applied
self.dump_results(hyps, suffix=f's{s_0}_d{delta}_{crit}')
# Dump actions
self.dump_lines(actions, suffix=f's{s_0}_d{delta}_{crit}.acts')
def run(self, s_0, delta, criterion):
# R/W actions generated for the whole test set
actions = []
# Final translations
translations = []
# Set criterion
crit_fn = getattr(self, criterion)
start = time.time()
for batch in progress_bar(self.data_loader, unit='batch'):
self.clear_states()
batch.device(DEVICE)
# Compute all at once
self.cache_encoder_states(batch)
# Read some words and get trimmed states
state_dict = self.read_more(s_0)
self.update_s(s_0)
# Initial state is None i.e. 0. state_dict is not used
self.prev_h = self.decoder_init(state_dict)
# last batch could be smaller than the requested batch size
cur_batch_size = batch.size
# Start all sentences with <s>
self.set_first_word_to_bos(cur_batch_size)
while not self.eos_written and self.t_ptr < self.max_len:
logp, new_h, new_word = self.decoder_step(
state_dict, self.prev_word, self.prev_h, self.tf_decoder_input)
if self.is_src_read():
# All source words are read, no choice but writing
self.write(new_word, new_h)
else:
# C' is empty
if self._c_states is None:
self._c_states = self.read_more(delta)
# Evaluate candidate
cand_logp, cand_h, cand_new_word = self.decoder_step(
self._c_states, self.prev_word, self.prev_h, self.tf_decoder_input)
if crit_fn(logp, new_word, cand_logp, cand_new_word):
# Wait/Read more words and do another decoding attempt
state_dict = self._c_states
self._c_states = None
self.update_s(delta)
else:
# Commit the last candidate
self.write(new_word, new_h)
# All finished, convert translations to python lists on CPU
idxs = self.buffer[self.buffer.ne(0)].tolist()
if idxs[-1] != self.eos:
# In cases where <eos> not produced and the above loop
# went on until max_len, add an explicit <eos> for correctness
idxs.append(self.eos)
# compute action sequence from which metrics will be computed
actions.append(' '.join(map(lambda i: str(i), self.actions)))
translations.append(self.vocab.idxs_to_sent(idxs))
up_time = time.time() - start
hyps = sort_predictions(self.data_loader, translations)
actions = sort_predictions(self.data_loader, actions)
return (hyps, actions, up_time)
def set_first_word_to_bos(self, cur_batch_size):
self.prev_word = self.model.get_bos(cur_batch_size).to(DEVICE)
self.tf_decoder_input = self.prev_word.unsqueeze(0)
Classes
class SimultaneousGreedySearch (model, data_loader, out_prefix, batch_size, filter_chain=None, max_len=100, **kwargs)
-
Expand source code
class SimultaneousGreedySearch(GreedySearch): ACT_READ, ACT_WRITE = 0, 1 def __init__(self, model, data_loader, out_prefix, batch_size, filter_chain=None, max_len=100, **kwargs): assert not model.opts.model['enc_bidirectional'], \ "Bidirectional models can not be used for simultaneous MT." assert model.opts.model.get('dec_init', 'zero') == 'zero', \ "`dec_init` should be 'zero' for simplicity." logger.info(f'Ignoring batch_size {batch_size} for simultaneous greedy search') batch_size = 1 super().__init__(model, data_loader, out_prefix, batch_size, filter_chain, max_len) # Partial modality i.e. text self._partial_key = str(model.sl) self.buffer = None self.list_of_s_0 = kwargs.pop('s_0', '').split(',') self.list_of_delta = kwargs.pop('delta', '').split(',') self.criteria = kwargs.pop('criteria', '').split(',') self.tf_decoder_input = None @staticmethod def wait_if_diff(cur_log_p, cur_next_pred, cand_log_p, cand_next_pred): """If the candidate changes with more context, READ. Otherwise WRITE.""" return cand_next_pred.ne(cur_next_pred) @staticmethod def wait_if_worse(cur_log_p, cur_next_pred, cand_log_p, cand_next_pred): """If confidence for the candidate decreases WAIT/READ. Otherwise WRITE.""" return cand_log_p[0, cur_next_pred] < cur_log_p[0, cur_next_pred] def write(self, new_word, new_h): """Write the new word, move the pointer and accept the hidden state.""" self.prev_word, self.buffer[self.t_ptr] = new_word, new_word self.prev_h = new_h self.actions.append(self.ACT_WRITE) self.t_ptr += 1 self.eos_written = new_word.item() == self.eos self.tf_decoder_input = torch.cat((self.tf_decoder_input, new_word.unsqueeze(0)), dim=0) def update_s(self, increment): """Update read pointer.""" new_pos = min(self.s_len, self.s_ptr + increment) n_reads = new_pos - self.s_ptr self.actions.extend([self.ACT_READ] * n_reads) self.s_ptr = new_pos def clear_states(self): self.s_ptr = 0 self.t_ptr = 0 self.prev_h = None self._c_states = None self.prev_word = None self.eos_written = False self.actions = [] if self.buffer is None: # Write buffer self.buffer = torch.zeros((self.max_len, ), dtype=torch.long, device=DEVICE) else: # Reset hypothesis buffer self.buffer.zero_() def is_src_read(self): return self.s_ptr >= self.s_len def cache_encoder_states(self, batch): """Encode full source sentence and cache the states.""" self.model.cache_enc_states(batch) self.s_len = batch[self._partial_key].size(0) def read_more(self, n): """Reads more source words and computes new states.""" return self.model.get_enc_state_dict(up_to=self.s_ptr + n) def get_src_prefix_str(self, batch): idxs = batch[self._partial_key][:self.s_ptr, 0].tolist() return self.model.vocabs[self._partial_key].idxs_to_sent(idxs) def get_trg_prefix_str(self): idxs = self.buffer[:self.t_ptr].tolist() return self.model.vocabs['trg'].idxs_to_sent(idxs) def run_all(self): """Do a grid search over the given list of parameters.""" ############# # grid search ############# settings = itertools.product( self.list_of_s_0, self.list_of_delta, self.criteria, ) for s_0, delta, crit in settings: # Run the decoding hyps, actions, up_time = self.run(int(s_0), int(delta), crit) # Dumps two files one with segmentations preserved, another # with post-processing filters applied self.dump_results(hyps, suffix=f's{s_0}_d{delta}_{crit}') # Dump actions self.dump_lines(actions, suffix=f's{s_0}_d{delta}_{crit}.acts') def run(self, s_0, delta, criterion): # R/W actions generated for the whole test set actions = [] # Final translations translations = [] # Set criterion crit_fn = getattr(self, criterion) start = time.time() for batch in progress_bar(self.data_loader, unit='batch'): self.clear_states() batch.device(DEVICE) # Compute all at once self.cache_encoder_states(batch) # Read some words and get trimmed states state_dict = self.read_more(s_0) self.update_s(s_0) # Initial state is None i.e. 0. state_dict is not used self.prev_h = self.decoder_init(state_dict) # last batch could be smaller than the requested batch size cur_batch_size = batch.size # Start all sentences with <s> self.set_first_word_to_bos(cur_batch_size) while not self.eos_written and self.t_ptr < self.max_len: logp, new_h, new_word = self.decoder_step( state_dict, self.prev_word, self.prev_h, self.tf_decoder_input) if self.is_src_read(): # All source words are read, no choice but writing self.write(new_word, new_h) else: # C' is empty if self._c_states is None: self._c_states = self.read_more(delta) # Evaluate candidate cand_logp, cand_h, cand_new_word = self.decoder_step( self._c_states, self.prev_word, self.prev_h, self.tf_decoder_input) if crit_fn(logp, new_word, cand_logp, cand_new_word): # Wait/Read more words and do another decoding attempt state_dict = self._c_states self._c_states = None self.update_s(delta) else: # Commit the last candidate self.write(new_word, new_h) # All finished, convert translations to python lists on CPU idxs = self.buffer[self.buffer.ne(0)].tolist() if idxs[-1] != self.eos: # In cases where <eos> not produced and the above loop # went on until max_len, add an explicit <eos> for correctness idxs.append(self.eos) # compute action sequence from which metrics will be computed actions.append(' '.join(map(lambda i: str(i), self.actions))) translations.append(self.vocab.idxs_to_sent(idxs)) up_time = time.time() - start hyps = sort_predictions(self.data_loader, translations) actions = sort_predictions(self.data_loader, actions) return (hyps, actions, up_time) def set_first_word_to_bos(self, cur_batch_size): self.prev_word = self.model.get_bos(cur_batch_size).to(DEVICE) self.tf_decoder_input = self.prev_word.unsqueeze(0)
Ancestors
Subclasses
Class variables
var ACT_READ
var ACT_WRITE
Static methods
def wait_if_diff(cur_log_p, cur_next_pred, cand_log_p, cand_next_pred)
-
If the candidate changes with more context, READ. Otherwise WRITE.
Expand source code
@staticmethod def wait_if_diff(cur_log_p, cur_next_pred, cand_log_p, cand_next_pred): """If the candidate changes with more context, READ. Otherwise WRITE.""" return cand_next_pred.ne(cur_next_pred)
def wait_if_worse(cur_log_p, cur_next_pred, cand_log_p, cand_next_pred)
-
If confidence for the candidate decreases WAIT/READ. Otherwise WRITE.
Expand source code
@staticmethod def wait_if_worse(cur_log_p, cur_next_pred, cand_log_p, cand_next_pred): """If confidence for the candidate decreases WAIT/READ. Otherwise WRITE.""" return cand_log_p[0, cur_next_pred] < cur_log_p[0, cur_next_pred]
Methods
def cache_encoder_states(self, batch)
-
Encode full source sentence and cache the states.
Expand source code
def cache_encoder_states(self, batch): """Encode full source sentence and cache the states.""" self.model.cache_enc_states(batch) self.s_len = batch[self._partial_key].size(0)
def clear_states(self)
-
Expand source code
def clear_states(self): self.s_ptr = 0 self.t_ptr = 0 self.prev_h = None self._c_states = None self.prev_word = None self.eos_written = False self.actions = [] if self.buffer is None: # Write buffer self.buffer = torch.zeros((self.max_len, ), dtype=torch.long, device=DEVICE) else: # Reset hypothesis buffer self.buffer.zero_()
def get_src_prefix_str(self, batch)
-
Expand source code
def get_src_prefix_str(self, batch): idxs = batch[self._partial_key][:self.s_ptr, 0].tolist() return self.model.vocabs[self._partial_key].idxs_to_sent(idxs)
def get_trg_prefix_str(self)
-
Expand source code
def get_trg_prefix_str(self): idxs = self.buffer[:self.t_ptr].tolist() return self.model.vocabs['trg'].idxs_to_sent(idxs)
def is_src_read(self)
-
Expand source code
def is_src_read(self): return self.s_ptr >= self.s_len
def read_more(self, n)
-
Reads more source words and computes new states.
Expand source code
def read_more(self, n): """Reads more source words and computes new states.""" return self.model.get_enc_state_dict(up_to=self.s_ptr + n)
def run(self, s_0, delta, criterion)
-
Expand source code
def run(self, s_0, delta, criterion): # R/W actions generated for the whole test set actions = [] # Final translations translations = [] # Set criterion crit_fn = getattr(self, criterion) start = time.time() for batch in progress_bar(self.data_loader, unit='batch'): self.clear_states() batch.device(DEVICE) # Compute all at once self.cache_encoder_states(batch) # Read some words and get trimmed states state_dict = self.read_more(s_0) self.update_s(s_0) # Initial state is None i.e. 0. state_dict is not used self.prev_h = self.decoder_init(state_dict) # last batch could be smaller than the requested batch size cur_batch_size = batch.size # Start all sentences with <s> self.set_first_word_to_bos(cur_batch_size) while not self.eos_written and self.t_ptr < self.max_len: logp, new_h, new_word = self.decoder_step( state_dict, self.prev_word, self.prev_h, self.tf_decoder_input) if self.is_src_read(): # All source words are read, no choice but writing self.write(new_word, new_h) else: # C' is empty if self._c_states is None: self._c_states = self.read_more(delta) # Evaluate candidate cand_logp, cand_h, cand_new_word = self.decoder_step( self._c_states, self.prev_word, self.prev_h, self.tf_decoder_input) if crit_fn(logp, new_word, cand_logp, cand_new_word): # Wait/Read more words and do another decoding attempt state_dict = self._c_states self._c_states = None self.update_s(delta) else: # Commit the last candidate self.write(new_word, new_h) # All finished, convert translations to python lists on CPU idxs = self.buffer[self.buffer.ne(0)].tolist() if idxs[-1] != self.eos: # In cases where <eos> not produced and the above loop # went on until max_len, add an explicit <eos> for correctness idxs.append(self.eos) # compute action sequence from which metrics will be computed actions.append(' '.join(map(lambda i: str(i), self.actions))) translations.append(self.vocab.idxs_to_sent(idxs)) up_time = time.time() - start hyps = sort_predictions(self.data_loader, translations) actions = sort_predictions(self.data_loader, actions) return (hyps, actions, up_time)
def run_all(self)
-
Do a grid search over the given list of parameters.
Expand source code
def run_all(self): """Do a grid search over the given list of parameters.""" ############# # grid search ############# settings = itertools.product( self.list_of_s_0, self.list_of_delta, self.criteria, ) for s_0, delta, crit in settings: # Run the decoding hyps, actions, up_time = self.run(int(s_0), int(delta), crit) # Dumps two files one with segmentations preserved, another # with post-processing filters applied self.dump_results(hyps, suffix=f's{s_0}_d{delta}_{crit}') # Dump actions self.dump_lines(actions, suffix=f's{s_0}_d{delta}_{crit}.acts')
def set_first_word_to_bos(self, cur_batch_size)
-
Expand source code
def set_first_word_to_bos(self, cur_batch_size): self.prev_word = self.model.get_bos(cur_batch_size).to(DEVICE) self.tf_decoder_input = self.prev_word.unsqueeze(0)
def update_s(self, increment)
-
Update read pointer.
Expand source code
def update_s(self, increment): """Update read pointer.""" new_pos = min(self.s_len, self.s_ptr + increment) n_reads = new_pos - self.s_ptr self.actions.extend([self.ACT_READ] * n_reads) self.s_ptr = new_pos
def write(self, new_word, new_h)
-
Write the new word, move the pointer and accept the hidden state.
Expand source code
def write(self, new_word, new_h): """Write the new word, move the pointer and accept the hidden state.""" self.prev_word, self.buffer[self.t_ptr] = new_word, new_word self.prev_h = new_h self.actions.append(self.ACT_WRITE) self.t_ptr += 1 self.eos_written = new_word.item() == self.eos self.tf_decoder_input = torch.cat((self.tf_decoder_input, new_word.unsqueeze(0)), dim=0)