Skip to content

Inference support for mps device #355

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 43 commits into from
Sep 8, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
43 commits
Select commit Hold shift + click to select a range
137e1d5
Initial support for mps in Stable Diffusion pipeline.
pcuenca Sep 4, 2022
0ef1d1e
Initial "warmup" implementation when using mps.
pcuenca Sep 4, 2022
ae5ea46
Make some deterministic tests pass with mps.
pcuenca Sep 4, 2022
4ed22c2
Disable training tests when using mps.
pcuenca Sep 4, 2022
bc93b51
SD: generate latents in CPU then move to device.
pcuenca Sep 4, 2022
34c0eff
Remove prints.
pcuenca Sep 4, 2022
314d70a
Merge remote-tracking branch 'origin/main' into mps
pcuenca Sep 4, 2022
66b6752
Pass AutoencoderKL test_output_pretrained with mps.
pcuenca Sep 4, 2022
db7da01
Style
pcuenca Sep 5, 2022
f20a0dd
Do not use torch.long for log op in mps device.
pcuenca Sep 5, 2022
7f40f24
Perform incompatible padding ops in CPU.
pcuenca Sep 5, 2022
7103993
Style: fix import order.
pcuenca Sep 5, 2022
c931d2a
Remove unused symbols.
pcuenca Sep 5, 2022
d0e85f3
Remove MPSWarmupMixin, do not apply automatically.
pcuenca Sep 5, 2022
692b1be
Add comment for mps fallback to CPU step.
pcuenca Sep 5, 2022
36b6a46
Add README_mps.md for mps installation and use.
pcuenca Sep 5, 2022
261a784
Apply `black` to modified files.
pcuenca Sep 5, 2022
15d86ff
Restrict README_mps to SD, show measures in table.
pcuenca Sep 5, 2022
5ed8889
Make PNDM indexing compatible with mps.
pcuenca Sep 5, 2022
12f6670
Do not use float64 when using LDMScheduler.
pcuenca Sep 5, 2022
ce1e863
Fix typo identified by @patil-suraj
pcuenca Sep 6, 2022
8d5c595
Merge branch 'main' into mps
pcuenca Sep 6, 2022
3943fde
Adapt example to new output style.
pcuenca Sep 6, 2022
cec0c7e
Restore 1:1 results reproducibility with CompVis.
pcuenca Sep 6, 2022
86299b1
Merge branch 'mps' of github.com:huggingface/diffusers into mps
pcuenca Sep 6, 2022
1bf8c4c
Move PyTorch nightly to requirements.
pcuenca Sep 6, 2022
220999f
Adapt `test_scheduler_outputs_equivalence` ton MPS.
pcuenca Sep 6, 2022
9280913
mps: skip training tests instead of ignoring silently.
pcuenca Sep 6, 2022
d8a093d
Make VQModel tests pass on mps.
pcuenca Sep 6, 2022
0f60435
mps ddim tests: warmup, increase tolerance.
pcuenca Sep 6, 2022
3c59b39
ScoreSdeVeScheduler indexing made mps compatible.
pcuenca Sep 6, 2022
1a2af52
Make ldm pipeline tests pass using warmup.
pcuenca Sep 6, 2022
99d0704
Merge branch 'main' into mps
pcuenca Sep 6, 2022
e4181f0
Style
pcuenca Sep 6, 2022
1b05d0f
Simplify casting as suggested in PR.
pcuenca Sep 6, 2022
df6683b
Add Known Issues to readme.
pcuenca Sep 6, 2022
4e4bd62
`isort` import order.
pcuenca Sep 6, 2022
c985b50
Remove _mps_warmup helpers from ModelMixin.
pcuenca Sep 8, 2022
cdd1c41
Skip tests using unittest decorator for consistency.
pcuenca Sep 8, 2022
44f485b
Remove temporary var.
pcuenca Sep 8, 2022
dfd5a6e
Remove spurious blank space.
pcuenca Sep 8, 2022
b0579c2
Remove unused symbol.
pcuenca Sep 8, 2022
2e52457
Remove README_mps.
pcuenca Sep 8, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,10 @@ pip install --upgrade diffusers # should install diffusers 0.2.4
conda install -c conda-forge diffusers
```

**Apple Silicon (M1/M2) support**

Please, refer to [the documentation](https://huggingface.co/docs/diffusers/optimization/mps).

## Contributing

We ❤️ contributions from the open-source community!
Expand Down
1 change: 1 addition & 0 deletions src/diffusers/models/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@ def __init__(self, dim, n_heads, d_head, dropout=0.0, context_dim=None, gated_ff
self.checkpoint = checkpoint

def forward(self, x, context=None):
x = x.contiguous() if x.device.type == "mps" else x
x = self.attn1(self.norm1(x)) + x
x = self.attn2(self.norm2(x), context=context) + x
x = self.ff(self.norm3(x)) + x
Expand Down
5 changes: 5 additions & 0 deletions src/diffusers/models/resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -446,10 +446,15 @@ def upfirdn2d_native(input, kernel, up=1, down=1, pad=(0, 0)):
kernel_h, kernel_w = kernel.shape

out = input.view(-1, in_h, 1, in_w, 1, minor)

# Temporary workaround for mps specific issue: https://github.com/pytorch/pytorch/issues/84535
if input.device.type == "mps":
out = out.to("cpu")
out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1])
out = out.view(-1, in_h * up_y, in_w * up_x, minor)

out = F.pad(out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)])
out = out.to(input.device) # Move back to mps if necessary
out = out[
:,
max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0),
Expand Down
3 changes: 2 additions & 1 deletion src/diffusers/models/unet_2d_condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,8 @@ def forward(
if not torch.is_tensor(timesteps):
timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device)
elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
timesteps = timesteps[None].to(sample.device)
timesteps = timesteps.to(dtype=torch.float32)
timesteps = timesteps[None].to(device=sample.device)

# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timesteps = timesteps * torch.ones(sample.shape[0], dtype=timesteps.dtype, device=timesteps.device)
Expand Down
5 changes: 4 additions & 1 deletion src/diffusers/models/vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,7 +338,10 @@ def __init__(self, parameters, deterministic=False):
self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device)

def sample(self, generator: Optional[torch.Generator] = None) -> torch.FloatTensor:
x = self.mean + self.std * torch.randn(self.mean.shape, generator=generator, device=self.parameters.device)
device = self.parameters.device
sample_device = "cpu" if device.type == "mps" else device
sample = torch.randn(self.mean.shape, generator=generator, device=sample_device).to(device)
x = self.mean + self.std * sample
return x

def kl(self, other=None):
Expand Down
1 change: 0 additions & 1 deletion src/diffusers/pipeline_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,6 @@ class ImagePipelineOutput(BaseOutput):


class DiffusionPipeline(ConfigMixin):

config_name = "model_index.json"

def register_modules(self, **kwargs):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,17 +101,22 @@ def __call__(
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])

# get the initial random noise unless the user supplied it

# Unlike in other pipelines, latents need to be generated in the target device
# for 1-to-1 results reproducibility with the CompVis implementation.
# However this currently doesn't work in `mps`.
latents_device = "cpu" if self.device.type == "mps" else self.device
latents_shape = (batch_size, self.unet.in_channels, height // 8, width // 8)
if latents is None:
latents = torch.randn(
latents_shape,
generator=generator,
device=self.device,
device=latents_device,
)
else:
if latents.shape != latents_shape:
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")
latents = latents.to(self.device)
latents = latents.to(self.device)

# set timesteps
accepts_offset = "offset" in set(inspect.signature(self.scheduler.set_timesteps).parameters.keys())
Expand Down
3 changes: 2 additions & 1 deletion src/diffusers/schedulers/scheduling_pndm.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,8 @@ def add_noise(
noise: Union[torch.FloatTensor, np.ndarray],
timesteps: Union[torch.IntTensor, np.ndarray],
) -> torch.Tensor:

# mps requires indices to be in the same device, so we use cpu as is the default with cuda
timesteps = timesteps.to(self.alphas_cumprod.device)
sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
sqrt_alpha_prod = self.match_shape(sqrt_alpha_prod, original_samples)
sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5
Expand Down
9 changes: 7 additions & 2 deletions src/diffusers/schedulers/scheduling_sde_ve.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,9 @@ def get_adjacent_sigma(self, timesteps, t):
return np.where(timesteps == 0, np.zeros_like(t), self.discrete_sigmas[timesteps - 1])
elif tensor_format == "pt":
return torch.where(
timesteps == 0, torch.zeros_like(t), self.discrete_sigmas[timesteps - 1].to(timesteps.device)
timesteps == 0,
torch.zeros_like(t.to(timesteps.device)),
self.discrete_sigmas[timesteps - 1].to(timesteps.device),
)

raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.")
Expand Down Expand Up @@ -155,8 +157,11 @@ def step_pred(
) # torch.repeat_interleave(timestep, sample.shape[0])
timesteps = (timestep * (len(self.timesteps) - 1)).long()

# mps requires indices to be in the same device, so we use cpu as is the default with cuda
timesteps = timesteps.to(self.discrete_sigmas.device)

sigma = self.discrete_sigmas[timesteps].to(sample.device)
adjacent_sigma = self.get_adjacent_sigma(timesteps, timestep)
adjacent_sigma = self.get_adjacent_sigma(timesteps, timestep).to(sample.device)
drift = self.zeros_like(sample)
diffusion = (sigma**2 - adjacent_sigma**2) ** 0.5

Expand Down
1 change: 1 addition & 0 deletions src/diffusers/testing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

global_rng = random.Random()
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
torch_device = "mps" if torch.backends.mps.is_available() else torch_device


def parse_flag_from_env(key, default=False):
Expand Down
30 changes: 27 additions & 3 deletions tests/test_modeling_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,13 @@

import inspect
import tempfile
import unittest
from typing import Dict, List, Tuple

import numpy as np
import torch

from diffusers.modeling_utils import ModelMixin
from diffusers.testing_utils import torch_device
from diffusers.training_utils import EMAModel

Expand All @@ -38,6 +40,11 @@ def test_from_pretrained_save_pretrained(self):
new_model.to(torch_device)

with torch.no_grad():
# Warmup pass when using mps (see #372)
if torch_device == "mps" and isinstance(model, ModelMixin):
_ = model(**self.dummy_input)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice!

_ = new_model(**self.dummy_input)

image = model(**inputs_dict)
if isinstance(image, dict):
image = image.sample
Expand All @@ -55,7 +62,12 @@ def test_determinism(self):
model = self.model_class(**init_dict)
model.to(torch_device)
model.eval()

with torch.no_grad():
# Warmup pass when using mps (see #372)
if torch_device == "mps" and isinstance(model, ModelMixin):
model(**self.dummy_input)

first = model(**inputs_dict)
if isinstance(first, dict):
first = first.sample
Expand Down Expand Up @@ -132,6 +144,7 @@ def test_model_from_config(self):

self.assertEqual(output_1.shape, output_2.shape)

@unittest.skipIf(torch_device == "mps", "Training is not supported in mps")
def test_training(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()

Expand All @@ -147,6 +160,7 @@ def test_training(self):
loss = torch.nn.functional.mse_loss(output, noise)
loss.backward()

@unittest.skipIf(torch_device == "mps", "Training is not supported in mps")
def test_ema_training(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()

Expand All @@ -167,8 +181,13 @@ def test_ema_training(self):

def test_scheduler_outputs_equivalence(self):
def set_nan_tensor_to_zero(t):
# Temporary fallback until `aten::_index_put_impl_` is implemented in mps
# Track progress in https://github.com/pytorch/pytorch/issues/77764
device = t.device
if device.type == "mps":
t = t.to("cpu")
t[t != t] = 0
return t
return t.to(device)

def recursive_check(tuple_object, dict_object):
if isinstance(tuple_object, (List, Tuple)):
Expand Down Expand Up @@ -198,7 +217,12 @@ def recursive_check(tuple_object, dict_object):
model.to(torch_device)
model.eval()

outputs_dict = model(**inputs_dict)
outputs_tuple = model(**inputs_dict, return_dict=False)
with torch.no_grad():
# Warmup pass when using mps (see #372)
if torch_device == "mps" and isinstance(model, ModelMixin):
model(**self.dummy_input)

outputs_dict = model(**inputs_dict)
outputs_tuple = model(**inputs_dict, return_dict=False)

recursive_check(outputs_tuple, outputs_dict)
2 changes: 1 addition & 1 deletion tests/test_models_unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ def dummy_input(self, sizes=(32, 32)):
num_channels = 3

noise = floats_tensor((batch_size, num_channels) + sizes).to(torch_device)
time_step = torch.tensor(batch_size * [10]).to(torch_device)
time_step = torch.tensor(batch_size * [10]).to(dtype=torch.int32, device=torch_device)

return {"sample": noise, "timestep": time_step}

Expand Down
8 changes: 8 additions & 0 deletions tests/test_models_vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import torch

from diffusers import AutoencoderKL
from diffusers.modeling_utils import ModelMixin
from diffusers.testing_utils import floats_tensor, torch_device

from .test_modeling_common import ModelTesterMixin
Expand Down Expand Up @@ -80,6 +81,13 @@ def test_output_pretrained(self):
model = model.to(torch_device)
model.eval()

# One-time warmup pass (see #372)
if torch_device == "mps" and isinstance(model, ModelMixin):
image = torch.randn(1, model.config.in_channels, model.config.sample_size, model.config.sample_size)
image = image.to(torch_device)
with torch.no_grad():
_ = model(image, sample_posterior=True).sample

torch.manual_seed(0)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(0)
Expand Down
3 changes: 3 additions & 0 deletions tests/test_models_vq.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,9 @@ def test_output_pretrained(self):
image = torch.randn(1, model.config.in_channels, model.config.sample_size, model.config.sample_size)
image = image.to(torch_device)
with torch.no_grad():
# Warmup pass when using mps (see #372)
if torch_device == "mps":
_ = model(image)
output = model(image).sample

output_slice = output[0, -1, -3:, -3:].flatten().cpu()
Expand Down
22 changes: 20 additions & 2 deletions tests/test_pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,10 @@ def test_ddim(self):
ddpm.to(torch_device)
ddpm.set_progress_bar_config(disable=None)

# Warmup pass when using mps (see #372)
if torch_device == "mps":
_ = ddpm(num_inference_steps=1)

generator = torch.manual_seed(0)
image = ddpm(generator=generator, num_inference_steps=2, output_type="numpy").images

Expand All @@ -208,8 +212,9 @@ def test_ddim(self):
expected_slice = np.array(
[1.000e00, 5.717e-01, 4.717e-01, 1.000e00, 0.000e00, 1.000e00, 3.000e-04, 0.000e00, 9.000e-04]
)
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
tolerance = 1e-2 if torch_device != "mps" else 3e-2
assert np.abs(image_slice.flatten() - expected_slice).max() < tolerance
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < tolerance

def test_pndm_cifar10(self):
unet = self.dummy_uncond_unet
Expand Down Expand Up @@ -245,6 +250,14 @@ def test_ldm_text2img(self):
ldm.set_progress_bar_config(disable=None)

prompt = "A painting of a squirrel eating a burger"

# Warmup pass when using mps (see #372)
if torch_device == "mps":
generator = torch.manual_seed(0)
_ = ldm([prompt], generator=generator, guidance_scale=6.0, num_inference_steps=1, output_type="numpy")[
"sample"
]

generator = torch.manual_seed(0)
image = ldm([prompt], generator=generator, guidance_scale=6.0, num_inference_steps=2, output_type="numpy")[
"sample"
Expand Down Expand Up @@ -442,6 +455,11 @@ def test_ldm_uncond(self):
ldm.to(torch_device)
ldm.set_progress_bar_config(disable=None)

# Warmup pass when using mps (see #372)
if torch_device == "mps":
generator = torch.manual_seed(0)
_ = ldm(generator=generator, num_inference_steps=1, output_type="numpy").images

generator = torch.manual_seed(0)
image = ldm(generator=generator, num_inference_steps=2, output_type="numpy").images

Expand Down