Module pysimt.utils.data
Expand source code
from torch.utils.data import DataLoader
def sort_predictions(data_loader, results):
"""Recovers the dataset order when bucketing samplers are used."""
if getattr(data_loader.batch_sampler, 'store_indices', False):
results = [results[i] for i, j in sorted(
enumerate(data_loader.batch_sampler.orig_idxs), key=lambda k: k[1])]
return results
def make_dataloader(dataset, pin_memory=False, num_workers=0):
return DataLoader(
dataset, batch_sampler=dataset.sampler,
collate_fn=dataset.collate_fn,
pin_memory=pin_memory, num_workers=num_workers)
Functions
def make_dataloader(dataset, pin_memory=False, num_workers=0)
-
Expand source code
def make_dataloader(dataset, pin_memory=False, num_workers=0): return DataLoader( dataset, batch_sampler=dataset.sampler, collate_fn=dataset.collate_fn, pin_memory=pin_memory, num_workers=num_workers)
def sort_predictions(data_loader, results)
-
Recovers the dataset order when bucketing samplers are used.
Expand source code
def sort_predictions(data_loader, results): """Recovers the dataset order when bucketing samplers are used.""" if getattr(data_loader.batch_sampler, 'store_indices', False): results = [results[i] for i, j in sorted( enumerate(data_loader.batch_sampler.orig_idxs), key=lambda k: k[1])] return results