Module pysimt.datasets.imagefolder
Expand source code
from functools import lru_cache
from pathlib import Path
from PIL import Image
import torch
from torchvision import transforms
from .base import BaseDataset
class ImageFolderDataset(BaseDataset):
"""A variant of torchvision.datasets.ImageFolder which drops support for
target loading, i.e. this only loads images not attached to any other
label.
This class also makes use of ``lru_cache`` to cache an image file once
opened to avoid repetitive disk access.
Arguments:
root (str): The root folder that contains the images and index.txt
resize (int, optional): An optional integer to be given to
``torchvision.transforms.Resize``. Default: ``None``.
crop (int, optional): An optional integer to be given to
``torchvision.transforms.CenterCrop``. Default: ``None``.
replicate(int, optional): Replicate the image names ``replicate``
times in order to process the same image ``replicate`` times
if ``replicate`` sentences are available during training time.
warmup(bool, optional): If ``True``, the images will be read once
at the beginning to fill the cache.
"""
def __init__(self, root, resize=None, crop=None,
replicate=1, warmup=False):
self.root = Path(root).expanduser().resolve()
self.replicate = replicate
self.warmup = warmup
self._image_files = []
# Image list in dataset order
self.index = self.root / 'index.txt'
assert self.index.exists(), f"{self.index} could not be found"
_transforms = []
if resize is not None:
_transforms.append(transforms.Resize(resize))
if crop is not None:
_transforms.append(transforms.CenterCrop(crop))
_transforms.append(transforms.ToTensor())
_transforms.append(
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]))
self.transform = transforms.Compose(_transforms)
with self.index.open() as f:
for fname in f:
fname = self.root / fname.strip()
assert fname.exists(), "{} does not exist.".format(fname)
self._image_files.append(str(fname))
# Setup reader
self.read_image = lru_cache(maxsize=self.__len__())(self._read_image)
if self.warmup:
for idx in range(self.__len__()):
self[idx]
# Replicate the list if requested
self._image_files = self._image_files * self.replicate
def _read_image(self, fname):
with open(fname, 'rb') as f:
img = Image.open(f).convert('RGB')
return self.transform(img)
def collate(elems):
return torch.stack(elems)
def __getitem__(self, idx):
return self.read_image(self._image_files[idx])
def __len__(self):
return len(self._image_files)
Classes
class ImageFolderDataset (root, resize=None, crop=None, replicate=1, warmup=False)
-
A variant of torchvision.datasets.ImageFolder which drops support for target loading, i.e. this only loads images not attached to any other label.
This class also makes use of
lru_cache
to cache an image file once opened to avoid repetitive disk access.Arguments
root (str): The root folder that contains the images and index.txt resize (int, optional): An optional integer to be given to
torchvision.transforms.Resize
. Default:None
. crop (int, optional): An optional integer to be given totorchvision.transforms.CenterCrop
. Default:None
. replicate(int, optional): Replicate the image namesreplicate
times in order to process the same imagereplicate
times ifreplicate
sentences are available during training time. warmup(bool, optional): IfTrue
, the images will be read once at the beginning to fill the cache.Expand source code
class ImageFolderDataset(BaseDataset): """A variant of torchvision.datasets.ImageFolder which drops support for target loading, i.e. this only loads images not attached to any other label. This class also makes use of ``lru_cache`` to cache an image file once opened to avoid repetitive disk access. Arguments: root (str): The root folder that contains the images and index.txt resize (int, optional): An optional integer to be given to ``torchvision.transforms.Resize``. Default: ``None``. crop (int, optional): An optional integer to be given to ``torchvision.transforms.CenterCrop``. Default: ``None``. replicate(int, optional): Replicate the image names ``replicate`` times in order to process the same image ``replicate`` times if ``replicate`` sentences are available during training time. warmup(bool, optional): If ``True``, the images will be read once at the beginning to fill the cache. """ def __init__(self, root, resize=None, crop=None, replicate=1, warmup=False): self.root = Path(root).expanduser().resolve() self.replicate = replicate self.warmup = warmup self._image_files = [] # Image list in dataset order self.index = self.root / 'index.txt' assert self.index.exists(), f"{self.index} could not be found" _transforms = [] if resize is not None: _transforms.append(transforms.Resize(resize)) if crop is not None: _transforms.append(transforms.CenterCrop(crop)) _transforms.append(transforms.ToTensor()) _transforms.append( transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])) self.transform = transforms.Compose(_transforms) with self.index.open() as f: for fname in f: fname = self.root / fname.strip() assert fname.exists(), "{} does not exist.".format(fname) self._image_files.append(str(fname)) # Setup reader self.read_image = lru_cache(maxsize=self.__len__())(self._read_image) if self.warmup: for idx in range(self.__len__()): self[idx] # Replicate the list if requested self._image_files = self._image_files * self.replicate def _read_image(self, fname): with open(fname, 'rb') as f: img = Image.open(f).convert('RGB') return self.transform(img) def collate(elems): return torch.stack(elems) def __getitem__(self, idx): return self.read_image(self._image_files[idx]) def __len__(self): return len(self._image_files)
Ancestors
Inherited members