-
Notifications
You must be signed in to change notification settings - Fork 6.3k
[Half precision] Make sure half-precision is correct #182
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
42e6d51
4667928
760a071
f3d19e1
c7743d5
b30c8c7
468b548
387a6b0
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -50,7 +50,13 @@ def __call__( | |
self.text_encoder.to(torch_device) | ||
|
||
# get prompt text embeddings | ||
text_input = self.tokenizer(prompt, padding=True, truncation=True, return_tensors="pt") | ||
text_input = self.tokenizer( | ||
prompt, | ||
padding="max_length", | ||
max_length=self.tokenizer.model_max_length, | ||
truncation=True, | ||
return_tensors="pt", | ||
) | ||
Comment on lines
+58
to
+64
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is important here to always pad to max_length, as that's how the model was trained. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes agree, but let's make sure to not do this when we create our text to image training script (it's def cleaner to mask out padding tokens and should help the model learn better as stated by Katherine on Slack as well) |
||
text_embeddings = self.text_encoder(text_input.input_ids.to(torch_device))[0] | ||
|
||
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) | ||
|
@@ -74,25 +80,32 @@ def __call__( | |
latents = torch.randn( | ||
(batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size), | ||
generator=generator, | ||
device=torch_device, | ||
patrickvonplaten marked this conversation as resolved.
Show resolved
Hide resolved
|
||
) | ||
latents = latents.to(torch_device) | ||
|
||
# set timesteps | ||
accepts_offset = "offset" in set(inspect.signature(self.scheduler.set_timesteps).parameters.keys()) | ||
extra_set_kwargs = {} | ||
if accepts_offset: | ||
extra_set_kwargs["offset"] = 1 | ||
|
||
self.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs) | ||
|
||
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature | ||
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. | ||
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 | ||
# and should be between [0, 1] | ||
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) | ||
extra_kwargs = {} | ||
extra_forward_kwargs = {} | ||
patrickvonplaten marked this conversation as resolved.
Show resolved
Hide resolved
|
||
if accepts_eta: | ||
extra_kwargs["eta"] = eta | ||
|
||
self.scheduler.set_timesteps(num_inference_steps) | ||
extra_forward_kwargs["eta"] = eta | ||
patrickvonplaten marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
for t in tqdm(self.scheduler.timesteps): | ||
# expand the latents if we are doing classifier free guidance | ||
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents | ||
|
||
# predict the noise residual | ||
t = t + 1 | ||
patrickvonplaten marked this conversation as resolved.
Show resolved
Hide resolved
|
||
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings)["sample"] | ||
|
||
# perform guidance | ||
|
@@ -101,7 +114,7 @@ def __call__( | |
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) | ||
|
||
# compute the previous noisy sample x_t -> x_t-1 | ||
latents = self.scheduler.step(noise_pred, t, latents, **extra_kwargs)["prev_sample"] | ||
latents = self.scheduler.step(noise_pred, t, latents, **extra_forward_kwargs)["prev_sample"] | ||
patrickvonplaten marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
# scale and decode the image latents with vae | ||
latents = 1 / 0.18215 * latents | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -59,6 +59,7 @@ def __init__( | |
trained_betas=None, | ||
timestep_values=None, | ||
clip_sample=True, | ||
do_neg_alpha_one=True, | ||
tensor_format="pt", | ||
): | ||
|
||
|
@@ -75,7 +76,7 @@ def __init__( | |
|
||
self.alphas = 1.0 - self.betas | ||
self.alphas_cumprod = np.cumprod(self.alphas, axis=0) | ||
self.one = np.array(1.0) | ||
self.negative_alpha_cumprod = np.array(1.0) if do_neg_alpha_one else self.alphas_cumprod[0] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm not sure I understand why this is needed, could you maybe add some comment here explaining why we added it here as this There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yeah maybe we need a better name here indeed. Let me know if you have better ideas? For every step we need to know the previous step. Just at the step There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ok I changed it to There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. changing that actually real quick |
||
|
||
# setable values | ||
self.num_inference_steps = None | ||
|
@@ -86,19 +87,20 @@ def __init__( | |
|
||
def _get_variance(self, timestep, prev_timestep): | ||
alpha_prod_t = self.alphas_cumprod[timestep] | ||
alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.one | ||
alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.negative_alpha_cumprod | ||
beta_prod_t = 1 - alpha_prod_t | ||
beta_prod_t_prev = 1 - alpha_prod_t_prev | ||
|
||
variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev) | ||
|
||
return variance | ||
|
||
def set_timesteps(self, num_inference_steps): | ||
def set_timesteps(self, num_inference_steps, offset=0): | ||
self.num_inference_steps = num_inference_steps | ||
self.timesteps = np.arange( | ||
0, self.config.num_train_timesteps, self.config.num_train_timesteps // self.num_inference_steps | ||
)[::-1].copy() | ||
self.timesteps += offset | ||
self.set_format(tensor_format=self.tensor_format) | ||
|
||
def step( | ||
|
@@ -126,7 +128,7 @@ def step( | |
|
||
# 2. compute alphas, betas | ||
alpha_prod_t = self.alphas_cumprod[timestep] | ||
alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.one | ||
alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.negative_alpha_cumprod | ||
beta_prod_t = 1 - alpha_prod_t | ||
|
||
# 3. compute predicted original sample from predicted noise also called | ||
|
Uh oh!
There was an error while loading. Please reload this page.