Skip to content

Support training with a local image folder #152

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Aug 3, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 67 additions & 2 deletions examples/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ The command to train a DDPM UNet model on the Oxford Flowers dataset:

```bash
accelerate launch train_unconditional.py \
--dataset="huggan/flowers-102-categories" \
--dataset_name="huggan/flowers-102-categories" \
--resolution=64 \
--output_dir="ddpm-ema-flowers-64" \
--train_batch_size=16 \
Expand All @@ -46,7 +46,7 @@ The command to train a DDPM UNet model on the Pokemon dataset:

```bash
accelerate launch train_unconditional.py \
--dataset="huggan/pokemon" \
--dataset_name="huggan/pokemon" \
--resolution=64 \
--output_dir="ddpm-ema-pokemon-64" \
--train_batch_size=16 \
Expand All @@ -62,3 +62,68 @@ An example trained model: https://huggingface.co/anton-l/ddpm-ema-pokemon-64
A full training run takes 2 hours on 4xV100 GPUs.

<img src="https://user-images.githubusercontent.com/26864830/180248200-928953b4-db38-48db-b0c6-8b740fe6786f.png" width="700" />


### Using your own data

To use your own dataset, there are 2 ways:
- you can either provide your own folder as `--train_data_dir`
- or you can upload your dataset to the hub (possibly as a private repo, if you prefer so), and simply pass the `--dataset_name` argument.

Below, we explain both in more detail.

#### Provide the dataset as a folder

If you provide your own folders with images, the script expects the following directory structure:

```bash
data_dir/xxx.png
data_dir/xxy.png
data_dir/[...]/xxz.png
```

In other words, the script will take care of gathering all images inside the folder. You can then run the script like this:

```bash
accelerate launch train_unconditional.py \
--train_data_dir <path-to-train-directory> \
<other-arguments>
```

Internally, the script will use the [`ImageFolder`](https://huggingface.co/docs/datasets/v2.0.0/en/image_process#imagefolder) feature which will automatically turn the folders into 🤗 Dataset objects.

#### Upload your data to the hub, as a (possibly private) repo

It's very easy (and convenient) to upload your image dataset to the hub using the [`ImageFolder`](https://huggingface.co/docs/datasets/v2.0.0/en/image_process#imagefolder) feature available in 🤗 Datasets. Simply do the following:

```python
from datasets import load_dataset

# example 1: local folder
dataset = load_dataset("imagefolder", data_dir="path_to_your_folder")

# example 2: local files (suppoted formats are tar, gzip, zip, xz, rar, zstd)
dataset = load_dataset("imagefolder", data_files="path_to_zip_file")

# example 3: remote files (supported formats are tar, gzip, zip, xz, rar, zstd)
dataset = load_dataset("imagefolder", data_files="https://download.microsoft.com/download/3/E/1/3E1C3F21-ECDB-4869-8368-6DEBA77B919F/kagglecatsanddogs_3367a.zip")

# example 4: providing several splits
dataset = load_dataset("imagefolder", data_files={"train": ["path/to/file1", "path/to/file2"], "test": ["path/to/file3", "path/to/file4"]})
```

`ImageFolder` will create an `image` column containing the PIL-encoded images.

Next, push it to the hub!

```python
# assuming you have ran the huggingface-cli login command in a terminal
dataset.push_to_hub("name_of_your_dataset")

# if you want to push to a private repo, simply pass private=True:
dataset.push_to_hub("name_of_your_dataset", private=True)
```

and that's it! You can now train your model by simply setting the `--dataset_name` argument to the name of your dataset on the hub.

More on this can also be found in [this blog post](https://huggingface.co/blog/image-search-datasets).
23 changes: 20 additions & 3 deletions examples/train_unconditional.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,17 @@ def main(args):
Normalize([0.5], [0.5]),
]
)
dataset = load_dataset(args.dataset, split="train")

if args.dataset_name is not None:
dataset = load_dataset(
args.dataset_name,
args.dataset_config_name,
cache_dir=args.cache_dir,
use_auth_token=True if args.use_auth_token else None,
split="train",
)
else:
dataset = load_dataset("imagefolder", data_dir=args.train_data_dir, cache_dir=args.cache_dir, split="train")

def transforms(examples):
images = [augmentations(image.convert("RGB")) for image in examples["image"]]
Expand Down Expand Up @@ -179,9 +189,12 @@ def transforms(examples):
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Simple example of a training script.")
parser.add_argument("--local_rank", type=int, default=-1)
parser.add_argument("--dataset", type=str, default="huggan/flowers-102-categories")
parser.add_argument("--output_dir", type=str, default="ddpm-flowers-64")
parser.add_argument("--dataset_name", type=str, default=None)
parser.add_argument("--dataset_config_name", type=str, default=None)
parser.add_argument("--train_data_dir", type=str, default=None, help="A folder containing the training data.")
parser.add_argument("--output_dir", type=str, default="ddpm-model-64")
parser.add_argument("--overwrite_output_dir", action="store_true")
parser.add_argument("--cache_dir", type=str, default=None)
parser.add_argument("--resolution", type=int, default=64)
parser.add_argument("--train_batch_size", type=int, default=16)
parser.add_argument("--eval_batch_size", type=int, default=16)
Expand All @@ -201,6 +214,7 @@ def transforms(examples):
parser.add_argument("--ema_power", type=float, default=3 / 4)
parser.add_argument("--ema_max_decay", type=float, default=0.9999)
parser.add_argument("--push_to_hub", action="store_true")
parser.add_argument("--use_auth_token", action="store_true")
parser.add_argument("--hub_token", type=str, default=None)
parser.add_argument("--hub_model_id", type=str, default=None)
parser.add_argument("--hub_private_repo", action="store_true")
Expand All @@ -222,4 +236,7 @@ def transforms(examples):
if env_local_rank != -1 and env_local_rank != args.local_rank:
args.local_rank = env_local_rank

if args.dataset_name is None and args.train_data_dir is None:
raise ValueError("You must specify either a dataset name from the hub or a train data directory.")

main(args)