Skip to content

Commit 986c892

Browse files
author
recris
committed
make timestep sampling behave in the standard way when huber loss is used
1 parent 0b7927e commit 986c892

File tree

1 file changed

+10
-16
lines changed

1 file changed

+10
-16
lines changed

library/train_util.py

Lines changed: 10 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -5124,34 +5124,27 @@ def save_sd_model_on_train_end_common(
51245124

51255125

51265126
def get_timesteps_and_huber_c(args, min_timestep, max_timestep, noise_scheduler, b_size, device):
5127-
5128-
# TODO: if a huber loss is selected, it will use constant timesteps for each batch
5129-
# as. In the future there may be a smarter way
5127+
timesteps = torch.randint(min_timestep, max_timestep, (b_size,), device='cpu')
51305128

51315129
if args.loss_type == "huber" or args.loss_type == "smooth_l1":
5132-
timesteps = torch.randint(min_timestep, max_timestep, (1,), device="cpu")
5133-
timestep = timesteps.item()
5134-
51355130
if args.huber_schedule == "exponential":
51365131
alpha = -math.log(args.huber_c) / noise_scheduler.config.num_train_timesteps
5137-
huber_c = math.exp(-alpha * timestep)
5132+
huber_c = torch.exp(-alpha * timesteps)
51385133
elif args.huber_schedule == "snr":
5139-
alphas_cumprod = noise_scheduler.alphas_cumprod[timestep]
5134+
alphas_cumprod = torch.index_select(noise_scheduler.alphas_cumprod, 0, timesteps)
51405135
sigmas = ((1.0 - alphas_cumprod) / alphas_cumprod) ** 0.5
51415136
huber_c = (1 - args.huber_c) / (1 + sigmas) ** 2 + args.huber_c
51425137
elif args.huber_schedule == "constant":
5143-
huber_c = args.huber_c
5138+
huber_c = torch.full((b_size,), args.huber_c)
51445139
else:
51455140
raise NotImplementedError(f"Unknown Huber loss schedule {args.huber_schedule}!")
5146-
5147-
timesteps = timesteps.repeat(b_size).to(device)
5141+
huber_c = huber_c.to(device)
51485142
elif args.loss_type == "l2":
5149-
timesteps = torch.randint(min_timestep, max_timestep, (b_size,), device=device)
5150-
huber_c = 1 # may be anything, as it's not used
5143+
huber_c = None # may be anything, as it's not used
51515144
else:
51525145
raise NotImplementedError(f"Unknown loss type {args.loss_type}")
5153-
timesteps = timesteps.long()
51545146

5147+
timesteps = timesteps.long().to(device)
51555148
return timesteps, huber_c
51565149

51575150

@@ -5190,20 +5183,21 @@ def get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents):
51905183
return noise, noisy_latents, timesteps, huber_c
51915184

51925185

5193-
# NOTE: if you're using the scheduled version, huber_c has to depend on the timesteps already
51945186
def conditional_loss(
5195-
model_pred: torch.Tensor, target: torch.Tensor, reduction: str = "mean", loss_type: str = "l2", huber_c: float = 0.1
5187+
model_pred: torch.Tensor, target: torch.Tensor, reduction: str, loss_type: str, huber_c: Optional[torch.Tensor]
51965188
):
51975189

51985190
if loss_type == "l2":
51995191
loss = torch.nn.functional.mse_loss(model_pred, target, reduction=reduction)
52005192
elif loss_type == "huber":
5193+
huber_c = huber_c.view(-1,1,1,1)
52015194
loss = 2 * huber_c * (torch.sqrt((model_pred - target) ** 2 + huber_c**2) - huber_c)
52025195
if reduction == "mean":
52035196
loss = torch.mean(loss)
52045197
elif reduction == "sum":
52055198
loss = torch.sum(loss)
52065199
elif loss_type == "smooth_l1":
5200+
huber_c = huber_c.view(-1, 1, 1, 1)
52075201
loss = 2 * (torch.sqrt((model_pred - target) ** 2 + huber_c**2) - huber_c)
52085202
if reduction == "mean":
52095203
loss = torch.mean(loss)

0 commit comments

Comments
 (0)