Skip to content

Commit c632af8

Browse files
authored
Merge pull request #1715 from catboxanon/vpred-ztsnr-fixes
Update debiased estimation loss function to accommodate V-pred
2 parents 012e7e6 + 0e7c592 commit c632af8

9 files changed

+13
-10
lines changed

fine_tune.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -386,7 +386,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
386386
if args.scale_v_pred_loss_like_noise_pred:
387387
loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler)
388388
if args.debiased_estimation_loss:
389-
loss = apply_debiased_estimation(loss, timesteps, noise_scheduler)
389+
loss = apply_debiased_estimation(loss, timesteps, noise_scheduler, args.v_parameterization)
390390

391391
loss = loss.mean() # mean over batch dimension
392392
else:

library/custom_train_functions.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -96,10 +96,13 @@ def add_v_prediction_like_loss(loss, timesteps, noise_scheduler, v_pred_like_los
9696
return loss
9797

9898

99-
def apply_debiased_estimation(loss, timesteps, noise_scheduler):
99+
def apply_debiased_estimation(loss, timesteps, noise_scheduler, v_prediction=False):
100100
snr_t = torch.stack([noise_scheduler.all_snr[t] for t in timesteps]) # batch_size
101101
snr_t = torch.minimum(snr_t, torch.ones_like(snr_t) * 1000) # if timestep is 0, snr_t is inf, so limit it to 1000
102-
weight = 1 / torch.sqrt(snr_t)
102+
if v_prediction:
103+
weight = 1 / (snr_t + 1)
104+
else:
105+
weight = 1 / torch.sqrt(snr_t)
103106
loss = weight * loss
104107
return loss
105108

sdxl_train.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -730,7 +730,7 @@ def optimizer_hook(parameter: torch.Tensor):
730730
if args.v_pred_like_loss:
731731
loss = add_v_prediction_like_loss(loss, timesteps, noise_scheduler, args.v_pred_like_loss)
732732
if args.debiased_estimation_loss:
733-
loss = apply_debiased_estimation(loss, timesteps, noise_scheduler)
733+
loss = apply_debiased_estimation(loss, timesteps, noise_scheduler, args.v_parameterization)
734734

735735
loss = loss.mean() # mean over batch dimension
736736
else:

sdxl_train_control_net_lllite.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -479,7 +479,7 @@ def remove_model(old_ckpt_name):
479479
if args.v_pred_like_loss:
480480
loss = add_v_prediction_like_loss(loss, timesteps, noise_scheduler, args.v_pred_like_loss)
481481
if args.debiased_estimation_loss:
482-
loss = apply_debiased_estimation(loss, timesteps, noise_scheduler)
482+
loss = apply_debiased_estimation(loss, timesteps, noise_scheduler, args.v_parameterization)
483483

484484
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
485485

sdxl_train_control_net_lllite_old.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -439,7 +439,7 @@ def remove_model(old_ckpt_name):
439439
if args.v_pred_like_loss:
440440
loss = add_v_prediction_like_loss(loss, timesteps, noise_scheduler, args.v_pred_like_loss)
441441
if args.debiased_estimation_loss:
442-
loss = apply_debiased_estimation(loss, timesteps, noise_scheduler)
442+
loss = apply_debiased_estimation(loss, timesteps, noise_scheduler, args.v_parameterization)
443443

444444
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
445445

train_db.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -373,7 +373,7 @@ def train(args):
373373
if args.scale_v_pred_loss_like_noise_pred:
374374
loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler)
375375
if args.debiased_estimation_loss:
376-
loss = apply_debiased_estimation(loss, timesteps, noise_scheduler)
376+
loss = apply_debiased_estimation(loss, timesteps, noise_scheduler, args.v_parameterization)
377377

378378
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
379379

train_network.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -998,7 +998,7 @@ def remove_model(old_ckpt_name):
998998
if args.v_pred_like_loss:
999999
loss = add_v_prediction_like_loss(loss, timesteps, noise_scheduler, args.v_pred_like_loss)
10001000
if args.debiased_estimation_loss:
1001-
loss = apply_debiased_estimation(loss, timesteps, noise_scheduler)
1001+
loss = apply_debiased_estimation(loss, timesteps, noise_scheduler, args.v_parameterization)
10021002

10031003
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
10041004

train_textual_inversion.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -603,7 +603,7 @@ def remove_model(old_ckpt_name):
603603
if args.v_pred_like_loss:
604604
loss = add_v_prediction_like_loss(loss, timesteps, noise_scheduler, args.v_pred_like_loss)
605605
if args.debiased_estimation_loss:
606-
loss = apply_debiased_estimation(loss, timesteps, noise_scheduler)
606+
loss = apply_debiased_estimation(loss, timesteps, noise_scheduler, args.v_parameterization)
607607

608608
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
609609

train_textual_inversion_XTI.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -486,7 +486,7 @@ def remove_model(old_ckpt_name):
486486
if args.scale_v_pred_loss_like_noise_pred:
487487
loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler)
488488
if args.debiased_estimation_loss:
489-
loss = apply_debiased_estimation(loss, timesteps, noise_scheduler)
489+
loss = apply_debiased_estimation(loss, timesteps, noise_scheduler, args.v_parameterization)
490490

491491
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
492492

0 commit comments

Comments
 (0)