Skip to content

Commit 5f25818

Browse files
authored
allow custom height, width in StableDiffusionPipeline (#179)
* allow custom height width * raise if height width are not mul of 8
1 parent c25d8c9 commit 5f25818

File tree

1 file changed

+6
-1
lines changed

1 file changed

+6
-1
lines changed

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@ def __init__(
2828
def __call__(
2929
self,
3030
prompt: Union[str, List[str]],
31+
height: Optional[int] = 512,
32+
width: Optional[int] = 512,
3133
num_inference_steps: Optional[int] = 50,
3234
guidance_scale: Optional[float] = 1.0,
3335
eta: Optional[float] = 0.0,
@@ -45,6 +47,9 @@ def __call__(
4547
else:
4648
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
4749

50+
if height % 8 != 0 or width % 8 != 0:
51+
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
52+
4853
self.unet.to(torch_device)
4954
self.vae.to(torch_device)
5055
self.text_encoder.to(torch_device)
@@ -72,7 +77,7 @@ def __call__(
7277

7378
# get the intial random noise
7479
latents = torch.randn(
75-
(batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size),
80+
(batch_size, self.unet.in_channels, height // 8, width // 8),
7681
generator=generator,
7782
)
7883
latents = latents.to(torch_device)

0 commit comments

Comments
 (0)