Module pysimt.datasets.kaldi

Expand source code
from pathlib import Path
from tqdm import tqdm

import torch
from torch.utils.data import Dataset
from torch.nn.utils.rnn import pad_sequence

from ..utils.kaldi import readMatrixShape, readMatrixByOffset



class KaldiDataset(Dataset):
    """A PyTorch dataset for Kaldi .scp/ark.

    Arguments:
        fname (str or pathlib.Path): A string or ``Path`` object for a folder
            which contains ``feats_local.scp`` and optionally
            a ``segments.len`` file with segment lengths for effective
            batching. If the latter is not found, the lengths will be
            dynamically read, which is a rather slow operation.
    """

    def __init__(self, fname, **kwargs):
        self._data = []
        self._lengths = []
        self.root = Path(fname)
        self.scp_path = self.root / 'feats_local.scp'
        self.len_path = self.root / 'segments.len'

        if not self.scp_path.exists():
            raise RuntimeError('{} does not exist.'.format(self.scp_path))

        if self.len_path.exists():
            read_lengths = False
            # Read lengths file
            with open(self.len_path) as f:
                for line in f:
                    self._lengths.append(int(line.strip()))
        else:
            # Read them below (this is slow)
            read_lengths = True

        with open(self.scp_path) as scp_input_file:
            for line in tqdm(scp_input_file, unit='segments'):
                uttid, pointer = line.strip().split()
                arkfile, offset = pointer.rsplit(':', 1)
                offset = int(offset)
                self._data.append((arkfile, offset))
                if read_lengths:
                    with open(arkfile, "rb") as g:
                        g.seek(offset)
                        feat_len = readMatrixShape(g)[0]

                    self._lengths.append(feat_len)

        # Set dataset size
        self.size = len(self._data)

        if self.size != len(self._lengths):
            raise RuntimeError("Dataset size and lengths size does not match.")

    def collate(elems):
        return pad_sequence([torch.FloatTensor(e) for e in elems])

    def __getitem__(self, idx):
        """Read segment features from the actual .ark file."""
        # an lru_cache() decorated version of readMatrixByOffset() will make
        # sure that all data is cached into memory after 1 epoch.
        return readMatrixByOffset(*self._data[idx])

    def __len__(self):
        return self.size

Classes

class KaldiDataset (fname, **kwargs)

A PyTorch dataset for Kaldi .scp/ark.

Arguments

fname (str or pathlib.Path): A string or Path object for a folder which contains feats_local.scp and optionally a segments.len file with segment lengths for effective batching. If the latter is not found, the lengths will be dynamically read, which is a rather slow operation.

Expand source code
class KaldiDataset(Dataset):
    """A PyTorch dataset for Kaldi .scp/ark.

    Arguments:
        fname (str or pathlib.Path): A string or ``Path`` object for a folder
            which contains ``feats_local.scp`` and optionally
            a ``segments.len`` file with segment lengths for effective
            batching. If the latter is not found, the lengths will be
            dynamically read, which is a rather slow operation.
    """

    def __init__(self, fname, **kwargs):
        self._data = []
        self._lengths = []
        self.root = Path(fname)
        self.scp_path = self.root / 'feats_local.scp'
        self.len_path = self.root / 'segments.len'

        if not self.scp_path.exists():
            raise RuntimeError('{} does not exist.'.format(self.scp_path))

        if self.len_path.exists():
            read_lengths = False
            # Read lengths file
            with open(self.len_path) as f:
                for line in f:
                    self._lengths.append(int(line.strip()))
        else:
            # Read them below (this is slow)
            read_lengths = True

        with open(self.scp_path) as scp_input_file:
            for line in tqdm(scp_input_file, unit='segments'):
                uttid, pointer = line.strip().split()
                arkfile, offset = pointer.rsplit(':', 1)
                offset = int(offset)
                self._data.append((arkfile, offset))
                if read_lengths:
                    with open(arkfile, "rb") as g:
                        g.seek(offset)
                        feat_len = readMatrixShape(g)[0]

                    self._lengths.append(feat_len)

        # Set dataset size
        self.size = len(self._data)

        if self.size != len(self._lengths):
            raise RuntimeError("Dataset size and lengths size does not match.")

    def collate(elems):
        return pad_sequence([torch.FloatTensor(e) for e in elems])

    def __getitem__(self, idx):
        """Read segment features from the actual .ark file."""
        # an lru_cache() decorated version of readMatrixByOffset() will make
        # sure that all data is cached into memory after 1 epoch.
        return readMatrixByOffset(*self._data[idx])

    def __len__(self):
        return self.size

Ancestors

  • torch.utils.data.dataset.Dataset
  • typing.Generic

Methods

def collate(elems)
Expand source code
def collate(elems):
    return pad_sequence([torch.FloatTensor(e) for e in elems])