Skip to content

FER2013 dataset #5120

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Jan 5, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/datasets.rst
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ You can also create your own datasets using the provided :ref:`base classes <bas
EMNIST
FakeData
FashionMNIST
FER2013
Flickr8k
Flickr30k
FlyingChairs
Expand Down
34 changes: 34 additions & 0 deletions test/test_datasets.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import bz2
import contextlib
import csv
import io
import itertools
import json
Expand Down Expand Up @@ -2241,5 +2242,38 @@ def inject_fake_data(self, tmpdir: str, config):
return len(image_ids_in_config)


class FER2013TestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.FER2013
ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(split=("train", "test"))

FEATURE_TYPES = (PIL.Image.Image, (int, type(None)))

def inject_fake_data(self, tmpdir, config):
base_folder = os.path.join(tmpdir, "fer2013")
os.makedirs(base_folder)

num_samples = 5
with open(os.path.join(base_folder, f"{config['split']}.csv"), "w", newline="") as file:
writer = csv.DictWriter(
file,
fieldnames=("emotion", "pixels") if config["split"] == "train" else ("pixels",),
quoting=csv.QUOTE_NONNUMERIC,
quotechar='"',
)
writer.writeheader()
for _ in range(num_samples):
row = dict(
pixels=" ".join(
str(pixel) for pixel in datasets_utils.create_image_or_video_tensor((48, 48)).view(-1).tolist()
)
)
if config["split"] == "train":
row["emotion"] = str(int(torch.randint(0, 7, ())))

writer.writerow(row)

return num_samples


if __name__ == "__main__":
unittest.main()
2 changes: 2 additions & 0 deletions torchvision/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from .coco import CocoCaptions, CocoDetection
from .dtd import DTD
from .fakedata import FakeData
from .fer2013 import FER2013
from .flickr import Flickr8k, Flickr30k
from .folder import ImageFolder, DatasetFolder
from .food101 import Food101
Expand Down Expand Up @@ -81,4 +82,5 @@
"HD1K",
"Food101",
"DTD",
"FER2013",
)
75 changes: 75 additions & 0 deletions torchvision/datasets/fer2013.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
import csv
import pathlib
from typing import Any, Callable, Optional, Tuple

import torch
from PIL import Image

from .utils import verify_str_arg, check_integrity
from .vision import VisionDataset


class FER2013(VisionDataset):
"""`FER2013
<https://www.kaggle.com/c/challenges-in-representation-learning-facial-expression-recognition-challenge>`_ Dataset.

Args:
root (string): Root directory of dataset where directory
``root/fer2013`` exists.
split (string, optional): The dataset split, supports ``"train"`` (default), or ``"test"``.
transform (callable, optional): A function/transform that takes in an PIL image and returns a transformed
version. E.g, ``transforms.RandomCrop``
target_transform (callable, optional): A function/transform that takes in the target and transforms it.
"""

_RESOURCES = {
"train": ("train.csv", "3f0dfb3d3fd99c811a1299cb947e3131"),
"test": ("test.csv", "b02c2298636a634e8c2faabbf3ea9a23"),
}

def __init__(
self,
root: str,
split: str = "train",
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
) -> None:
self._split = verify_str_arg(split, "split", self._RESOURCES.keys())
super().__init__(root, transform=transform, target_transform=target_transform)

base_folder = pathlib.Path(self.root) / "fer2013"
file_name, md5 = self._RESOURCES[self._split]
data_file = base_folder / file_name
if not check_integrity(str(data_file), md5=md5):
raise RuntimeError(
f"{file_name} not found in {base_folder} or corrupted. "
f"You can download it from "
f"https://www.kaggle.com/c/challenges-in-representation-learning-facial-expression-recognition-challenge"
)

with open(data_file, "r", newline="") as file:
self._samples = [
(
torch.tensor([int(idx) for idx in row["pixels"].split()], dtype=torch.uint8).reshape(48, 48),
int(row["emotion"]) if "emotion" in row else None,
)
for row in csv.DictReader(file)
]

def __len__(self) -> int:
return len(self._samples)

def __getitem__(self, idx: int) -> Tuple[Any, Any]:
image_tensor, target = self._samples[idx]
image = Image.fromarray(image_tensor.numpy())

if self.transform is not None:
image = self.transform(image)

if self.target_transform is not None:
target = self.target_transform(target)

return image, target

def extra_repr(self) -> str:
return f"split={self._split}"
1 change: 1 addition & 0 deletions torchvision/prototype/datasets/_builtin/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from .cifar import Cifar10, Cifar100
from .coco import Coco
from .dtd import DTD
from .fer2013 import FER2013
from .imagenet import ImageNet
from .mnist import MNIST, FashionMNIST, KMNIST, EMNIST, QMNIST
from .sbd import SBD
Expand Down
80 changes: 80 additions & 0 deletions torchvision/prototype/datasets/_builtin/fer2013.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
import functools
import io
from typing import Any, Callable, Dict, List, Optional, Union, cast

import torch
from torchdata.datapipes.iter import IterDataPipe, Mapper, CSVDictParser
from torchvision.prototype.datasets.decoder import raw
from torchvision.prototype.datasets.utils import (
Dataset,
DatasetConfig,
DatasetInfo,
OnlineResource,
DatasetType,
KaggleDownloadResource,
)
from torchvision.prototype.datasets.utils._internal import (
hint_sharding,
hint_shuffling,
image_buffer_from_array,
)
from torchvision.prototype.features import Label, Image


class FER2013(Dataset):
def _make_info(self) -> DatasetInfo:
return DatasetInfo(
"fer2013",
type=DatasetType.RAW,
homepage="https://www.kaggle.com/c/challenges-in-representation-learning-facial-expression-recognition-challenge",
categories=("angry", "disgust", "fear", "happy", "sad", "surprise", "neutral"),
valid_options=dict(split=("train", "test")),
)

_CHECKSUMS = {
"train": "a2b7c9360cc0b38d21187e5eece01c2799fce5426cdeecf746889cc96cda2d10",
"test": "dec8dfe8021e30cd6704b85ec813042b4a5d99d81cb55e023291a94104f575c3",
}

def resources(self, config: DatasetConfig) -> List[OnlineResource]:
archive = KaggleDownloadResource(
cast(str, self.info.homepage),
file_name=f"{config.split}.csv.zip",
sha256=self._CHECKSUMS[config.split],
)
return [archive]

def _collate_and_decode_sample(
self,
data: Dict[str, Any],
*,
decoder: Optional[Callable[[io.IOBase], torch.Tensor]],
) -> Dict[str, Any]:
raw_image = torch.tensor([int(idx) for idx in data["pixels"].split()], dtype=torch.uint8).reshape(48, 48)
label_id = data.get("emotion")
label_idx = int(label_id) if label_id is not None else None

image: Union[Image, io.BytesIO]
if decoder is raw:
image = Image(raw_image)
else:
image_buffer = image_buffer_from_array(raw_image.numpy())
image = decoder(image_buffer) if decoder else image_buffer # type: ignore[assignment]

return dict(
image=image,
label=Label(label_idx, category=self.info.categories[label_idx]) if label_idx is not None else None,
)

def _make_datapipe(
self,
resource_dps: List[IterDataPipe],
*,
config: DatasetConfig,
decoder: Optional[Callable[[io.IOBase], torch.Tensor]],
) -> IterDataPipe[Dict[str, Any]]:
dp = resource_dps[0]
dp = CSVDictParser(dp)
dp = hint_sharding(dp)
dp = hint_shuffling(dp)
return Mapper(dp, functools.partial(self._collate_and_decode_sample, decoder=decoder))
2 changes: 1 addition & 1 deletion torchvision/prototype/datasets/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from . import _internal
from ._dataset import DatasetType, DatasetConfig, DatasetInfo, Dataset
from ._query import SampleQuery
from ._resource import OnlineResource, HttpResource, GDriveResource, ManualDownloadResource
from ._resource import OnlineResource, HttpResource, GDriveResource, ManualDownloadResource, KaggleDownloadResource
14 changes: 14 additions & 0 deletions torchvision/prototype/datasets/utils/_resource.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,3 +176,17 @@ def _download(self, root: pathlib.Path) -> NoReturn:
f"Please follow the instructions below and place it in {root}\n\n"
f"{self.instructions}"
)


class KaggleDownloadResource(ManualDownloadResource):
def __init__(self, challenge_url: str, *, file_name: str, **kwargs: Any) -> None:
instructions = "\n".join(
(
"1. Register and login at https://www.kaggle.com",
f"2. Navigate to {challenge_url}",
"3. Click 'Join Competition' and follow the instructions there",
"4. Navigate to the 'Data' tab",
f"5. Select {file_name} in the 'Data Explorer' and click the download button",
)
)
super().__init__(instructions, file_name=file_name, **kwargs)