-
Notifications
You must be signed in to change notification settings - Fork 6.2k
Inference support for mps
device
#355
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Changes from all commits
Commits
Show all changes
43 commits
Select commit
Hold shift + click to select a range
137e1d5
Initial support for mps in Stable Diffusion pipeline.
pcuenca 0ef1d1e
Initial "warmup" implementation when using mps.
pcuenca ae5ea46
Make some deterministic tests pass with mps.
pcuenca 4ed22c2
Disable training tests when using mps.
pcuenca bc93b51
SD: generate latents in CPU then move to device.
pcuenca 34c0eff
Remove prints.
pcuenca 314d70a
Merge remote-tracking branch 'origin/main' into mps
pcuenca 66b6752
Pass AutoencoderKL test_output_pretrained with mps.
pcuenca db7da01
Style
pcuenca f20a0dd
Do not use torch.long for log op in mps device.
pcuenca 7f40f24
Perform incompatible padding ops in CPU.
pcuenca 7103993
Style: fix import order.
pcuenca c931d2a
Remove unused symbols.
pcuenca d0e85f3
Remove MPSWarmupMixin, do not apply automatically.
pcuenca 692b1be
Add comment for mps fallback to CPU step.
pcuenca 36b6a46
Add README_mps.md for mps installation and use.
pcuenca 261a784
Apply `black` to modified files.
pcuenca 15d86ff
Restrict README_mps to SD, show measures in table.
pcuenca 5ed8889
Make PNDM indexing compatible with mps.
pcuenca 12f6670
Do not use float64 when using LDMScheduler.
pcuenca ce1e863
Fix typo identified by @patil-suraj
pcuenca 8d5c595
Merge branch 'main' into mps
pcuenca 3943fde
Adapt example to new output style.
pcuenca cec0c7e
Restore 1:1 results reproducibility with CompVis.
pcuenca 86299b1
Merge branch 'mps' of github.com:huggingface/diffusers into mps
pcuenca 1bf8c4c
Move PyTorch nightly to requirements.
pcuenca 220999f
Adapt `test_scheduler_outputs_equivalence` ton MPS.
pcuenca 9280913
mps: skip training tests instead of ignoring silently.
pcuenca d8a093d
Make VQModel tests pass on mps.
pcuenca 0f60435
mps ddim tests: warmup, increase tolerance.
pcuenca 3c59b39
ScoreSdeVeScheduler indexing made mps compatible.
pcuenca 1a2af52
Make ldm pipeline tests pass using warmup.
pcuenca 99d0704
Merge branch 'main' into mps
pcuenca e4181f0
Style
pcuenca 1b05d0f
Simplify casting as suggested in PR.
pcuenca df6683b
Add Known Issues to readme.
pcuenca 4e4bd62
`isort` import order.
pcuenca c985b50
Remove _mps_warmup helpers from ModelMixin.
pcuenca cdd1c41
Skip tests using unittest decorator for consistency.
pcuenca 44f485b
Remove temporary var.
pcuenca dfd5a6e
Remove spurious blank space.
pcuenca b0579c2
Remove unused symbol.
pcuenca 2e52457
Remove README_mps.
pcuenca File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -15,11 +15,13 @@ | |
|
||
import inspect | ||
import tempfile | ||
import unittest | ||
from typing import Dict, List, Tuple | ||
|
||
import numpy as np | ||
import torch | ||
|
||
from diffusers.modeling_utils import ModelMixin | ||
from diffusers.testing_utils import torch_device | ||
from diffusers.training_utils import EMAModel | ||
|
||
|
@@ -38,6 +40,11 @@ def test_from_pretrained_save_pretrained(self): | |
new_model.to(torch_device) | ||
|
||
with torch.no_grad(): | ||
# Warmup pass when using mps (see #372) | ||
if torch_device == "mps" and isinstance(model, ModelMixin): | ||
_ = model(**self.dummy_input) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nice! |
||
_ = new_model(**self.dummy_input) | ||
|
||
image = model(**inputs_dict) | ||
if isinstance(image, dict): | ||
image = image.sample | ||
|
@@ -55,7 +62,12 @@ def test_determinism(self): | |
model = self.model_class(**init_dict) | ||
model.to(torch_device) | ||
model.eval() | ||
|
||
with torch.no_grad(): | ||
# Warmup pass when using mps (see #372) | ||
if torch_device == "mps" and isinstance(model, ModelMixin): | ||
model(**self.dummy_input) | ||
|
||
first = model(**inputs_dict) | ||
if isinstance(first, dict): | ||
first = first.sample | ||
|
@@ -132,6 +144,7 @@ def test_model_from_config(self): | |
|
||
self.assertEqual(output_1.shape, output_2.shape) | ||
|
||
pcuenca marked this conversation as resolved.
Show resolved
Hide resolved
|
||
@unittest.skipIf(torch_device == "mps", "Training is not supported in mps") | ||
def test_training(self): | ||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() | ||
|
||
|
@@ -147,6 +160,7 @@ def test_training(self): | |
loss = torch.nn.functional.mse_loss(output, noise) | ||
loss.backward() | ||
|
||
@unittest.skipIf(torch_device == "mps", "Training is not supported in mps") | ||
def test_ema_training(self): | ||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() | ||
|
||
|
@@ -167,8 +181,13 @@ def test_ema_training(self): | |
|
||
def test_scheduler_outputs_equivalence(self): | ||
def set_nan_tensor_to_zero(t): | ||
# Temporary fallback until `aten::_index_put_impl_` is implemented in mps | ||
# Track progress in https://github.com/pytorch/pytorch/issues/77764 | ||
device = t.device | ||
if device.type == "mps": | ||
t = t.to("cpu") | ||
t[t != t] = 0 | ||
return t | ||
return t.to(device) | ||
|
||
def recursive_check(tuple_object, dict_object): | ||
if isinstance(tuple_object, (List, Tuple)): | ||
|
@@ -198,7 +217,12 @@ def recursive_check(tuple_object, dict_object): | |
model.to(torch_device) | ||
model.eval() | ||
|
||
outputs_dict = model(**inputs_dict) | ||
outputs_tuple = model(**inputs_dict, return_dict=False) | ||
with torch.no_grad(): | ||
pcuenca marked this conversation as resolved.
Show resolved
Hide resolved
|
||
# Warmup pass when using mps (see #372) | ||
if torch_device == "mps" and isinstance(model, ModelMixin): | ||
model(**self.dummy_input) | ||
|
||
outputs_dict = model(**inputs_dict) | ||
outputs_tuple = model(**inputs_dict, return_dict=False) | ||
|
||
recursive_check(outputs_tuple, outputs_dict) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.