Skip to content

Add vggface2 dataset #2910

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 23 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
53e0ffd
add vggface dataset class
jgbradley1 Oct 26, 2020
5e75f3c
fix flake8 errors
jgbradley1 Oct 26, 2020
79a8b49
Merge branch 'master' into add-vggface2-dataset
jgbradley1 Oct 26, 2020
dc36580
add standard dataset arguments
jgbradley1 Oct 26, 2020
ab31f88
fix code formatting and standardize dataset class
jgbradley1 Oct 27, 2020
9d3590d
add dataset citation
jgbradley1 Oct 27, 2020
0ea263a
more formatting fixes
jgbradley1 Oct 27, 2020
cdfc612
Merge branch 'master' into add-vggface2-dataset
jgbradley1 Oct 27, 2020
7bb168e
docstring update
jgbradley1 Oct 30, 2020
d7da9f6
xMerge branch 'master' into add-vggface2-dataset
jgbradley1 Oct 30, 2020
7510edf
formatting update
jgbradley1 Oct 30, 2020
170eff1
remove unused variable
jgbradley1 Oct 30, 2020
d8d5e02
use double quoted strings
jgbradley1 Oct 31, 2020
58dfd04
add vggface2 unit test
jgbradley1 Oct 31, 2020
f464212
Merge branch 'master' into add-vggface2-dataset
jgbradley1 Oct 31, 2020
baf22b7
add pandas check
jgbradley1 Oct 31, 2020
1bb70df
Merge branch 'add-vggface2-dataset' of github.com:jgbradley1/vision i…
jgbradley1 Oct 31, 2020
1256a35
add docstring to vggface fakedata generator
jgbradley1 Oct 31, 2020
941682b
fix docstring indentation
jgbradley1 Nov 1, 2020
a392514
use local variable scope and fixed minor docstring formatting
jgbradley1 Nov 2, 2020
7089156
minor style fixes
jgbradley1 Nov 12, 2020
a363104
Merge branch 'master' into add-vggface2-dataset
jgbradley1 Nov 12, 2020
7a03699
Merge branch 'master' into add-vggface2-dataset
jgbradley1 Nov 21, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 108 additions & 0 deletions test/fakedata_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,114 @@ def _make_devkit_archive(root):
yield root


@contextlib.contextmanager
def vggface2_root():
"""
Generates a dataset with the following folder structure and returns the path root:
<root>
└── vggface2
├── bb_landmark.tar.gz ('bb_landmark' when uncompressed)
├── vggface2_train.tar.gz ('train' when uncompressed)
├── vggface2_test.tar.gz ('test' when uncompressed)
├── train_list.txt
└── test_list.txt

The dataset consist of 1 image in the train set and 1 image in the test set.
"""

class_id = 'n000001'
image_id = '0001'
face_id = '01'
basefolder = 'vggface2'

def _make_image(file):
PIL.Image.fromarray(np.zeros((32, 32, 3), dtype=np.uint8)).save(file)

def _make_tar(archive, content, arcname=None, compress=False):
mode = 'w:gz' if compress else 'w'
if arcname is None:
arcname = os.path.basename(content)
with tarfile.open(archive, mode) as fh:
fh.add(content, arcname=arcname)

def _make_image_list_files(root):
image_list_contents = os.path.join(class_id, image_id + '_' + face_id + '.jpg') # e.g. n000001/0001_01.jpg
# train image list
image_list_file = os.path.join(root, "train_list.txt")
with open(image_list_file, "w") as txt_file:
txt_file.write(image_list_contents)
# test image list
image_list_file = os.path.join(root, "test_list.txt")
with open(image_list_file, "w") as txt_file:
txt_file.write(image_list_contents)

def _make_train_archive(root):
with get_tmp_dir() as tmp:
extracted_dir = os.path.join(tmp, 'train', class_id)
os.makedirs(extracted_dir)
print("EXTRACTED_DIR: " + extracted_dir)
_make_image(os.path.join(extracted_dir, image_id + '_' + face_id + '.jpg'))
train_archive = os.path.join(root, 'vggface2_train.tar.gz')
top_level_dir = os.path.join(tmp, 'train')
_make_tar(train_archive, top_level_dir, arcname='train', compress=True)

def _make_test_archive(root):
with get_tmp_dir() as tmp:
extracted_dir = os.path.join(tmp, 'test', class_id)
os.makedirs(extracted_dir)
_make_image(os.path.join(extracted_dir, image_id + '_' + face_id + '.jpg'))
test_archive = os.path.join(root, 'vggface2_test.tar.gz')
top_level_dir = os.path.join(tmp, 'test')
_make_tar(test_archive, top_level_dir, arcname='test', compress=True)

def _make_bb_landmark_archive(root):
train_bbox_contents = 'NAME_ID,X,Y,W,H\n"n000001/0001_01",161,140,224,324'
test_bbox_contents = 'NAME_ID,X,Y,W,H\n"n000001/0001_01",161,140,224,324'
train_landmark_contents = ('NAME_ID,P1X,P1Y,P2X,P2Y,P3X,P3Y,P4X,P4Y,P5X,P5Y\n'
'"n000001/0001_01",75.81253,110.2077,103.1778,104.6074,'
'90.06353,133.3624,85.39182,149.4176,114.9009,144.9259')
test_landmark_contents = ('NAME_ID,P1X,P1Y,P2X,P2Y,P3X,P3Y,P4X,P4Y,P5X,P5Y\n'
'"n000001/0001_01",75.81253,110.2077,103.1778,104.6074,'
'90.06353,133.3624,85.39182,149.4176,114.9009,144.9259')

with get_tmp_dir() as tmp:
extracted_dir = os.path.join(tmp, 'bb_landmark')
os.makedirs(extracted_dir)

# bbox training file
bbox_file = os.path.join(extracted_dir, "loose_bb_train.csv")
with open(bbox_file, "w") as csv_file:
csv_file.write(train_bbox_contents)

# bbox testing file
bbox_file = os.path.join(extracted_dir, "loose_bb_test.csv")
with open(bbox_file, "w") as csv_file:
csv_file.write(test_bbox_contents)

# landmark training file
landmark_file = os.path.join(extracted_dir, "loose_landmark_train.csv")
with open(landmark_file, "w") as csv_file:
csv_file.write(train_landmark_contents)

# landmark testing file
landmark_file = os.path.join(extracted_dir, "loose_landmark_test.csv")
with open(landmark_file, "w") as csv_file:
csv_file.write(test_landmark_contents)

archive = os.path.join(root, 'bb_landmark.tar.gz')
_make_tar(archive, extracted_dir, compress=True)

with get_tmp_dir() as root:
root_base = os.path.join(root, basefolder)
os.makedirs(root_base)
_make_train_archive(root_base)
_make_test_archive(root_base)
_make_image_list_files(root_base)
_make_bb_landmark_archive(root_base)

yield root


@contextlib.contextmanager
def cityscapes_root():

Expand Down
21 changes: 20 additions & 1 deletion test/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,18 @@
import torchvision
from common_utils import get_tmp_dir
from fakedata_generation import mnist_root, cifar_root, imagenet_root, \
cityscapes_root, svhn_root, voc_root, ucf101_root, places365_root
cityscapes_root, svhn_root, voc_root, ucf101_root, places365_root, vggface2_root
import xml.etree.ElementTree as ET
from urllib.request import Request, urlopen
import itertools


try:
import pandas
HAS_PANDAS = True
except ImportError:
HAS_PANDAS = False

try:
import scipy
HAS_SCIPY = True
Expand Down Expand Up @@ -139,6 +145,19 @@ def test_imagenet(self, mock_verify):
dataset = torchvision.datasets.ImageNet(root, split='val')
self.generic_classification_dataset_test(dataset)

@unittest.skipIf(not HAS_PANDAS, "pandas unavailable")
def test_vggface2(self):
with vggface2_root() as root:
dataset = torchvision.datasets.VGGFace2(root, split='train')
self.assertEqual(len(dataset), 1)
img, target = dataset[0]
self.assertTrue(isinstance(img, PIL.Image.Image))

dataset = torchvision.datasets.VGGFace2(root, split='test')
self.assertEqual(len(dataset), 1)
img, target = dataset[0]
self.assertTrue(isinstance(img, PIL.Image.Image))

@mock.patch('torchvision.datasets.cifar.check_integrity')
@mock.patch('torchvision.datasets.cifar.CIFAR10._check_integrity')
def test_cifar10(self, mock_ext_check, mock_int_check):
Expand Down
3 changes: 2 additions & 1 deletion torchvision/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from .celeba import CelebA
from .sbd import SBDataset
from .vision import VisionDataset
from .vggface2 import VGGFace2
from .usps import USPS
from .kinetics import Kinetics400
from .hmdb51 import HMDB51
Expand All @@ -32,4 +33,4 @@
'Omniglot', 'SBU', 'Flickr8k', 'Flickr30k',
'VOCSegmentation', 'VOCDetection', 'Cityscapes', 'ImageNet',
'Caltech101', 'Caltech256', 'CelebA', 'SBDataset', 'VisionDataset',
'USPS', 'Kinetics400', 'HMDB51', 'UCF101', 'Places365')
'VGGFace2', 'USPS', 'Kinetics400', 'HMDB51', 'UCF101', 'Places365')
157 changes: 157 additions & 0 deletions torchvision/datasets/vggface2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
from functools import partial
from PIL import Image
import os
import torch
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from .utils import check_integrity, extract_archive, verify_str_arg
from .vision import VisionDataset


class VGGFace2(VisionDataset):
""" VGGFace2 <http://zeus.robots.ox.ac.uk/vgg_face2/>`_ Dataset.

Citation:
@inproceedings{Cao18,
author = "Cao, Q. and Shen, L. and Xie, W. and Parkhi, O. M. and Zisserman, A.",
title = "VGGFace2: A dataset for recognising faces across pose and age",
booktitle = "International Conference on Automatic Face and Gesture Recognition",
year = "2018"}

Args:
root (string): Root directory of the VGGFace2 Dataset.
Expects the following folder structure if download=False:
<root>
└── vggface2
├── bb_landmark.tar.gz (or 'bb_landmark' if uncompressed)
├── vggface2_train.tar.gz (or 'train' if uncompressed)
├── vggface2_test.tar.gz (or 'test' if uncompressed)
├── train_list.txt
└── test_list.txt
split (string): The dataset split to use. One of {``train``, ``test``}.
Defaults to ``train``.
target_type (string): The type of target to use. One of
{``class_id``, ``image_id``, ``face_id``, ``bbox``, ``landmarks``.``""``}
Can also be a list to output a tuple with all specified target types.
The targets represent:
``class_id`` (string)
``image_id`` (string)
``face_id`` (string)
``bbox`` (torch.tensor shape=(4,) dtype=int): bounding box (x, y, width, height)
``landmarks`` (torch.tensor shape=(10,) dtype=float): values that
represent five points (P1X, P1Y, P2X, P2Y, P3X, P3Y, P4X, P4Y, P5X, P5Y)
Defaults to ``bbox``. If empty, ``None`` will be returned as target.
transform (callable, optional): A function/transform that takes in a PIL image
and returns a transformed version. E.g, ``transforms.RandomCrop``
target_transform (callable, optional): A function/transform that takes in the
target and transforms it.
download (bool, optional): If true, downloads the dataset from the internet and
puts it in root directory. If dataset is already downloaded, it is not
downloaded again.
"""

BASE_FOLDER = "vggface2"
FILE_LIST = [
# Filename MD5 Hash Uncompressed filename
("vggface2_train.tar.gz", "88813c6b15de58afc8fa75ea83361d7f", "train"),
("vggface2_test.tar.gz", "bb7a323824d1004e14e00c23974facd3", "test"),
("bb_landmark.tar.gz", "26f7ba288a782862d137348a1cb97540", "bb_landmark")
]

def __init__(
self,
root: str,
split: str = "train",
target_type: Union[List[str], str] = "bbox",
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
download: bool = False,
) -> None:
import pandas
super(VGGFace2, self).__init__(root=os.path.join(root, self.BASE_FOLDER),
transform=transform,
target_transform=target_transform)

# stay consistent with other datasets and check for a download option
if download:
msg = ("The dataset is not publicly accessible. You must login and "
"download the archives externally and place them in the root "
"directory.")
raise RuntimeError(msg)

# check arguments
self.split = verify_str_arg(split, "split", ("train", "test"))
self.img_info: List[Dict[str, object]] = []

if isinstance(target_type, list):
self.target_type = target_type
else:
self.target_type = [target_type]
self.target_type = [verify_str_arg(t, "target_type",
("class_id", "image_id", "face_id", "bbox", "landmarks", ""))
for t in self.target_type]

if not self.target_type and self.target_transform is not None:
raise RuntimeError("target_transform is specified but target_type is empty")

image_list_file = "train_list.txt" if self.split == "train" else "test_list.txt"
image_list_file = os.path.join(self.root, image_list_file)

# prepare dataset
for (filename, _, extracted_dir) in self.FILE_LIST:
filename = os.path.join(self.root, filename)
extracted_dir_path = os.path.join(self.root, extracted_dir)
if not os.path.isdir(extracted_dir_path):
extract_archive(filename)

# process dataset
fn = partial(os.path.join, self.root, self.FILE_LIST[2][2])
bbox_frames = [pandas.read_csv(fn("loose_bb_train.csv"), index_col=0),
pandas.read_csv(fn("loose_bb_test.csv"), index_col=0)]
self.bbox = pandas.concat(bbox_frames)
landmark_frames = [pandas.read_csv(fn("loose_landmark_train.csv"), index_col=0),
pandas.read_csv(fn("loose_landmark_test.csv"), index_col=0)]
self.landmarks = pandas.concat(landmark_frames)

with open(image_list_file, 'r') as f:
for img_file in f:
img_file = img_file.strip()
img_filename, ext = os.path.splitext(img_file) # e.g. ["n004332/0317_01", "jpg"]
class_id, image_face_id = img_filename.split("/") # e.g. ["n004332", "0317_01"]
class_id = class_id[1:]
image_id, face_id = image_face_id.split("_")
img_filepath = os.path.join(self.root, self.split, img_file)
self.img_info.append({
"img_path": img_filepath,
"class_id": class_id,
"image_id": image_id,
"face_id": face_id,
"bbox": torch.tensor(self.bbox.loc[img_filename].values),
"landmarks": torch.tensor(self.landmarks.loc[img_filename].values),
})

def __len__(self) -> int:
return len(self.img_info)

def __getitem__(self, index) -> Tuple[Any, Any]:
# prepare image
img = Image.open(self.img_info[index]["img_path"])
if self.transform:
img = self.transform(img)

# prepare target
target: Any = []
for t in self.target_type:
if t == "":
target = None
break
target.append(self.img_info[index][t])
if target:
target = tuple(target) if len(target) > 1 else target[0]
if self.target_transform is not None:
target = self.target_transform(target)

return img, target

def extra_repr(self) -> str:
lines = ["Target type: {target_type}", "Split: {split}"]
return "\n".join(lines).format(**self.__dict__)