From 739639da4b89ef19c54850d89c0118fdb630d2e1 Mon Sep 17 00:00:00 2001 From: NotNANtoN Date: Tue, 18 Oct 2022 16:20:52 +0200 Subject: [PATCH] Fix img2img speed with LMS-Discrete Scheduler Casting `self.sigmas` into a different dtype (the one of original_samples) is not advisable. In my img2img pipeline this leads to a long running time in the `integrate.quad` call later on- by long I mean more than 10x slower. --- src/diffusers/schedulers/scheduling_lms_discrete.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_lms_discrete.py b/src/diffusers/schedulers/scheduling_lms_discrete.py index 1b8ca7c5df8d..9b20149d8a47 100644 --- a/src/diffusers/schedulers/scheduling_lms_discrete.py +++ b/src/diffusers/schedulers/scheduling_lms_discrete.py @@ -250,12 +250,10 @@ def add_noise( timesteps: torch.FloatTensor, ) -> torch.FloatTensor: # Make sure sigmas and timesteps have the same device and dtype as original_samples - self.sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype) - self.timesteps = self.timesteps.to(original_samples.device) + sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype) + schedule_timesteps = self.timesteps.to(original_samples.device) timesteps = timesteps.to(original_samples.device) - schedule_timesteps = self.timesteps - if isinstance(timesteps, torch.IntTensor) or isinstance(timesteps, torch.LongTensor): deprecate( "timesteps as indices", @@ -269,7 +267,7 @@ def add_noise( else: step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps] - sigma = self.sigmas[step_indices].flatten() + sigma = sigmas[step_indices].flatten() while len(sigma.shape) < len(original_samples.shape): sigma = sigma.unsqueeze(-1)