Module pysimt.datasets.base
Expand source code
import json
from abc import ABCMeta, abstractmethod
class BaseDataset(metaclass=ABCMeta):
r"""Abstract dataset class that all nmtpytorch datasets should be
extending.
A sane enough `__repr__` is provided which will neatly dump the
instance attributes not prefixed with `_`. So make sure to hide
non-public attributes by prefixing them with `_` to avoid cluttering
on the terminal. You can also redefine this method in the deriving
classes.
"""
requires_vocab = False
@abstractmethod
def __len__(self):
pass
@abstractmethod
def __getitem__(self, idx):
"""Returns the `idx`'th item of the dataset."""
pass
@abstractmethod
def collate(self, elems):
"""Collates a list of elements into a `torch.Tensor`."""
pass
def get_batch_tensor(self, idxs):
"""Collates by looking up elements of the `idxs` list."""
return self.collate([self.__getitem__(idx) for idx in idxs])
def __repr__(self):
public_attrs = [s for s in self.__dict__.keys() if s[0] != '_']
info_dict = {k: self.__dict__[k] for k in public_attrs}
public_str = json.dumps(
info_dict, indent=1, default=lambda obj: str(obj))
return f"\n{self.__class__.__name__} {public_str}\n"
Classes
class BaseDataset
-
Abstract dataset class that all nmtpytorch datasets should be extending.
A sane enough
__repr__
is provided which will neatly dump the instance attributes not prefixed with_
. So make sure to hide non-public attributes by prefixing them with_
to avoid cluttering on the terminal. You can also redefine this method in the deriving classes.Expand source code
class BaseDataset(metaclass=ABCMeta): r"""Abstract dataset class that all nmtpytorch datasets should be extending. A sane enough `__repr__` is provided which will neatly dump the instance attributes not prefixed with `_`. So make sure to hide non-public attributes by prefixing them with `_` to avoid cluttering on the terminal. You can also redefine this method in the deriving classes. """ requires_vocab = False @abstractmethod def __len__(self): pass @abstractmethod def __getitem__(self, idx): """Returns the `idx`'th item of the dataset.""" pass @abstractmethod def collate(self, elems): """Collates a list of elements into a `torch.Tensor`.""" pass def get_batch_tensor(self, idxs): """Collates by looking up elements of the `idxs` list.""" return self.collate([self.__getitem__(idx) for idx in idxs]) def __repr__(self): public_attrs = [s for s in self.__dict__.keys() if s[0] != '_'] info_dict = {k: self.__dict__[k] for k in public_attrs} public_str = json.dumps( info_dict, indent=1, default=lambda obj: str(obj)) return f"\n{self.__class__.__name__} {public_str}\n"
Subclasses
Class variables
var requires_vocab
Methods
def collate(self, elems)
-
Collates a list of elements into a
torch.Tensor
.Expand source code
@abstractmethod def collate(self, elems): """Collates a list of elements into a `torch.Tensor`.""" pass
def get_batch_tensor(self, idxs)
-
Collates by looking up elements of the
idxs
list.Expand source code
def get_batch_tensor(self, idxs): """Collates by looking up elements of the `idxs` list.""" return self.collate([self.__getitem__(idx) for idx in idxs])