Skip to content

Add inference support for mps device (Apple Silicon) #292

@pcuenca

Description

@pcuenca

There is a lot of community interest in running diffusers on Apple Silicon. A first step could be to introduce support for the mps device, which currently requires PyTorch nightly builds. Another step down the line could be to convert to Core ML and optimize to make sure that the use of the ANE (Neural Engine) is maximized.

This PR deals with the first approach.

Describe the solution you'd like
The following should work:

pipe = StableDiffusionPipeline.from_pretrained(model_id, use_auth_token=True)
pipe = pipe.to("mps")

# Rest of inference code remains the same

pipe.to would determine whether mps is available in the computer, and would raise an error if it's not.

One consideration is that perhaps not all the components have to be moved to the mps device. Perhaps it's more efficient or practical to keep the text encoder in CPU, for instance. If so, to() would move the required components and the pipelines would need to be adapted to move the tensors transparently.

Describe alternatives you've considered
Conversion to Core ML, as mentioned above.

Additional context
I have tested the unet module in mps vs cpu, and these are some preliminary results in my computer (M1 Max with 64 GB of unified RAM), when iterating through the 51 default steps of Stable Diffusion with the default scheduler and no classifier-free guidance:

Device: cpu, torch: 1.12.1,  time: 92.09s
Device: cpu, torch: 1.13.0.dev20220830,  time: 92.12s
Device: mps:0, torch: 1.13.0.dev20220830,  time: 17.20s

Execution in mps complains about one operation having to be performed in CPU:

The operator 'aten::masked_select' is not currently supported on the MPS backend and will fall back to run on the CPU. This may have performance implications. (Triggered internally at /Users/runner/work/pytorch/pytorch/pytorch/aten/src/ATen/mps/MPSFallback.mm:11.)

We need to investigate the reason and whether an alternative would yield better performance.

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions