Skip to content

Commit 6a472c2

Browse files
patrickvonplatenPrathik Rao
authored andcommitted
[Community] One step unet (huggingface#840)
1 parent 27cb665 commit 6a472c2

File tree

2 files changed

+23
-0
lines changed

2 files changed

+23
-0
lines changed

examples/community/README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,3 +5,4 @@
55
| Example | Description | Author | Colab |
66
|:----------|:----------------------|:-----------------|----------:|
77
| CLIP Guided Stable Diffusion | Doing CLIP guidance for text to image generation with Stable Diffusion| [Suraj Patil](https://github.com/patil-suraj/) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/CLIP_Guided_Stable_diffusion_with_diffusers.ipynb) |
8+
| One Step U-Net (Dummy) | [Patrick von Platen](https://github.com/patrickvonplaten/) | - |

examples/community/one_step_unet.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
#!/usr/bin/env python3
2+
import torch
3+
4+
from diffusers import DiffusionPipeline
5+
6+
7+
class UnetSchedulerOneForwardPipeline(DiffusionPipeline):
8+
def __init__(self, unet, scheduler):
9+
super().__init__()
10+
11+
self.register_modules(unet=unet, scheduler=scheduler)
12+
13+
def __call__(self):
14+
image = torch.randn(
15+
(1, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size),
16+
)
17+
timestep = 1
18+
19+
model_output = self.unet(image, timestep).sample
20+
scheduler_output = self.scheduler.step(model_output, timestep, image).prev_sample
21+
22+
return scheduler_output

0 commit comments

Comments
 (0)