Skip to content

Commit db67529

Browse files
authored
画像のアルファチャンネルをlossのマスクとして使用するオプションを追加 (kohya-ss#1223)
* Add alpha_mask parameter and apply masked loss * Fix type hint in trim_and_resize_if_required function * Refactor code to use keyword arguments in train_util.py * Fix alpha mask flipping logic * Fix alpha mask initialization * Fix alpha_mask transformation * Cache alpha_mask * Update alpha_masks to be on CPU * Set flipped_alpha_masks to Null if option disabled * Check if alpha_mask is None * Set alpha_mask to None if option disabled * Add description of alpha_mask option to docs
1 parent febc5c5 commit db67529

10 files changed

+105
-129
lines changed

docs/train_network_README-ja.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,8 @@ accelerate launch --num_cpu_threads_per_process 1 train_network.py
102102
* Text Encoderに関連するLoRAモジュールに、通常の学習率(--learning_rateオプションで指定)とは異なる学習率を使う時に指定します。Text Encoderのほうを若干低めの学習率(5e-5など)にしたほうが良い、という話もあるようです。
103103
* `--network_args`
104104
* 複数の引数を指定できます。後述します。
105+
* `--alpha_mask`
106+
* 画像のアルファ値をマスクとして使用します。透過画像を学習する際に使用します。[PR #1223](https://github.com/kohya-ss/sd-scripts/pull/1223)
105107

106108
`--network_train_unet_only``--network_train_text_encoder_only` の両方とも未指定時(デフォルト)はText EncoderとU-Netの両方のLoRAモジュールを有効にします。
107109

docs/train_network_README-zh.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,8 @@ LoRA的模型将会被保存在通过`--output_dir`选项指定的文件夹中
101101
* 当在Text Encoder相关的LoRA模块中使用与常规学习率(由`--learning_rate`选项指定)不同的学习率时,应指定此选项。可能最好将Text Encoder的学习率稍微降低(例如5e-5)。
102102
* `--network_args`
103103
* 可以指定多个参数。将在下面详细说明。
104+
* `--alpha_mask`
105+
* 使用图像的 Alpha 值作为遮罩。这在学习透明图像时使用。[PR #1223](https://github.com/kohya-ss/sd-scripts/pull/1223)
104106

105107
当未指定`--network_train_unet_only``--network_train_text_encoder_only`时(默认情况),将启用Text Encoder和U-Net的两个LoRA模块。
106108

library/config_util.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@ class BaseSubsetParams:
7878
caption_tag_dropout_rate: float = 0.0
7979
token_warmup_min: int = 1
8080
token_warmup_step: float = 0
81+
alpha_mask: bool = False
8182

8283

8384
@dataclass
@@ -538,6 +539,7 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu
538539
random_crop: {subset.random_crop}
539540
token_warmup_min: {subset.token_warmup_min},
540541
token_warmup_step: {subset.token_warmup_step},
542+
alpha_mask: {subset.alpha_mask},
541543
"""
542544
),
543545
" ",

library/custom_train_functions.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -479,9 +479,10 @@ def apply_noise_offset(latents, noise, noise_offset, adaptive_noise_scale):
479479
return noise
480480

481481

482-
def apply_masked_loss(loss, batch):
482+
def apply_masked_loss(loss, mask_image):
483483
# mask image is -1 to 1. we need to convert it to 0 to 1
484-
mask_image = batch["conditioning_images"].to(dtype=loss.dtype)[:, 0].unsqueeze(1) # use R channel
484+
# mask_image = batch["conditioning_images"].to(dtype=loss.dtype)[:, 0].unsqueeze(1) # use R channel
485+
mask_image = mask_image.to(dtype=loss.dtype)
485486

486487
# resize to the same size as the loss
487488
mask_image = torch.nn.functional.interpolate(mask_image, size=loss.shape[2:], mode="area")

0 commit comments

Comments
 (0)