Skip to content

Commit 4081f87

Browse files
PfannkuchensackclaudeJPPhoto
authored andcommitted
fix(flux2): Fix FLUX.2 Klein image generation quality (#8838)
* fix(flux2): Fix image quality degradation at resolutions > 1024x1024 This commit addresses severe quality degradation and artifacts when generating images larger than 1024x1024 with FLUX.2 Klein models. Root causes fixed: 1. Dynamic max_image_seq_len in scheduler (flux2_denoise.py) - Previously hardcoded to 4096 (1024x1024 only) - Now dynamically calculated based on actual resolution - Allows proper schedule shifting at all resolutions 2. Smoothed mu calculation discontinuity (sampling_utils.py) - Eliminated 40-50% mu value drop at seq_len 4300 threshold - Implemented smooth cosine interpolation (4096-4500 transition zone) - Gradual blend between low-res and high-res formulas Impact: - FLUX.2 Klein 9B: Major quality improvement at high resolutions - FLUX.2 Klein 4B: Improved quality at high resolutions - Baseline 1024x1024: Unchanged (no regression) - All generation modes: T2I and Kontext (reference images) Fixes: Community-reported quality degradation issue See: Discord discussions in #garbage-bin and #devchat Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com> * fix(flux2): Fix high-resolution quality degradation for FLUX.2 Klein Fixes grid/diamond artifacts and color loss at resolutions > 1024x1024. Root causes identified and fixed: - BN normalization was incorrectly applied to random noise input (diffusers only normalizes image latents from VAE.encode) - BN denormalization must be applied to output before VAE decode - mu parameter was resolution-dependent causing over-shifted schedules at high resolutions (now fixed to 2.02, matching ComfyUI) Changes: - Remove BN normalization on noise input (not needed for N(0,1) noise) - Preserve BN denormalization on denoised output (required for VAE) - Fix mu to constant 2.02 for all resolutions (matches ComfyUI) Tested at 2048x2048 with FLUX.2 Klein 4B * Chore Ruff --------- Co-authored-by: Claude Sonnet 4.5 <noreply@anthropic.com> Co-authored-by: Jonathan <34005131+JPPhoto@users.noreply.github.com>
1 parent 5649b60 commit 4081f87

File tree

3 files changed

+26
-45
lines changed

3 files changed

+26
-45
lines changed

invokeai/app/invocations/flux2_denoise.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -329,15 +329,13 @@ def _run_diffusion(self, context: InvocationContext) -> torch.Tensor:
329329
noise_packed = pack_flux2(noise)
330330
x = pack_flux2(x)
331331

332-
# Apply BN normalization BEFORE denoising (as per diffusers Flux2KleinPipeline)
333-
# BN normalization: y = (x - mean) / std
334-
# This transforms latents to normalized space for the transformer
335-
# IMPORTANT: Also normalize init_latents and noise for inpainting to maintain consistency
336-
if bn_mean is not None and bn_std is not None:
337-
x = self._bn_normalize(x, bn_mean, bn_std)
338-
if init_latents_packed is not None:
339-
init_latents_packed = self._bn_normalize(init_latents_packed, bn_mean, bn_std)
340-
noise_packed = self._bn_normalize(noise_packed, bn_mean, bn_std)
332+
# BN normalization for txt2img:
333+
# - DO NOT normalize random noise (it's already N(0,1) distributed)
334+
# - Diffusers only normalizes image latents from VAE (for img2img/kontext)
335+
# - Output MUST be denormalized after denoising before VAE decode
336+
#
337+
# For img2img with init_latents, we should normalize init_latents on unpacked
338+
# shape (B, 128, H/16, W/16) - this is handled by _bn_normalize_unpacked below
341339

342340
# Verify packed dimensions
343341
assert packed_h * packed_w == x.shape[1]

invokeai/app/invocations/flux2_vae_decode.py

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -57,20 +57,6 @@ def _vae_decode(self, vae_info: LoadedModel, latents: torch.Tensor) -> Image.Ima
5757
# Decode using diffusers API
5858
decoded = vae.decode(latents, return_dict=False)[0]
5959

60-
# Debug: Log decoded output statistics
61-
print(
62-
f"[FLUX.2 VAE] Decoded output: shape={decoded.shape}, "
63-
f"min={decoded.min().item():.4f}, max={decoded.max().item():.4f}, "
64-
f"mean={decoded.mean().item():.4f}"
65-
)
66-
# Check per-channel statistics to diagnose color issues
67-
for c in range(min(3, decoded.shape[1])):
68-
ch = decoded[0, c]
69-
print(
70-
f"[FLUX.2 VAE] Channel {c}: min={ch.min().item():.4f}, "
71-
f"max={ch.max().item():.4f}, mean={ch.mean().item():.4f}"
72-
)
73-
7460
# Convert from [-1, 1] to [0, 1] then to [0, 255] PIL image
7561
img = (decoded / 2 + 0.5).clamp(0, 1)
7662
img = rearrange(img[0], "c h w -> h w c")

invokeai/backend/flux2/sampling_utils.py

Lines changed: 19 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -108,33 +108,27 @@ def unpack_flux2(x: torch.Tensor, height: int, width: int) -> torch.Tensor:
108108

109109

110110
def compute_empirical_mu(image_seq_len: int, num_steps: int) -> float:
111-
"""Compute empirical mu for FLUX.2 schedule shifting.
111+
"""Compute mu for FLUX.2 schedule shifting.
112112
113-
This matches the diffusers Flux2Pipeline implementation.
114-
The mu value controls how much the schedule is shifted towards higher timesteps.
113+
Uses a fixed mu value of 2.02, matching ComfyUI's proven FLUX.2 configuration.
114+
115+
The previous implementation (from diffusers' FLUX.1 pipeline) computed mu as a
116+
linear function of image_seq_len, which produced excessively high values at
117+
high resolutions (e.g., mu=3.23 at 2048x2048). This over-shifted the sigma
118+
schedule, compressing almost all values above 0.9 and forcing the model to
119+
denoise everything in the final 1-2 steps, causing severe grid/diamond artifacts.
120+
121+
ComfyUI uses a fixed shift=2.02 for FLUX.2 Klein at all resolutions and produces
122+
artifact-free images even at 2048x2048.
115123
116124
Args:
117-
image_seq_len: Number of image tokens (packed_h * packed_w).
118-
num_steps: Number of denoising steps.
125+
image_seq_len: Number of image tokens (packed_h * packed_w). Currently unused.
126+
num_steps: Number of denoising steps. Currently unused.
119127
120128
Returns:
121-
The empirical mu value.
129+
The mu value (fixed at 2.02).
122130
"""
123-
a1, b1 = 8.73809524e-05, 1.89833333
124-
a2, b2 = 0.00016927, 0.45666666
125-
126-
if image_seq_len > 4300:
127-
mu = a2 * image_seq_len + b2
128-
return float(mu)
129-
130-
m_200 = a2 * image_seq_len + b2
131-
m_10 = a1 * image_seq_len + b1
132-
133-
a = (m_200 - m_10) / 190.0
134-
b = m_200 - 200.0 * a
135-
mu = a * num_steps + b
136-
137-
return float(mu)
131+
return 2.02
138132

139133

140134
def get_schedule_flux2(
@@ -169,11 +163,14 @@ def get_schedule_flux2(
169163

170164

171165
def generate_img_ids_flux2(h: int, w: int, batch_size: int, device: torch.device) -> torch.Tensor:
172-
"""Generate tensor of image position ids for FLUX.2.
166+
"""Generate tensor of image position ids for FLUX.2 with RoPE scaling.
173167
174168
FLUX.2 uses 4D position coordinates (T, H, W, L) for its rotary position embeddings.
175169
This is different from FLUX.1 which uses 3D coordinates.
176170
171+
RoPE Scaling: For resolutions >1536x1536, position IDs are scaled down using
172+
Position Interpolation to prevent RoPE degradation and diamond/grid artifacts.
173+
177174
IMPORTANT: Position IDs must use int64 (long) dtype like diffusers, not bfloat16.
178175
Using floating point dtype for position IDs can cause NaN in rotary embeddings.
179176

0 commit comments

Comments
 (0)