@@ -195,20 +195,50 @@ def block_wrap(args):
195
195
img = self .final_layer (img , vec ) # (N, T, patch_size ** 2 * out_channels)
196
196
return img
197
197
198
- def forward (self , x , timestep , context , y = None , guidance = None , control = None , transformer_options = {}, ** kwargs ):
198
+ def process_img (self , x , index = 0 , h_offset = 0 , w_offset = 0 ):
199
199
bs , c , h , w = x .shape
200
200
patch_size = self .patch_size
201
201
x = comfy .ldm .common_dit .pad_to_patch_size (x , (patch_size , patch_size ))
202
202
203
203
img = rearrange (x , "b c (h ph) (w pw) -> b (h w) (c ph pw)" , ph = patch_size , pw = patch_size )
204
-
205
204
h_len = ((h + (patch_size // 2 )) // patch_size )
206
205
w_len = ((w + (patch_size // 2 )) // patch_size )
206
+
207
+ h_offset = ((h_offset + (patch_size // 2 )) // patch_size )
208
+ w_offset = ((w_offset + (patch_size // 2 )) // patch_size )
209
+
207
210
img_ids = torch .zeros ((h_len , w_len , 3 ), device = x .device , dtype = x .dtype )
208
- img_ids [:, :, 1 ] = img_ids [:, :, 1 ] + torch .linspace (0 , h_len - 1 , steps = h_len , device = x .device , dtype = x .dtype ).unsqueeze (1 )
209
- img_ids [:, :, 2 ] = img_ids [:, :, 2 ] + torch .linspace (0 , w_len - 1 , steps = w_len , device = x .device , dtype = x .dtype ).unsqueeze (0 )
210
- img_ids = repeat (img_ids , "h w c -> b (h w) c" , b = bs )
211
+ img_ids [:, :, 0 ] = img_ids [:, :, 1 ] + index
212
+ img_ids [:, :, 1 ] = img_ids [:, :, 1 ] + torch .linspace (h_offset , h_len - 1 + h_offset , steps = h_len , device = x .device , dtype = x .dtype ).unsqueeze (1 )
213
+ img_ids [:, :, 2 ] = img_ids [:, :, 2 ] + torch .linspace (w_offset , w_len - 1 + w_offset , steps = w_len , device = x .device , dtype = x .dtype ).unsqueeze (0 )
214
+ return img , repeat (img_ids , "h w c -> b (h w) c" , b = bs )
215
+
216
+ def forward (self , x , timestep , context , y = None , guidance = None , ref_latents = None , control = None , transformer_options = {}, ** kwargs ):
217
+ bs , c , h_orig , w_orig = x .shape
218
+ patch_size = self .patch_size
219
+
220
+ h_len = ((h_orig + (patch_size // 2 )) // patch_size )
221
+ w_len = ((w_orig + (patch_size // 2 )) // patch_size )
222
+ img , img_ids = self .process_img (x )
223
+ img_tokens = img .shape [1 ]
224
+ if ref_latents is not None :
225
+ h = 0
226
+ w = 0
227
+ for ref in ref_latents :
228
+ h_offset = 0
229
+ w_offset = 0
230
+ if ref .shape [- 2 ] + h > ref .shape [- 1 ] + w :
231
+ w_offset = w
232
+ else :
233
+ h_offset = h
234
+
235
+ kontext , kontext_ids = self .process_img (ref , index = 1 , h_offset = h_offset , w_offset = w_offset )
236
+ img = torch .cat ([img , kontext ], dim = 1 )
237
+ img_ids = torch .cat ([img_ids , kontext_ids ], dim = 1 )
238
+ h = max (h , ref .shape [- 2 ] + h_offset )
239
+ w = max (w , ref .shape [- 1 ] + w_offset )
211
240
212
241
txt_ids = torch .zeros ((bs , context .shape [1 ], 3 ), device = x .device , dtype = x .dtype )
213
242
out = self .forward_orig (img , img_ids , context , txt_ids , timestep , y , guidance , control , transformer_options , attn_mask = kwargs .get ("attention_mask" , None ))
214
- return rearrange (out , "b (h w) (c ph pw) -> b c (h ph) (w pw)" , h = h_len , w = w_len , ph = 2 , pw = 2 )[:,:,:h ,:w ]
243
+ out = out [:, :img_tokens ]
244
+ return rearrange (out , "b (h w) (c ph pw) -> b c (h ph) (w pw)" , h = h_len , w = w_len , ph = 2 , pw = 2 )[:,:,:h_orig ,:w_orig ]
0 commit comments