@@ -28,6 +28,8 @@ def __init__(
28
28
def __call__ (
29
29
self ,
30
30
prompt : Union [str , List [str ]],
31
+ height : Optional [int ] = 512 ,
32
+ width : Optional [int ] = 512 ,
31
33
num_inference_steps : Optional [int ] = 50 ,
32
34
guidance_scale : Optional [float ] = 1.0 ,
33
35
eta : Optional [float ] = 0.0 ,
@@ -45,6 +47,9 @@ def __call__(
45
47
else :
46
48
raise ValueError (f"`prompt` has to be of type `str` or `list` but is { type (prompt )} " )
47
49
50
+ if height % 8 != 0 or width % 8 != 0 :
51
+ raise ValueError (f"`height` and `width` have to be divisible by 8 but are { height } and { width } ." )
52
+
48
53
self .unet .to (torch_device )
49
54
self .vae .to (torch_device )
50
55
self .text_encoder .to (torch_device )
@@ -72,7 +77,7 @@ def __call__(
72
77
73
78
# get the intial random noise
74
79
latents = torch .randn (
75
- (batch_size , self .unet .in_channels , self . unet . sample_size , self . unet . sample_size ),
80
+ (batch_size , self .unet .in_channels , height // 8 , width // 8 ),
76
81
generator = generator ,
77
82
)
78
83
latents = latents .to (torch_device )
0 commit comments