Module pysimt.samplers.bucket

Expand source code
import math
import logging
from collections import defaultdict

import numpy as np

from torch.utils.data.sampler import Sampler

logger = logging.getLogger('pysimt')


class BucketBatchSampler(Sampler):
    r"""Samples batch indices from sequence-length buckets efficiently
    with very little memory overhead.

    Epoch overhead for 5M dataset with batch_size=32 is around 400ms.

    Arguments:
        batch_size (int): Size of mini-batch.
        sort_lens (list): List of source or target lengths corresponding to each
            item in the dataset.
        max_len (int, optional): A maximum sequence length that will be used
            to filter out very long sequences. ``None`` means no filtering.
        store_indices (bool, optional): If ``True``, indices that will unsort
            the dataset will be stored. This used by beam search/inference.
        order (str, optional): Default is ``None``, i.e. buckets are shuffled.
            If ``ascending`` or ``descending``, will iterate w.r.t bucket
            lengths to implement length-based curriculum learning.

    Example:
        # Generate dummy length information
        >> lengths = np.random.randint(1, 20, size=10000)
        >> sampler = BucketBatchSampler(batch_size=10, sort_lens=lengths)
        >> batch = list(sampler)[0]
        >> batch
        [7526, 8473, 9194, 1030, 1568, 4182, 3082, 827, 3688, 9336]
        >> [lengths[i] for i in batch]
        # All samples in the batch have same length
        [4, 4, 4, 4, 4, 4, 4, 4, 4, 4]

    """

    def __init__(self, batch_size, sort_lens,
                 max_len=None, store_indices=False, order=None):
        self.batch_size = batch_size
        self.max_len = max_len
        self.store_indices = store_indices
        self.n_rejects = 0
        self.order = order

        # fix the seed
        self._rng = np.random.RandomState(seed=12345)

        assert sort_lens is not None, \
            "BucketBatchSampler() received `sort_lens` == None"

        assert self.order in (None, 'ascending', 'descending'), \
            "order should be None, 'ascending' or 'descending'"

        # Buckets: sort_lens -> list of sample indices
        self.buckets = defaultdict(list)

        # Fill the buckets while optionally filtering out long sequences
        if self.max_len is not None:
            for idx, len_ in enumerate(sort_lens):
                if len_ <= self.max_len:
                    self.buckets[len_].append(idx)
                else:
                    self.n_rejects += 1
            logger.info('{} samples rejected because of length filtering @ {}'.format(
                self.n_rejects, self.max_len))
        else:
            # No length filtering
            for idx, len_ in enumerate(sort_lens):
                self.buckets[len_].append(idx)

        # Pre-compute how many times a bucket will be sampled
        self.bucket_idxs = []

        self.stats = {k: len(self.buckets[k]) for k in sorted(self.buckets)}

        for len_ in self.buckets:
            # Convery bucket to numpy array
            np_bucket = np.array(self.buckets[len_])

            # How many batches will be done for this bucket?
            bucket_bs = np_bucket.size / self.batch_size
            idxs = [len_] * math.ceil(bucket_bs)

            self.buckets[len_] = np_bucket
            self.bucket_idxs.extend(idxs)

        # Convert to numpy array
        self.bucket_idxs = np.array(self.bucket_idxs)

        # Set number of batches
        self.n_batches = len(self.bucket_idxs)

    def __iter__(self):
        # Keep offsets for each bucket for efficiency
        bucket_offsets = {}

        # Random access indices
        bucket_views = {}

        # If beam-search with ordered batches, original indices will be
        # necessary.
        self.orig_idxs = []

        # Create permuted access indices for each bucket
        # to avoid shuffling the lists
        for len_, elems in self.buckets.items():
            bucket_offsets[len_] = 0
            perms = self._rng.permutation(len(elems))
            bucket_views[len_] = perms

        if self.order is None:
            # Shuffle bucket order
            shuf_idxs = self._rng.permutation(self.bucket_idxs)
        elif self.order == "ascending":
            # Start from shortest sequences and increase
            shuf_idxs = np.sort(self.bucket_idxs)
        elif self.order == "descending":
            # Start from longest sequences and decrease
            shuf_idxs = -np.sort(-self.bucket_idxs)

        # For each bucket, slide the window to yield the next batch
        for bidx in shuf_idxs:
            # Get offset pointer for this bucket: 0 initially
            offset = bucket_offsets[bidx]

            # Convert them to permuted view
            idxs = bucket_views[bidx][offset: offset + self.batch_size]

            # Increment offset
            bucket_offsets[bidx] += len(idxs)

            # Get actual sample indices
            sidxs = self.buckets[bidx][idxs]

            if self.store_indices:
                self.orig_idxs.extend(sidxs)

            # Return sample indices
            yield sidxs

    def __len__(self):
        """Returns how many batches are inside."""
        return self.n_batches

    def __repr__(self):
        return f"BucketBatchSampler(order={self.order}, max_len={self.max_len}, n_rejects={self.n_rejects})"

Classes

class BucketBatchSampler (batch_size, sort_lens, max_len=None, store_indices=False, order=None)

Samples batch indices from sequence-length buckets efficiently with very little memory overhead.

Epoch overhead for 5M dataset with batch_size=32 is around 400ms.

Arguments

batch_size (int): Size of mini-batch. sort_lens (list): List of source or target lengths corresponding to each item in the dataset. max_len (int, optional): A maximum sequence length that will be used to filter out very long sequences. None means no filtering. store_indices (bool, optional): If True, indices that will unsort the dataset will be stored. This used by beam search/inference. order (str, optional): Default is None, i.e. buckets are shuffled. If ascending or descending, will iterate w.r.t bucket lengths to implement length-based curriculum learning.

Example

Generate dummy length information

lengths = np.random.randint(1, 20, size=10000) sampler = BucketBatchSampler(batch_size=10, sort_lens=lengths) batch = list(sampler)[0] batch [7526, 8473, 9194, 1030, 1568, 4182, 3082, 827, 3688, 9336] [lengths[i] for i in batch]

All samples in the batch have same length

[4, 4, 4, 4, 4, 4, 4, 4, 4, 4]

Expand source code
class BucketBatchSampler(Sampler):
    r"""Samples batch indices from sequence-length buckets efficiently
    with very little memory overhead.

    Epoch overhead for 5M dataset with batch_size=32 is around 400ms.

    Arguments:
        batch_size (int): Size of mini-batch.
        sort_lens (list): List of source or target lengths corresponding to each
            item in the dataset.
        max_len (int, optional): A maximum sequence length that will be used
            to filter out very long sequences. ``None`` means no filtering.
        store_indices (bool, optional): If ``True``, indices that will unsort
            the dataset will be stored. This used by beam search/inference.
        order (str, optional): Default is ``None``, i.e. buckets are shuffled.
            If ``ascending`` or ``descending``, will iterate w.r.t bucket
            lengths to implement length-based curriculum learning.

    Example:
        # Generate dummy length information
        >> lengths = np.random.randint(1, 20, size=10000)
        >> sampler = BucketBatchSampler(batch_size=10, sort_lens=lengths)
        >> batch = list(sampler)[0]
        >> batch
        [7526, 8473, 9194, 1030, 1568, 4182, 3082, 827, 3688, 9336]
        >> [lengths[i] for i in batch]
        # All samples in the batch have same length
        [4, 4, 4, 4, 4, 4, 4, 4, 4, 4]

    """

    def __init__(self, batch_size, sort_lens,
                 max_len=None, store_indices=False, order=None):
        self.batch_size = batch_size
        self.max_len = max_len
        self.store_indices = store_indices
        self.n_rejects = 0
        self.order = order

        # fix the seed
        self._rng = np.random.RandomState(seed=12345)

        assert sort_lens is not None, \
            "BucketBatchSampler() received `sort_lens` == None"

        assert self.order in (None, 'ascending', 'descending'), \
            "order should be None, 'ascending' or 'descending'"

        # Buckets: sort_lens -> list of sample indices
        self.buckets = defaultdict(list)

        # Fill the buckets while optionally filtering out long sequences
        if self.max_len is not None:
            for idx, len_ in enumerate(sort_lens):
                if len_ <= self.max_len:
                    self.buckets[len_].append(idx)
                else:
                    self.n_rejects += 1
            logger.info('{} samples rejected because of length filtering @ {}'.format(
                self.n_rejects, self.max_len))
        else:
            # No length filtering
            for idx, len_ in enumerate(sort_lens):
                self.buckets[len_].append(idx)

        # Pre-compute how many times a bucket will be sampled
        self.bucket_idxs = []

        self.stats = {k: len(self.buckets[k]) for k in sorted(self.buckets)}

        for len_ in self.buckets:
            # Convery bucket to numpy array
            np_bucket = np.array(self.buckets[len_])

            # How many batches will be done for this bucket?
            bucket_bs = np_bucket.size / self.batch_size
            idxs = [len_] * math.ceil(bucket_bs)

            self.buckets[len_] = np_bucket
            self.bucket_idxs.extend(idxs)

        # Convert to numpy array
        self.bucket_idxs = np.array(self.bucket_idxs)

        # Set number of batches
        self.n_batches = len(self.bucket_idxs)

    def __iter__(self):
        # Keep offsets for each bucket for efficiency
        bucket_offsets = {}

        # Random access indices
        bucket_views = {}

        # If beam-search with ordered batches, original indices will be
        # necessary.
        self.orig_idxs = []

        # Create permuted access indices for each bucket
        # to avoid shuffling the lists
        for len_, elems in self.buckets.items():
            bucket_offsets[len_] = 0
            perms = self._rng.permutation(len(elems))
            bucket_views[len_] = perms

        if self.order is None:
            # Shuffle bucket order
            shuf_idxs = self._rng.permutation(self.bucket_idxs)
        elif self.order == "ascending":
            # Start from shortest sequences and increase
            shuf_idxs = np.sort(self.bucket_idxs)
        elif self.order == "descending":
            # Start from longest sequences and decrease
            shuf_idxs = -np.sort(-self.bucket_idxs)

        # For each bucket, slide the window to yield the next batch
        for bidx in shuf_idxs:
            # Get offset pointer for this bucket: 0 initially
            offset = bucket_offsets[bidx]

            # Convert them to permuted view
            idxs = bucket_views[bidx][offset: offset + self.batch_size]

            # Increment offset
            bucket_offsets[bidx] += len(idxs)

            # Get actual sample indices
            sidxs = self.buckets[bidx][idxs]

            if self.store_indices:
                self.orig_idxs.extend(sidxs)

            # Return sample indices
            yield sidxs

    def __len__(self):
        """Returns how many batches are inside."""
        return self.n_batches

    def __repr__(self):
        return f"BucketBatchSampler(order={self.order}, max_len={self.max_len}, n_rejects={self.n_rejects})"

Ancestors

  • torch.utils.data.sampler.Sampler
  • typing.Generic

Subclasses