Skip to content

Apple MPS error in unet_2d_condition.py #358

@FahimF

Description

@FahimF

Describe the bug

When you use an LMSDiscreteScheduler on an Apple Silicon machine, you'll get the following error:
Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64. Please use float32 instead

The offending line is 134 in unet_2d_condition.py.

The current code is:
timesteps = timesteps[None].to(sample.device)

Changing that to the following stops the crash:
timesteps = timesteps[None].long().to(sample.device)

However, I believe you'd really want to do a check to see if the current device is MPS and only do the format conversion if you are on MPS?

Reproduction

When you use an LMSDiscreteScheduler on an Apple Silicon machine you should see the crash.

Logs

No response

System Info

The current main branch from the repo since that appears to be different from the current release version (0.2.4?)

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions