diff --git a/torchvision/datasets/__init__.py b/torchvision/datasets/__init__.py index 2bb81a693d4..26d84bbe2cc 100644 --- a/torchvision/datasets/__init__.py +++ b/torchvision/datasets/__init__.py @@ -13,6 +13,7 @@ from .flickr import Flickr8k, Flickr30k from .voc import VOCSegmentation, VOCDetection from .cityscapes import Cityscapes +from .dtd import DTD __all__ = ('LSUN', 'LSUNClass', 'ImageFolder', 'DatasetFolder', 'FakeData', @@ -20,4 +21,4 @@ 'CIFAR10', 'CIFAR100', 'EMNIST', 'FashionMNIST', 'MNIST', 'KMNIST', 'STL10', 'SVHN', 'PhotoTour', 'SEMEION', 'Omniglot', 'SBU', 'Flickr8k', 'Flickr30k', - 'VOCSegmentation', 'VOCDetection', 'Cityscapes') + 'VOCSegmentation', 'VOCDetection', 'Cityscapes', 'DTD') diff --git a/torchvision/datasets/dtd.py b/torchvision/datasets/dtd.py new file mode 100644 index 00000000000..acde10fe25d --- /dev/null +++ b/torchvision/datasets/dtd.py @@ -0,0 +1,138 @@ +import os +from .folder import ImageFolder +from torch.utils import data +from .utils import download_url, check_integrity + + +class FullDTD(ImageFolder): + """Full `DTD `_ Dataset. + + Args: + root (string): Root directory of dataset where directory + ``dtd`` exists. + download (bool, optional): If true, downloads the dataset from the + internet and puts it in root directory. If dataset is already + downloaded, it is not downloaded again. + transform (callable, optional): A function/transform that takes in an + PIL image and returns a transformed version. E.g, + ``transforms.RandomCrop`` + target_transform (callable, optional): A function/transform that takes + in the target and transforms it. + loader (callable, optional): A function to load an image given its path. + Attributes: + classes (list): List of the class names. + class_to_idx (dict): Dict with items (class_name, class_index). + imgs (list): List of (image path, class_index) tuples + """ + image_folder = os.path.join('dtd', 'images') + label_folder = os.path.join('dtd', 'labels') + url = 'https://www.robots.ox.ac.uk/~vgg/data/dtd/download/dtd-r1.0.1.tar.gz' + filename = 'dtd-r1.0.1.tar.gz' + tgz_md5 = 'fff73e5086ae6bdbea199a49dfb8a4c1' + + def __init__(self, root, download=False, **kwargs): + root = self.root = os.path.expanduser(root) + + if download: + self.download() + + super(FullDTD, self).__init__(os.path.join(self.root, + self.image_folder), + **kwargs) + # super class sets this to the root of the image folder, which is inside + # the data folder + self.root = root + + def download(self): + import tarfile + + if not check_integrity(os.path.join(self.root, self.filename), + self.tgz_md5): + download_url(self.url, self.root, self.filename, + self.tgz_md5) + + cwd = os.getcwd() + tar = tarfile.open(os.path.join(self.root, self.filename), + "r:gz") + os.chdir(self.root) + tar.extractall() + tar.close() + os.chdir(cwd) + + def __repr__(self): + fmt_str = 'Dataset ' + self.__class__.__name__ + '\n' + fmt_str += ' Number of datapoints: {}\n'.format(self.__len__()) + fmt_str += ' Root Location: {}\n'.format(self.root) + tmp = ' Transforms (if any): ' + transform_repr = self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)) + fmt_str += '{0}{1}\n'.format(tmp, transform_repr) + tmp = ' Target Transforms (if any): ' + transform_repr = self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)) + fmt_str += '{0}{1}'.format(tmp, transform_repr) + return fmt_str + + +class DTD(data.Subset): + """`DTD `_ Dataset. + + Args: + root (string): Root directory of dataset where directory + ``dtd`` exists. + split (string, optional): The image split to use, ``train``, ``test`` + or ``val`` + fold (int, optional): The image fold to use, ``[1 ... 10]`` + download (bool, optional): If true, downloads the dataset from the + internet and puts it in root directory. If dataset is already + downloaded, it is not downloaded again. + transform (callable, optional): A function/transform that takes in an + PIL image and returns a transformed version. E.g, + ``transforms.RandomCrop`` + target_transform (callable, optional): A function/transform that takes + in the target and transforms it. + loader (callable, optional): A function to load an image given its path. + Attributes: + classes (list): List of the class names. + class_to_idx (dict): Dict with items (class_name, class_index). + split (string): image split + fold (int): image fold + """ + def __init__(self, root, split='train', fold=1, **kwargs): + assert split in ('train', 'val', 'test'), \ + "split should be train, val or test" + self.split = split + + assert fold in range(1, 11), "fold should be integer in [1, 10]" + self.fold = fold + + dataset = FullDTD(root, **kwargs) + indices = self._make_indices(dataset) + super(DTD, self).__init__(dataset, indices) + + def _make_indices(self, dataset): + image_folder = os.path.join(dataset.root, dataset.image_folder) + image_paths = [path for path, target in dataset.imgs] + + label_folder = os.path.join(dataset.root, dataset.label_folder) + file_name = '{}{}.txt'.format(self.split, self.fold) + file_path = os.path.join(label_folder, file_name) + with open(file_path, 'r') as f: + image_paths_subset = f.read() + image_paths_subset = [os.path.join(image_folder, path) + for path in image_paths_subset.splitlines()] + + return [image_paths.index(image_path_subset) for + image_path_subset in image_paths_subset] + + def __repr__(self): + fmt_str = 'Dataset ' + self.__class__.__name__ + '\n' + fmt_str += ' Number of datapoints: {}\n'.format(self.__len__()) + fmt_str += ' Split: {}\n'.format(self.split) + fmt_str += ' Fold: {}\n'.format(self.fold) + fmt_str += ' Root Location: {}\n'.format(self.dataset.root) + tmp = ' Transforms (if any): ' + transform_repr = self.dataset.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)) + fmt_str += '{0}{1}\n'.format(tmp, transform_repr) + tmp = ' Target Transforms (if any): ' + transform_repr = self.dataset.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)) + fmt_str += '{0}{1}'.format(tmp, transform_repr) + return fmt_str