Skip to content

Commit 20f288c

Browse files
kabachuhakohya-ss
andauthored
Add option to use Scheduled Huber Loss in all training pipelines to improve resilience to data corruption (kohya-ss#1228)
* add huber loss and huber_c compute to train_util * add reduction modes * add huber_c retrieval from timestep getter * move get timesteps and huber to own function * add conditional loss to all training scripts * add cond loss to train network * add (scheduled) huber_loss to args * fixup twice timesteps getting * PHL-schedule should depend on noise scheduler's num timesteps * *2 multiplier to huber loss cause of 1/2 a^2 conv. The Taylor expansion of sqrt near zero gives 1/2 a^2, which differs from a^2 of the standard MSE loss. This change scales them better against one another * add option for smooth l1 (huber / delta) * unify huber scheduling * add snr huber scheduler --------- Co-authored-by: Kohya S <[email protected]>
1 parent cbe0f50 commit 20f288c

10 files changed

+96
-30
lines changed

fine_tune.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -354,7 +354,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
354354

355355
# Sample noise, sample a random timestep for each image, and add noise to the latents,
356356
# with noise offset and/or multires noise if specified
357-
noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents)
357+
noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents)
358358

359359
# Predict the noise residual
360360
with accelerator.autocast():
@@ -368,7 +368,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
368368

369369
if args.min_snr_gamma or args.scale_v_pred_loss_like_noise_pred or args.debiased_estimation_loss:
370370
# do not mean over batch dimension for snr weight or scale v-pred loss
371-
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
371+
loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c)
372372
loss = loss.mean([1, 2, 3])
373373

374374
if args.min_snr_gamma:
@@ -380,7 +380,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
380380

381381
loss = loss.mean() # mean over batch dimension
382382
else:
383-
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="mean")
383+
loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="mean", loss_type=args.loss_type, huber_c=huber_c)
384384

385385
accelerator.backward(loss)
386386
if accelerator.sync_gradients and args.max_grad_norm != 0.0:

library/train_util.py

Lines changed: 75 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3236,6 +3236,26 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
32363236
default=None,
32373237
help="set maximum time step for U-Net training (1~1000, default is 1000) / U-Net学習時のtime stepの最大値を設定する(1~1000で指定、省略時はデフォルト値(1000))",
32383238
)
3239+
parser.add_argument(
3240+
"--loss_type",
3241+
type=str,
3242+
default="l2",
3243+
choices=["l2", "huber", "smooth_l1"],
3244+
help="The type of loss to use and whether it's scheduled based on the timestep"
3245+
)
3246+
parser.add_argument(
3247+
"--huber_schedule",
3248+
type=str,
3249+
default="exponential",
3250+
choices=["constant", "exponential", "snr"],
3251+
help="The type of loss to use and whether it's scheduled based on the timestep"
3252+
)
3253+
parser.add_argument(
3254+
"--huber_c",
3255+
type=float,
3256+
default=0.1,
3257+
help="The huber loss parameter. Only used if one of the huber loss modes (huber or smooth l1) is selected with loss_type.",
3258+
)
32393259

32403260
parser.add_argument(
32413261
"--lowram",
@@ -4842,6 +4862,38 @@ def save_sd_model_on_train_end_common(
48424862
if args.huggingface_repo_id is not None:
48434863
huggingface_util.upload(args, out_dir, "/" + model_name, force_sync_upload=True)
48444864

4865+
def get_timesteps_and_huber_c(args, min_timestep, max_timestep, noise_scheduler, b_size, device):
4866+
4867+
#TODO: if a huber loss is selected, it will use constant timesteps for each batch
4868+
# as. In the future there may be a smarter way
4869+
4870+
if args.loss_type == 'huber' or args.loss_type == 'smooth_l1':
4871+
timesteps = torch.randint(
4872+
min_timestep, max_timestep, (1,), device='cpu'
4873+
)
4874+
timestep = timesteps.item()
4875+
4876+
if args.huber_schedule == "exponential":
4877+
alpha = - math.log(args.huber_c) / noise_scheduler.config.num_train_timesteps
4878+
huber_c = math.exp(-alpha * timestep)
4879+
elif args.huber_schedule == "snr":
4880+
alphas_cumprod = noise_scheduler.alphas_cumprod[timestep]
4881+
sigmas = ((1.0 - alphas_cumprod) / alphas_cumprod) ** 0.5
4882+
huber_c = (1 - args.huber_c) / (1 + sigmas)**2 + args.huber_c
4883+
elif args.huber_schedule == "constant":
4884+
huber_c = args.huber_c
4885+
else:
4886+
raise NotImplementedError(f'Unknown Huber loss schedule {args.huber_schedule}!')
4887+
4888+
timesteps = timesteps.repeat(b_size).to(device)
4889+
elif args.loss_type == 'l2':
4890+
timesteps = torch.randint(min_timestep, max_timestep, (b_size,), device=device)
4891+
huber_c = 1 # may be anything, as it's not used
4892+
else:
4893+
raise NotImplementedError(f'Unknown loss type {args.loss_type}')
4894+
timesteps = timesteps.long()
4895+
4896+
return timesteps, huber_c
48454897

48464898
def get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents):
48474899
# Sample noise that we'll add to the latents
@@ -4862,8 +4914,7 @@ def get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents):
48624914
min_timestep = 0 if args.min_timestep is None else args.min_timestep
48634915
max_timestep = noise_scheduler.config.num_train_timesteps if args.max_timestep is None else args.max_timestep
48644916

4865-
timesteps = torch.randint(min_timestep, max_timestep, (b_size,), device=latents.device)
4866-
timesteps = timesteps.long()
4917+
timesteps, huber_c = get_timesteps_and_huber_c(args, min_timestep, max_timestep, noise_scheduler, b_size, latents.device)
48674918

48684919
# Add noise to the latents according to the noise magnitude at each timestep
48694920
# (this is the forward diffusion process)
@@ -4876,8 +4927,28 @@ def get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents):
48764927
else:
48774928
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
48784929

4879-
return noise, noisy_latents, timesteps
4880-
4930+
return noise, noisy_latents, timesteps, huber_c
4931+
4932+
# NOTE: if you're using the scheduled version, huber_c has to depend on the timesteps already
4933+
def conditional_loss(model_pred:torch.Tensor, target:torch.Tensor, reduction:str="mean", loss_type:str="l2", huber_c:float=0.1):
4934+
4935+
if loss_type == 'l2':
4936+
loss = torch.nn.functional.mse_loss(model_pred, target, reduction=reduction)
4937+
elif loss_type == 'huber':
4938+
loss = 2 * huber_c * (torch.sqrt((model_pred - target) ** 2 + huber_c**2) - huber_c)
4939+
if reduction == "mean":
4940+
loss = torch.mean(loss)
4941+
elif reduction == "sum":
4942+
loss = torch.sum(loss)
4943+
elif loss_type == 'smooth_l1':
4944+
loss = 2 * (torch.sqrt((model_pred - target) ** 2 + huber_c**2) - huber_c)
4945+
if reduction == "mean":
4946+
loss = torch.mean(loss)
4947+
elif reduction == "sum":
4948+
loss = torch.sum(loss)
4949+
else:
4950+
raise NotImplementedError(f'Unsupported Loss Type {loss_type}')
4951+
return loss
48814952

48824953
def append_lr_to_logs(logs, lr_scheduler, optimizer_type, including_unet=True):
48834954
names = []

sdxl_train.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -582,7 +582,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
582582

583583
# Sample noise, sample a random timestep for each image, and add noise to the latents,
584584
# with noise offset and/or multires noise if specified
585-
noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents)
585+
noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents)
586586

587587
noisy_latents = noisy_latents.to(weight_dtype) # TODO check why noisy_latents is not weight_dtype
588588

@@ -600,7 +600,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
600600
or args.masked_loss
601601
):
602602
# do not mean over batch dimension for snr weight or scale v-pred loss
603-
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
603+
loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c)
604604
if args.masked_loss:
605605
loss = apply_masked_loss(loss, batch)
606606
loss = loss.mean([1, 2, 3])
@@ -616,7 +616,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
616616

617617
loss = loss.mean() # mean over batch dimension
618618
else:
619-
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="mean")
619+
loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="mean", loss_type=args.loss_type, huber_c=huber_c)
620620

621621
accelerator.backward(loss)
622622
if accelerator.sync_gradients and args.max_grad_norm != 0.0:

sdxl_train_control_net_lllite.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -439,7 +439,7 @@ def remove_model(old_ckpt_name):
439439

440440
# Sample noise, sample a random timestep for each image, and add noise to the latents,
441441
# with noise offset and/or multires noise if specified
442-
noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents)
442+
noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents)
443443

444444
noisy_latents = noisy_latents.to(weight_dtype) # TODO check why noisy_latents is not weight_dtype
445445

@@ -458,7 +458,7 @@ def remove_model(old_ckpt_name):
458458
else:
459459
target = noise
460460

461-
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
461+
loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c)
462462
loss = loss.mean([1, 2, 3])
463463

464464
loss_weights = batch["loss_weights"] # 各sampleごとのweight

sdxl_train_control_net_lllite_old.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -406,7 +406,7 @@ def remove_model(old_ckpt_name):
406406

407407
# Sample noise, sample a random timestep for each image, and add noise to the latents,
408408
# with noise offset and/or multires noise if specified
409-
noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents)
409+
noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents)
410410

411411
noisy_latents = noisy_latents.to(weight_dtype) # TODO check why noisy_latents is not weight_dtype
412412

@@ -426,7 +426,7 @@ def remove_model(old_ckpt_name):
426426
else:
427427
target = noise
428428

429-
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
429+
loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c)
430430
loss = loss.mean([1, 2, 3])
431431

432432
loss_weights = batch["loss_weights"] # 各sampleごとのweight

train_controlnet.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -420,13 +420,8 @@ def remove_model(old_ckpt_name):
420420
)
421421

422422
# Sample a random timestep for each image
423-
timesteps = torch.randint(
424-
0,
425-
noise_scheduler.config.num_train_timesteps,
426-
(b_size,),
427-
device=latents.device,
428-
)
429-
timesteps = timesteps.long()
423+
timesteps, huber_c = train_util.get_timesteps_and_huber_c(args, 0, noise_scheduler.config.num_train_timesteps, noise_scheduler, b_size, latents.device)
424+
430425
# Add noise to the latents according to the noise magnitude at each timestep
431426
# (this is the forward diffusion process)
432427
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
@@ -457,7 +452,7 @@ def remove_model(old_ckpt_name):
457452
else:
458453
target = noise
459454

460-
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
455+
loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c)
461456
loss = loss.mean([1, 2, 3])
462457

463458
loss_weights = batch["loss_weights"] # 各sampleごとのweight

train_db.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -346,7 +346,7 @@ def train(args):
346346

347347
# Sample noise, sample a random timestep for each image, and add noise to the latents,
348348
# with noise offset and/or multires noise if specified
349-
noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents)
349+
noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents)
350350

351351
# Predict the noise residual
352352
with accelerator.autocast():
@@ -358,7 +358,7 @@ def train(args):
358358
else:
359359
target = noise
360360

361-
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
361+
loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c)
362362
if args.masked_loss:
363363
loss = apply_masked_loss(loss, batch)
364364
loss = loss.mean([1, 2, 3])

train_network.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -843,7 +843,7 @@ def remove_model(old_ckpt_name):
843843

844844
# Sample noise, sample a random timestep for each image, and add noise to the latents,
845845
# with noise offset and/or multires noise if specified
846-
noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(
846+
noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(
847847
args, noise_scheduler, latents
848848
)
849849

@@ -873,7 +873,7 @@ def remove_model(old_ckpt_name):
873873
else:
874874
target = noise
875875

876-
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
876+
loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c)
877877
if args.masked_loss:
878878
loss = apply_masked_loss(loss, batch)
879879
loss = loss.mean([1, 2, 3])

train_textual_inversion.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -572,7 +572,7 @@ def remove_model(old_ckpt_name):
572572

573573
# Sample noise, sample a random timestep for each image, and add noise to the latents,
574574
# with noise offset and/or multires noise if specified
575-
noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(
575+
noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(
576576
args, noise_scheduler, latents
577577
)
578578

@@ -588,7 +588,7 @@ def remove_model(old_ckpt_name):
588588
else:
589589
target = noise
590590

591-
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
591+
loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c)
592592
if args.masked_loss:
593593
loss = apply_masked_loss(loss, batch)
594594
loss = loss.mean([1, 2, 3])

train_textual_inversion_XTI.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -461,7 +461,7 @@ def remove_model(old_ckpt_name):
461461

462462
# Sample noise, sample a random timestep for each image, and add noise to the latents,
463463
# with noise offset and/or multires noise if specified
464-
noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents)
464+
noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents)
465465

466466
# Predict the noise residual
467467
with accelerator.autocast():
@@ -473,7 +473,7 @@ def remove_model(old_ckpt_name):
473473
else:
474474
target = noise
475475

476-
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
476+
loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c)
477477
if args.masked_loss:
478478
loss = apply_masked_loss(loss, batch)
479479
loss = loss.mean([1, 2, 3])

0 commit comments

Comments
 (0)