Skip to content

Commit ef5266b

Browse files
Support Flux Kontext Dev model. (#8679)
1 parent a96e65d commit ef5266b

File tree

3 files changed

+97
-6
lines changed

3 files changed

+97
-6
lines changed

comfy/ldm/flux/model.py

Lines changed: 36 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -195,20 +195,50 @@ def block_wrap(args):
195195
img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
196196
return img
197197

198-
def forward(self, x, timestep, context, y=None, guidance=None, control=None, transformer_options={}, **kwargs):
198+
def process_img(self, x, index=0, h_offset=0, w_offset=0):
199199
bs, c, h, w = x.shape
200200
patch_size = self.patch_size
201201
x = comfy.ldm.common_dit.pad_to_patch_size(x, (patch_size, patch_size))
202202

203203
img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size)
204-
205204
h_len = ((h + (patch_size // 2)) // patch_size)
206205
w_len = ((w + (patch_size // 2)) // patch_size)
206+
207+
h_offset = ((h_offset + (patch_size // 2)) // patch_size)
208+
w_offset = ((w_offset + (patch_size // 2)) // patch_size)
209+
207210
img_ids = torch.zeros((h_len, w_len, 3), device=x.device, dtype=x.dtype)
208-
img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype).unsqueeze(1)
209-
img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0)
210-
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
211+
img_ids[:, :, 0] = img_ids[:, :, 1] + index
212+
img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(h_offset, h_len - 1 + h_offset, steps=h_len, device=x.device, dtype=x.dtype).unsqueeze(1)
213+
img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(w_offset, w_len - 1 + w_offset, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0)
214+
return img, repeat(img_ids, "h w c -> b (h w) c", b=bs)
215+
216+
def forward(self, x, timestep, context, y=None, guidance=None, ref_latents=None, control=None, transformer_options={}, **kwargs):
217+
bs, c, h_orig, w_orig = x.shape
218+
patch_size = self.patch_size
219+
220+
h_len = ((h_orig + (patch_size // 2)) // patch_size)
221+
w_len = ((w_orig + (patch_size // 2)) // patch_size)
222+
img, img_ids = self.process_img(x)
223+
img_tokens = img.shape[1]
224+
if ref_latents is not None:
225+
h = 0
226+
w = 0
227+
for ref in ref_latents:
228+
h_offset = 0
229+
w_offset = 0
230+
if ref.shape[-2] + h > ref.shape[-1] + w:
231+
w_offset = w
232+
else:
233+
h_offset = h
234+
235+
kontext, kontext_ids = self.process_img(ref, index=1, h_offset=h_offset, w_offset=w_offset)
236+
img = torch.cat([img, kontext], dim=1)
237+
img_ids = torch.cat([img_ids, kontext_ids], dim=1)
238+
h = max(h, ref.shape[-2] + h_offset)
239+
w = max(w, ref.shape[-1] + w_offset)
211240

212241
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
213242
out = self.forward_orig(img, img_ids, context, txt_ids, timestep, y, guidance, control, transformer_options, attn_mask=kwargs.get("attention_mask", None))
214-
return rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=2, pw=2)[:,:,:h,:w]
243+
out = out[:, :img_tokens]
244+
return rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=2, pw=2)[:,:,:h_orig,:w_orig]

comfy/model_base.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -816,6 +816,7 @@ def extra_conds(self, **kwargs):
816816
class Flux(BaseModel):
817817
def __init__(self, model_config, model_type=ModelType.FLUX, device=None, unet_model=comfy.ldm.flux.model.Flux):
818818
super().__init__(model_config, model_type, device=device, unet_model=unet_model)
819+
self.memory_usage_factor_conds = ("kontext",)
819820

820821
def concat_cond(self, **kwargs):
821822
try:
@@ -876,8 +877,23 @@ def extra_conds(self, **kwargs):
876877
guidance = kwargs.get("guidance", 3.5)
877878
if guidance is not None:
878879
out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor([guidance]))
880+
881+
ref_latents = kwargs.get("reference_latents", None)
882+
if ref_latents is not None:
883+
latents = []
884+
for lat in ref_latents:
885+
latents.append(self.process_latent_in(lat))
886+
out['ref_latents'] = comfy.conds.CONDList(latents)
879887
return out
880888

889+
def extra_conds_shapes(self, **kwargs):
890+
out = {}
891+
ref_latents = kwargs.get("reference_latents", None)
892+
if ref_latents is not None:
893+
out['ref_latents'] = list([1, 16, sum(map(lambda a: math.prod(a.size()), ref_latents)) // 16])
894+
return out
895+
896+
881897
class GenmoMochi(BaseModel):
882898
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
883899
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.genmo.joint_model.asymm_models_joint.AsymmDiTJoint)

comfy_extras/nodes_flux.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import node_helpers
2+
import comfy.utils
23

34
class CLIPTextEncodeFlux:
45
@classmethod
@@ -56,8 +57,52 @@ def append(self, conditioning):
5657
return (c, )
5758

5859

60+
PREFERED_KONTEXT_RESOLUTIONS = [
61+
(672, 1568),
62+
(688, 1504),
63+
(720, 1456),
64+
(752, 1392),
65+
(800, 1328),
66+
(832, 1248),
67+
(880, 1184),
68+
(944, 1104),
69+
(1024, 1024),
70+
(1104, 944),
71+
(1184, 880),
72+
(1248, 832),
73+
(1328, 800),
74+
(1392, 752),
75+
(1456, 720),
76+
(1504, 688),
77+
(1568, 672),
78+
]
79+
80+
81+
class FluxKontextImageScale:
82+
@classmethod
83+
def INPUT_TYPES(s):
84+
return {"required": {"image": ("IMAGE", ),
85+
},
86+
}
87+
88+
RETURN_TYPES = ("IMAGE",)
89+
FUNCTION = "scale"
90+
91+
CATEGORY = "advanced/conditioning/flux"
92+
DESCRIPTION = "This node resizes the image to one that is more optimal for flux kontext."
93+
94+
def scale(self, image):
95+
width = image.shape[2]
96+
height = image.shape[1]
97+
aspect_ratio = width / height
98+
_, width, height = min((abs(aspect_ratio - w / h), w, h) for w, h in PREFERED_KONTEXT_RESOLUTIONS)
99+
image = comfy.utils.common_upscale(image.movedim(-1, 1), width, height, "lanczos", "center").movedim(1, -1)
100+
return (image, )
101+
102+
59103
NODE_CLASS_MAPPINGS = {
60104
"CLIPTextEncodeFlux": CLIPTextEncodeFlux,
61105
"FluxGuidance": FluxGuidance,
62106
"FluxDisableGuidance": FluxDisableGuidance,
107+
"FluxKontextImageScale": FluxKontextImageScale,
63108
}

0 commit comments

Comments
 (0)