Skip to content

Commit 78e0a76

Browse files
authored
Merge pull request #1206 from kohya-ss/dataset-cache
Add metadata caching for DreamBooth dataset
2 parents 5a2afb3 + c86e356 commit 78e0a76

File tree

6 files changed

+92
-28
lines changed

6 files changed

+92
-28
lines changed

docs/config_README-en.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,7 @@ Options related to the configuration of DreamBooth subsets.
177177
| `image_dir` | `'C:\hoge'` | - | - | o (required) |
178178
| `caption_extension` | `".txt"` | o | o | o |
179179
| `class_tokens` | `"sks girl"` | - | - | o |
180+
| `cache_info` | `false` | o | o | o |
180181
| `is_reg` | `false` | - | - | o |
181182

182183
Firstly, note that for `image_dir`, the path to the image files must be specified as being directly in the directory. Unlike the previous DreamBooth method, where images had to be placed in subdirectories, this is not compatible with that specification. Also, even if you name the folder something like "5_cat", the number of repeats of the image and the class name will not be reflected. If you want to set these individually, you will need to explicitly specify them using `num_repeats` and `class_tokens`.
@@ -187,6 +188,9 @@ Firstly, note that for `image_dir`, the path to the image files must be specifie
187188
* `class_tokens`
188189
* Sets the class tokens.
189190
* Only used during training when a corresponding caption file does not exist. The determination of whether or not to use it is made on a per-image basis. If `class_tokens` is not specified and a caption file is not found, an error will occur.
191+
* `cache_info`
192+
* Specifies whether to cache the image size and caption. If not specified, it is set to `false`. The cache is saved in `metadata_cache.json` in `image_dir`.
193+
* Caching speeds up the loading of the dataset after the first time. It is effective when dealing with thousands of images or more.
190194
* `is_reg`
191195
* Specifies whether the subset images are for normalization. If not specified, it is set to `false`, meaning that the images are not for normalization.
192196

docs/config_README-ja.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,7 @@ DreamBooth 方式のサブセットの設定に関わるオプションです。
173173
| `image_dir` | `‘C:\hoge’` | - | - | o(必須) |
174174
| `caption_extension` | `".txt"` | o | o | o |
175175
| `class_tokens` | `“sks girl”` | - | - | o |
176+
| `cache_info` | `false` | o | o | o |
176177
| `is_reg` | `false` | - | - | o |
177178

178179
まず注意点として、 `image_dir` には画像ファイルが直下に置かれているパスを指定する必要があります。従来の DreamBooth の手法ではサブディレクトリに画像を置く必要がありましたが、そちらとは仕様に互換性がありません。また、`5_cat` のようなフォルダ名にしても、画像の繰り返し回数とクラス名は反映されません。これらを個別に設定したい場合、`num_repeats``class_tokens` で明示的に指定する必要があることに注意してください。
@@ -183,6 +184,9 @@ DreamBooth 方式のサブセットの設定に関わるオプションです。
183184
* `class_tokens`
184185
* クラストークンを設定します。
185186
* 画像に対応する caption ファイルが存在しない場合にのみ学習時に利用されます。利用するかどうかの判定は画像ごとに行います。`class_tokens` を指定しなかった場合に caption ファイルも見つからなかった場合にはエラーになります。
187+
* `cache_info`
188+
* 画像サイズ、キャプションをキャッシュするかどうかを指定します。指定しなかった場合は `false` になります。キャッシュは `image_dir``metadata_cache.json` というファイル名で保存されます。
189+
* キャッシュを行うと、二回目以降のデータセット読み込みが高速化されます。数千枚以上の画像を扱う場合には有効です。
186190
* `is_reg`
187191
* サブセットの画像が正規化用かどうかを指定します。指定しなかった場合は `false` として、つまり正規化画像ではないとして扱います。
188192

library/config_util.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@ class DreamBoothSubsetParams(BaseSubsetParams):
8585
is_reg: bool = False
8686
class_tokens: Optional[str] = None
8787
caption_extension: str = ".caption"
88+
cache_info: bool = False
8889

8990

9091
@dataclass
@@ -96,6 +97,7 @@ class FineTuningSubsetParams(BaseSubsetParams):
9697
class ControlNetSubsetParams(BaseSubsetParams):
9798
conditioning_data_dir: str = None
9899
caption_extension: str = ".caption"
100+
cache_info: bool = False
99101

100102

101103
@dataclass
@@ -205,6 +207,7 @@ def __validate_and_convert_scalar_or_twodim(klass, value: Union[float, Sequence]
205207
DB_SUBSET_ASCENDABLE_SCHEMA = {
206208
"caption_extension": str,
207209
"class_tokens": str,
210+
"cache_info": bool,
208211
}
209212
DB_SUBSET_DISTINCT_SCHEMA = {
210213
Required("image_dir"): str,
@@ -217,6 +220,7 @@ def __validate_and_convert_scalar_or_twodim(klass, value: Union[float, Sequence]
217220
}
218221
CN_SUBSET_ASCENDABLE_SCHEMA = {
219222
"caption_extension": str,
223+
"cache_info": bool,
220224
}
221225
CN_SUBSET_DISTINCT_SCHEMA = {
222226
Required("image_dir"): str,

library/train_util.py

Lines changed: 77 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@
6363
from huggingface_hub import hf_hub_download
6464
import numpy as np
6565
from PIL import Image
66+
import imagesize
6667
import cv2
6768
import safetensors.torch
6869
from library.lpw_stable_diffusion import StableDiffusionLongPromptWeightingPipeline
@@ -410,6 +411,7 @@ def __init__(
410411
is_reg: bool,
411412
class_tokens: Optional[str],
412413
caption_extension: str,
414+
cache_info: bool,
413415
num_repeats,
414416
shuffle_caption,
415417
caption_separator: str,
@@ -458,6 +460,7 @@ def __init__(
458460
self.caption_extension = caption_extension
459461
if self.caption_extension and not self.caption_extension.startswith("."):
460462
self.caption_extension = "." + self.caption_extension
463+
self.cache_info = cache_info
461464

462465
def __eq__(self, other) -> bool:
463466
if not isinstance(other, DreamBoothSubset):
@@ -527,6 +530,7 @@ def __init__(
527530
image_dir: str,
528531
conditioning_data_dir: str,
529532
caption_extension: str,
533+
cache_info: bool,
530534
num_repeats,
531535
shuffle_caption,
532536
caption_separator,
@@ -574,6 +578,7 @@ def __init__(
574578
self.caption_extension = caption_extension
575579
if self.caption_extension and not self.caption_extension.startswith("."):
576580
self.caption_extension = "." + self.caption_extension
581+
self.cache_info = cache_info
577582

578583
def __eq__(self, other) -> bool:
579584
if not isinstance(other, ControlNetSubset):
@@ -1081,8 +1086,7 @@ def cache_text_encoder_outputs(
10811086
)
10821087

10831088
def get_image_size(self, image_path):
1084-
image = Image.open(image_path)
1085-
return image.size
1089+
return imagesize.get(image_path)
10861090

10871091
def load_image_with_face_info(self, subset: BaseSubset, image_path: str):
10881092
img = load_image(image_path)
@@ -1411,6 +1415,8 @@ def get_item_for_caching(self, bucket, bucket_batch_size, image_index):
14111415

14121416

14131417
class DreamBoothDataset(BaseDataset):
1418+
IMAGE_INFO_CACHE_FILE = "metadata_cache.json"
1419+
14141420
def __init__(
14151421
self,
14161422
subsets: Sequence[DreamBoothSubset],
@@ -1485,26 +1491,54 @@ def load_dreambooth_dir(subset: DreamBoothSubset):
14851491
logger.warning(f"not directory: {subset.image_dir}")
14861492
return [], []
14871493

1488-
img_paths = glob_images(subset.image_dir, "*")
1489-
logger.info(f"found directory {subset.image_dir} contains {len(img_paths)} image files")
1490-
1491-
# 画像ファイルごとにプロンプトを読み込み、もしあればそちらを使う
1492-
captions = []
1493-
missing_captions = []
1494-
for img_path in img_paths:
1495-
cap_for_img = read_caption(img_path, subset.caption_extension, subset.enable_wildcard)
1496-
if cap_for_img is None and subset.class_tokens is None:
1494+
info_cache_file = os.path.join(subset.image_dir, self.IMAGE_INFO_CACHE_FILE)
1495+
use_cached_info_for_subset = subset.cache_info
1496+
if use_cached_info_for_subset:
1497+
logger.info(
1498+
f"using cached image info for this subset / このサブセットで、キャッシュされた画像情報を使います: {info_cache_file}"
1499+
)
1500+
if not os.path.isfile(info_cache_file):
14971501
logger.warning(
1498-
f"neither caption file nor class tokens are found. use empty caption for {img_path} / キャプションファイルもclass tokenも見つかりませんでした。空のキャプションを使用します: {img_path}"
1502+
f"image info file not found. You can ignore this warning if this is the first time to use this subset"
1503+
+ " / キャッシュファイルが見つかりませんでした。初回実行時はこの警告を無視してください: {metadata_file}"
14991504
)
1500-
captions.append("")
1501-
missing_captions.append(img_path)
1502-
else:
1503-
if cap_for_img is None:
1504-
captions.append(subset.class_tokens)
1505+
use_cached_info_for_subset = False
1506+
1507+
if use_cached_info_for_subset:
1508+
# json: {`img_path`:{"caption": "caption...", "resolution": [width, height]}, ...}
1509+
with open(info_cache_file, "r", encoding="utf-8") as f:
1510+
metas = json.load(f)
1511+
img_paths = list(metas.keys())
1512+
sizes = [meta["resolution"] for meta in metas.values()]
1513+
1514+
# we may need to check image size and existence of image files, but it takes time, so user should check it before training
1515+
else:
1516+
img_paths = glob_images(subset.image_dir, "*")
1517+
sizes = [None] * len(img_paths)
1518+
1519+
logger.info(f"found directory {subset.image_dir} contains {len(img_paths)} image files")
1520+
1521+
if use_cached_info_for_subset:
1522+
captions = [meta["caption"] for meta in metas.values()]
1523+
missing_captions = [img_path for img_path, caption in zip(img_paths, captions) if caption is None or caption == ""]
1524+
else:
1525+
# 画像ファイルごとにプロンプトを読み込み、もしあればそちらを使う
1526+
captions = []
1527+
missing_captions = []
1528+
for img_path in img_paths:
1529+
cap_for_img = read_caption(img_path, subset.caption_extension, subset.enable_wildcard)
1530+
if cap_for_img is None and subset.class_tokens is None:
1531+
logger.warning(
1532+
f"neither caption file nor class tokens are found. use empty caption for {img_path} / キャプションファイルもclass tokenも見つかりませんでした。空のキャプションを使用します: {img_path}"
1533+
)
1534+
captions.append("")
15051535
missing_captions.append(img_path)
15061536
else:
1507-
captions.append(cap_for_img)
1537+
if cap_for_img is None:
1538+
captions.append(subset.class_tokens)
1539+
missing_captions.append(img_path)
1540+
else:
1541+
captions.append(cap_for_img)
15081542

15091543
self.set_tag_frequency(os.path.basename(subset.image_dir), captions) # タグ頻度を記録
15101544

@@ -1521,7 +1555,19 @@ def load_dreambooth_dir(subset: DreamBoothSubset):
15211555
logger.warning(missing_caption + f"... and {remaining_missing_captions} more")
15221556
break
15231557
logger.warning(missing_caption)
1524-
return img_paths, captions
1558+
1559+
if not use_cached_info_for_subset and subset.cache_info:
1560+
logger.info(f"cache image info for / 画像情報をキャッシュします : {info_cache_file}")
1561+
sizes = [self.get_image_size(img_path) for img_path in tqdm(img_paths, desc="get image size")]
1562+
matas = {}
1563+
for img_path, caption, size in zip(img_paths, captions, sizes):
1564+
matas[img_path] = {"caption": caption, "resolution": list(size)}
1565+
with open(info_cache_file, "w", encoding="utf-8") as f:
1566+
json.dump(matas, f, ensure_ascii=False, indent=2)
1567+
logger.info(f"cache image info done for / 画像情報を出力しました : {info_cache_file}")
1568+
1569+
# if sizes are not set, image size will be read in make_buckets
1570+
return img_paths, captions, sizes
15251571

15261572
logger.info("prepare images.")
15271573
num_train_images = 0
@@ -1540,7 +1586,7 @@ def load_dreambooth_dir(subset: DreamBoothSubset):
15401586
)
15411587
continue
15421588

1543-
img_paths, captions = load_dreambooth_dir(subset)
1589+
img_paths, captions, sizes = load_dreambooth_dir(subset)
15441590
if len(img_paths) < 1:
15451591
logger.warning(
15461592
f"ignore subset with image_dir='{subset.image_dir}': no images found / 画像が見つからないためサブセットを無視します"
@@ -1552,8 +1598,10 @@ def load_dreambooth_dir(subset: DreamBoothSubset):
15521598
else:
15531599
num_train_images += subset.num_repeats * len(img_paths)
15541600

1555-
for img_path, caption in zip(img_paths, captions):
1601+
for img_path, caption, size in zip(img_paths, captions, sizes):
15561602
info = ImageInfo(img_path, subset.num_repeats, caption, subset.is_reg, img_path)
1603+
if size is not None:
1604+
info.image_size = size
15571605
if subset.is_reg:
15581606
reg_infos.append((info, subset))
15591607
else:
@@ -1842,7 +1890,8 @@ def __init__(
18421890
subset.image_dir,
18431891
False,
18441892
None,
1845-
subset.caption_extension,
1893+
subset.caption_extension,
1894+
subset.cache_info,
18461895
subset.num_repeats,
18471896
subset.shuffle_caption,
18481897
subset.caption_separator,
@@ -3384,6 +3433,12 @@ def add_dataset_arguments(
33843433
parser.add_argument(
33853434
"--train_data_dir", type=str, default=None, help="directory for train images / 学習画像データのディレクトリ"
33863435
)
3436+
parser.add_argument(
3437+
"--cache_info",
3438+
action="store_true",
3439+
help="cache meta information (caption and image size) for faster dataset loading. only available for DreamBooth"
3440+
+ " / メタ情報(キャプションとサイズ)をキャッシュしてデータセット読み込みを高速化する。DreamBooth方式のみ有効",
3441+
)
33873442
parser.add_argument(
33883443
"--shuffle_caption", action="store_true", help="shuffle separated caption / 区切られたcaptionの各要素をshuffleする"
33893444
)

requirements.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@ easygui==0.98.3
1717
toml==0.10.2
1818
voluptuous==0.13.1
1919
huggingface-hub==0.20.1
20+
# for Image utils
21+
imagesize==1.4.1
2022
# for BLIP captioning
2123
# requests==2.28.2
2224
# timm==0.6.12

train_network.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,19 +14,14 @@
1414
import torch
1515
from library.device_utils import init_ipex, clean_memory_on_device
1616

17-
1817
init_ipex()
1918

20-
from torch.nn.parallel import DistributedDataParallel as DDP
21-
2219
from accelerate.utils import set_seed
2320
from diffusers import DDPMScheduler
2421
from library import deepspeed_utils, model_util
2522

2623
import library.train_util as train_util
27-
from library.train_util import (
28-
DreamBoothDataset,
29-
)
24+
from library.train_util import DreamBoothDataset
3025
import library.config_util as config_util
3126
from library.config_util import (
3227
ConfigSanitizer,

0 commit comments

Comments
 (0)