1
1
import torch
2
2
3
- from diffusers .pipelines .hunyuan_video .pipeline_hunyuan_video import DEFAULT_PROMPT_TEMPLATE
4
- from diffusers_helper .utils import crop_or_pad_yield_mask
3
+ from diffusers .pipelines .hunyuan_video .pipeline_hunyuan_video import (
4
+ DEFAULT_PROMPT_TEMPLATE ,
5
+ )
5
6
6
7
7
8
@torch .no_grad ()
8
- def encode_prompt_conds (prompt , text_encoder , text_encoder_2 , tokenizer , tokenizer_2 , max_length = 256 ):
9
+ def encode_prompt_conds (
10
+ prompt , text_encoder , text_encoder_2 , tokenizer , tokenizer_2 , max_length = 256
11
+ ):
9
12
assert isinstance (prompt , str )
10
13
11
14
prompt = [prompt ]
12
15
13
16
# LLAMA
14
-
17
+
15
18
# Check if there's a custom system prompt template in settings
16
19
custom_template = None
17
20
try :
18
21
from modules .settings import Settings
22
+
19
23
settings = Settings ()
20
24
override_system_prompt = settings .get ("override_system_prompt" , False )
21
25
custom_template_str = settings .get ("system_prompt_template" )
22
-
26
+
23
27
if override_system_prompt and custom_template_str :
24
28
try :
25
29
# Convert the string representation to a dictionary
26
30
# Extract template and crop_start directly from the string using regex
27
31
import re
28
-
32
+
29
33
# Try to extract the template value
30
- template_match = re .search (r"['\"]template['\"]\s*:\s*['\"](.+?)['\"](?=\s*,|\s*})" , custom_template_str , re .DOTALL )
31
- crop_start_match = re .search (r"['\"]crop_start['\"]\s*:\s*(\d+)" , custom_template_str )
32
-
34
+ template_match = re .search (
35
+ r"['\"]template['\"]\s*:\s*['\"](.+?)['\"](?=\s*,|\s*})" ,
36
+ custom_template_str ,
37
+ re .DOTALL ,
38
+ )
39
+ crop_start_match = re .search (
40
+ r"['\"]crop_start['\"]\s*:\s*(\d+)" , custom_template_str
41
+ )
42
+
33
43
if template_match and crop_start_match :
34
44
template_value = template_match .group (1 )
35
45
crop_start_value = int (crop_start_match .group (1 ))
36
-
46
+
37
47
# Unescape any escaped characters in the template
38
- template_value = template_value .replace ("\\ n" , "\n " ).replace ("\\ \" " , "\" " ).replace ("\\ '" , "'" )
39
-
48
+ template_value = (
49
+ template_value .replace ("\\ n" , "\n " )
50
+ .replace ('\\ "' , '"' )
51
+ .replace ("\\ '" , "'" )
52
+ )
53
+
40
54
custom_template = {
41
55
"template" : template_value ,
42
- "crop_start" : crop_start_value
56
+ "crop_start" : crop_start_value ,
43
57
}
44
- print (f"Using custom system prompt template from settings: { custom_template } " )
58
+ print (
59
+ f"Using custom system prompt template from settings: { custom_template } "
60
+ )
45
61
else :
46
- print (f"Could not extract template or crop_start from system prompt template string" )
47
- print (f"Falling back to default template" )
62
+ print (
63
+ "Could not extract template or crop_start from system prompt template string"
64
+ )
65
+ print ("Falling back to default template" )
48
66
custom_template = None
49
67
except Exception as e :
50
68
print (f"Error parsing custom system prompt template: { e } " )
51
- print (f "Falling back to default template" )
69
+ print ("Falling back to default template" )
52
70
custom_template = None
53
71
else :
54
72
if not override_system_prompt :
55
- print (f "Override system prompt is disabled, using default template" )
73
+ print ("Override system prompt is disabled, using default template" )
56
74
elif not custom_template_str :
57
- print (f "No custom system prompt template found in settings" )
75
+ print ("No custom system prompt template found in settings" )
58
76
custom_template = None
59
77
except Exception as e :
60
78
print (f"Error loading settings: { e } " )
61
- print (f "Falling back to default template" )
79
+ print ("Falling back to default template" )
62
80
custom_template = None
63
-
81
+
64
82
# Use custom template if available, otherwise use default
65
83
template = custom_template if custom_template else DEFAULT_PROMPT_TEMPLATE
66
-
84
+
67
85
prompt_llama = [template ["template" ].format (p ) for p in prompt ]
68
86
crop_start = template ["crop_start" ]
69
87
@@ -105,7 +123,9 @@ def encode_prompt_conds(prompt, text_encoder, text_encoder_2, tokenizer, tokeniz
105
123
return_length = False ,
106
124
return_tensors = "pt" ,
107
125
).input_ids
108
- clip_l_pooler = text_encoder_2 (clip_l_input_ids .to (text_encoder_2 .device ), output_hidden_states = False ).pooler_output
126
+ clip_l_pooler = text_encoder_2 (
127
+ clip_l_input_ids .to (text_encoder_2 .device ), output_hidden_states = False
128
+ ).pooler_output
109
129
110
130
return llama_vec , clip_l_pooler
111
131
@@ -128,15 +148,21 @@ def vae_decode_fake(latents):
128
148
[- 0.2315 , - 0.1920 , - 0.1355 ],
129
149
[- 0.0270 , 0.0401 , - 0.0821 ],
130
150
[- 0.0616 , - 0.0997 , - 0.0727 ],
131
- [0.0249 , - 0.0469 , - 0.1703 ]
151
+ [0.0249 , - 0.0469 , - 0.1703 ],
132
152
] # From comfyui
133
153
134
154
latent_rgb_factors_bias = [0.0259 , - 0.0192 , - 0.0761 ]
135
155
136
- weight = torch .tensor (latent_rgb_factors , device = latents .device , dtype = latents .dtype ).transpose (0 , 1 )[:, :, None , None , None ]
137
- bias = torch .tensor (latent_rgb_factors_bias , device = latents .device , dtype = latents .dtype )
156
+ weight = torch .tensor (
157
+ latent_rgb_factors , device = latents .device , dtype = latents .dtype
158
+ ).transpose (0 , 1 )[:, :, None , None , None ]
159
+ bias = torch .tensor (
160
+ latent_rgb_factors_bias , device = latents .device , dtype = latents .dtype
161
+ )
138
162
139
- images = torch .nn .functional .conv3d (latents , weight , bias = bias , stride = 1 , padding = 0 , dilation = 1 , groups = 1 )
163
+ images = torch .nn .functional .conv3d (
164
+ latents , weight , bias = bias , stride = 1 , padding = 0 , dilation = 1 , groups = 1
165
+ )
140
166
images = images .clamp (0.0 , 1.0 )
141
167
142
168
return images
@@ -158,6 +184,8 @@ def vae_decode(latents, vae, image_mode=False):
158
184
159
185
@torch .no_grad ()
160
186
def vae_encode (image , vae ):
161
- latents = vae .encode (image .to (device = vae .device , dtype = vae .dtype )).latent_dist .sample ()
187
+ latents = vae .encode (
188
+ image .to (device = vae .device , dtype = vae .dtype )
189
+ ).latent_dist .sample ()
162
190
latents = latents * vae .config .scaling_factor
163
191
return latents
0 commit comments