Module pysimt.datasets.multimodal

Expand source code
import logging
from torch.utils.data import Dataset
from torch.utils.data.sampler import BatchSampler, SequentialSampler, RandomSampler

from . import get_dataset
from .collate import get_collate
from ..samplers import get_sampler

logger = logging.getLogger('pysimt')


class MultimodalDataset(Dataset):
    """Returns a Dataset for parallel multimodal corpora

    Arguments:
        data(dict): [data] section's relevant split dictionary
        mode(str): One of train/eval/beam.
        batch_size(int): Batch size.
        vocabs(dict): dictionary mapping keys to Vocabulary() objects
        topology(Topology): A topology object.
        bucket_by(str): String identifier of the modality which will define how
            the batches will be bucketed, i.e. sort key. If `None`, no
            bucketing will be performed but the layers and models should
            support packing/padding/masking for this to work.
        max_len(int, optional): Maximum sequence length for ``bucket_by``
            modality to reject batches with long sequences. Does not have an effect
            if mode != 'train'.
        bucket_order (str, optional): ``ascending`` or ``descending`` to
            perform length-based curriculum learning. Default is ``None``
            which shuffles bucket order. Does not have an effect if mode != 'train'.
        sampler_type(str, optional): 'bucket' or 'approximate' (Default: 'bucket')
        kwargs (dict): Additional arguments to pass to the dataset constructors.
    """
    def __init__(self, data, mode, batch_size, vocabs, topology,
                 bucket_by, bucket_order=None, max_len=None,
                 sampler_type='bucket', **kwargs):
        self.datasets = {}
        self.mode = mode
        self.vocabs = vocabs
        self.batch_size = batch_size
        self.topology = topology
        self.bucket_by = bucket_by
        self.sampler_type = sampler_type

        # Disable filtering if not training
        self.max_len = max_len if self.mode == 'train' else None

        # This is only useful for training
        self.bucket_order = bucket_order if self.mode == 'train' else None

        # Collect dataset sizes
        self.size_dict = {}

        # For old models to work, set it to the first source
        if self.bucket_by is None:
            logger.info(
                'WARNING: Bucketing disabled. It is up to the model '
                'to take care of packing/padding/masking if any.')

        for key, ds in self.topology.all.items():
            if self.mode == 'beam' and ds.trg:
                # Skip target streams for beam-search
                logger.info("Skipping '{}' as target".format(key))
                continue

            try:
                # Get the relevant dataset class
                dataset_constructor = get_dataset(ds._type)
            except KeyError:
                raise RuntimeError(f"Unknown dataset type {ds._type!r}")

            logger.info("Initializing dataset for '{}'...".format(ds))
            if key in data:
                # Force <eos> for target side, relax it for source side
                kwargs['eos'] = kwargs.get('eos', True) or ds.trg
                # Construct the dataset
                self.datasets[ds] = dataset_constructor(
                    fname=data[key],
                    vocab=vocabs.get(key, None), bos=ds.trg, **kwargs)
                self.size_dict[ds] = len(self.datasets[ds])
            else:
                logger.info("  Skipping as '{}' not defined. This may be an issue.".format(key))

        # Set dataset size
        if len(set(self.size_dict.values())) > 1:
            raise RuntimeError("Underlying datasets are not parallel!")
        else:
            self.size = list(self.size_dict.values())[0]

        # Set list of available datasets
        self.keys = list(self.datasets.keys())

        # Get collator
        self.collate_fn = get_collate(self.keys)

        if self.bucket_by in self.datasets:
            self.sort_lens = self.datasets[self.bucket_by].lengths
            assert self.sort_lens is not None

            # Get a batch sampler
            self.sampler = get_sampler(self.sampler_type)(
                batch_size=self.batch_size,
                sort_lens=self.sort_lens,
                max_len=self.max_len,
                store_indices=self.mode != 'train',
                order=self.bucket_order)
        else:
            # bucket_by was only valid for training
            if self.bucket_by:
                self.bucket_by = None
                logger.info('Disabling bucketing for data loader.')
            # No modality provided to bucket sequential batches
            # Used for beam-search in image->text tasks
            if self.mode == 'beam':
                sampler = SequentialSampler(self)
                self.sampler_type = 'sequential'
            else:
                sampler = RandomSampler(self)
                self.sampler_type = 'random'
            # Create a batch sampler
            self.sampler = BatchSampler(
                sampler, batch_size=self.batch_size, drop_last=False)

        # Set some metadata
        self.n_sources = len([k for k in self.keys if k.src])
        self.n_targets = len([k for k in self.keys if k.trg])

    def __getitem__(self, idx):
        return {k: self.datasets[k][idx] for k in self.keys}

    def __len__(self):
        return self.size

    def __repr__(self):
        s = "{} - ({} source(s) / {} target(s))\n".format(
            self.__class__.__name__, self.n_sources, self.n_targets)
        s += "  Sampler type: {}, bucket_by: {}\n".format(
            self.sampler_type, self.bucket_by)

        if self.n_sources > 0:
            s += "  Sources:\n"
            for name in filter(lambda k: k.src, self.keys):
                dstr = self.datasets[name].__repr__()
                s += f'    --> {dstr}\n'
        if self.n_targets > 0:
            s += "  Targets:\n"
            for name in filter(lambda k: k.trg, self.keys):
                dstr = self.datasets[name].__repr__()
                s += f'    --> {dstr}\n'
        return s

Classes

class MultimodalDataset (data, mode, batch_size, vocabs, topology, bucket_by, bucket_order=None, max_len=None, sampler_type='bucket', **kwargs)

Returns a Dataset for parallel multimodal corpora

Arguments

data(dict): [data] section's relevant split dictionary mode(str): One of train/eval/beam. batch_size(int): Batch size. vocabs(dict): dictionary mapping keys to Vocabulary() objects topology(Topology): A topology object. bucket_by(str): String identifier of the modality which will define how the batches will be bucketed, i.e. sort key. If None, no bucketing will be performed but the layers and models should support packing/padding/masking for this to work. max_len(int, optional): Maximum sequence length for bucket_by modality to reject batches with long sequences. Does not have an effect if mode != 'train'. bucket_order (str, optional): ascending or descending to perform length-based curriculum learning. Default is None which shuffles bucket order. Does not have an effect if mode != 'train'. sampler_type(str, optional): 'bucket' or 'approximate' (Default: 'bucket') kwargs (dict): Additional arguments to pass to the dataset constructors.

Expand source code
class MultimodalDataset(Dataset):
    """Returns a Dataset for parallel multimodal corpora

    Arguments:
        data(dict): [data] section's relevant split dictionary
        mode(str): One of train/eval/beam.
        batch_size(int): Batch size.
        vocabs(dict): dictionary mapping keys to Vocabulary() objects
        topology(Topology): A topology object.
        bucket_by(str): String identifier of the modality which will define how
            the batches will be bucketed, i.e. sort key. If `None`, no
            bucketing will be performed but the layers and models should
            support packing/padding/masking for this to work.
        max_len(int, optional): Maximum sequence length for ``bucket_by``
            modality to reject batches with long sequences. Does not have an effect
            if mode != 'train'.
        bucket_order (str, optional): ``ascending`` or ``descending`` to
            perform length-based curriculum learning. Default is ``None``
            which shuffles bucket order. Does not have an effect if mode != 'train'.
        sampler_type(str, optional): 'bucket' or 'approximate' (Default: 'bucket')
        kwargs (dict): Additional arguments to pass to the dataset constructors.
    """
    def __init__(self, data, mode, batch_size, vocabs, topology,
                 bucket_by, bucket_order=None, max_len=None,
                 sampler_type='bucket', **kwargs):
        self.datasets = {}
        self.mode = mode
        self.vocabs = vocabs
        self.batch_size = batch_size
        self.topology = topology
        self.bucket_by = bucket_by
        self.sampler_type = sampler_type

        # Disable filtering if not training
        self.max_len = max_len if self.mode == 'train' else None

        # This is only useful for training
        self.bucket_order = bucket_order if self.mode == 'train' else None

        # Collect dataset sizes
        self.size_dict = {}

        # For old models to work, set it to the first source
        if self.bucket_by is None:
            logger.info(
                'WARNING: Bucketing disabled. It is up to the model '
                'to take care of packing/padding/masking if any.')

        for key, ds in self.topology.all.items():
            if self.mode == 'beam' and ds.trg:
                # Skip target streams for beam-search
                logger.info("Skipping '{}' as target".format(key))
                continue

            try:
                # Get the relevant dataset class
                dataset_constructor = get_dataset(ds._type)
            except KeyError:
                raise RuntimeError(f"Unknown dataset type {ds._type!r}")

            logger.info("Initializing dataset for '{}'...".format(ds))
            if key in data:
                # Force <eos> for target side, relax it for source side
                kwargs['eos'] = kwargs.get('eos', True) or ds.trg
                # Construct the dataset
                self.datasets[ds] = dataset_constructor(
                    fname=data[key],
                    vocab=vocabs.get(key, None), bos=ds.trg, **kwargs)
                self.size_dict[ds] = len(self.datasets[ds])
            else:
                logger.info("  Skipping as '{}' not defined. This may be an issue.".format(key))

        # Set dataset size
        if len(set(self.size_dict.values())) > 1:
            raise RuntimeError("Underlying datasets are not parallel!")
        else:
            self.size = list(self.size_dict.values())[0]

        # Set list of available datasets
        self.keys = list(self.datasets.keys())

        # Get collator
        self.collate_fn = get_collate(self.keys)

        if self.bucket_by in self.datasets:
            self.sort_lens = self.datasets[self.bucket_by].lengths
            assert self.sort_lens is not None

            # Get a batch sampler
            self.sampler = get_sampler(self.sampler_type)(
                batch_size=self.batch_size,
                sort_lens=self.sort_lens,
                max_len=self.max_len,
                store_indices=self.mode != 'train',
                order=self.bucket_order)
        else:
            # bucket_by was only valid for training
            if self.bucket_by:
                self.bucket_by = None
                logger.info('Disabling bucketing for data loader.')
            # No modality provided to bucket sequential batches
            # Used for beam-search in image->text tasks
            if self.mode == 'beam':
                sampler = SequentialSampler(self)
                self.sampler_type = 'sequential'
            else:
                sampler = RandomSampler(self)
                self.sampler_type = 'random'
            # Create a batch sampler
            self.sampler = BatchSampler(
                sampler, batch_size=self.batch_size, drop_last=False)

        # Set some metadata
        self.n_sources = len([k for k in self.keys if k.src])
        self.n_targets = len([k for k in self.keys if k.trg])

    def __getitem__(self, idx):
        return {k: self.datasets[k][idx] for k in self.keys}

    def __len__(self):
        return self.size

    def __repr__(self):
        s = "{} - ({} source(s) / {} target(s))\n".format(
            self.__class__.__name__, self.n_sources, self.n_targets)
        s += "  Sampler type: {}, bucket_by: {}\n".format(
            self.sampler_type, self.bucket_by)

        if self.n_sources > 0:
            s += "  Sources:\n"
            for name in filter(lambda k: k.src, self.keys):
                dstr = self.datasets[name].__repr__()
                s += f'    --> {dstr}\n'
        if self.n_targets > 0:
            s += "  Targets:\n"
            for name in filter(lambda k: k.trg, self.keys):
                dstr = self.datasets[name].__repr__()
                s += f'    --> {dstr}\n'
        return s

Ancestors

  • torch.utils.data.dataset.Dataset
  • typing.Generic