Skip to content

Commit ee3bbbc

Browse files
authored
Merge pull request kohya-ss#1207 from kohya-ss/masked-loss
Add masked loss
2 parents 69b3200 + f5eb30a commit ee3bbbc

11 files changed

+112
-27
lines changed

README.md

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,15 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser
133133

134134
## Change History
135135

136+
### Masked loss
137+
138+
`train_network.py`, `sdxl_train_network.py` and `sdxl_train.py` now support the masked loss. `--masked_loss` option is added.
139+
140+
NOTE: `train_network.py` and `sdxl_train.py` are not tested yet.
141+
142+
ControlNet dataset is used to specify the mask. The mask images should be the RGB images. The pixel value 255 in R channel is treated as the mask (the loss is calculated only for the pixels with the mask), and 0 is treated as the non-mask. See details for the dataset specification in the [LLLite documentation](./docs/train_lllite_README.md#preparing-the-dataset).
143+
144+
136145
### Working in progress
137146

138147
- Colab seems to stop with log output. Try specifying `--console_log_simple` option in the training script to disable rich logging.

docs/train_lllite_README-ja.md

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,13 @@ ComfyUIのカスタムノードを用意しています。: https://github.com/k
2121
## モデルの学習
2222

2323
### データセットの準備
24-
通常のdatasetに加え`conditioning_data_dir` で指定したディレクトリにconditioning imageを格納してください。conditioning imageは学習用画像と同じbasenameを持つ必要があります。また、conditioning imageは学習用画像と同じサイズに自動的にリサイズされます。conditioning imageにはキャプションファイルは不要です
24+
DreamBooth 方式の dataset で`conditioning_data_dir` で指定したディレクトリにconditioning imageを格納してください。
2525

26-
たとえば DreamBooth 方式でキャプションファイルを用いる場合の設定ファイルは以下のようになります。
26+
(finetuning 方式の dataset はサポートしていません。)
27+
28+
conditioning imageは学習用画像と同じbasenameを持つ必要があります。また、conditioning imageは学習用画像と同じサイズに自動的にリサイズされます。conditioning imageにはキャプションファイルは不要です。
29+
30+
たとえば、キャプションにフォルダ名ではなくキャプションファイルを用いる場合の設定ファイルは以下のようになります。
2731

2832
```toml
2933
[[datasets.subsets]]

docs/train_lllite_README.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,9 @@ Due to the limitations of the inference environment, only CrossAttention (attn1
2626

2727
### Preparing the dataset
2828

29-
In addition to the normal dataset, please store the conditioning image in the directory specified by `conditioning_data_dir`. The conditioning image must have the same basename as the training image. The conditioning image will be automatically resized to the same size as the training image. The conditioning image does not require a caption file.
29+
In addition to the normal DreamBooth method dataset, please store the conditioning image in the directory specified by `conditioning_data_dir`. The conditioning image must have the same basename as the training image. The conditioning image will be automatically resized to the same size as the training image. The conditioning image does not require a caption file.
30+
31+
(We do not support the finetuning method dataset.)
3032

3133
```toml
3234
[[datasets.subsets]]

library/config_util.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -253,9 +253,10 @@ def __validate_and_convert_scalar_or_twodim(klass, value: Union[float, Sequence]
253253
}
254254

255255
def __init__(self, support_dreambooth: bool, support_finetuning: bool, support_controlnet: bool, support_dropout: bool) -> None:
256-
assert (
257-
support_dreambooth or support_finetuning or support_controlnet
258-
), "Neither DreamBooth mode nor fine tuning mode specified. Please specify one mode or more. / DreamBooth モードか fine tuning モードのどちらも指定されていません。1つ以上指定してください。"
256+
assert support_dreambooth or support_finetuning or support_controlnet, (
257+
"Neither DreamBooth mode nor fine tuning mode nor controlnet mode specified. Please specify one mode or more."
258+
+ " / DreamBooth モードか fine tuning モードか controlnet モードのどれも指定されていません。1つ以上指定してください。"
259+
)
259260

260261
self.db_subset_schema = self.__merge_dict(
261262
self.SUBSET_ASCENDABLE_SCHEMA,
@@ -322,7 +323,10 @@ def validate_flex_dataset(dataset_config: dict):
322323

323324
self.dataset_schema = validate_flex_dataset
324325
elif support_dreambooth:
325-
self.dataset_schema = self.db_dataset_schema
326+
if support_controlnet:
327+
self.dataset_schema = self.cn_dataset_schema
328+
else:
329+
self.dataset_schema = self.db_dataset_schema
326330
elif support_finetuning:
327331
self.dataset_schema = self.ft_dataset_schema
328332
elif support_controlnet:

library/custom_train_functions.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,14 @@
33
import random
44
import re
55
from typing import List, Optional, Union
6-
from .utils import setup_logging
6+
from .utils import setup_logging
7+
78
setup_logging()
8-
import logging
9+
import logging
10+
911
logger = logging.getLogger(__name__)
1012

13+
1114
def prepare_scheduler_for_custom_training(noise_scheduler, device):
1215
if hasattr(noise_scheduler, "all_snr"):
1316
return
@@ -64,7 +67,7 @@ def apply_snr_weight(loss, timesteps, noise_scheduler, gamma, v_prediction=False
6467
snr = torch.stack([noise_scheduler.all_snr[t] for t in timesteps])
6568
min_snr_gamma = torch.minimum(snr, torch.full_like(snr, gamma))
6669
if v_prediction:
67-
snr_weight = torch.div(min_snr_gamma, snr+1).float().to(loss.device)
70+
snr_weight = torch.div(min_snr_gamma, snr + 1).float().to(loss.device)
6871
else:
6972
snr_weight = torch.div(min_snr_gamma, snr).float().to(loss.device)
7073
loss = loss * snr_weight
@@ -92,13 +95,15 @@ def add_v_prediction_like_loss(loss, timesteps, noise_scheduler, v_pred_like_los
9295
loss = loss + loss / scale * v_pred_like_loss
9396
return loss
9497

98+
9599
def apply_debiased_estimation(loss, timesteps, noise_scheduler):
96100
snr_t = torch.stack([noise_scheduler.all_snr[t] for t in timesteps]) # batch_size
97101
snr_t = torch.minimum(snr_t, torch.ones_like(snr_t) * 1000) # if timestep is 0, snr_t is inf, so limit it to 1000
98-
weight = 1/torch.sqrt(snr_t)
102+
weight = 1 / torch.sqrt(snr_t)
99103
loss = weight * loss
100104
return loss
101105

106+
102107
# TODO train_utilと分散しているのでどちらかに寄せる
103108

104109

@@ -474,6 +479,17 @@ def apply_noise_offset(latents, noise, noise_offset, adaptive_noise_scale):
474479
return noise
475480

476481

482+
def apply_masked_loss(loss, batch):
483+
# mask image is -1 to 1. we need to convert it to 0 to 1
484+
mask_image = batch["conditioning_images"].to(dtype=loss.dtype)[:, 0].unsqueeze(1) # use R channel
485+
486+
# resize to the same size as the loss
487+
mask_image = torch.nn.functional.interpolate(mask_image, size=loss.shape[2:], mode="area")
488+
mask_image = mask_image / 2 + 0.5
489+
loss = loss * mask_image
490+
return loss
491+
492+
477493
"""
478494
##########################################
479495
# Perlin Noise

library/train_util.py

Lines changed: 35 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1835,6 +1835,9 @@ def __init__(
18351835

18361836
db_subsets = []
18371837
for subset in subsets:
1838+
assert (
1839+
not subset.random_crop
1840+
), "random_crop is not supported in ControlNetDataset / random_cropはControlNetDatasetではサポートされていません"
18381841
db_subset = DreamBoothSubset(
18391842
subset.image_dir,
18401843
False,
@@ -1885,7 +1888,7 @@ def __init__(
18851888

18861889
# assert all conditioning data exists
18871890
missing_imgs = []
1888-
cond_imgs_with_img = set()
1891+
cond_imgs_with_pair = set()
18891892
for image_key, info in self.dreambooth_dataset_delegate.image_data.items():
18901893
db_subset = self.dreambooth_dataset_delegate.image_to_subset[image_key]
18911894
subset = None
@@ -1899,23 +1902,29 @@ def __init__(
18991902
logger.warning(f"not directory: {subset.conditioning_data_dir}")
19001903
continue
19011904

1902-
img_basename = os.path.basename(info.absolute_path)
1903-
ctrl_img_path = os.path.join(subset.conditioning_data_dir, img_basename)
1904-
if not os.path.exists(ctrl_img_path):
1905+
img_basename = os.path.splitext(os.path.basename(info.absolute_path))[0]
1906+
ctrl_img_path = glob_images(subset.conditioning_data_dir, img_basename)
1907+
if len(ctrl_img_path) < 1:
19051908
missing_imgs.append(img_basename)
1909+
continue
1910+
ctrl_img_path = ctrl_img_path[0]
1911+
ctrl_img_path = os.path.abspath(ctrl_img_path) # normalize path
19061912

19071913
info.cond_img_path = ctrl_img_path
1908-
cond_imgs_with_img.add(ctrl_img_path)
1914+
cond_imgs_with_pair.add(os.path.splitext(ctrl_img_path)[0]) # remove extension because Windows is case insensitive
19091915

19101916
extra_imgs = []
19111917
for subset in subsets:
19121918
conditioning_img_paths = glob_images(subset.conditioning_data_dir, "*")
1913-
extra_imgs.extend(
1914-
[cond_img_path for cond_img_path in conditioning_img_paths if cond_img_path not in cond_imgs_with_img]
1915-
)
1919+
conditioning_img_paths = [os.path.abspath(p) for p in conditioning_img_paths] # normalize path
1920+
extra_imgs.extend([p for p in conditioning_img_paths if os.path.splitext(p)[0] not in cond_imgs_with_pair])
19161921

1917-
assert len(missing_imgs) == 0, f"missing conditioning data for {len(missing_imgs)} images: {missing_imgs}"
1918-
assert len(extra_imgs) == 0, f"extra conditioning data for {len(extra_imgs)} images: {extra_imgs}"
1922+
assert (
1923+
len(missing_imgs) == 0
1924+
), f"missing conditioning data for {len(missing_imgs)} images / 制御用画像が見つかりませんでした: {missing_imgs}"
1925+
assert (
1926+
len(extra_imgs) == 0
1927+
), f"extra conditioning data for {len(extra_imgs)} images / 余分な制御用画像があります: {extra_imgs}"
19191928

19201929
self.conditioning_image_transforms = IMAGE_TRANSFORMS
19211930

@@ -3049,6 +3058,7 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
30493058
"--full_bf16", action="store_true", help="bf16 training including gradients / 勾配も含めてbf16で学習する"
30503059
) # TODO move to SDXL training, because it is not supported by SD1/2
30513060
parser.add_argument("--fp8_base", action="store_true", help="use fp8 for base model / base modelにfp8を使う")
3061+
30523062
parser.add_argument(
30533063
"--ddp_timeout",
30543064
type=int,
@@ -3111,6 +3121,7 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
31113121
default=None,
31123122
help="specify WandB API key to log in before starting training (optional). / WandB APIキーを指定して学習開始前にログインする(オプション)",
31133123
)
3124+
31143125
parser.add_argument(
31153126
"--noise_offset",
31163127
type=float,
@@ -3284,6 +3295,20 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
32843295
)
32853296

32863297

3298+
def add_masked_loss_arguments(parser: argparse.ArgumentParser):
3299+
parser.add_argument(
3300+
"--conditioning_data_dir",
3301+
type=str,
3302+
default=None,
3303+
help="conditioning data directory / 条件付けデータのディレクトリ",
3304+
)
3305+
parser.add_argument(
3306+
"--masked_loss",
3307+
action="store_true",
3308+
help="apply mask for calculating loss. conditioning_data_dir is required for dataset. / 損失計算時にマスクを適用する。datasetにはconditioning_data_dirが必要",
3309+
)
3310+
3311+
32873312
def verify_training_args(args: argparse.Namespace):
32883313
r"""
32893314
Verify training arguments. Also reflect highvram option to global variable

sdxl_train.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import torch
1313
from library.device_utils import init_ipex, clean_memory_on_device
1414

15+
1516
init_ipex()
1617

1718
from accelerate.utils import set_seed
@@ -40,6 +41,7 @@
4041
scale_v_prediction_loss_like_noise_prediction,
4142
add_v_prediction_like_loss,
4243
apply_debiased_estimation,
44+
apply_masked_loss,
4345
)
4446
from library.sdxl_original_unet import SdxlUNet2DConditionModel
4547

@@ -126,7 +128,7 @@ def train(args):
126128

127129
# データセットを準備する
128130
if args.dataset_class is None:
129-
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False, True))
131+
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, args.masked_loss, True))
130132
if args.dataset_config is not None:
131133
logger.info(f"Load dataset config from {args.dataset_config}")
132134
user_config = config_util.load_user_config(args.dataset_config)
@@ -595,9 +597,12 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
595597
or args.scale_v_pred_loss_like_noise_pred
596598
or args.v_pred_like_loss
597599
or args.debiased_estimation_loss
600+
or args.masked_loss
598601
):
599602
# do not mean over batch dimension for snr weight or scale v-pred loss
600603
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
604+
if args.masked_loss:
605+
loss = apply_masked_loss(loss, batch)
601606
loss = loss.mean([1, 2, 3])
602607

603608
if args.min_snr_gamma:
@@ -763,6 +768,7 @@ def setup_parser() -> argparse.ArgumentParser:
763768
train_util.add_sd_models_arguments(parser)
764769
train_util.add_dataset_arguments(parser, True, True, True)
765770
train_util.add_training_arguments(parser, False)
771+
train_util.add_masked_loss_arguments(parser)
766772
deepspeed_utils.add_deepspeed_arguments(parser)
767773
train_util.add_sd_saving_arguments(parser)
768774
train_util.add_optimizer_arguments(parser)
@@ -799,7 +805,6 @@ def setup_parser() -> argparse.ArgumentParser:
799805
help=f"learning rates for each block of U-Net, comma-separated, {UNET_NUM_BLOCKS_FOR_BLOCK_LR} values / "
800806
+ f"U-Netの各ブロックの学習率、カンマ区切り、{UNET_NUM_BLOCKS_FOR_BLOCK_LR}個の値",
801807
)
802-
803808
return parser
804809

805810

train_db.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from library import deepspeed_utils
1515
from library.device_utils import init_ipex, clean_memory_on_device
1616

17+
1718
init_ipex()
1819

1920
from accelerate.utils import set_seed
@@ -34,6 +35,7 @@
3435
apply_noise_offset,
3536
scale_v_prediction_loss_like_noise_prediction,
3637
apply_debiased_estimation,
38+
apply_masked_loss,
3739
)
3840
from library.utils import setup_logging, add_logging_arguments
3941

@@ -60,7 +62,7 @@ def train(args):
6062

6163
# データセットを準備する
6264
if args.dataset_class is None:
63-
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, False, False, True))
65+
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, False, args.masked_loss, True))
6466
if args.dataset_config is not None:
6567
logger.info(f"Load dataset config from {args.dataset_config}")
6668
user_config = config_util.load_user_config(args.dataset_config)
@@ -357,6 +359,8 @@ def train(args):
357359
target = noise
358360

359361
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
362+
if args.masked_loss:
363+
loss = apply_masked_loss(loss, batch)
360364
loss = loss.mean([1, 2, 3])
361365

362366
loss_weights = batch["loss_weights"] # 各sampleごとのweight
@@ -482,6 +486,7 @@ def setup_parser() -> argparse.ArgumentParser:
482486
train_util.add_sd_models_arguments(parser)
483487
train_util.add_dataset_arguments(parser, True, False, True)
484488
train_util.add_training_arguments(parser, True)
489+
train_util.add_masked_loss_arguments(parser)
485490
deepspeed_utils.add_deepspeed_arguments(parser)
486491
train_util.add_sd_saving_arguments(parser)
487492
train_util.add_optimizer_arguments(parser)

train_network.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import torch
1515
from library.device_utils import init_ipex, clean_memory_on_device
1616

17+
1718
init_ipex()
1819

1920
from torch.nn.parallel import DistributedDataParallel as DDP
@@ -40,6 +41,7 @@
4041
scale_v_prediction_loss_like_noise_prediction,
4142
add_v_prediction_like_loss,
4243
apply_debiased_estimation,
44+
apply_masked_loss,
4345
)
4446
from library.utils import setup_logging, add_logging_arguments
4547

@@ -159,7 +161,7 @@ def train(self, args):
159161

160162
# データセットを準備する
161163
if args.dataset_class is None:
162-
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False, True))
164+
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, args.masked_loss, True))
163165
if use_user_config:
164166
logger.info(f"Loading dataset config from {args.dataset_config}")
165167
user_config = config_util.load_user_config(args.dataset_config)
@@ -852,6 +854,8 @@ def remove_model(old_ckpt_name):
852854
target = noise
853855

854856
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
857+
if args.masked_loss:
858+
loss = apply_masked_loss(loss, batch)
855859
loss = loss.mean([1, 2, 3])
856860

857861
loss_weights = batch["loss_weights"] # 各sampleごとのweight
@@ -975,6 +979,7 @@ def setup_parser() -> argparse.ArgumentParser:
975979
train_util.add_sd_models_arguments(parser)
976980
train_util.add_dataset_arguments(parser, True, True, True)
977981
train_util.add_training_arguments(parser, True)
982+
train_util.add_masked_loss_arguments(parser)
978983
deepspeed_utils.add_deepspeed_arguments(parser)
979984
train_util.add_optimizer_arguments(parser)
980985
config_util.add_config_arguments(parser)

0 commit comments

Comments
 (0)