Skip to content

Commit ec70ed6

Browse files
Omnigen2 model implementation. (comfyanonymous#8669)
1 parent 7a13f74 commit ec70ed6

File tree

13 files changed

+152295
-7
lines changed

13 files changed

+152295
-7
lines changed

comfy/ldm/omnigen/omnigen2.py

Lines changed: 469 additions & 0 deletions
Large diffs are not rendered by default.

comfy/model_base.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
import comfy.ldm.hidream.model
4242
import comfy.ldm.chroma.model
4343
import comfy.ldm.ace.model
44+
import comfy.ldm.omnigen.omnigen2
4445

4546
import comfy.model_management
4647
import comfy.patcher_extension
@@ -1230,3 +1231,33 @@ def extra_conds(self, **kwargs):
12301231
out['speaker_embeds'] = comfy.conds.CONDRegular(torch.zeros(noise.shape[0], 512, device=noise.device, dtype=noise.dtype))
12311232
out['lyrics_strength'] = comfy.conds.CONDConstant(kwargs.get("lyrics_strength", 1.0))
12321233
return out
1234+
1235+
class Omnigen2(BaseModel):
1236+
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
1237+
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.omnigen.omnigen2.OmniGen2Transformer2DModel)
1238+
self.memory_usage_factor_conds = ("ref_latents",)
1239+
1240+
def extra_conds(self, **kwargs):
1241+
out = super().extra_conds(**kwargs)
1242+
attention_mask = kwargs.get("attention_mask", None)
1243+
if attention_mask is not None:
1244+
if torch.numel(attention_mask) != attention_mask.sum():
1245+
out['attention_mask'] = comfy.conds.CONDRegular(attention_mask)
1246+
out['num_tokens'] = comfy.conds.CONDConstant(max(1, torch.sum(attention_mask).item()))
1247+
cross_attn = kwargs.get("cross_attn", None)
1248+
if cross_attn is not None:
1249+
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
1250+
ref_latents = kwargs.get("reference_latents", None)
1251+
if ref_latents is not None:
1252+
latents = []
1253+
for lat in ref_latents:
1254+
latents.append(self.process_latent_in(lat))
1255+
out['ref_latents'] = comfy.conds.CONDList(latents)
1256+
return out
1257+
1258+
def extra_conds_shapes(self, **kwargs):
1259+
out = {}
1260+
ref_latents = kwargs.get("reference_latents", None)
1261+
if ref_latents is not None:
1262+
out['ref_latents'] = list([1, 16, sum(map(lambda a: math.prod(a.size()), ref_latents)) // 16])
1263+
return out

comfy/model_detection.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -459,6 +459,26 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
459459

460460
return dit_config
461461

462+
if '{}time_caption_embed.timestep_embedder.linear_1.bias'.format(key_prefix) in state_dict_keys: # Omnigen2
463+
dit_config = {}
464+
dit_config["image_model"] = "omnigen2"
465+
dit_config["axes_dim_rope"] = [40, 40, 40]
466+
dit_config["axes_lens"] = [1024, 1664, 1664]
467+
dit_config["ffn_dim_multiplier"] = None
468+
dit_config["hidden_size"] = 2520
469+
dit_config["in_channels"] = 16
470+
dit_config["multiple_of"] = 256
471+
dit_config["norm_eps"] = 1e-05
472+
dit_config["num_attention_heads"] = 21
473+
dit_config["num_kv_heads"] = 7
474+
dit_config["num_layers"] = 32
475+
dit_config["num_refiner_layers"] = 2
476+
dit_config["out_channels"] = None
477+
dit_config["patch_size"] = 2
478+
dit_config["text_feat_dim"] = 2048
479+
dit_config["timestep_scale"] = 1000.0
480+
return dit_config
481+
462482
if '{}input_blocks.0.0.weight'.format(key_prefix) not in state_dict_keys:
463483
return None
464484

comfy/sd.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
import comfy.text_encoders.wan
4545
import comfy.text_encoders.hidream
4646
import comfy.text_encoders.ace
47+
import comfy.text_encoders.omnigen2
4748

4849
import comfy.model_patcher
4950
import comfy.lora
@@ -754,6 +755,7 @@ class CLIPType(Enum):
754755
HIDREAM = 14
755756
CHROMA = 15
756757
ACE = 16
758+
OMNIGEN2 = 17
757759

758760

759761
def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}):
@@ -773,6 +775,7 @@ class TEModel(Enum):
773775
LLAMA3_8 = 7
774776
T5_XXL_OLD = 8
775777
GEMMA_2_2B = 9
778+
QWEN25_3B = 10
776779

777780
def detect_te_model(sd):
778781
if "text_model.encoder.layers.30.mlp.fc1.weight" in sd:
@@ -793,6 +796,8 @@ def detect_te_model(sd):
793796
return TEModel.T5_BASE
794797
if 'model.layers.0.post_feedforward_layernorm.weight' in sd:
795798
return TEModel.GEMMA_2_2B
799+
if 'model.layers.0.self_attn.k_proj.bias' in sd:
800+
return TEModel.QWEN25_3B
796801
if "model.layers.0.post_attention_layernorm.weight" in sd:
797802
return TEModel.LLAMA3_8
798803
return None
@@ -894,6 +899,9 @@ class EmptyClass:
894899
clip_target.clip = comfy.text_encoders.hidream.hidream_clip(**llama_detect(clip_data),
895900
clip_l=False, clip_g=False, t5=False, llama=True, dtype_t5=None, t5xxl_scaled_fp8=None)
896901
clip_target.tokenizer = comfy.text_encoders.hidream.HiDreamTokenizer
902+
elif te_model == TEModel.QWEN25_3B:
903+
clip_target.clip = comfy.text_encoders.omnigen2.te(**llama_detect(clip_data))
904+
clip_target.tokenizer = comfy.text_encoders.omnigen2.Omnigen2Tokenizer
897905
else:
898906
# clip_l
899907
if clip_type == CLIPType.SD3:

comfy/sd1_clip.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -482,7 +482,8 @@ def __init__(self, tokenizer_path=None, max_length=77, pad_with_end=True, embedd
482482
if end_token is not None:
483483
self.end_token = end_token
484484
else:
485-
self.end_token = empty[0]
485+
if has_end_token:
486+
self.end_token = empty[0]
486487

487488
if pad_token is not None:
488489
self.pad_token = pad_token

comfy/supported_models.py

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import comfy.text_encoders.lumina2
1919
import comfy.text_encoders.wan
2020
import comfy.text_encoders.ace
21+
import comfy.text_encoders.omnigen2
2122

2223
from . import supported_models_base
2324
from . import latent_formats
@@ -1181,6 +1182,36 @@ def get_model(self, state_dict, prefix="", device=None):
11811182
def clip_target(self, state_dict={}):
11821183
return supported_models_base.ClipTarget(comfy.text_encoders.ace.AceT5Tokenizer, comfy.text_encoders.ace.AceT5Model)
11831184

1184-
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, Lumina2, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, Hunyuan3Dv2mini, Hunyuan3Dv2, HiDream, Chroma, ACEStep]
1185+
class Omnigen2(supported_models_base.BASE):
1186+
unet_config = {
1187+
"image_model": "omnigen2",
1188+
}
1189+
1190+
sampling_settings = {
1191+
"multiplier": 1.0,
1192+
"shift": 2.6,
1193+
}
1194+
1195+
memory_usage_factor = 1.65 #TODO
1196+
1197+
unet_extra_config = {}
1198+
latent_format = latent_formats.Flux
1199+
1200+
supported_inference_dtypes = [torch.float16, torch.bfloat16, torch.float32]
1201+
1202+
vae_key_prefix = ["vae."]
1203+
text_encoder_key_prefix = ["text_encoders."]
1204+
1205+
def get_model(self, state_dict, prefix="", device=None):
1206+
out = model_base.Omnigen2(self, device=device)
1207+
return out
1208+
1209+
def clip_target(self, state_dict={}):
1210+
pref = self.text_encoder_key_prefix[0]
1211+
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_3b.transformer.".format(pref))
1212+
return supported_models_base.ClipTarget(comfy.text_encoders.omnigen2.LuminaTokenizer, comfy.text_encoders.omnigen2.te(**hunyuan_detect))
1213+
1214+
1215+
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, Lumina2, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, Hunyuan3Dv2mini, Hunyuan3Dv2, HiDream, Chroma, ACEStep, Omnigen2]
11851216

11861217
models += [SVD_img2vid]

comfy/text_encoders/llama.py

Lines changed: 30 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,24 @@ class Llama2Config:
2424
head_dim = 128
2525
rms_norm_add = False
2626
mlp_activation = "silu"
27+
qkv_bias = False
28+
29+
@dataclass
30+
class Qwen25_3BConfig:
31+
vocab_size: int = 151936
32+
hidden_size: int = 2048
33+
intermediate_size: int = 11008
34+
num_hidden_layers: int = 36
35+
num_attention_heads: int = 16
36+
num_key_value_heads: int = 2
37+
max_position_embeddings: int = 128000
38+
rms_norm_eps: float = 1e-6
39+
rope_theta: float = 1000000.0
40+
transformer_type: str = "llama"
41+
head_dim = 128
42+
rms_norm_add = False
43+
mlp_activation = "silu"
44+
qkv_bias = True
2745

2846
@dataclass
2947
class Gemma2_2B_Config:
@@ -40,6 +58,7 @@ class Gemma2_2B_Config:
4058
head_dim = 256
4159
rms_norm_add = True
4260
mlp_activation = "gelu_pytorch_tanh"
61+
qkv_bias = False
4362

4463
class RMSNorm(nn.Module):
4564
def __init__(self, dim: int, eps: float = 1e-5, add=False, device=None, dtype=None):
@@ -98,9 +117,9 @@ def __init__(self, config: Llama2Config, device=None, dtype=None, ops: Any = Non
98117
self.inner_size = self.num_heads * self.head_dim
99118

100119
ops = ops or nn
101-
self.q_proj = ops.Linear(config.hidden_size, self.inner_size, bias=False, device=device, dtype=dtype)
102-
self.k_proj = ops.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=False, device=device, dtype=dtype)
103-
self.v_proj = ops.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=False, device=device, dtype=dtype)
120+
self.q_proj = ops.Linear(config.hidden_size, self.inner_size, bias=config.qkv_bias, device=device, dtype=dtype)
121+
self.k_proj = ops.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=config.qkv_bias, device=device, dtype=dtype)
122+
self.v_proj = ops.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=config.qkv_bias, device=device, dtype=dtype)
104123
self.o_proj = ops.Linear(self.inner_size, config.hidden_size, bias=False, device=device, dtype=dtype)
105124

106125
def forward(
@@ -320,6 +339,14 @@ def __init__(self, config_dict, dtype, device, operations):
320339
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
321340
self.dtype = dtype
322341

342+
class Qwen25_3B(BaseLlama, torch.nn.Module):
343+
def __init__(self, config_dict, dtype, device, operations):
344+
super().__init__()
345+
config = Qwen25_3BConfig(**config_dict)
346+
self.num_layers = config.num_hidden_layers
347+
348+
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
349+
self.dtype = dtype
323350

324351
class Gemma2_2B(BaseLlama, torch.nn.Module):
325352
def __init__(self, config_dict, dtype, device, operations):

comfy/text_encoders/omnigen2.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
from transformers import Qwen2Tokenizer
2+
from comfy import sd1_clip
3+
import comfy.text_encoders.llama
4+
import os
5+
6+
7+
class Qwen25_3BTokenizer(sd1_clip.SDTokenizer):
8+
def __init__(self, embedding_directory=None, tokenizer_data={}):
9+
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "qwen25_tokenizer")
10+
super().__init__(tokenizer_path, pad_with_end=False, embedding_size=2048, embedding_key='qwen25_3b', tokenizer_class=Qwen2Tokenizer, has_start_token=False, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, pad_token=151643, tokenizer_data=tokenizer_data)
11+
12+
13+
class Omnigen2Tokenizer(sd1_clip.SD1Tokenizer):
14+
def __init__(self, embedding_directory=None, tokenizer_data={}):
15+
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, name="qwen25_3b", tokenizer=Qwen25_3BTokenizer)
16+
self.llama_template = '<|im_start|>system\nYou are a helpful assistant that generates high-quality images based on user instructions.<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n'
17+
18+
def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None,**kwargs):
19+
if llama_template is None:
20+
llama_text = self.llama_template.format(text)
21+
else:
22+
llama_text = llama_template.format(text)
23+
return super().tokenize_with_weights(llama_text, return_word_ids=return_word_ids, **kwargs)
24+
25+
class Qwen25_3BModel(sd1_clip.SDClipModel):
26+
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, attention_mask=True, model_options={}):
27+
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"pad": 151643}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Qwen25_3B, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
28+
29+
30+
class Omnigen2Model(sd1_clip.SD1ClipModel):
31+
def __init__(self, device="cpu", dtype=None, model_options={}):
32+
super().__init__(device=device, dtype=dtype, name="qwen25_3b", clip_model=Qwen25_3BModel, model_options=model_options)
33+
34+
35+
def te(dtype_llama=None, llama_scaled_fp8=None):
36+
class Omnigen2TEModel_(Omnigen2Model):
37+
def __init__(self, device="cpu", dtype=None, model_options={}):
38+
if llama_scaled_fp8 is not None and "scaled_fp8" not in model_options:
39+
model_options = model_options.copy()
40+
model_options["scaled_fp8"] = llama_scaled_fp8
41+
if dtype_llama is not None:
42+
dtype = dtype_llama
43+
super().__init__(device=device, dtype=dtype, model_options=model_options)
44+
return Omnigen2TEModel_

0 commit comments

Comments
 (0)