Skip to content

[pull] master from comfyanonymous:master #140

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 26, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
469 changes: 469 additions & 0 deletions comfy/ldm/omnigen/omnigen2.py

Large diffs are not rendered by default.

31 changes: 31 additions & 0 deletions comfy/model_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
import comfy.ldm.hidream.model
import comfy.ldm.chroma.model
import comfy.ldm.ace.model
import comfy.ldm.omnigen.omnigen2

import comfy.model_management
import comfy.patcher_extension
Expand Down Expand Up @@ -1230,3 +1231,33 @@ def extra_conds(self, **kwargs):
out['speaker_embeds'] = comfy.conds.CONDRegular(torch.zeros(noise.shape[0], 512, device=noise.device, dtype=noise.dtype))
out['lyrics_strength'] = comfy.conds.CONDConstant(kwargs.get("lyrics_strength", 1.0))
return out

class Omnigen2(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.omnigen.omnigen2.OmniGen2Transformer2DModel)
self.memory_usage_factor_conds = ("ref_latents",)

def extra_conds(self, **kwargs):
out = super().extra_conds(**kwargs)
attention_mask = kwargs.get("attention_mask", None)
if attention_mask is not None:
if torch.numel(attention_mask) != attention_mask.sum():
out['attention_mask'] = comfy.conds.CONDRegular(attention_mask)
out['num_tokens'] = comfy.conds.CONDConstant(max(1, torch.sum(attention_mask).item()))
cross_attn = kwargs.get("cross_attn", None)
if cross_attn is not None:
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
ref_latents = kwargs.get("reference_latents", None)
if ref_latents is not None:
latents = []
for lat in ref_latents:
latents.append(self.process_latent_in(lat))
out['ref_latents'] = comfy.conds.CONDList(latents)
return out

def extra_conds_shapes(self, **kwargs):
out = {}
ref_latents = kwargs.get("reference_latents", None)
if ref_latents is not None:
out['ref_latents'] = list([1, 16, sum(map(lambda a: math.prod(a.size()), ref_latents)) // 16])
return out
20 changes: 20 additions & 0 deletions comfy/model_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -459,6 +459,26 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):

return dit_config

if '{}time_caption_embed.timestep_embedder.linear_1.bias'.format(key_prefix) in state_dict_keys: # Omnigen2
dit_config = {}
dit_config["image_model"] = "omnigen2"
dit_config["axes_dim_rope"] = [40, 40, 40]
dit_config["axes_lens"] = [1024, 1664, 1664]
dit_config["ffn_dim_multiplier"] = None
dit_config["hidden_size"] = 2520
dit_config["in_channels"] = 16
dit_config["multiple_of"] = 256
dit_config["norm_eps"] = 1e-05
dit_config["num_attention_heads"] = 21
dit_config["num_kv_heads"] = 7
dit_config["num_layers"] = 32
dit_config["num_refiner_layers"] = 2
dit_config["out_channels"] = None
dit_config["patch_size"] = 2
dit_config["text_feat_dim"] = 2048
dit_config["timestep_scale"] = 1000.0
return dit_config

if '{}input_blocks.0.0.weight'.format(key_prefix) not in state_dict_keys:
return None

Expand Down
8 changes: 8 additions & 0 deletions comfy/sd.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
import comfy.text_encoders.wan
import comfy.text_encoders.hidream
import comfy.text_encoders.ace
import comfy.text_encoders.omnigen2

import comfy.model_patcher
import comfy.lora
Expand Down Expand Up @@ -754,6 +755,7 @@ class CLIPType(Enum):
HIDREAM = 14
CHROMA = 15
ACE = 16
OMNIGEN2 = 17


def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}):
Expand All @@ -773,6 +775,7 @@ class TEModel(Enum):
LLAMA3_8 = 7
T5_XXL_OLD = 8
GEMMA_2_2B = 9
QWEN25_3B = 10

def detect_te_model(sd):
if "text_model.encoder.layers.30.mlp.fc1.weight" in sd:
Expand All @@ -793,6 +796,8 @@ def detect_te_model(sd):
return TEModel.T5_BASE
if 'model.layers.0.post_feedforward_layernorm.weight' in sd:
return TEModel.GEMMA_2_2B
if 'model.layers.0.self_attn.k_proj.bias' in sd:
return TEModel.QWEN25_3B
if "model.layers.0.post_attention_layernorm.weight" in sd:
return TEModel.LLAMA3_8
return None
Expand Down Expand Up @@ -894,6 +899,9 @@ class EmptyClass:
clip_target.clip = comfy.text_encoders.hidream.hidream_clip(**llama_detect(clip_data),
clip_l=False, clip_g=False, t5=False, llama=True, dtype_t5=None, t5xxl_scaled_fp8=None)
clip_target.tokenizer = comfy.text_encoders.hidream.HiDreamTokenizer
elif te_model == TEModel.QWEN25_3B:
clip_target.clip = comfy.text_encoders.omnigen2.te(**llama_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.omnigen2.Omnigen2Tokenizer
else:
# clip_l
if clip_type == CLIPType.SD3:
Expand Down
3 changes: 2 additions & 1 deletion comfy/sd1_clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -482,7 +482,8 @@ def __init__(self, tokenizer_path=None, max_length=77, pad_with_end=True, embedd
if end_token is not None:
self.end_token = end_token
else:
self.end_token = empty[0]
if has_end_token:
self.end_token = empty[0]

if pad_token is not None:
self.pad_token = pad_token
Expand Down
33 changes: 32 additions & 1 deletion comfy/supported_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import comfy.text_encoders.lumina2
import comfy.text_encoders.wan
import comfy.text_encoders.ace
import comfy.text_encoders.omnigen2

from . import supported_models_base
from . import latent_formats
Expand Down Expand Up @@ -1181,6 +1182,36 @@ def get_model(self, state_dict, prefix="", device=None):
def clip_target(self, state_dict={}):
return supported_models_base.ClipTarget(comfy.text_encoders.ace.AceT5Tokenizer, comfy.text_encoders.ace.AceT5Model)

models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, Lumina2, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, Hunyuan3Dv2mini, Hunyuan3Dv2, HiDream, Chroma, ACEStep]
class Omnigen2(supported_models_base.BASE):
unet_config = {
"image_model": "omnigen2",
}

sampling_settings = {
"multiplier": 1.0,
"shift": 2.6,
}

memory_usage_factor = 1.65 #TODO

unet_extra_config = {}
latent_format = latent_formats.Flux

supported_inference_dtypes = [torch.float16, torch.bfloat16, torch.float32]

vae_key_prefix = ["vae."]
text_encoder_key_prefix = ["text_encoders."]

def get_model(self, state_dict, prefix="", device=None):
out = model_base.Omnigen2(self, device=device)
return out

def clip_target(self, state_dict={}):
pref = self.text_encoder_key_prefix[0]
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_3b.transformer.".format(pref))
return supported_models_base.ClipTarget(comfy.text_encoders.omnigen2.LuminaTokenizer, comfy.text_encoders.omnigen2.te(**hunyuan_detect))


models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, Lumina2, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, Hunyuan3Dv2mini, Hunyuan3Dv2, HiDream, Chroma, ACEStep, Omnigen2]

models += [SVD_img2vid]
33 changes: 30 additions & 3 deletions comfy/text_encoders/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,24 @@ class Llama2Config:
head_dim = 128
rms_norm_add = False
mlp_activation = "silu"
qkv_bias = False

@dataclass
class Qwen25_3BConfig:
vocab_size: int = 151936
hidden_size: int = 2048
intermediate_size: int = 11008
num_hidden_layers: int = 36
num_attention_heads: int = 16
num_key_value_heads: int = 2
max_position_embeddings: int = 128000
rms_norm_eps: float = 1e-6
rope_theta: float = 1000000.0
transformer_type: str = "llama"
head_dim = 128
rms_norm_add = False
mlp_activation = "silu"
qkv_bias = True

@dataclass
class Gemma2_2B_Config:
Expand All @@ -40,6 +58,7 @@ class Gemma2_2B_Config:
head_dim = 256
rms_norm_add = True
mlp_activation = "gelu_pytorch_tanh"
qkv_bias = False

class RMSNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e-5, add=False, device=None, dtype=None):
Expand Down Expand Up @@ -98,9 +117,9 @@ def __init__(self, config: Llama2Config, device=None, dtype=None, ops: Any = Non
self.inner_size = self.num_heads * self.head_dim

ops = ops or nn
self.q_proj = ops.Linear(config.hidden_size, self.inner_size, bias=False, device=device, dtype=dtype)
self.k_proj = ops.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=False, device=device, dtype=dtype)
self.v_proj = ops.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=False, device=device, dtype=dtype)
self.q_proj = ops.Linear(config.hidden_size, self.inner_size, bias=config.qkv_bias, device=device, dtype=dtype)
self.k_proj = ops.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=config.qkv_bias, device=device, dtype=dtype)
self.v_proj = ops.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=config.qkv_bias, device=device, dtype=dtype)
self.o_proj = ops.Linear(self.inner_size, config.hidden_size, bias=False, device=device, dtype=dtype)

def forward(
Expand Down Expand Up @@ -320,6 +339,14 @@ def __init__(self, config_dict, dtype, device, operations):
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
self.dtype = dtype

class Qwen25_3B(BaseLlama, torch.nn.Module):
def __init__(self, config_dict, dtype, device, operations):
super().__init__()
config = Qwen25_3BConfig(**config_dict)
self.num_layers = config.num_hidden_layers

self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
self.dtype = dtype

class Gemma2_2B(BaseLlama, torch.nn.Module):
def __init__(self, config_dict, dtype, device, operations):
Expand Down
44 changes: 44 additions & 0 deletions comfy/text_encoders/omnigen2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
from transformers import Qwen2Tokenizer
from comfy import sd1_clip
import comfy.text_encoders.llama
import os


class Qwen25_3BTokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "qwen25_tokenizer")
super().__init__(tokenizer_path, pad_with_end=False, embedding_size=2048, embedding_key='qwen25_3b', tokenizer_class=Qwen2Tokenizer, has_start_token=False, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, pad_token=151643, tokenizer_data=tokenizer_data)


class Omnigen2Tokenizer(sd1_clip.SD1Tokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, name="qwen25_3b", tokenizer=Qwen25_3BTokenizer)
self.llama_template = '<|im_start|>system\nYou are a helpful assistant that generates high-quality images based on user instructions.<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n'

def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None,**kwargs):
if llama_template is None:
llama_text = self.llama_template.format(text)
else:
llama_text = llama_template.format(text)
return super().tokenize_with_weights(llama_text, return_word_ids=return_word_ids, **kwargs)

class Qwen25_3BModel(sd1_clip.SDClipModel):
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, attention_mask=True, model_options={}):
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"pad": 151643}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Qwen25_3B, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)


class Omnigen2Model(sd1_clip.SD1ClipModel):
def __init__(self, device="cpu", dtype=None, model_options={}):
super().__init__(device=device, dtype=dtype, name="qwen25_3b", clip_model=Qwen25_3BModel, model_options=model_options)


def te(dtype_llama=None, llama_scaled_fp8=None):
class Omnigen2TEModel_(Omnigen2Model):
def __init__(self, device="cpu", dtype=None, model_options={}):
if llama_scaled_fp8 is not None and "scaled_fp8" not in model_options:
model_options = model_options.copy()
model_options["scaled_fp8"] = llama_scaled_fp8
if dtype_llama is not None:
dtype = dtype_llama
super().__init__(device=device, dtype=dtype, model_options=model_options)
return Omnigen2TEModel_
Loading
Loading