Skip to content

Chroma as a FLUX.1 variant #11566

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
135 changes: 79 additions & 56 deletions src/diffusers/loaders/single_file_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2137,9 +2137,18 @@ def convert_flux_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
converted_state_dict = {}
keys = list(checkpoint.keys())

variant = "chroma" if "distilled_guidance_layer.in_proj.weight" in checkpoint else "flux"

for k in keys:
if "model.diffusion_model." in k:
checkpoint[k.replace("model.diffusion_model.", "")] = checkpoint.pop(k)
if variant == "chroma" and "distilled_guidance_layer." in k:
new_key = k
if k.startswith("distilled_guidance_layer.norms"):
new_key = k.replace(".scale", ".weight")
elif k.startswith("distilled_guidance_layer.layer"):
new_key = k.replace("in_layer", "linear_1").replace("out_layer", "linear_2")
converted_state_dict[new_key] = checkpoint.pop(k)

num_layers = list(set(int(k.split(".", 2)[1]) for k in checkpoint if "double_blocks." in k))[-1] + 1 # noqa: C401
num_single_layers = list(set(int(k.split(".", 2)[1]) for k in checkpoint if "single_blocks." in k))[-1] + 1 # noqa: C401
Expand All @@ -2153,40 +2162,49 @@ def swap_scale_shift(weight):
new_weight = torch.cat([scale, shift], dim=0)
return new_weight

## time_text_embed.timestep_embedder <- time_in
converted_state_dict["time_text_embed.timestep_embedder.linear_1.weight"] = checkpoint.pop(
"time_in.in_layer.weight"
)
converted_state_dict["time_text_embed.timestep_embedder.linear_1.bias"] = checkpoint.pop("time_in.in_layer.bias")
converted_state_dict["time_text_embed.timestep_embedder.linear_2.weight"] = checkpoint.pop(
"time_in.out_layer.weight"
)
converted_state_dict["time_text_embed.timestep_embedder.linear_2.bias"] = checkpoint.pop("time_in.out_layer.bias")

## time_text_embed.text_embedder <- vector_in
converted_state_dict["time_text_embed.text_embedder.linear_1.weight"] = checkpoint.pop("vector_in.in_layer.weight")
converted_state_dict["time_text_embed.text_embedder.linear_1.bias"] = checkpoint.pop("vector_in.in_layer.bias")
converted_state_dict["time_text_embed.text_embedder.linear_2.weight"] = checkpoint.pop(
"vector_in.out_layer.weight"
)
converted_state_dict["time_text_embed.text_embedder.linear_2.bias"] = checkpoint.pop("vector_in.out_layer.bias")

# guidance
has_guidance = any("guidance" in k for k in checkpoint)
if has_guidance:
converted_state_dict["time_text_embed.guidance_embedder.linear_1.weight"] = checkpoint.pop(
"guidance_in.in_layer.weight"
if variant == "flux":
## time_text_embed.timestep_embedder <- time_in
converted_state_dict["time_text_embed.timestep_embedder.linear_1.weight"] = checkpoint.pop(
"time_in.in_layer.weight"
)
converted_state_dict["time_text_embed.guidance_embedder.linear_1.bias"] = checkpoint.pop(
"guidance_in.in_layer.bias"
converted_state_dict["time_text_embed.timestep_embedder.linear_1.bias"] = checkpoint.pop(
"time_in.in_layer.bias"
)
converted_state_dict["time_text_embed.guidance_embedder.linear_2.weight"] = checkpoint.pop(
"guidance_in.out_layer.weight"
converted_state_dict["time_text_embed.timestep_embedder.linear_2.weight"] = checkpoint.pop(
"time_in.out_layer.weight"
)
converted_state_dict["time_text_embed.guidance_embedder.linear_2.bias"] = checkpoint.pop(
"guidance_in.out_layer.bias"
converted_state_dict["time_text_embed.timestep_embedder.linear_2.bias"] = checkpoint.pop(
"time_in.out_layer.bias"
)

## time_text_embed.text_embedder <- vector_in
converted_state_dict["time_text_embed.text_embedder.linear_1.weight"] = checkpoint.pop(
"vector_in.in_layer.weight"
)
converted_state_dict["time_text_embed.text_embedder.linear_1.bias"] = checkpoint.pop("vector_in.in_layer.bias")
converted_state_dict["time_text_embed.text_embedder.linear_2.weight"] = checkpoint.pop(
"vector_in.out_layer.weight"
)
converted_state_dict["time_text_embed.text_embedder.linear_2.bias"] = checkpoint.pop(
"vector_in.out_layer.bias"
)

# guidance
has_guidance = any("guidance" in k for k in checkpoint)
if has_guidance:
converted_state_dict["time_text_embed.guidance_embedder.linear_1.weight"] = checkpoint.pop(
"guidance_in.in_layer.weight"
)
converted_state_dict["time_text_embed.guidance_embedder.linear_1.bias"] = checkpoint.pop(
"guidance_in.in_layer.bias"
)
converted_state_dict["time_text_embed.guidance_embedder.linear_2.weight"] = checkpoint.pop(
"guidance_in.out_layer.weight"
)
converted_state_dict["time_text_embed.guidance_embedder.linear_2.bias"] = checkpoint.pop(
"guidance_in.out_layer.bias"
)

# context_embedder
converted_state_dict["context_embedder.weight"] = checkpoint.pop("txt_in.weight")
converted_state_dict["context_embedder.bias"] = checkpoint.pop("txt_in.bias")
Expand All @@ -2199,20 +2217,21 @@ def swap_scale_shift(weight):
for i in range(num_layers):
block_prefix = f"transformer_blocks.{i}."
# norms.
## norm1
converted_state_dict[f"{block_prefix}norm1.linear.weight"] = checkpoint.pop(
f"double_blocks.{i}.img_mod.lin.weight"
)
converted_state_dict[f"{block_prefix}norm1.linear.bias"] = checkpoint.pop(
f"double_blocks.{i}.img_mod.lin.bias"
)
## norm1_context
converted_state_dict[f"{block_prefix}norm1_context.linear.weight"] = checkpoint.pop(
f"double_blocks.{i}.txt_mod.lin.weight"
)
converted_state_dict[f"{block_prefix}norm1_context.linear.bias"] = checkpoint.pop(
f"double_blocks.{i}.txt_mod.lin.bias"
)
if variant == "flux":
## norm1
converted_state_dict[f"{block_prefix}norm1.linear.weight"] = checkpoint.pop(
f"double_blocks.{i}.img_mod.lin.weight"
)
converted_state_dict[f"{block_prefix}norm1.linear.bias"] = checkpoint.pop(
f"double_blocks.{i}.img_mod.lin.bias"
)
## norm1_context
converted_state_dict[f"{block_prefix}norm1_context.linear.weight"] = checkpoint.pop(
f"double_blocks.{i}.txt_mod.lin.weight"
)
converted_state_dict[f"{block_prefix}norm1_context.linear.bias"] = checkpoint.pop(
f"double_blocks.{i}.txt_mod.lin.bias"
)
# Q, K, V
sample_q, sample_k, sample_v = torch.chunk(checkpoint.pop(f"double_blocks.{i}.img_attn.qkv.weight"), 3, dim=0)
context_q, context_k, context_v = torch.chunk(
Expand Down Expand Up @@ -2285,13 +2304,15 @@ def swap_scale_shift(weight):
# single transformer blocks
for i in range(num_single_layers):
block_prefix = f"single_transformer_blocks.{i}."
# norm.linear <- single_blocks.0.modulation.lin
converted_state_dict[f"{block_prefix}norm.linear.weight"] = checkpoint.pop(
f"single_blocks.{i}.modulation.lin.weight"
)
converted_state_dict[f"{block_prefix}norm.linear.bias"] = checkpoint.pop(
f"single_blocks.{i}.modulation.lin.bias"
)

if variant == "flux":
# norm.linear <- single_blocks.0.modulation.lin
converted_state_dict[f"{block_prefix}norm.linear.weight"] = checkpoint.pop(
f"single_blocks.{i}.modulation.lin.weight"
)
converted_state_dict[f"{block_prefix}norm.linear.bias"] = checkpoint.pop(
f"single_blocks.{i}.modulation.lin.bias"
)
# Q, K, V, mlp
mlp_hidden_dim = int(inner_dim * mlp_ratio)
split_size = (inner_dim, inner_dim, inner_dim, mlp_hidden_dim)
Expand Down Expand Up @@ -2320,12 +2341,14 @@ def swap_scale_shift(weight):

converted_state_dict["proj_out.weight"] = checkpoint.pop("final_layer.linear.weight")
converted_state_dict["proj_out.bias"] = checkpoint.pop("final_layer.linear.bias")
converted_state_dict["norm_out.linear.weight"] = swap_scale_shift(
checkpoint.pop("final_layer.adaLN_modulation.1.weight")
)
converted_state_dict["norm_out.linear.bias"] = swap_scale_shift(
checkpoint.pop("final_layer.adaLN_modulation.1.bias")
)

if variant == "flux":
converted_state_dict["norm_out.linear.weight"] = swap_scale_shift(
checkpoint.pop("final_layer.adaLN_modulation.1.weight")
)
converted_state_dict["norm_out.linear.bias"] = swap_scale_shift(
checkpoint.pop("final_layer.adaLN_modulation.1.bias")
)

return converted_state_dict

Expand Down
52 changes: 50 additions & 2 deletions src/diffusers/models/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def get_timestep_embedding(
downscale_freq_shift: float = 1,
scale: float = 1,
max_period: int = 10000,
):
) -> torch.Tensor:
"""
This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.

Expand Down Expand Up @@ -1327,7 +1327,7 @@ def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shif
self.downscale_freq_shift = downscale_freq_shift
self.scale = scale

def forward(self, timesteps):
def forward(self, timesteps: torch.Tensor) -> torch.Tensor:
t_emb = get_timestep_embedding(
timesteps,
self.num_channels,
Expand Down Expand Up @@ -1637,6 +1637,35 @@ def forward(self, timestep, guidance, pooled_projection):
return conditioning


class CombinedTimestepTextProjChromaEmbeddings(nn.Module):
def __init__(self, factor: int, hidden_dim: int, out_dim: int, n_layers: int, embedding_dim: int):
super().__init__()

self.time_proj = Timesteps(num_channels=factor, flip_sin_to_cos=True, downscale_freq_shift=0)
self.guidance_proj = Timesteps(num_channels=factor, flip_sin_to_cos=True, downscale_freq_shift=0)

self.register_buffer(
"mod_proj",
get_timestep_embedding(torch.arange(out_dim)*1000, 2 * factor, flip_sin_to_cos=True, downscale_freq_shift=0, ),
persistent=False,
)

def forward(
self, timestep: torch.Tensor, guidance: Optional[torch.Tensor], pooled_projections: torch.Tensor
) -> torch.Tensor:
mod_index_length = self.mod_proj.shape[0]
timesteps_proj = self.time_proj(timestep).to(dtype=timestep.dtype)
guidance_proj = self.guidance_proj(torch.tensor([0])).to(dtype=timestep.dtype, device=timestep.device)

mod_proj = self.mod_proj.to(dtype=timesteps_proj.dtype, device=timesteps_proj.device)
timestep_guidance = (
torch.cat([timesteps_proj, guidance_proj], dim=1).unsqueeze(1).repeat(1, mod_index_length, 1)
)
input_vec = torch.cat([timestep_guidance, mod_proj.unsqueeze(0)], dim=-1)

return input_vec


class CogView3CombinedTimestepSizeEmbeddings(nn.Module):
def __init__(self, embedding_dim: int, condition_dim: int, pooled_projection_dim: int, timesteps_dim: int = 256):
super().__init__()
Expand Down Expand Up @@ -2230,6 +2259,25 @@ def forward(self, caption):
return hidden_states


class ChromaApproximator(nn.Module):
def __init__(self, in_dim: int, out_dim: int, hidden_dim: int, n_layers: int = 5):
super().__init__()
self.in_proj = nn.Linear(in_dim, hidden_dim, bias=True)
self.layers = nn.ModuleList(
[PixArtAlphaTextProjection(hidden_dim, hidden_dim, act_fn="silu") for _ in range(n_layers)]
)
self.norms = nn.ModuleList([nn.RMSNorm(hidden_dim) for _ in range(n_layers)])
self.out_proj = nn.Linear(hidden_dim, out_dim)

def forward(self, x):
x = self.in_proj(x)

for layer, norms in zip(self.layers, self.norms):
x = x + layer(norms(x))

return self.out_proj(x)


class IPAdapterPlusImageProjectionBlock(nn.Module):
def __init__(
self,
Expand Down
113 changes: 113 additions & 0 deletions src/diffusers/models/normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,46 @@ def forward(
return x, gate_msa, shift_mlp, scale_mlp, gate_mlp


class AdaLayerNormZeroPruned(nn.Module):
r"""
Norm layer adaptive layer norm zero (adaLN-Zero).

Parameters:
embedding_dim (`int`): The size of each embedding vector.
num_embeddings (`int`): The size of the embeddings dictionary.
"""

def __init__(self, embedding_dim: int, num_embeddings: Optional[int] = None, norm_type="layer_norm", bias=True):
super().__init__()
if num_embeddings is not None:
self.emb = CombinedTimestepLabelEmbeddings(num_embeddings, embedding_dim)
else:
self.emb = None

if norm_type == "layer_norm":
self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6)
elif norm_type == "fp32_layer_norm":
self.norm = FP32LayerNorm(embedding_dim, elementwise_affine=False, bias=False)
else:
raise ValueError(
f"Unsupported `norm_type` ({norm_type}) provided. Supported ones are: 'layer_norm', 'fp32_layer_norm'."
)

def forward(
self,
x: torch.Tensor,
timestep: Optional[torch.Tensor] = None,
class_labels: Optional[torch.LongTensor] = None,
hidden_dtype: Optional[torch.dtype] = None,
emb: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
if self.emb is not None:
emb = self.emb(timestep, class_labels, hidden_dtype=hidden_dtype)
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.squeeze(0).chunk(6, dim=0)
x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None]
return x, gate_msa, shift_mlp, scale_mlp, gate_mlp


class AdaLayerNormZeroSingle(nn.Module):
r"""
Norm layer adaptive layer norm zero (adaLN-Zero).
Expand Down Expand Up @@ -203,6 +243,35 @@ def forward(
return x, gate_msa


class AdaLayerNormZeroSinglePruned(nn.Module):
r"""
Norm layer adaptive layer norm zero (adaLN-Zero).

Parameters:
embedding_dim (`int`): The size of each embedding vector.
num_embeddings (`int`): The size of the embeddings dictionary.
"""

def __init__(self, embedding_dim: int, norm_type="layer_norm", bias=True):
super().__init__()

if norm_type == "layer_norm":
self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6)
else:
raise ValueError(
f"Unsupported `norm_type` ({norm_type}) provided. Supported ones are: 'layer_norm', 'fp32_layer_norm'."
)

def forward(
self,
x: torch.Tensor,
emb: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
shift_msa, scale_msa, gate_msa = emb.squeeze(0).chunk(3, dim=0)
x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None]
return x, gate_msa


class LuminaRMSNormZero(nn.Module):
"""
Norm layer adaptive RMS normalization zero.
Expand Down Expand Up @@ -305,6 +374,50 @@ def forward(self, x: torch.Tensor, emb: torch.Tensor) -> torch.Tensor:
return x


class AdaLayerNormContinuousPruned(nn.Module):
r"""
Adaptive normalization layer with a norm layer (layer_norm or rms_norm).

Args:
embedding_dim (`int`): Embedding dimension to use during projection.
conditioning_embedding_dim (`int`): Dimension of the input condition.
elementwise_affine (`bool`, defaults to `True`):
Boolean flag to denote if affine transformation should be applied.
eps (`float`, defaults to 1e-5): Epsilon factor.
bias (`bias`, defaults to `True`): Boolean flag to denote if bias should be use.
norm_type (`str`, defaults to `"layer_norm"`):
Normalization layer to use. Values supported: "layer_norm", "rms_norm".
"""

def __init__(
self,
embedding_dim: int,
conditioning_embedding_dim: int,
# NOTE: It is a bit weird that the norm layer can be configured to have scale and shift parameters
# because the output is immediately scaled and shifted by the projected conditioning embeddings.
# Note that AdaLayerNorm does not let the norm layer have scale and shift parameters.
# However, this is how it was implemented in the original code, and it's rather likely you should
# set `elementwise_affine` to False.
elementwise_affine=True,
eps=1e-5,
bias=True,
norm_type="layer_norm",
):
super().__init__()
if norm_type == "layer_norm":
self.norm = LayerNorm(embedding_dim, eps, elementwise_affine, bias)
elif norm_type == "rms_norm":
self.norm = RMSNorm(embedding_dim, eps, elementwise_affine)
else:
raise ValueError(f"unknown norm_type {norm_type}")

def forward(self, x: torch.Tensor, emb: torch.Tensor) -> torch.Tensor:
# convert back to the original dtype in case `conditioning_embedding`` is upcasted to float32 (needed for hunyuanDiT)
shift, scale = torch.chunk(emb.squeeze(0).to(x.dtype), 2, dim=0)
x = self.norm(x) * (1 + scale)[:, None, :] + shift[:, None, :]
return x


class AdaLayerNormContinuous(nn.Module):
r"""
Adaptive normalization layer with a norm layer (layer_norm or rms_norm).
Expand Down
Loading