Module pysimt.stranslator

Expand source code
import sys
import logging
from pathlib import Path

import torch

from .utils.misc import load_pt_file
from .utils.filterchain import FilterChain
from .utils.data import make_dataloader
from .utils.device import DEVICE

from . import models
from .config import Options
from .translators import get_translator

logger = logging.getLogger('pysimt')


class STranslator:
    """A utility class that wraps single-model simultaneous translation stuff."""

    def __init__(self, **kwargs):
        # Store attributes directly. See bin/pysimt for their list.
        self.__dict__.update(kwargs)

        for key, value in kwargs.items():
            logger.info('-- {} -> {}'.format(key, value))

        try:
            self._translator_type = get_translator(self.func)
        except KeyError:
            logger.info(f'Error: decoding function {self.func!r} unknown.')
            sys.exit(1)

        # Handle batch size
        if self.func != 'gs':
            logger.info(f'STranslator: setting batch_size=1 for {self.func!r} decoding')
            self.batch_size = 1

        data = load_pt_file(self.model)
        weights, _, opts = data['model'], data['history'], data['opts']
        opts = Options.from_dict(opts, override_list=self.override)

        # Create model instance
        instance = getattr(models, opts.train['model_type'])(opts=opts)

        # Setup layers
        instance.setup(is_train=False)

        # Load weights
        instance.load_state_dict(weights, strict=False)

        # Move to device
        instance.to(DEVICE)

        # Switch to eval mode
        instance.train(False)

        logger.info(instance)

        # Split the string
        self.splits = self.splits.split(',')

        # Do some sanity-check
        if self.source and len(self.splits) > 1:
            logger.info('You can only give one split name when -S is provided.')
            sys.exit(1)

        # Setup post-processing filters
        eval_filters = instance.opts.train['eval_filters']

        if self.disable_filters or not eval_filters:
            logger.info('Post-processing filters disabled.')
            self.filter = None
        else:
            logger.info('Post-processing filters enabled.')
            self.filter = FilterChain(eval_filters)

        # Can be a comma separated list of hardcoded test splits
        logger.info('Will translate "{}"'.format(self.splits))
        if self.source:
            # We have to have single split name in this case
            split_set = '{}_set'.format(self.splits[0])
            input_dict = instance.opts.data.get(split_set, {})
            logger.info('Input configuration:')
            for data_source in self.source.split(','):
                key, path = data_source.split(':', 1)
                input_dict[key] = Path(path)
                logger.info(' {}: {}'.format(key, input_dict[key]))

            # Overwrite config's set name
            instance.opts.data[split_set] = input_dict

        # Override with the actual model
        self.model = instance

    def translate(self, split):
        """Returns the hypotheses generated by translating the given split
        using the given model instance.

        Arguments:
            split(str): A test split defined in the .conf file before
                training.

        Returns:
            list:
                A list of optionally post-processed string hypotheses.
        """

        # Load data
        dataset = self.model.load_data(split, self.batch_size, mode='beam')
        loader = make_dataloader(dataset)

        # Pick decoding method
        translator = self._translator_type(
            self.model, loader,
            f"{self.output}.{split}",
            self.batch_size, self.filter,
            max_len=self.max_len, delta=self.delta, s_0=self.n_init_tokens,
            criteria=self.criteria)

        logger.info(f'Starting translation for {split!r}')
        with torch.no_grad():
            translator.run_all()

    def __call__(self):
        for input_ in self.splits:
            self.translate(input_)

Classes

class STranslator (**kwargs)

A utility class that wraps single-model simultaneous translation stuff.

Expand source code
class STranslator:
    """A utility class that wraps single-model simultaneous translation stuff."""

    def __init__(self, **kwargs):
        # Store attributes directly. See bin/pysimt for their list.
        self.__dict__.update(kwargs)

        for key, value in kwargs.items():
            logger.info('-- {} -> {}'.format(key, value))

        try:
            self._translator_type = get_translator(self.func)
        except KeyError:
            logger.info(f'Error: decoding function {self.func!r} unknown.')
            sys.exit(1)

        # Handle batch size
        if self.func != 'gs':
            logger.info(f'STranslator: setting batch_size=1 for {self.func!r} decoding')
            self.batch_size = 1

        data = load_pt_file(self.model)
        weights, _, opts = data['model'], data['history'], data['opts']
        opts = Options.from_dict(opts, override_list=self.override)

        # Create model instance
        instance = getattr(models, opts.train['model_type'])(opts=opts)

        # Setup layers
        instance.setup(is_train=False)

        # Load weights
        instance.load_state_dict(weights, strict=False)

        # Move to device
        instance.to(DEVICE)

        # Switch to eval mode
        instance.train(False)

        logger.info(instance)

        # Split the string
        self.splits = self.splits.split(',')

        # Do some sanity-check
        if self.source and len(self.splits) > 1:
            logger.info('You can only give one split name when -S is provided.')
            sys.exit(1)

        # Setup post-processing filters
        eval_filters = instance.opts.train['eval_filters']

        if self.disable_filters or not eval_filters:
            logger.info('Post-processing filters disabled.')
            self.filter = None
        else:
            logger.info('Post-processing filters enabled.')
            self.filter = FilterChain(eval_filters)

        # Can be a comma separated list of hardcoded test splits
        logger.info('Will translate "{}"'.format(self.splits))
        if self.source:
            # We have to have single split name in this case
            split_set = '{}_set'.format(self.splits[0])
            input_dict = instance.opts.data.get(split_set, {})
            logger.info('Input configuration:')
            for data_source in self.source.split(','):
                key, path = data_source.split(':', 1)
                input_dict[key] = Path(path)
                logger.info(' {}: {}'.format(key, input_dict[key]))

            # Overwrite config's set name
            instance.opts.data[split_set] = input_dict

        # Override with the actual model
        self.model = instance

    def translate(self, split):
        """Returns the hypotheses generated by translating the given split
        using the given model instance.

        Arguments:
            split(str): A test split defined in the .conf file before
                training.

        Returns:
            list:
                A list of optionally post-processed string hypotheses.
        """

        # Load data
        dataset = self.model.load_data(split, self.batch_size, mode='beam')
        loader = make_dataloader(dataset)

        # Pick decoding method
        translator = self._translator_type(
            self.model, loader,
            f"{self.output}.{split}",
            self.batch_size, self.filter,
            max_len=self.max_len, delta=self.delta, s_0=self.n_init_tokens,
            criteria=self.criteria)

        logger.info(f'Starting translation for {split!r}')
        with torch.no_grad():
            translator.run_all()

    def __call__(self):
        for input_ in self.splits:
            self.translate(input_)

Methods

def translate(self, split)

Returns the hypotheses generated by translating the given split using the given model instance.

Arguments

split(str): A test split defined in the .conf file before training.

Returns

list: A list of optionally post-processed string hypotheses.

Expand source code
def translate(self, split):
    """Returns the hypotheses generated by translating the given split
    using the given model instance.

    Arguments:
        split(str): A test split defined in the .conf file before
            training.

    Returns:
        list:
            A list of optionally post-processed string hypotheses.
    """

    # Load data
    dataset = self.model.load_data(split, self.batch_size, mode='beam')
    loader = make_dataloader(dataset)

    # Pick decoding method
    translator = self._translator_type(
        self.model, loader,
        f"{self.output}.{split}",
        self.batch_size, self.filter,
        max_len=self.max_len, delta=self.delta, s_0=self.n_init_tokens,
        criteria=self.criteria)

    logger.info(f'Starting translation for {split!r}')
    with torch.no_grad():
        translator.run_all()