Module pysimt.datasets.kaldi
Expand source code
from pathlib import Path
from tqdm import tqdm
import torch
from torch.utils.data import Dataset
from torch.nn.utils.rnn import pad_sequence
from ..utils.kaldi import readMatrixShape, readMatrixByOffset
class KaldiDataset(Dataset):
"""A PyTorch dataset for Kaldi .scp/ark.
Arguments:
fname (str or pathlib.Path): A string or ``Path`` object for a folder
which contains ``feats_local.scp`` and optionally
a ``segments.len`` file with segment lengths for effective
batching. If the latter is not found, the lengths will be
dynamically read, which is a rather slow operation.
"""
def __init__(self, fname, **kwargs):
self._data = []
self._lengths = []
self.root = Path(fname)
self.scp_path = self.root / 'feats_local.scp'
self.len_path = self.root / 'segments.len'
if not self.scp_path.exists():
raise RuntimeError('{} does not exist.'.format(self.scp_path))
if self.len_path.exists():
read_lengths = False
# Read lengths file
with open(self.len_path) as f:
for line in f:
self._lengths.append(int(line.strip()))
else:
# Read them below (this is slow)
read_lengths = True
with open(self.scp_path) as scp_input_file:
for line in tqdm(scp_input_file, unit='segments'):
uttid, pointer = line.strip().split()
arkfile, offset = pointer.rsplit(':', 1)
offset = int(offset)
self._data.append((arkfile, offset))
if read_lengths:
with open(arkfile, "rb") as g:
g.seek(offset)
feat_len = readMatrixShape(g)[0]
self._lengths.append(feat_len)
# Set dataset size
self.size = len(self._data)
if self.size != len(self._lengths):
raise RuntimeError("Dataset size and lengths size does not match.")
def collate(elems):
return pad_sequence([torch.FloatTensor(e) for e in elems])
def __getitem__(self, idx):
"""Read segment features from the actual .ark file."""
# an lru_cache() decorated version of readMatrixByOffset() will make
# sure that all data is cached into memory after 1 epoch.
return readMatrixByOffset(*self._data[idx])
def __len__(self):
return self.size
Classes
class KaldiDataset (fname, **kwargs)
-
A PyTorch dataset for Kaldi .scp/ark.
Arguments
fname (str or pathlib.Path): A string or
Path
object for a folder which containsfeats_local.scp
and optionally asegments.len
file with segment lengths for effective batching. If the latter is not found, the lengths will be dynamically read, which is a rather slow operation.Expand source code
class KaldiDataset(Dataset): """A PyTorch dataset for Kaldi .scp/ark. Arguments: fname (str or pathlib.Path): A string or ``Path`` object for a folder which contains ``feats_local.scp`` and optionally a ``segments.len`` file with segment lengths for effective batching. If the latter is not found, the lengths will be dynamically read, which is a rather slow operation. """ def __init__(self, fname, **kwargs): self._data = [] self._lengths = [] self.root = Path(fname) self.scp_path = self.root / 'feats_local.scp' self.len_path = self.root / 'segments.len' if not self.scp_path.exists(): raise RuntimeError('{} does not exist.'.format(self.scp_path)) if self.len_path.exists(): read_lengths = False # Read lengths file with open(self.len_path) as f: for line in f: self._lengths.append(int(line.strip())) else: # Read them below (this is slow) read_lengths = True with open(self.scp_path) as scp_input_file: for line in tqdm(scp_input_file, unit='segments'): uttid, pointer = line.strip().split() arkfile, offset = pointer.rsplit(':', 1) offset = int(offset) self._data.append((arkfile, offset)) if read_lengths: with open(arkfile, "rb") as g: g.seek(offset) feat_len = readMatrixShape(g)[0] self._lengths.append(feat_len) # Set dataset size self.size = len(self._data) if self.size != len(self._lengths): raise RuntimeError("Dataset size and lengths size does not match.") def collate(elems): return pad_sequence([torch.FloatTensor(e) for e in elems]) def __getitem__(self, idx): """Read segment features from the actual .ark file.""" # an lru_cache() decorated version of readMatrixByOffset() will make # sure that all data is cached into memory after 1 epoch. return readMatrixByOffset(*self._data[idx]) def __len__(self): return self.size
Ancestors
- torch.utils.data.dataset.Dataset
- typing.Generic
Methods
def collate(elems)
-
Expand source code
def collate(elems): return pad_sequence([torch.FloatTensor(e) for e in elems])