Module pysimt.datasets.text

Expand source code
import logging
from pathlib import Path
from typing import Tuple, List

import torch

from torch.utils.data import Dataset
from torch.nn.utils.rnn import pad_sequence

from ..utils.io import fopen, progress_bar
from ..vocabulary import Vocabulary

logger = logging.getLogger('pysimt')


class TextDataset(Dataset):
    """A convenience dataset for reading monolingual text files.

    Args:
        fname: A string or ``pathlib.Path`` object giving
            the corpus.
        vocab: A ``Vocabulary`` instance for the given corpus.
        bos: Optional; If ``True``, a special beginning-of-sentence
            `<bos>` marker will be prepended to sequences.
        eos: Optional; If ``True``, a special end-of-sentence
            `<eos>` marker will be appended to sequences.
    """

    def __init__(self, fname, vocab, bos=False, eos=True, **kwargs):
        self.path = Path(fname)
        self.vocab = vocab
        self.bos = bos
        self.eos = eos

        # Detect glob patterns
        self.fnames = sorted(self.path.parent.glob(self.path.name))

        if len(self.fnames) == 0:
            raise RuntimeError('{} does not exist.'.format(self.path))
        elif len(self.fnames) > 1:
            logger.info('Multiple files found, using first: {}'.format(self.fnames[0]))

        # Read the sentences and map them to vocabulary
        self.data, self.lengths = self.read_sentences(
            self.fnames[0], self.vocab, bos=self.bos, eos=self.eos)

        # Dataset size
        self.size = len(self.data)

    @staticmethod
    def read_sentences(fname: str,
                       vocab: Vocabulary,
                       bos: bool = False,
                       eos: bool = True) -> Tuple[List[List[int]], List[int]]:
        lines = []
        lens = []
        with fopen(fname) as f:
            for idx, line in enumerate(progress_bar(f, unit='sents')):
                line = line.strip()

                # Empty lines will cause a lot of headaches,
                # get rid of them during preprocessing!
                assert line, "Empty line (%d) found in %s" % (idx + 1, fname)

                # Map and append
                seq = vocab.sent_to_idxs(line, explicit_bos=bos, explicit_eos=eos)
                lines.append(seq)
                lens.append(len(seq))

        return lines, lens


    @staticmethod
    def to_torch(batch, **kwargs):
        return pad_sequence(
            [torch.tensor(b, dtype=torch.long) for b in batch], batch_first=False)

    def __getitem__(self, idx):
        return self.data[idx]

    def __len__(self):
        return self.size

    def __repr__(self):
        s = "{} '{}' ({} sentences)".format(
            self.__class__.__name__, self.fnames[0].name, self.__len__())
        return s

Classes

class TextDataset (fname, vocab, bos=False, eos=True, **kwargs)

A convenience dataset for reading monolingual text files.

Args

fname
A string or pathlib.Path object giving the corpus.
vocab
A Vocabulary instance for the given corpus.
bos
Optional; If True, a special beginning-of-sentence <bos> marker will be prepended to sequences.
eos
Optional; If True, a special end-of-sentence <eos> marker will be appended to sequences.
Expand source code
class TextDataset(Dataset):
    """A convenience dataset for reading monolingual text files.

    Args:
        fname: A string or ``pathlib.Path`` object giving
            the corpus.
        vocab: A ``Vocabulary`` instance for the given corpus.
        bos: Optional; If ``True``, a special beginning-of-sentence
            `<bos>` marker will be prepended to sequences.
        eos: Optional; If ``True``, a special end-of-sentence
            `<eos>` marker will be appended to sequences.
    """

    def __init__(self, fname, vocab, bos=False, eos=True, **kwargs):
        self.path = Path(fname)
        self.vocab = vocab
        self.bos = bos
        self.eos = eos

        # Detect glob patterns
        self.fnames = sorted(self.path.parent.glob(self.path.name))

        if len(self.fnames) == 0:
            raise RuntimeError('{} does not exist.'.format(self.path))
        elif len(self.fnames) > 1:
            logger.info('Multiple files found, using first: {}'.format(self.fnames[0]))

        # Read the sentences and map them to vocabulary
        self.data, self.lengths = self.read_sentences(
            self.fnames[0], self.vocab, bos=self.bos, eos=self.eos)

        # Dataset size
        self.size = len(self.data)

    @staticmethod
    def read_sentences(fname: str,
                       vocab: Vocabulary,
                       bos: bool = False,
                       eos: bool = True) -> Tuple[List[List[int]], List[int]]:
        lines = []
        lens = []
        with fopen(fname) as f:
            for idx, line in enumerate(progress_bar(f, unit='sents')):
                line = line.strip()

                # Empty lines will cause a lot of headaches,
                # get rid of them during preprocessing!
                assert line, "Empty line (%d) found in %s" % (idx + 1, fname)

                # Map and append
                seq = vocab.sent_to_idxs(line, explicit_bos=bos, explicit_eos=eos)
                lines.append(seq)
                lens.append(len(seq))

        return lines, lens


    @staticmethod
    def to_torch(batch, **kwargs):
        return pad_sequence(
            [torch.tensor(b, dtype=torch.long) for b in batch], batch_first=False)

    def __getitem__(self, idx):
        return self.data[idx]

    def __len__(self):
        return self.size

    def __repr__(self):
        s = "{} '{}' ({} sentences)".format(
            self.__class__.__name__, self.fnames[0].name, self.__len__())
        return s

Ancestors

  • torch.utils.data.dataset.Dataset
  • typing.Generic

Static methods

def read_sentences(fname: str, vocab: Vocabulary, bos: bool = False, eos: bool = True) ‑> Tuple[List[List[int]], List[int]]
Expand source code
@staticmethod
def read_sentences(fname: str,
                   vocab: Vocabulary,
                   bos: bool = False,
                   eos: bool = True) -> Tuple[List[List[int]], List[int]]:
    lines = []
    lens = []
    with fopen(fname) as f:
        for idx, line in enumerate(progress_bar(f, unit='sents')):
            line = line.strip()

            # Empty lines will cause a lot of headaches,
            # get rid of them during preprocessing!
            assert line, "Empty line (%d) found in %s" % (idx + 1, fname)

            # Map and append
            seq = vocab.sent_to_idxs(line, explicit_bos=bos, explicit_eos=eos)
            lines.append(seq)
            lens.append(len(seq))

    return lines, lens
def to_torch(batch, **kwargs)
Expand source code
@staticmethod
def to_torch(batch, **kwargs):
    return pad_sequence(
        [torch.tensor(b, dtype=torch.long) for b in batch], batch_first=False)