-
Notifications
You must be signed in to change notification settings - Fork 6.2k
Description
Describe the bug
When you use an LMSDiscreteScheduler on an Apple Silicon machine, you'll get the following error:
Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64. Please use float32 instead
The offending line is 134 in unet_2d_condition.py.
The current code is:
timesteps = timesteps[None].to(sample.device)
Changing that to the following stops the crash:
timesteps = timesteps[None].long().to(sample.device)
However, I believe you'd really want to do a check to see if the current device is MPS and only do the format conversion if you are on MPS?
Reproduction
When you use an LMSDiscreteScheduler on an Apple Silicon machine you should see the crash.
Logs
No response
System Info
The current main branch from the repo since that appears to be different from the current release version (0.2.4?)