Module pysimt.stranslator
Expand source code
import sys
import logging
from pathlib import Path
import torch
from .utils.misc import load_pt_file
from .utils.filterchain import FilterChain
from .utils.data import make_dataloader
from .utils.device import DEVICE
from . import models
from .config import Options
from .translators import get_translator
logger = logging.getLogger('pysimt')
class STranslator:
"""A utility class that wraps single-model simultaneous translation stuff."""
def __init__(self, **kwargs):
# Store attributes directly. See bin/pysimt for their list.
self.__dict__.update(kwargs)
for key, value in kwargs.items():
logger.info('-- {} -> {}'.format(key, value))
try:
self._translator_type = get_translator(self.func)
except KeyError:
logger.info(f'Error: decoding function {self.func!r} unknown.')
sys.exit(1)
# Handle batch size
if self.func != 'gs':
logger.info(f'STranslator: setting batch_size=1 for {self.func!r} decoding')
self.batch_size = 1
data = load_pt_file(self.model)
weights, _, opts = data['model'], data['history'], data['opts']
opts = Options.from_dict(opts, override_list=self.override)
# Create model instance
instance = getattr(models, opts.train['model_type'])(opts=opts)
# Setup layers
instance.setup(is_train=False)
# Load weights
instance.load_state_dict(weights, strict=False)
# Move to device
instance.to(DEVICE)
# Switch to eval mode
instance.train(False)
logger.info(instance)
# Split the string
self.splits = self.splits.split(',')
# Do some sanity-check
if self.source and len(self.splits) > 1:
logger.info('You can only give one split name when -S is provided.')
sys.exit(1)
# Setup post-processing filters
eval_filters = instance.opts.train['eval_filters']
if self.disable_filters or not eval_filters:
logger.info('Post-processing filters disabled.')
self.filter = None
else:
logger.info('Post-processing filters enabled.')
self.filter = FilterChain(eval_filters)
# Can be a comma separated list of hardcoded test splits
logger.info('Will translate "{}"'.format(self.splits))
if self.source:
# We have to have single split name in this case
split_set = '{}_set'.format(self.splits[0])
input_dict = instance.opts.data.get(split_set, {})
logger.info('Input configuration:')
for data_source in self.source.split(','):
key, path = data_source.split(':', 1)
input_dict[key] = Path(path)
logger.info(' {}: {}'.format(key, input_dict[key]))
# Overwrite config's set name
instance.opts.data[split_set] = input_dict
# Override with the actual model
self.model = instance
def translate(self, split):
"""Returns the hypotheses generated by translating the given split
using the given model instance.
Arguments:
split(str): A test split defined in the .conf file before
training.
Returns:
list:
A list of optionally post-processed string hypotheses.
"""
# Load data
dataset = self.model.load_data(split, self.batch_size, mode='beam')
loader = make_dataloader(dataset)
# Pick decoding method
translator = self._translator_type(
self.model, loader,
f"{self.output}.{split}",
self.batch_size, self.filter,
max_len=self.max_len, delta=self.delta, s_0=self.n_init_tokens,
criteria=self.criteria)
logger.info(f'Starting translation for {split!r}')
with torch.no_grad():
translator.run_all()
def __call__(self):
for input_ in self.splits:
self.translate(input_)
Classes
class STranslator (**kwargs)
-
A utility class that wraps single-model simultaneous translation stuff.
Expand source code
class STranslator: """A utility class that wraps single-model simultaneous translation stuff.""" def __init__(self, **kwargs): # Store attributes directly. See bin/pysimt for their list. self.__dict__.update(kwargs) for key, value in kwargs.items(): logger.info('-- {} -> {}'.format(key, value)) try: self._translator_type = get_translator(self.func) except KeyError: logger.info(f'Error: decoding function {self.func!r} unknown.') sys.exit(1) # Handle batch size if self.func != 'gs': logger.info(f'STranslator: setting batch_size=1 for {self.func!r} decoding') self.batch_size = 1 data = load_pt_file(self.model) weights, _, opts = data['model'], data['history'], data['opts'] opts = Options.from_dict(opts, override_list=self.override) # Create model instance instance = getattr(models, opts.train['model_type'])(opts=opts) # Setup layers instance.setup(is_train=False) # Load weights instance.load_state_dict(weights, strict=False) # Move to device instance.to(DEVICE) # Switch to eval mode instance.train(False) logger.info(instance) # Split the string self.splits = self.splits.split(',') # Do some sanity-check if self.source and len(self.splits) > 1: logger.info('You can only give one split name when -S is provided.') sys.exit(1) # Setup post-processing filters eval_filters = instance.opts.train['eval_filters'] if self.disable_filters or not eval_filters: logger.info('Post-processing filters disabled.') self.filter = None else: logger.info('Post-processing filters enabled.') self.filter = FilterChain(eval_filters) # Can be a comma separated list of hardcoded test splits logger.info('Will translate "{}"'.format(self.splits)) if self.source: # We have to have single split name in this case split_set = '{}_set'.format(self.splits[0]) input_dict = instance.opts.data.get(split_set, {}) logger.info('Input configuration:') for data_source in self.source.split(','): key, path = data_source.split(':', 1) input_dict[key] = Path(path) logger.info(' {}: {}'.format(key, input_dict[key])) # Overwrite config's set name instance.opts.data[split_set] = input_dict # Override with the actual model self.model = instance def translate(self, split): """Returns the hypotheses generated by translating the given split using the given model instance. Arguments: split(str): A test split defined in the .conf file before training. Returns: list: A list of optionally post-processed string hypotheses. """ # Load data dataset = self.model.load_data(split, self.batch_size, mode='beam') loader = make_dataloader(dataset) # Pick decoding method translator = self._translator_type( self.model, loader, f"{self.output}.{split}", self.batch_size, self.filter, max_len=self.max_len, delta=self.delta, s_0=self.n_init_tokens, criteria=self.criteria) logger.info(f'Starting translation for {split!r}') with torch.no_grad(): translator.run_all() def __call__(self): for input_ in self.splits: self.translate(input_)
Methods
def translate(self, split)
-
Returns the hypotheses generated by translating the given split using the given model instance.
Arguments
split(str): A test split defined in the .conf file before training.
Returns
list: A list of optionally post-processed string hypotheses.
Expand source code
def translate(self, split): """Returns the hypotheses generated by translating the given split using the given model instance. Arguments: split(str): A test split defined in the .conf file before training. Returns: list: A list of optionally post-processed string hypotheses. """ # Load data dataset = self.model.load_data(split, self.batch_size, mode='beam') loader = make_dataloader(dataset) # Pick decoding method translator = self._translator_type( self.model, loader, f"{self.output}.{split}", self.batch_size, self.filter, max_len=self.max_len, delta=self.delta, s_0=self.n_init_tokens, criteria=self.criteria) logger.info(f'Starting translation for {split!r}') with torch.no_grad(): translator.run_all()