Skip to content

Commit 923747a

Browse files
committed
initial ruff pass
1 parent 8dc30b8 commit 923747a

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

60 files changed

+11216
-4364
lines changed

diffusers_helper/bucket_tools.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -63,35 +63,37 @@
6363
def find_nearest_bucket(h, w, resolution=640):
6464
# Use the provided resolution or find the closest available bucket size
6565
# print(f"find_nearest_bucket called with h={h}, w={w}, resolution={resolution}")
66-
66+
6767
# Convert resolution to int if it's not already
6868
resolution = int(resolution) if not isinstance(resolution, int) else resolution
69-
69+
7070
if resolution not in bucket_options:
7171
# Find the closest available resolution
7272
available_resolutions = list(bucket_options.keys())
73-
closest_resolution = min(available_resolutions, key=lambda x: abs(x - resolution))
73+
closest_resolution = min(
74+
available_resolutions, key=lambda x: abs(x - resolution)
75+
)
7476
# print(f"Resolution {resolution} not found in bucket options, using closest available: {closest_resolution}")
7577
resolution = closest_resolution
7678
# else:
77-
# print(f"Resolution {resolution} found in bucket options")
78-
79+
# print(f"Resolution {resolution} found in bucket options")
80+
7981
# Calculate the aspect ratio of the input image
8082
input_aspect_ratio = w / h if h > 0 else 1.0
8183
# print(f"Input aspect ratio: {input_aspect_ratio:.4f}")
82-
83-
min_diff = float('inf')
84+
85+
min_diff = float("inf")
8486
best_bucket = None
85-
87+
8688
# Find the bucket size with the closest aspect ratio to the input image
87-
for (bucket_h, bucket_w) in bucket_options[resolution]:
89+
for bucket_h, bucket_w in bucket_options[resolution]:
8890
bucket_aspect_ratio = bucket_w / bucket_h if bucket_h > 0 else 1.0
8991
# Calculate the difference in aspect ratios
9092
diff = abs(bucket_aspect_ratio - input_aspect_ratio)
9193
if diff < min_diff:
9294
min_diff = diff
9395
best_bucket = (bucket_h, bucket_w)
9496
# print(f" Checking bucket ({bucket_h}, {bucket_w}), aspect ratio={bucket_aspect_ratio:.4f}, diff={diff:.4f}, current best={best_bucket}")
95-
97+
9698
# print(f"Using resolution {resolution}, selected bucket: {best_bucket}")
9799
return best_bucket

diffusers_helper/clip_vision.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,9 @@ def hf_clip_vision_encode(image, feature_extractor, image_encoder):
66
assert image.ndim == 3 and image.shape[2] == 3
77
assert image.dtype == np.uint8
88

9-
preprocessed = feature_extractor.preprocess(images=image, return_tensors="pt").to(device=image_encoder.device, dtype=image_encoder.dtype)
9+
preprocessed = feature_extractor.preprocess(images=image, return_tensors="pt").to(
10+
device=image_encoder.device, dtype=image_encoder.dtype
11+
)
1012
image_encoder_output = image_encoder(**preprocessed)
1113

1214
return image_encoder_output

diffusers_helper/dit_common.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,21 @@
11
import torch
22
import accelerate.accelerator
33

4-
from diffusers.models.normalization import RMSNorm, LayerNorm, FP32LayerNorm, AdaLayerNormContinuous
4+
from diffusers.models.normalization import (
5+
RMSNorm,
6+
LayerNorm,
7+
FP32LayerNorm,
8+
AdaLayerNormContinuous,
9+
)
510

611

712
accelerate.accelerator.convert_outputs_to_fp32 = lambda x: x
813

914

1015
def LayerNorm_forward(self, x):
11-
return torch.nn.functional.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps).to(x)
16+
return torch.nn.functional.layer_norm(
17+
x, self.normalized_shape, self.weight, self.bias, self.eps
18+
).to(x)
1219

1320

1421
LayerNorm.forward = LayerNorm_forward

diffusers_helper/gradio/progress_bar.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
1-
progress_html = '''
1+
progress_html = """
22
<div class="loader-container">
33
<div class="loader"></div>
44
<div class="progress-container">
55
<progress value="*number*" max="100"></progress>
66
</div>
77
<span>*text*</span>
88
</div>
9-
'''
9+
"""
1010

11-
css = '''
11+
css = """
1212
.loader-container {
1313
display: flex; /* Use flex to align items horizontally */
1414
align-items: center; /* Center items vertically within the container */
@@ -75,11 +75,11 @@
7575
display: none !important;
7676
}
7777
78-
'''
78+
"""
7979

8080

8181
def make_progress_bar_html(number, text):
82-
return progress_html.replace('*number*', str(number)).replace('*text*', text)
82+
return progress_html.replace("*number*", str(number)).replace("*text*", text)
8383

8484

8585
def make_progress_bar_css():

diffusers_helper/hf_login.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,14 @@ def login(token):
88
while True:
99
try:
1010
login(token)
11-
print('HF login ok.')
11+
print("HF login ok.")
1212
break
1313
except Exception as e:
14-
print(f'HF login failed: {e}. Retrying')
14+
print(f"HF login failed: {e}. Retrying")
1515
time.sleep(0.5)
1616

1717

18-
hf_token = os.environ.get('HF_TOKEN', None)
18+
hf_token = os.environ.get("HF_TOKEN", None)
1919

2020
if hf_token is not None:
2121
login(hf_token)

diffusers_helper/hunyuan.py

Lines changed: 56 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,69 +1,87 @@
11
import torch
22

3-
from diffusers.pipelines.hunyuan_video.pipeline_hunyuan_video import DEFAULT_PROMPT_TEMPLATE
4-
from diffusers_helper.utils import crop_or_pad_yield_mask
3+
from diffusers.pipelines.hunyuan_video.pipeline_hunyuan_video import (
4+
DEFAULT_PROMPT_TEMPLATE,
5+
)
56

67

78
@torch.no_grad()
8-
def encode_prompt_conds(prompt, text_encoder, text_encoder_2, tokenizer, tokenizer_2, max_length=256):
9+
def encode_prompt_conds(
10+
prompt, text_encoder, text_encoder_2, tokenizer, tokenizer_2, max_length=256
11+
):
912
assert isinstance(prompt, str)
1013

1114
prompt = [prompt]
1215

1316
# LLAMA
14-
17+
1518
# Check if there's a custom system prompt template in settings
1619
custom_template = None
1720
try:
1821
from modules.settings import Settings
22+
1923
settings = Settings()
2024
override_system_prompt = settings.get("override_system_prompt", False)
2125
custom_template_str = settings.get("system_prompt_template")
22-
26+
2327
if override_system_prompt and custom_template_str:
2428
try:
2529
# Convert the string representation to a dictionary
2630
# Extract template and crop_start directly from the string using regex
2731
import re
28-
32+
2933
# Try to extract the template value
30-
template_match = re.search(r"['\"]template['\"]\s*:\s*['\"](.+?)['\"](?=\s*,|\s*})", custom_template_str, re.DOTALL)
31-
crop_start_match = re.search(r"['\"]crop_start['\"]\s*:\s*(\d+)", custom_template_str)
32-
34+
template_match = re.search(
35+
r"['\"]template['\"]\s*:\s*['\"](.+?)['\"](?=\s*,|\s*})",
36+
custom_template_str,
37+
re.DOTALL,
38+
)
39+
crop_start_match = re.search(
40+
r"['\"]crop_start['\"]\s*:\s*(\d+)", custom_template_str
41+
)
42+
3343
if template_match and crop_start_match:
3444
template_value = template_match.group(1)
3545
crop_start_value = int(crop_start_match.group(1))
36-
46+
3747
# Unescape any escaped characters in the template
38-
template_value = template_value.replace("\\n", "\n").replace("\\\"", "\"").replace("\\'", "'")
39-
48+
template_value = (
49+
template_value.replace("\\n", "\n")
50+
.replace('\\"', '"')
51+
.replace("\\'", "'")
52+
)
53+
4054
custom_template = {
4155
"template": template_value,
42-
"crop_start": crop_start_value
56+
"crop_start": crop_start_value,
4357
}
44-
print(f"Using custom system prompt template from settings: {custom_template}")
58+
print(
59+
f"Using custom system prompt template from settings: {custom_template}"
60+
)
4561
else:
46-
print(f"Could not extract template or crop_start from system prompt template string")
47-
print(f"Falling back to default template")
62+
print(
63+
"Could not extract template or crop_start from system prompt template string"
64+
)
65+
print("Falling back to default template")
4866
custom_template = None
4967
except Exception as e:
5068
print(f"Error parsing custom system prompt template: {e}")
51-
print(f"Falling back to default template")
69+
print("Falling back to default template")
5270
custom_template = None
5371
else:
5472
if not override_system_prompt:
55-
print(f"Override system prompt is disabled, using default template")
73+
print("Override system prompt is disabled, using default template")
5674
elif not custom_template_str:
57-
print(f"No custom system prompt template found in settings")
75+
print("No custom system prompt template found in settings")
5876
custom_template = None
5977
except Exception as e:
6078
print(f"Error loading settings: {e}")
61-
print(f"Falling back to default template")
79+
print("Falling back to default template")
6280
custom_template = None
63-
81+
6482
# Use custom template if available, otherwise use default
6583
template = custom_template if custom_template else DEFAULT_PROMPT_TEMPLATE
66-
84+
6785
prompt_llama = [template["template"].format(p) for p in prompt]
6886
crop_start = template["crop_start"]
6987

@@ -105,7 +123,9 @@ def encode_prompt_conds(prompt, text_encoder, text_encoder_2, tokenizer, tokeniz
105123
return_length=False,
106124
return_tensors="pt",
107125
).input_ids
108-
clip_l_pooler = text_encoder_2(clip_l_input_ids.to(text_encoder_2.device), output_hidden_states=False).pooler_output
126+
clip_l_pooler = text_encoder_2(
127+
clip_l_input_ids.to(text_encoder_2.device), output_hidden_states=False
128+
).pooler_output
109129

110130
return llama_vec, clip_l_pooler
111131

@@ -128,15 +148,21 @@ def vae_decode_fake(latents):
128148
[-0.2315, -0.1920, -0.1355],
129149
[-0.0270, 0.0401, -0.0821],
130150
[-0.0616, -0.0997, -0.0727],
131-
[0.0249, -0.0469, -0.1703]
151+
[0.0249, -0.0469, -0.1703],
132152
] # From comfyui
133153

134154
latent_rgb_factors_bias = [0.0259, -0.0192, -0.0761]
135155

136-
weight = torch.tensor(latent_rgb_factors, device=latents.device, dtype=latents.dtype).transpose(0, 1)[:, :, None, None, None]
137-
bias = torch.tensor(latent_rgb_factors_bias, device=latents.device, dtype=latents.dtype)
156+
weight = torch.tensor(
157+
latent_rgb_factors, device=latents.device, dtype=latents.dtype
158+
).transpose(0, 1)[:, :, None, None, None]
159+
bias = torch.tensor(
160+
latent_rgb_factors_bias, device=latents.device, dtype=latents.dtype
161+
)
138162

139-
images = torch.nn.functional.conv3d(latents, weight, bias=bias, stride=1, padding=0, dilation=1, groups=1)
163+
images = torch.nn.functional.conv3d(
164+
latents, weight, bias=bias, stride=1, padding=0, dilation=1, groups=1
165+
)
140166
images = images.clamp(0.0, 1.0)
141167

142168
return images
@@ -158,6 +184,8 @@ def vae_decode(latents, vae, image_mode=False):
158184

159185
@torch.no_grad()
160186
def vae_encode(image, vae):
161-
latents = vae.encode(image.to(device=vae.device, dtype=vae.dtype)).latent_dist.sample()
187+
latents = vae.encode(
188+
image.to(device=vae.device, dtype=vae.dtype)
189+
).latent_dist.sample()
162190
latents = latents * vae.config.scaling_factor
163191
return latents

0 commit comments

Comments
 (0)