@@ -5124,34 +5124,27 @@ def save_sd_model_on_train_end_common(
5124
5124
5125
5125
5126
5126
def get_timesteps_and_huber_c (args , min_timestep , max_timestep , noise_scheduler , b_size , device ):
5127
-
5128
- # TODO: if a huber loss is selected, it will use constant timesteps for each batch
5129
- # as. In the future there may be a smarter way
5127
+ timesteps = torch .randint (min_timestep , max_timestep , (b_size ,), device = 'cpu' )
5130
5128
5131
5129
if args .loss_type == "huber" or args .loss_type == "smooth_l1" :
5132
- timesteps = torch .randint (min_timestep , max_timestep , (1 ,), device = "cpu" )
5133
- timestep = timesteps .item ()
5134
-
5135
5130
if args .huber_schedule == "exponential" :
5136
5131
alpha = - math .log (args .huber_c ) / noise_scheduler .config .num_train_timesteps
5137
- huber_c = math .exp (- alpha * timestep )
5132
+ huber_c = torch .exp (- alpha * timesteps )
5138
5133
elif args .huber_schedule == "snr" :
5139
- alphas_cumprod = noise_scheduler .alphas_cumprod [ timestep ]
5134
+ alphas_cumprod = torch . index_select ( noise_scheduler .alphas_cumprod , 0 , timesteps )
5140
5135
sigmas = ((1.0 - alphas_cumprod ) / alphas_cumprod ) ** 0.5
5141
5136
huber_c = (1 - args .huber_c ) / (1 + sigmas ) ** 2 + args .huber_c
5142
5137
elif args .huber_schedule == "constant" :
5143
- huber_c = args .huber_c
5138
+ huber_c = torch . full (( b_size ,), args .huber_c )
5144
5139
else :
5145
5140
raise NotImplementedError (f"Unknown Huber loss schedule { args .huber_schedule } !" )
5146
-
5147
- timesteps = timesteps .repeat (b_size ).to (device )
5141
+ huber_c = huber_c .to (device )
5148
5142
elif args .loss_type == "l2" :
5149
- timesteps = torch .randint (min_timestep , max_timestep , (b_size ,), device = device )
5150
- huber_c = 1 # may be anything, as it's not used
5143
+ huber_c = None # may be anything, as it's not used
5151
5144
else :
5152
5145
raise NotImplementedError (f"Unknown loss type { args .loss_type } " )
5153
- timesteps = timesteps .long ()
5154
5146
5147
+ timesteps = timesteps .long ().to (device )
5155
5148
return timesteps , huber_c
5156
5149
5157
5150
@@ -5190,20 +5183,21 @@ def get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents):
5190
5183
return noise , noisy_latents , timesteps , huber_c
5191
5184
5192
5185
5193
- # NOTE: if you're using the scheduled version, huber_c has to depend on the timesteps already
5194
5186
def conditional_loss (
5195
- model_pred : torch .Tensor , target : torch .Tensor , reduction : str = "mean" , loss_type : str = "l2" , huber_c : float = 0.1
5187
+ model_pred : torch .Tensor , target : torch .Tensor , reduction : str , loss_type : str , huber_c : Optional [ torch . Tensor ]
5196
5188
):
5197
5189
5198
5190
if loss_type == "l2" :
5199
5191
loss = torch .nn .functional .mse_loss (model_pred , target , reduction = reduction )
5200
5192
elif loss_type == "huber" :
5193
+ huber_c = huber_c .view (- 1 ,1 ,1 ,1 )
5201
5194
loss = 2 * huber_c * (torch .sqrt ((model_pred - target ) ** 2 + huber_c ** 2 ) - huber_c )
5202
5195
if reduction == "mean" :
5203
5196
loss = torch .mean (loss )
5204
5197
elif reduction == "sum" :
5205
5198
loss = torch .sum (loss )
5206
5199
elif loss_type == "smooth_l1" :
5200
+ huber_c = huber_c .view (- 1 , 1 , 1 , 1 )
5207
5201
loss = 2 * (torch .sqrt ((model_pred - target ) ** 2 + huber_c ** 2 ) - huber_c )
5208
5202
if reduction == "mean" :
5209
5203
loss = torch .mean (loss )
0 commit comments