Module pysimt.datasets.collate
Expand source code
class Batch(dict):
"""A custom dictionary representing a batch."""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
dim1s = set([x.size(1) for x in self.values()])
assert len(dim1s) == 1, \
"Incompatible batch dimension (1) between modalities."
self.size = dim1s.pop()
def device(self, device):
self.update({k: v.to(device) for k, v in self.items()})
def __repr__(self):
s = "Batch(size={})\n".format(self.size)
for data_source, tensor in self.items():
s += " {:10s} -> {} - {}\n".format(
str(data_source), tensor.shape, tensor.device)
return s
def get_collate(data_sources):
"""Returns a special collate_fn which will view the underlying data
in terms of the given DataSource keys."""
def collate_fn(batch):
return Batch(
{ds: ds.torchify([elem[ds] for elem in batch]) for ds in data_sources},
)
return collate_fn
Functions
def get_collate(data_sources)
-
Returns a special collate_fn which will view the underlying data in terms of the given DataSource keys.
Expand source code
def get_collate(data_sources): """Returns a special collate_fn which will view the underlying data in terms of the given DataSource keys.""" def collate_fn(batch): return Batch( {ds: ds.torchify([elem[ds] for elem in batch]) for ds in data_sources}, ) return collate_fn
Classes
class Batch (*args, **kwargs)
-
A custom dictionary representing a batch.
Expand source code
class Batch(dict): """A custom dictionary representing a batch.""" def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) dim1s = set([x.size(1) for x in self.values()]) assert len(dim1s) == 1, \ "Incompatible batch dimension (1) between modalities." self.size = dim1s.pop() def device(self, device): self.update({k: v.to(device) for k, v in self.items()}) def __repr__(self): s = "Batch(size={})\n".format(self.size) for data_source, tensor in self.items(): s += " {:10s} -> {} - {}\n".format( str(data_source), tensor.shape, tensor.device) return s
Ancestors
- builtins.dict
Methods
def device(self, device)
-
Expand source code
def device(self, device): self.update({k: v.to(device) for k, v in self.items()})