Skip to content

Inference support for mps device #355

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 43 commits into from
Sep 8, 2022
Merged

Inference support for mps device #355

merged 43 commits into from
Sep 8, 2022

Conversation

pcuenca
Copy link
Member

@pcuenca pcuenca commented Sep 4, 2022

This addresses the issues identified when assessing inference on Apple Silicon, see #292 (comment) for details.

Current status

  • Stable Diffusion pipeline works.
  • Results in CPU / MPU are reproducible when using the same seeds. Generators do not work in the mps device, so we needed some minor adjustments.
  • Some incompatible ops identified by failing tests were rewritten (or we fall back to CPU). But I have only verified test_models_unet and test_models_vae.
  • Determinism tests pass, but the solution is a workaround hack.

The hack

We perform a one-time "warmup" forward pass through unet and vae because the result from the first pass is different than subsequent results. I suspect something related to randomness might be at play, but I haven't identified the root cause. We have several options here:

  • Remove the hack, it's probably overkill. We can recommend users to run a full pass through the pipeline (1 step is enough) after moving it to the device. Downside is that determinism tests fail in test_models_unet and test_models_vae.
  • Find the cause and apply a proper solution.

I'd like to merge soon but don't like the hack. Can we remove it and live with some failing tests until we can find what's causing the issue?

Update: we only perform the hack during testing. Users are recommended to perform an initial pass if they care about reproducibility.

To do:

  • Some unet and vae tests still fail. For example, test_output_pretrained in AutoencoderKLTests.
  • Translate some incompatible ops or fall back to CPU.
  • Ensure tests pass.
  • Find a more principled workaround to perform the initial "warm up" pass. I tried to use forward hooks, but couldn't do it because they don't pass keyword arguments, which we use in many of our forward implementations.
  • Create fail case for first-pass issue (results are different than in subsequent passes). Deferred to be tracked in MPS: models require an initial pass for reproducibility #372.

Fixes #292.

Required when classifier-free guidance is enabled.
For some reason the first run produces results different than the rest.
This is especially important when using the mps device, because
generators are not supported there. See for example
pytorch/pytorch#84288.

In addition, the other pipelines seem to use the same approach: generate
the random samples then move to the appropriate device.

After this change, generating an image in MPS produces the same result
as when using the CPU, if the same seed is used.
@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Sep 4, 2022

The documentation is not available anymore as the PR was closed or merged.

Copy link
Contributor

@patrickvonplaten patrickvonplaten left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not a big fan of the warmup abstraction - I'd prefer to just leave it to the user to call the warmup method. Ok with adding a warmup_mps() method to the ModelMixin, but don't like that this is automatically called - wdyt @pcuenca ?

@patrickvonplaten
Copy link
Contributor

So I'd advocate for the following:

  • Add an experimental _mps_warmup(...) method to ModelMixin and PipelineMixin that the user has to call themselves.
  • With the underscore _ and with a note we make sure to tell people it's experimental and due to a hack currently.
  • We can adapt the tests with some if device is statements to make them pass

@pcuenca
Copy link
Member Author

pcuenca commented Sep 5, 2022

So I'd advocate for the following:

  • Add an experimental _mps_warmup(...) method to ModelMixin and PipelineMixin that the user has to call themselves.
  • With the underscore _ and with a note we make sure to tell people it's experimental and due to a hack currently.
  • We can adapt the tests with some if device is statements to make them pass

Actually I think I'd recommend users to run a pipeline pass with 1 iteration instead before using the outputs, since that's easier. We can use the underscored methods ourselves just to make tests pass.

@pcuenca
Copy link
Member Author

pcuenca commented Sep 5, 2022

Not a big fan of the warmup abstraction - I'd prefer to just leave it to the user to call the warmup method. Ok with adding a warmup_mps() method to the ModelMixin, but don't like that this is automatically called - wdyt @pcuenca ?

Yes, you are right, it's not important enough to expose it to all users. Thanks!

@pcuenca
Copy link
Member Author

pcuenca commented Sep 5, 2022

I removed the new abstraction and the automatic warmup pass, and created a new README_mps.md file with installation instructions and the recommendation to run a 1-step inference.

@patrickvonplaten
Copy link
Contributor

Also @pcuenca - could you downgrade black to 22.3.0 so that make style works? :-) @anton-l let's upgrade black after the release on Thursday :-)

Point to the documentation instead.
@@ -38,6 +40,11 @@ def test_from_pretrained_save_pretrained(self):
new_model.to(torch_device)

with torch.no_grad():
# Warmup pass when using mps (see #372)
if torch_device == "mps" and isinstance(model, ModelMixin):
_ = model(**self.dummy_input)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice!

Copy link
Contributor

@patrickvonplaten patrickvonplaten left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks very nice! Feel free to merge whenever @pcuenca :-)

@pcuenca
Copy link
Member Author

pcuenca commented Sep 8, 2022

Thanks for the help!

@pcuenca pcuenca merged commit 5dda173 into main Sep 8, 2022
@pcuenca pcuenca deleted the mps branch September 8, 2022 11:37
PhaneeshB pushed a commit to nod-ai/diffusers that referenced this pull request Mar 1, 2023
yoonseokjin pushed a commit to yoonseokjin/diffusers that referenced this pull request Dec 25, 2023
* Initial support for mps in Stable Diffusion pipeline.

* Initial "warmup" implementation when using mps.

* Make some deterministic tests pass with mps.

* Disable training tests when using mps.

* SD: generate latents in CPU then move to device.

This is especially important when using the mps device, because
generators are not supported there. See for example
pytorch/pytorch#84288.

In addition, the other pipelines seem to use the same approach: generate
the random samples then move to the appropriate device.

After this change, generating an image in MPS produces the same result
as when using the CPU, if the same seed is used.

* Remove prints.

* Pass AutoencoderKL test_output_pretrained with mps.

Sampling from `posterior` must be done in CPU.

* Style

* Do not use torch.long for log op in mps device.

* Perform incompatible padding ops in CPU.

UNet tests now pass.
See pytorch/pytorch#84535

* Style: fix import order.

* Remove unused symbols.

* Remove MPSWarmupMixin, do not apply automatically.

We do apply warmup in the tests, but not during normal use.
This adopts some PR suggestions by @patrickvonplaten.

* Add comment for mps fallback to CPU step.

* Add README_mps.md for mps installation and use.

* Apply `black` to modified files.

* Restrict README_mps to SD, show measures in table.

* Make PNDM indexing compatible with mps.

Addresses huggingface#239.

* Do not use float64 when using LDMScheduler.

Fixes huggingface#358.

* Fix typo identified by @patil-suraj

Co-authored-by: Suraj Patil <[email protected]>

* Adapt example to new output style.

* Restore 1:1 results reproducibility with CompVis.

However, mps latents need to be generated in CPU because generators
don't work in the mps device.

* Move PyTorch nightly to requirements.

* Adapt `test_scheduler_outputs_equivalence` ton MPS.

* mps: skip training tests instead of ignoring silently.

* Make VQModel tests pass on mps.

* mps ddim tests: warmup, increase tolerance.

* ScoreSdeVeScheduler indexing made mps compatible.

* Make ldm pipeline tests pass using warmup.

* Style

* Simplify casting as suggested in PR.

* Add Known Issues to readme.

* `isort` import order.

* Remove _mps_warmup helpers from ModelMixin.

And just make changes to the tests.

* Skip tests using unittest decorator for consistency.

* Remove temporary var.

* Remove spurious blank space.

* Remove unused symbol.

* Remove README_mps.

Co-authored-by: Suraj Patil <[email protected]>
Co-authored-by: Patrick von Platen <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Add inference support for mps device (Apple Silicon)
4 participants