From 084dd5579c75e7a1803510461c6d97a69c0777b3 Mon Sep 17 00:00:00 2001 From: arledesma Date: Wed, 16 Jul 2025 11:54:39 -0500 Subject: [PATCH 01/17] Import kohya-ss LoRA loader Brings support from kohya-ss implementations in their FramePack-LoraReady fork as well as their contributions to FramePack-eichi (that do not seem to be correctly attributed to kohya-ss in thei primary eichi repo) https://gist.github.com/kohya-ss/fa4b7ae7119c10850ae7d70c90a59277 https://github.com/kohya-ss/FramePack-LoRAReady/blob/3613b67366b0bbf4a719c85ba9c3954e075e0e57 https://github.com/kohya-ss/FramePack-eichi/blob/4085a24baf08d6f1c25e2de06f376c3fc132a470 --- .../lora_utils_kohya_ss/__init__.py | 7 + .../fp8_optimization_utils.py | 407 ++++++++++++++++++ .../lora_utils_kohya_ss/lora_check_helper.py | 96 +++++ .../lora_utils_kohya_ss/lora_loader.py | 67 +++ .../lora_utils_kohya_ss/lora_utils.py | 376 ++++++++++++++++ .../lora_utils_kohya_ss/safetensors_utils.py | 105 +++++ 6 files changed, 1058 insertions(+) create mode 100644 diffusers_helper/lora_utils_kohya_ss/__init__.py create mode 100644 diffusers_helper/lora_utils_kohya_ss/fp8_optimization_utils.py create mode 100644 diffusers_helper/lora_utils_kohya_ss/lora_check_helper.py create mode 100644 diffusers_helper/lora_utils_kohya_ss/lora_loader.py create mode 100644 diffusers_helper/lora_utils_kohya_ss/lora_utils.py create mode 100644 diffusers_helper/lora_utils_kohya_ss/safetensors_utils.py diff --git a/diffusers_helper/lora_utils_kohya_ss/__init__.py b/diffusers_helper/lora_utils_kohya_ss/__init__.py new file mode 100644 index 00000000..37f8e2d8 --- /dev/null +++ b/diffusers_helper/lora_utils_kohya_ss/__init__.py @@ -0,0 +1,7 @@ +from .lora_utils import merge_lora_to_state_dict +from .lora_loader import load_and_apply_lora + +__all__ = [ + "merge_lora_to_state_dict", + "load_and_apply_lora", +] diff --git a/diffusers_helper/lora_utils_kohya_ss/fp8_optimization_utils.py b/diffusers_helper/lora_utils_kohya_ss/fp8_optimization_utils.py new file mode 100644 index 00000000..78b4bba3 --- /dev/null +++ b/diffusers_helper/lora_utils_kohya_ss/fp8_optimization_utils.py @@ -0,0 +1,407 @@ +# Original https://github.com/kohya-ss/FramePack-LoRAReady/blob/3613b67366b0bbf4a719c85ba9c3954e075e0e57/utils/fp8_optimization_utils.py +# Updates for eichi https://github.com/kohya-ss/FramePack-eichi/blob/4085a24baf08d6f1c25e2de06f376c3fc132a470/webui/lora_utils/fp8_optimization_utils.py + +import torch.nn.functional as F +import torch +import torch.nn as nn +import os + +from tqdm import tqdm +from typing import Literal, cast + +# Flags to track whether a warning message was displayed. +FP8_E4M3_WARNING_SHOWN = False +FP8_DIMENSIONS_WARNING_SHOWN = False + +# cSpell: ignore maxval, dequantized + + +def calculate_fp8_maxval(exp_bits=4, mantissa_bits=3, sign_bits=1) -> float: + """ + Calculates the maximum value that can be expressed in FP8 format. + The default is E4M3 format (4-bit exponent, 3-bit mantissa, 1-bit sign). + + Args: + exp_bits (int): Number of bits in exponent + mantissa_bits (int): Number of bits in mantissa + sign_bits (int): Number of bits in sign (0 or 1) + + Returns: + float: Maximum value that can be expressed in FP8 format. + """ + assert ( + exp_bits + mantissa_bits + sign_bits == 8 + ), f"The total number of bits for FP8 must be 8, but got {exp_bits + mantissa_bits + sign_bits} bits (E{exp_bits} M{mantissa_bits} S{sign_bits})" + + # Calculate the exponent bias + bias: int = 2 ** (exp_bits - 1) - 1 + + # Calculate the maximum mantissa value + # Maybe this can be an int? + mantissa_max: float = 1.0 + for i in range(mantissa_bits - 1): + mantissa_max += 2 ** -(i + 1) + + # Calculate the maximum value + max_value = mantissa_max * (2 ** (2**exp_bits - 1 - bias)) + + return cast(float, max_value) + + +def quantize_tensor_to_fp8( + tensor: torch.Tensor, + scale: float | torch.Tensor, + exp_bits: int = 4, + mantissa_bits: int = 3, + sign_bits: int = 1, + max_value: float | None = None, + min_value: float | None = None, +): + """ + Quantize the tensor to FP8 format + + Args: + tensor (torch.Tensor): The tensor to quantize. + scale (float or torch.Tensor): Scale factor. + exp_bits (int): Number of bits in exponent. + mantissa_bits (int): Number of bits in mantissa. + sign_bits (int): Number of bits in sign. + max_value (float, optional): Maximum value (automatically calculated if None). + min_value (float, optional): Minimum value (automatically calculated if None). + + Returns: + tuple: (quantized tensor, scale factor) + """ + # スケーリングされたテンソルを作成 + scaled_tensor = tensor / scale + + # FP8パラメータを計算 + bias: int = 2 ** (exp_bits - 1) - 1 + + if max_value is None: + # 最大値と最小値を計算 + max_value = calculate_fp8_maxval(exp_bits, mantissa_bits, sign_bits) + min_value = -max_value if sign_bits > 0 else 0.0 + + # テンソルを範囲内に制限 + clamped_tensor = torch.clamp(scaled_tensor, min_value, max_value) + + # 量子化プロセス + abs_values = torch.abs(clamped_tensor) + nonzero_mask = abs_values > 0 + + # logFスケールを計算(非ゼロ要素のみ) + log_scales = torch.zeros_like(clamped_tensor) + if nonzero_mask.any(): + log_scales[nonzero_mask] = torch.floor( + torch.log2(abs_values[nonzero_mask]) + bias + ).detach() + + # logスケールを制限し、量子化係数を計算 + log_scales = torch.clamp(log_scales, min=1.0) + quant_factor = 2.0 ** (log_scales - mantissa_bits - bias) + + # 量子化と逆量子化 + quantized = torch.round(clamped_tensor / quant_factor) * quant_factor + + return quantized, scale + + +def optimize_state_dict_with_fp8_on_the_fly( + model_files, + calc_device, + target_layer_keys=None, + exclude_layer_keys=None, + exp_bits: Literal[4, 5] = 4, + mantissa_bits: Literal[2, 3] = 3, + move_to_device=False, + weight_hook=None, +): + """ + Optimize linear layer weights in model state dictionary to FP8 format + + Args: + model_files (list): List of model files to optimize (updates as they are read) + calc_device (str): Device to quantize tensors to + target_layer_keys (list, optional): Pattern of layer keys to target (all linear layers if None) + exclude_layer_keys (list, optional): Pattern of layer keys to exclude + exp_bits (int): Number of exponent bits. Valid values are 4 or 5. If 4 then mantissa_bits must be 3, if 5 then mantissa_bits must be 2. + mantissa_bits (int): Number of mantissa bits + move_to_device (bool): Whether to move optimized tensors to compute device + weight_hook (callable, optional): Weight hook function (None if not used), applied to all weights before FP8 optimization, regardless of whether they are FP8 optimized or not. + + Returns: + dict: FP8 optimized state dictionary + """ + # Select FP8 data type + if exp_bits == 4 and mantissa_bits == 3: + fp8_dtype = torch.float8_e4m3fn + elif exp_bits == 5 and mantissa_bits == 2: + fp8_dtype = torch.float8_e5m2 + else: + raise ValueError(f"Unsupported FP8 formats: E{exp_bits} M{mantissa_bits}") + + # Calculate the maximum value of FP8 + max_value = calculate_fp8_maxval(exp_bits, mantissa_bits) + min_value = -max_value # この関数は符号付きFP8のみサポート + + # Create an optimized state dictionary + def is_target_key(key): + # Check if weight key matches include pattern and doesn't match exclude pattern + is_target = ( + target_layer_keys is None + or any(pattern in key for pattern in target_layer_keys) + ) and key.endswith(".weight") + is_excluded = exclude_layer_keys is not None and any( + pattern in key for pattern in exclude_layer_keys + ) + is_target = is_target and not is_excluded + return is_target + + # Optimized layer counter + optimized_count = 0 + + from diffusers_helper.lora_utils_kohya_ss.safetensors_utils import ( + MemoryEfficientSafeOpen, + ) + + state_dict = {} + + # Process each model file using MemoryEfficientSafeOpen + for model_file in model_files: + with MemoryEfficientSafeOpen(model_file) as f: + keys = f.keys() + for key in tqdm( + keys, desc=f"Loading {os.path.basename(model_file)}", unit="key" + ): + value = f.get_tensor(key) + if weight_hook is not None: + # If a weight hook is specified, apply the hook + value = weight_hook(key, value) + + if not is_target_key(key): + state_dict[key] = value + continue + + # Preserve original device and data type + original_device = value.device + original_dtype = value.dtype + + # Move to specified compute device + if calc_device is not None: + value = value.to(calc_device) + + # Calculate the scale factor + scale = torch.max(torch.abs(value.flatten())) / max_value + + # Quantize weights to FP8 + quantized_weight, _ = quantize_tensor_to_fp8( + value, scale, exp_bits, mantissa_bits, 1, max_value, min_value + ) + + # Use original key for weight, new key for scale + fp8_key = key + scale_key = key.replace(".weight", ".scale_weight") + + # Convert to FP8 data type + quantized_weight = quantized_weight.to(fp8_dtype) + + # If no device is specified, revert to original device + if not move_to_device: + quantized_weight = quantized_weight.to(original_device) + + # Create a scale tensor + scale_tensor = torch.tensor( + [scale], dtype=original_dtype, device=quantized_weight.device + ) + + # Add to state dictionary + state_dict[fp8_key] = quantized_weight + state_dict[scale_key] = scale_tensor + + optimized_count += 1 + + # Periodically free up memory + if calc_device is not None and optimized_count % 10 == 0: + torch.cuda.empty_cache() + + print(f"Optimized Linear Layer Count: {optimized_count}") + return state_dict + + +def fp8_linear_forward_patch( + self: nn.Linear, + x: torch.Tensor, + use_scaled_mm: bool = False, + max_value: float | None = None, +): + """ + Patched forward method for linear layers with FP8 weights + + Args: + self: Linear layer instance + x (torch.Tensor): Input tensor + use_scaled_mm (bool): Whether to use scaled_mm for FP8 linear layers (requires SM 8.9+, RTX 40 series) + max_value (float): Maximum FP8 quantization (if None, no quantization is applied to the input tensor) + + Returns: + torch.Tensor: Result of the linear transformation + """ + if use_scaled_mm: + # If you use scaled_mm (only works with RTX >= 40 series GPUs) + input_dtype = x.dtype + original_weight_dtype = cast(torch.dtype, self.scale_weight.dtype) + weight_dtype = self.weight.dtype + target_dtype = torch.float8_e5m2 + + # Falls back to normal method if not E4M3FN + # scaled_mm is only compatible with E4M3FN format even in FP8, so cannot be used with other formats + if weight_dtype != torch.float8_e4m3fn: + # may be noisy + print( + f"WARNING: scaled_mm requires FP8 E4M3FN format but {weight_dtype} was detected, falling back to regular method." + ) + + # fallback to normal method + return fp8_linear_forward_patch(self, x, False, max_value) + + # Check the dimensions of the input tensor + # scaled_mm expects a 3-dimensional tensor (batch_size, seq_len, hidden_dim), otherwise it will not work + if x.ndim != 3: + # may be noisy + print( + f"Warning: scaled_mm expects 3D input but found {x.ndim} dimensions. Falling back to normal method." + ) + + # fallback to normal method + return fp8_linear_forward_patch(self, x, False, max_value) + + if max_value is None: + # No input quantization, use scale of 1.0 + scale_x = torch.tensor(1.0, dtype=torch.float32, device=x.device) + else: + # Calculate the scale factor of the input tensor + scale_x = (torch.max(torch.abs(x.flatten())) / max_value).to(torch.float32) + + # Quantize input tensors to FP8 (can be memory intensive) + x, _ = quantize_tensor_to_fp8(x, scale_x, 5, 2, 1, max_value, -max_value) + + original_shape = x.shape + # Change the shape of the tensor to 2D + x = x.reshape(-1, x.shape[2]).to(target_dtype) + + # Transpose the weights + weight = self.weight.t() + scale_weight = cast(torch.Tensor, self.scale_weight.to(torch.float32)) + + # separate processing with and without biasing + if self.bias is not None: + # If biased then float32 is not supported + o = torch._scaled_mm( + x, + weight, + out_dtype=original_weight_dtype, + bias=self.bias, + scale_a=scale_x, + scale_b=scale_weight, + ) + else: + o = torch._scaled_mm( + x, weight, out_dtype=input_dtype, scale_a=scale_x, scale_b=scale_weight + ) + + # Return original shape + return o.reshape(original_shape[0], original_shape[1], -1).to(input_dtype) + else: + # calculate by inverse quantization of weights + original_dtype = cast(torch.dtype, self.scale_weight.dtype) + dequantized_weight = self.weight.to(original_dtype) * cast( + torch.Tensor, self.scale_weight + ) + + # Perform a linear transformation + if self.bias is not None: + output = F.linear(x, dequantized_weight, self.bias) + else: + output = F.linear(x, dequantized_weight) + + return output + + +def apply_fp8_monkey_patch(model, optimized_state_dict, use_scaled_mm=False): + """ + Apply monkey patching to a model using FP8 optimized state dict. + + Args: + model (nn.Module): Model instance to patch + optimized_state_dict (dict): FP8 optimized state dict + use_scaled_mm (bool): Use scaled_mm for FP8 Linear layers, requires SM 8.9+ (RTX 40 series) + + Returns: + nn.Module: The patched model (same instance, modified in-place) + """ + + # Calculate FP8 float8_e5m2 max value + max_value = None + + # Find all scale keys to identify FP8-optimized layers + scale_keys = [k for k in optimized_state_dict.keys() if k.endswith(".scale_weight")] + + # Enumerate patched layers + patched_module_paths = set() + for scale_key in scale_keys: + # Extract module path from scale key (remove .scale_weight) + module_path = scale_key.rsplit(".scale_weight", 1)[0] + patched_module_paths.add(module_path) + + patched_count = 0 + + # Apply monkey patch to each layer with FP8 weights + for name, module in model.named_modules(): + # Check if this module has a corresponding scale_weight + has_scale = name in patched_module_paths + + # Apply patch if it's a Linear layer with FP8 scale + if isinstance(module, nn.Linear) and has_scale: + # register the scale_weight as a buffer to load the state_dict + module.register_buffer( + "scale_weight", torch.tensor(1.0, dtype=module.weight.dtype) + ) + + # Create a new forward method with the patched version. + def new_forward(self, x): + return fp8_linear_forward_patch(self, x, use_scaled_mm, max_value) + + # Bind method to module + module.forward = new_forward.__get__(module, type(module)) + + patched_count += 1 + + print(f"Number of monkey-patched Linear layers: {patched_count}") + setattr(model, "_fp8_optimized", True) + return model + + +def check_fp8_support(): + """ + Checks if the current PyTorch version supports FP8 formats and scaled_mm. + + Returns: + tuple[bool, bool, bool]: (E4M3 support, E5M2 support, scaled_mm support) + """ + + has_e4m3 = hasattr(torch, "float8_e4m3fn") + has_e5m2 = hasattr(torch, "float8_e5m2") + + has_scaled_mm = hasattr(torch, "_scaled_mm") + + if has_e4m3 and has_e5m2: + print("FP8 support detected: E4M3 and E5M2 formats available") + if has_scaled_mm: + print( + "scaled_mm support detected: FP8 acceleration possible on RTX >=40 series GPUs" + ) + else: + print("WARNING: No FP8 support detected. PyTorch 2.1 or higher required") + + return has_e4m3, has_e5m2, has_scaled_mm diff --git a/diffusers_helper/lora_utils_kohya_ss/lora_check_helper.py b/diffusers_helper/lora_utils_kohya_ss/lora_check_helper.py new file mode 100644 index 00000000..62c2ea4b --- /dev/null +++ b/diffusers_helper/lora_utils_kohya_ss/lora_check_helper.py @@ -0,0 +1,96 @@ +# https://github.com/kohya-ss/FramePack-eichi/blob/4085a24baf08d6f1c25e2de06f376c3fc132a470/webui/lora_utils/lora_check_helper.py +# FramePack-eichi LoRA Check Helper +# +# LoRAの適用状態確認のための機能を提供します。 + +import torch + + +def check_lora_applied(model): + """ + Check if the model has LoRA applied. + This function checks if LoRA is applied to the model either through a direct flag or by checking for LoRA hooks in the model's modules. + + Args: + model: Target model to check. + + Returns: + (bool, str): If the model has a '_lora_applied' flag, it returns True and the source as 'direct_application'. If LoRA hooks are found in the model's named modules, it returns True and the source as 'hooks'. Otherwise, it returns False and 'none'. + """ + + has_flag = hasattr(model, "_lora_applied") and model._lora_applied + + if has_flag: + return True, "direct_application" + + # Check the named modules of the model for LoRA hooks + has_hooks = False + for name, module in model.named_modules(): + if hasattr(module, "_lora_hooks"): + has_hooks = True + break + + if has_hooks: + return True, "hooks" + + return False, "none" + + +def analyze_lora_application(model): + """ + モデルのLoRA適用率と影響を詳細に分析 + + Args: + model: 分析対象のモデル + + Returns: + dict: 分析結果の辞書 + """ + total_params = 0 + lora_affected_params = 0 + + # トータルパラメータ数とLoRAの影響を受けるパラメータ数をカウント + for name, module in model.named_modules(): + if hasattr(module, "weight") and isinstance(module.weight, torch.Tensor): + param_count = module.weight.numel() + total_params += param_count + + # LoRA適用されたモジュールかチェック + if hasattr(module, "_lora_hooks") or hasattr(module, "_lora_applied"): + lora_affected_params += param_count + + application_rate = 0.0 + if total_params > 0: + application_rate = lora_affected_params / total_params * 100.0 + + return { + "total_params": total_params, + "lora_affected_params": lora_affected_params, + "application_rate": application_rate, + "has_lora": lora_affected_params > 0, + } + + +def print_lora_status(model): + """ + モデルのLoRA適用状況を出力 + + Args: + model: 出力対象のモデル + """ + has_lora, source = check_lora_applied(model) + + if has_lora: + print("LoRA status: applied") + print(f"LoRA model: {source}") + + # 詳細な分析 + analysis = analyze_lora_application(model) + application_rate = analysis["application_rate"] + + print( + f'LoRA conditions: {analysis["lora_affected_params"]}/{analysis["total_params"]} parameters ({application_rate:.2f}%)' + ) + else: + print("LoRA status: Not applicable") + print("LoRA model: not applicable") diff --git a/diffusers_helper/lora_utils_kohya_ss/lora_loader.py b/diffusers_helper/lora_utils_kohya_ss/lora_loader.py new file mode 100644 index 00000000..d8b1bbdb --- /dev/null +++ b/diffusers_helper/lora_utils_kohya_ss/lora_loader.py @@ -0,0 +1,67 @@ +# https://github.com/kohya-ss/FramePack-eichi/blob/4085a24baf08d6f1c25e2de06f376c3fc132a470/webui/lora_utils/lora_loader.py +# FramePack-eichi LoRA Loader +# +# LoRAモデルの読み込みと適用のための機能を提供します。 + +import os +import torch +from tqdm import tqdm +from .lora_utils import merge_lora_to_state_dict + + +def load_and_apply_lora( + model_files: list[str], + lora_paths: list[str], + lora_scales=None, + fp8_enabled=False, + device=None, +) -> dict[str, torch.Tensor]: + """ + LoRA重みをロードして重みに適用する + + Args: + model_files: List of model files to load + lora_paths: List of LoRA file paths + lora_scales: List if LoRA weight scales + fp8_enabled: Whether to enable FP8 optimization. Default is False. + device: Device used for loading the model. If None, defaults to CPU. + + Returns: + State dictionary with LoRA weights applied. + """ + if lora_paths is None: + lora_paths = [] + + if device is None: + device = torch.device("cpu") # CPU fall back + + for lora_path in lora_paths: + if not os.path.exists(lora_path): + raise FileNotFoundError(f"LoRA file not found: {lora_path}") + + if lora_scales is None: + lora_scales = [0.8] * len(lora_paths) + if len(lora_scales) > len(lora_paths): + lora_scales = lora_scales[: len(lora_paths)] + if len(lora_scales) < len(lora_paths): + lora_scales += [0.8] * (len(lora_paths) - len(lora_scales)) + + for lora_path, lora_scale in zip(lora_paths, lora_scales): + print(f"LoRA loading: {os.path.basename(lora_path)} (scale: {lora_scale})") + + print(f"Model architecture: HunyuanVideo") + + # Merge the LoRA weighs into the state dictionary + merged_state_dict = merge_lora_to_state_dict( + model_files, lora_paths, lora_scales, fp8_enabled, device + ) + + print(f"LoRA loading complete") + return merged_state_dict + + +def check_lora_applied(model): + from lora_check_helper import check_lora_applied as check_lora_applied_helper + + # passthrough to the helper function - this function should be removed + return check_lora_applied_helper(model) diff --git a/diffusers_helper/lora_utils_kohya_ss/lora_utils.py b/diffusers_helper/lora_utils_kohya_ss/lora_utils.py new file mode 100644 index 00000000..bbf20c91 --- /dev/null +++ b/diffusers_helper/lora_utils_kohya_ss/lora_utils.py @@ -0,0 +1,376 @@ +# https://github.com/kohya-ss/FramePack-eichi/blob/4085a24baf08d6f1c25e2de06f376c3fc132a470/webui/lora_utils/lora_utils.py +import os +from typing import Callable +import torch +from safetensors.torch import load_file +from tqdm import tqdm + +# cSpell: ignore hunyuan, unet, musubi, framepack, conved + + +def merge_lora_to_state_dict( + model_files: list[str], + lora_files: list[str], + multipliers: list[float], + fp8_enabled: bool, + device: torch.device, +) -> dict[str, torch.Tensor]: + """ + Merge LoRA weights into the state dict of a model. + """ + list_of_lora_sd = [] + for lora_file in lora_files: + # Load LoRA safetensors file + lora_sd = load_file(lora_file) + + # Check the format of the LoRA file + keys = list(lora_sd.keys()) + if keys[0].startswith("lora_unet_"): + print("Musubi Tuner LoRA detected") + else: + transformer_prefixes = [ + "diffusion_model", + "transformer", + ] # to ignore Text Encoder modules + lora_suffix = None + prefix = None + for key in keys: + if lora_suffix is None and "lora_A" in key: + lora_suffix = "lora_A" + if prefix is None: + pfx = key.split(".")[0] + if pfx in transformer_prefixes: + prefix = pfx + if lora_suffix is not None and prefix is not None: + break + + if lora_suffix == "lora_A" and prefix is not None: + print("Diffusion-pipe (?) LoRA detected") + lora_sd = convert_from_diffusion_pipe_or_something( + lora_sd, "lora_unet_" + ) + + else: + print(f"LoRA file format not recognized: {os.path.basename(lora_file)}") + lora_sd = None + + if lora_sd is not None: + # Check LoRA is for FramePack or for HunyuanVideo + is_hunyuan = False + for key in lora_sd.keys(): + if "double_blocks" in key or "single_blocks" in key: + is_hunyuan = True + break + if is_hunyuan: + print(f"HunyuanVideo LoRA detected, converting to FramePack format") + lora_sd = convert_hunyuan_to_framepack(lora_sd) + + if lora_sd is not None: + list_of_lora_sd.append(lora_sd) + + if len(list_of_lora_sd) == 0: + # no LoRA files found, just load the model + return load_safetensors_with_fp8_optimization( + model_files, fp8_enabled, device, weight_hook=None + ) + + return load_safetensors_with_lora_and_fp8( + model_files, list_of_lora_sd, multipliers, fp8_enabled, device + ) + + +def convert_from_diffusion_pipe_or_something( + lora_sd: dict[str, torch.Tensor], prefix: str +) -> dict[str, torch.Tensor]: + """ + Convert LoRA weights to the format used by the diffusion pipeline to Musubi Tuner. + Copy from Musubi Tuner repo. + """ + # convert from diffusers(?) to default LoRA + # Diffusers format: {"diffusion_model.module.name.lora_A.weight": weight, "diffusion_model.module.name.lora_B.weight": weight, ...} + # default LoRA format: {"prefix_module_name.lora_down.weight": weight, "prefix_module_name.lora_up.weight": weight, ...} + + # note: Diffusers has no alpha, so alpha is set to rank + new_weights_sd = {} + lora_dims = {} + for key, weight in lora_sd.items(): + diffusers_prefix, key_body = key.split(".", 1) + if diffusers_prefix != "diffusion_model" and diffusers_prefix != "transformer": + print(f"unexpected key: {key} in diffusers format") + continue + + new_key = ( + f"{prefix}{key_body}".replace(".", "_") + .replace("_lora_A_", ".lora_down.") + .replace("_lora_B_", ".lora_up.") + ) + new_weights_sd[new_key] = weight + + lora_name = new_key.split(".")[0] # before first dot + if lora_name not in lora_dims and "lora_down" in new_key: + lora_dims[lora_name] = weight.shape[0] + + # add alpha with rank + for lora_name, dim in lora_dims.items(): + new_weights_sd[f"{lora_name}.alpha"] = torch.tensor(dim) + + return new_weights_sd + + +def load_safetensors_with_lora_and_fp8( + model_files: list[str], + list_of_lora_sd: list[dict[str, torch.Tensor]], + multipliers: list[float], + fp8_optimization: bool, + device: torch.device, +) -> dict[str, torch.Tensor]: + """ + Merge LoRA weights into the state dict of a model with fp8 optimization if needed. + """ + list_of_lora_weight_keys = [] + for lora_sd in list_of_lora_sd: + lora_weight_keys = set(lora_sd.keys()) + list_of_lora_weight_keys.append(lora_weight_keys) + + # Merge LoRA weights into the state dict + print(f"Merging LoRA weights into state dict. multiplier: {multipliers}") + + # make hook for LoRA merging + def weight_hook(model_weight_key, model_weight): + nonlocal list_of_lora_weight_keys, list_of_lora_sd, multipliers + + if not model_weight_key.endswith(".weight"): + return model_weight + + original_device = model_weight.device + if original_device != device: + model_weight = model_weight.to(device) # to make calculation faster + + for lora_weight_keys, lora_sd, multiplier in zip( + list_of_lora_weight_keys, list_of_lora_sd, multipliers + ): + # check if this weight has LoRA weights + lora_name = model_weight_key.rsplit(".", 1)[0] # remove trailing ".weight" + lora_name = "lora_unet_" + lora_name.replace(".", "_") + down_key = lora_name + ".lora_down.weight" + up_key = lora_name + ".lora_up.weight" + alpha_key = lora_name + ".alpha" + if down_key not in lora_weight_keys or up_key not in lora_weight_keys: + continue + + # get LoRA weights + down_weight = lora_sd[down_key] + up_weight = lora_sd[up_key] + + dim = down_weight.size()[0] + alpha = lora_sd.get(alpha_key, dim) + scale = alpha / dim + + down_weight = down_weight.to(device) + up_weight = up_weight.to(device) + + # W <- W + U * D + if len(model_weight.size()) == 2: + # linear + if len(up_weight.size()) == 4: # use linear projection mismatch + up_weight = up_weight.squeeze(3).squeeze(2) + down_weight = down_weight.squeeze(3).squeeze(2) + model_weight = ( + model_weight + multiplier * (up_weight @ down_weight) * scale + ) + elif down_weight.size()[2:4] == (1, 1): + # conv2d 1x1 + model_weight = ( + model_weight + + multiplier + * ( + up_weight.squeeze(3).squeeze(2) + @ down_weight.squeeze(3).squeeze(2) + ) + .unsqueeze(2) + .unsqueeze(3) + * scale + ) + else: + # conv2d 3x3 + conved = torch.nn.functional.conv2d( + down_weight.permute(1, 0, 2, 3), up_weight + ).permute(1, 0, 2, 3) + + model_weight = model_weight + multiplier * conved * scale + + # remove LoRA keys from set + lora_weight_keys.remove(down_key) + lora_weight_keys.remove(up_key) + if alpha_key in lora_weight_keys: + lora_weight_keys.remove(alpha_key) + + model_weight = model_weight.to(original_device) # move back to original device + return model_weight + + state_dict = load_safetensors_with_fp8_optimization( + model_files, fp8_optimization, device, weight_hook=weight_hook + ) + + for lora_weight_keys in list_of_lora_weight_keys: + # check if all LoRA keys are used + if len(lora_weight_keys) > 0: + # if there are still LoRA keys left, it means they are not used in the model + # this is a warning, not an error + print(f'Warning: not all LoRA keys are used: {", ".join(lora_weight_keys)}') + + return state_dict + + +# cSpell: ignore QKV, QKVM + + +def convert_hunyuan_to_framepack( + lora_sd: dict[str, torch.Tensor], +) -> dict[str, torch.Tensor]: + """ + Convert HunyuanVideo LoRA weights to FramePack format. + """ + new_lora_sd = {} + for key, weight in lora_sd.items(): + # if key.startswith("lora_unet_"): + # # hack? remove prefix from musubi tuner format + # key = key.replace("lora_unet_", "") + if "double_blocks" in key: + # print(f"Converting double_blocks HunyuanVideo LoRA key: {key}") + key = key.replace("double_blocks", "transformer_blocks") + key = key.replace("img_mod_linear", "norm1_linear") + key = key.replace("img_attn_qkv", "attn_to_QKV") # split later + key = key.replace("img_attn_proj", "attn_to_out_0") + key = key.replace("img_mlp_fc1", "ff_net_0_proj") + key = key.replace("img_mlp_fc2", "ff_net_2") + key = key.replace("txt_mod_linear", "norm1_context_linear") + key = key.replace("txt_attn_qkv", "attn_add_QKV_proj") # split later + key = key.replace("txt_attn_proj", "attn_to_add_out") + key = key.replace("txt_mlp_fc1", "ff_context_net_0_proj") + key = key.replace("txt_mlp_fc2", "ff_context_net_2") + # print(f"Converted double_blocks HunyuanVideo LoRA key: {key}") + elif "single_blocks" in key: + # print(f"Converting single_blocks HunyuanVideo LoRA key: {key}") + key = key.replace("single_blocks", "single_transformer_blocks") + key = key.replace("linear1", "attn_to_QKVM") # split later + key = key.replace("linear2", "proj_out") + key = key.replace("modulation_linear", "norm_linear") + # print(f"Converted single_blocks HunyuanVideo LoRA key: {key}") + else: + print( + f"Unsupported module name: {key}, only double_blocks and single_blocks are supported" + ) + continue + + if "QKVM" in key: + # print(f"Converting QKVM HunyuanVideo LoRA key: {key}") + # split QKVM into Q, K, V, M + key_q = key.replace("QKVM", "q") + key_k = key.replace("QKVM", "k") + key_v = key.replace("QKVM", "v") + key_m = key.replace("attn_to_QKVM", "proj_mlp") + if "_down" in key or "alpha" in key: + # copy QKVM weight or alpha to Q, K, V, M + assert ( + "alpha" in key or weight.size(1) == 3072 + ), f"QKVM weight size mismatch: {key}. {weight.size()}" + new_lora_sd[key_q] = weight + new_lora_sd[key_k] = weight + new_lora_sd[key_v] = weight + new_lora_sd[key_m] = weight + elif "_up" in key: + # split QKVM weight into Q, K, V, M + assert ( + weight.size(0) == 21504 + ), f"QKVM weight size mismatch: {key}. {weight.size()}" + new_lora_sd[key_q] = weight[:3072] + new_lora_sd[key_k] = weight[3072 : 3072 * 2] + new_lora_sd[key_v] = weight[3072 * 2 : 3072 * 3] + new_lora_sd[key_m] = weight[3072 * 3 :] # 21504 - 3072 * 3 = 12288 + else: + print(f"Unsupported module name: {key}") + continue + # print(f"Converted QKVM HunyuanVideo LoRA key: {key}") + elif "QKV" in key: + # print(f"Converting QKV HunyuanVideo LoRA key: {key}") + # split QKV into Q, K, V + key_q = key.replace("QKV", "q") + key_k = key.replace("QKV", "k") + key_v = key.replace("QKV", "v") + if "_down" in key or "alpha" in key: + # copy QKV weight or alpha to Q, K, V + assert ( + "alpha" in key or weight.size(1) == 3072 + ), f"QKV weight size mismatch: {key}. {weight.size()}" + new_lora_sd[key_q] = weight + new_lora_sd[key_k] = weight + new_lora_sd[key_v] = weight + elif "_up" in key: + # split QKV weight into Q, K, V + assert ( + weight.size(0) == 3072 * 3 + ), f"QKV weight size mismatch: {key}. {weight.size()}" + new_lora_sd[key_q] = weight[:3072] + new_lora_sd[key_k] = weight[3072 : 3072 * 2] + new_lora_sd[key_v] = weight[3072 * 2 :] + else: + print(f"Unsupported module name: {key}") + continue + # print(f"Converted QKV HunyuanVideo LoRA key: {key}") + else: + # no split needed + new_lora_sd[key] = weight + + return new_lora_sd + + +def load_safetensors_with_fp8_optimization( + model_files: list[str], + fp8_optimization: bool, + device: torch.device, + weight_hook: Callable | None = None, +) -> dict[str, torch.Tensor]: + """ + Load state dict from safetensors files and merge LoRA weights into the state dict with fp8 optimization if needed. + """ + state_dict = {} + if fp8_optimization: + raise RuntimeWarning("FP8 optimization is not yet supported in this version.") + from .fp8_optimization_utils import ( + optimize_state_dict_with_fp8_on_the_fly, + ) + + # Optimization targets and exclusion keys + TARGET_KEYS = ["transformer_blocks", "single_transformer_blocks"] + EXCLUDE_KEYS = [ + "norm" + ] # Exclude norm layers (e.g., LayerNorm, RMSNorm) from FP8 + + print(f"FP8: Optimizing state dictionary on the fly") + # Optimized state dictionary in FP8 format + state_dict = optimize_state_dict_with_fp8_on_the_fly( + model_files, + device, + TARGET_KEYS, + EXCLUDE_KEYS, + move_to_device=False, + weight_hook=weight_hook, + ) + else: + from .safetensors_utils import MemoryEfficientSafeOpen + + state_dict = {} + for model_file in model_files: + with MemoryEfficientSafeOpen(model_file) as f: + for key in tqdm( + f.keys(), + desc=f"Loading {os.path.basename(model_file)}", + leave=False, + ): + value = f.get_tensor(key) + if weight_hook is not None: + value = weight_hook(key, value) + state_dict[key] = value + + return state_dict diff --git a/diffusers_helper/lora_utils_kohya_ss/safetensors_utils.py b/diffusers_helper/lora_utils_kohya_ss/safetensors_utils.py new file mode 100644 index 00000000..d982f998 --- /dev/null +++ b/diffusers_helper/lora_utils_kohya_ss/safetensors_utils.py @@ -0,0 +1,105 @@ +# https://github.com/kohya-ss/FramePack-eichi/blob/4085a24baf08d6f1c25e2de06f376c3fc132a470/webui/lora_utils/safetensors_utils.py +# original: https://gist.github.com/kohya-ss/fa4b7ae7119c10850ae7d70c90a59277 +from typing import Dict +import json +import struct +import torch + + +class MemoryEfficientSafeOpen: + """ + A class to read tensors from a .safetensors file in a memory-efficient way. + """ + + # does not support metadata loading + def __init__(self, filename): + self.filename = filename + self.file = open(filename, "rb") + self.header, self.header_size = self._read_header() + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.file.close() + + def keys(self): + return [k for k in self.header.keys() if k != "__metadata__"] + + def metadata(self) -> Dict[str, str]: + return self.header.get("__metadata__", {}) + + def get_tensor(self, key): + if key not in self.header: + raise KeyError(f"Tensor '{key}' not found in the file") + + metadata = self.header[key] + offset_start, offset_end = metadata["data_offsets"] + + if offset_start == offset_end: + tensor_bytes = None + else: + # adjust offset by header size + self.file.seek(self.header_size + 8 + offset_start) + tensor_bytes = self.file.read(offset_end - offset_start) + + return self._deserialize_tensor(tensor_bytes, metadata) + + def _read_header(self): + header_size = struct.unpack(" torch.dtype: + dtype_map = { + "F64": torch.float64, + "F32": torch.float32, + "F16": torch.float16, + "BF16": torch.bfloat16, + "I64": torch.int64, + "I32": torch.int32, + "I16": torch.int16, + "I8": torch.int8, + "U8": torch.uint8, + "BOOL": torch.bool, + } + # add float8 types if available + if hasattr(torch, "float8_e5m2"): + dtype_map["F8_E5M2"] = torch.float8_e5m2 + if hasattr(torch, "float8_e4m3fn"): + dtype_map["F8_E4M3"] = torch.float8_e4m3fn + if dtype_str not in dtype_map: + raise ValueError(f"Unsupported dtype: {dtype_str}") + return dtype_map[dtype_str] + + @staticmethod + def _convert_float8(byte_tensor, dtype_str, shape): + if dtype_str == "F8_E5M2" and hasattr(torch, "float8_e5m2"): + return byte_tensor.view(torch.float8_e5m2).reshape(shape) + elif dtype_str == "F8_E4M3" and hasattr(torch, "float8_e4m3fn"): + return byte_tensor.view(torch.float8_e4m3fn).reshape(shape) + else: + # # convert to float16 if float8 is not supported + # print(f"Warning: {dtype_str} is not supported in this PyTorch version. Converting to float16.") + # return byte_tensor.view(torch.uint8).to(torch.float16).reshape(shape) + raise ValueError( + f"Unsupported float8 type: {dtype_str} (upgrade PyTorch to support float8 types)" + ) From ff6c542d91522f134394fe595c3500ea31802750 Mon Sep 17 00:00:00 2001 From: arledesma Date: Wed, 16 Jul 2025 18:39:05 -0500 Subject: [PATCH 02/17] VideoJobQueue singleton We do not manage multiple VideoJobQueue's, so this singleton can be imported and used anywhere that we need access to the queue --- modules/video_queue.py | 22 +++++++++++++++++----- 1 file changed, 17 insertions(+), 5 deletions(-) diff --git a/modules/video_queue.py b/modules/video_queue.py index 52b09d1a..e52e2473 100644 --- a/modules/video_queue.py +++ b/modules/video_queue.py @@ -6,6 +6,7 @@ import zipfile import shutil from dataclasses import dataclass, field +from collections.abc import Callable from enum import Enum from typing import Dict, Any, Optional, List import queue as queue_module # Renamed to avoid conflicts @@ -70,7 +71,7 @@ class Job: result: Optional[str] = None progress_data: Optional[Dict] = None queue_position: Optional[int] = None - stream: Optional[Any] = None + stream: Optional[AsyncStream] = None input_image: Optional[np.ndarray] = None latent_type: Optional[str] = None thumbnail: Optional[str] = None @@ -288,17 +289,25 @@ def __post_init__(self): class VideoJobQueue: + _instance: Optional['VideoJobQueue'] = None + def __init__(self): self.queue = queue_module.Queue() # Using standard Queue instead of LifoQueue - self.jobs = {} + self.jobs: dict[str, Job] = {} self.current_job = None self.lock = threading.Lock() self.worker_thread = threading.Thread(target=self._worker_loop, daemon=True) self.worker_thread.start() - self.worker_function = None # Will be set from outside + self.worker_function: Callable[..., None] | None = None # Will be set from outside self.is_processing = False # Flag to track if we're currently processing a job + + def __new__(cls): + if cls._instance is None: + print('Creating the VideoJobQueue instance') + cls._instance = super(VideoJobQueue, cls).__new__(cls) + return cls._instance - def set_worker_function(self, worker_function): + def set_worker_function(self, worker_function: Callable[..., None]): """Set the worker function to use for processing jobs""" self.worker_function = worker_function @@ -684,7 +693,10 @@ def synchronize_queue_images(self): def add_job(self, params, job_type=JobType.SINGLE, child_job_params_list=None, parent_job_id=None): """Add a job to the queue and return its ID""" - job_id = str(uuid.uuid4()) + # sortable time UUID + # get back the datetime with datetime.datetime(1582, 10, 15) + datetime.timedelta(microseconds=uuid.UUID(str(job_id)).time//10) + # which should roughly correspond to the time stored in created_at + job_id = str(uuid.uuid1()) # For grid jobs, create child jobs first child_job_ids = [] From 75eea90f60d9357b1bf11f28e85e1945eaf3c129 Mon Sep 17 00:00:00 2001 From: arledesma Date: Wed, 16 Jul 2025 18:41:49 -0500 Subject: [PATCH 03/17] Add LoraLoader enum Enable switching between known lora loaders Includes StrEnum implementation for python 3.10 (or older 3.x) users, otherwise use the builtin StrEnum from python >= 3.11 --- diffusers_helper/lora_utils_kohya_ss/enums.py | 85 +++++++++++++++++++ .../lora_utils_kohya_ss/lora_loader.py | 2 +- 2 files changed, 86 insertions(+), 1 deletion(-) create mode 100644 diffusers_helper/lora_utils_kohya_ss/enums.py diff --git a/diffusers_helper/lora_utils_kohya_ss/enums.py b/diffusers_helper/lora_utils_kohya_ss/enums.py new file mode 100644 index 00000000..3eb887ec --- /dev/null +++ b/diffusers_helper/lora_utils_kohya_ss/enums.py @@ -0,0 +1,85 @@ +import sys + +if sys.version_info >= (3, 11): + from enum import StrEnum + # StrEnum is introduced in 3.11 while we support python 3.10 +else: + from enum import Enum, auto + from typing import Any + + # Fallback for Python 3.10 and earlier + class StrEnum(str, Enum): + def __new__(cls, value, *args, **kwargs): + if not isinstance(value, (str, auto)): + raise TypeError( + f"Values of StrEnums must be strings: {value!r} is a {type(value)}" + ) + return super().__new__(cls, value, *args, **kwargs) + + def __str__(self): + return str(self.value) + + @staticmethod + def _generate_next_value_( + name: str, start: int, count: int, last_values: list[Any] + ) -> str: + return name + + +class LoraLoader(StrEnum): + DIFFUSERS = "diffusers" + LORA_READY = "lora_ready" + DEFAULT = DIFFUSERS + + @staticmethod + def supported_values() -> list[str]: + """Returns a list of all supported LoraLoader values.""" + return [loader.value for loader in LoraLoader] + + @staticmethod + def safe_parse(value: "str | LoraLoader") -> "LoraLoader": + if isinstance(value, LoraLoader): + return value + try: + return LoraLoader(value) + except ValueError: + return LoraLoader.DEFAULT + + +if __name__ == "__main__": + # Test the StrEnum functionality + print("diffusers:", LoraLoader.DIFFUSERS) # Should print "diffusers" + print("lora_ready:", LoraLoader.LORA_READY) # Should print "lora_ready" + print("default:", LoraLoader.DEFAULT) # Should print "lora_ready" + print( # Should print all unique supported values (excludes aliases like DEFAULT) + "supported_values:", LoraLoader.supported_values() + ) + try: + print("fail:", LoraLoader("invalid")) # Should raise ValueError + except ValueError as e: + print("pass:", e) # Prints: Invalid LoraLoader value: invalid + try: + print("pass:", LoraLoader("diffusers")) # Should return LoraLoader.DIFFUSERS + except ValueError as e: + print("fail:", e) + try: + print("type of LoraLoader.DEFAULT:", type(LoraLoader.DEFAULT)) + default = LoraLoader.DEFAULT + print("type of default:", type(default)) # Should be LoraLoader, not str + except Exception as e: + print(f"fail: {e}") + + assert isinstance(LoraLoader("lora_ready"), StrEnum) + assert isinstance( + LoraLoader.DIFFUSERS, LoraLoader + ), "DIFFUSERS should be an instance of LoraLoader" + assert ( + LoraLoader.DEFAULT == LoraLoader.DIFFUSERS + ), "Default loader should be DIFFUSERS" + assert ( + LoraLoader.DIFFUSERS != LoraLoader.LORA_READY + ), "DIFFUSERS should not equal LORA_READY" + + assert ( + LoraLoader.LORA_READY.value == "lora_ready" + ), "lora_ready string should equal LoraLoader.LORA_READY" diff --git a/diffusers_helper/lora_utils_kohya_ss/lora_loader.py b/diffusers_helper/lora_utils_kohya_ss/lora_loader.py index d8b1bbdb..21fe7d15 100644 --- a/diffusers_helper/lora_utils_kohya_ss/lora_loader.py +++ b/diffusers_helper/lora_utils_kohya_ss/lora_loader.py @@ -23,7 +23,7 @@ def load_and_apply_lora( model_files: List of model files to load lora_paths: List of LoRA file paths lora_scales: List if LoRA weight scales - fp8_enabled: Whether to enable FP8 optimization. Default is False. + fp8_enabled: Whether to enable FP8 optimization. Default is False. This requires fp8 model_files. device: Device used for loading the model. If None, defaults to CPU. Returns: From f5cca71b5b2142c6f6997b323c62b877eff1f3f3 Mon Sep 17 00:00:00 2001 From: arledesma Date: Wed, 16 Jul 2025 18:46:00 -0500 Subject: [PATCH 04/17] Settings singleton We do not manage multiple Settings objects, so this singleton can be imported and used anywhere that we need access to the Settings --- modules/settings.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/modules/settings.py b/modules/settings.py index e3bd25df..dc6e206e 100644 --- a/modules/settings.py +++ b/modules/settings.py @@ -4,6 +4,10 @@ import os class Settings: + """Singleton class to manage application settings.""" + + _instance: Optional['Settings'] = None + def __init__(self): # Get the project root directory (where settings.py is located) project_root = Path(__file__).parent.parent @@ -38,6 +42,12 @@ def __init__(self): } self.settings = self.load_settings() + def __new__(cls): + if cls._instance is None: + print('Creating the Settings instance') + cls._instance = super(Settings, cls).__new__(cls) + return cls._instance + def load_settings(self) -> Dict[str, Any]: """Load settings from file or return defaults""" if self.settings_file.exists(): From 5d53e3c33dcd36c016545fa71694e29372fc05ac Mon Sep 17 00:00:00 2001 From: arledesma Date: Wed, 16 Jul 2025 18:50:24 -0500 Subject: [PATCH 05/17] Add Settings - lora_loader, reuse_model_instance These are defaulted to continue existing behavior. diffusers lora loader and no reuse of model instance --- modules/settings.py | 22 +++++++++++++++++++++- 1 file changed, 21 insertions(+), 1 deletion(-) diff --git a/modules/settings.py b/modules/settings.py index dc6e206e..42067b3a 100644 --- a/modules/settings.py +++ b/modules/settings.py @@ -2,6 +2,7 @@ from pathlib import Path from typing import Dict, Any, Optional import os +from diffusers_helper.lora_utils_kohya_ss.enums import LoraLoader class Settings: """Singleton class to manage application settings.""" @@ -38,7 +39,9 @@ def __init__(self): User prompt: "{text_to_enhance}" -Enhanced prompt:""" +Enhanced prompt:""", + "lora_loader": LoraLoader.DEFAULT, # lora_loader options: diffusers, lora_ready. DEFAULT is existing behavior of diffusers + "reuse_model_instance": False, # Reuse model instance across generations - default of False is existing behavior } self.settings = self.load_settings() @@ -48,6 +51,23 @@ def __new__(cls): cls._instance = super(Settings, cls).__new__(cls) return cls._instance + @property + def lora_loader(self) -> LoraLoader: + return LoraLoader.safe_parse(self.settings.get("lora_loader", LoraLoader.DEFAULT)) + + @lora_loader.setter + def lora_loader(self, value: str | LoraLoader): + if not value: + value = LoraLoader.DEFAULT + if isinstance(value, str): + value = LoraLoader.safe_parse(value) + + self.set("lora_loader", value) + + @property + def reuse_model_instance(self) -> bool: + return self.settings.get("reuse_model_instance", False) + def load_settings(self) -> Dict[str, Any]: """Load settings from file or return defaults""" if self.settings_file.exists(): From 635e928809b610130f1ba31c623c085c4f65b59d Mon Sep 17 00:00:00 2001 From: arledesma Date: Thu, 17 Jul 2025 22:23:25 -0500 Subject: [PATCH 06/17] Expose lora_loader and reuse_model_instance in Settings UI --- modules/ui/settings.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/modules/ui/settings.py b/modules/ui/settings.py index 69e8fe46..37bc7c90 100644 --- a/modules/ui/settings.py +++ b/modules/ui/settings.py @@ -1,4 +1,5 @@ import gradio as gr +from diffusers_helper.lora_utils_kohya_ss.enums import LoraLoader from modules.ui.generate import load_presets def create_settings_ui(settings, get_latents_display_top, model_type_choices): @@ -29,6 +30,11 @@ def create_settings_ui(settings, get_latents_display_top, model_type_choices): auto_save = gr.Checkbox(label="Auto-save settings", value=settings.get("auto_save_settings", True)) gradio_themes = ["default", "base", "soft", "glass", "mono", "origin", "citrus", "monochrome", "ocean", "NoCrypt/miku", "earneleh/paris", "gstaff/xkcd"] theme_dropdown = gr.Dropdown(label="Theme", choices=gradio_themes, value=settings.get("gradio_theme", "default"), info="Select the Gradio UI theme. Requires restart.") + with gr.Accordion("Experimental Settings", open=False): + gr.Markdown("These settings are for advanced users. Changing them may affect the performance or functionality of the application.") + lora_loader = gr.Dropdown(label="LoRA Loader", choices=[loader.value for loader in LoraLoader], value=settings.lora_loader.value, info="Select the LoRA loader to use. 'diffusers' for Diffusers format, 'lora_ready' for Kohya-ss format.", interactive=True) + __reuse_model_instance_warning__ = gr.Markdown("The *Reuse of Model Instance* option may be unstable for lower memory GPUs. If you experience memory pressure or crashes, disable this option.") + reuse_model_instance = gr.Checkbox(label="Reuse Model Instance", value=settings.get("reuse_model_instance", False), info="If checked, the model instance will be reused across generations to save reload time when no LoRA changes are detected and the same model is used. If unchecked, a new model instance will be created for each generation.") save_btn = gr.Button("💾 Save Settings") cleanup_btn = gr.Button("🗑️ Clean Up Temporary Files") status = gr.HTML("") @@ -42,9 +48,11 @@ def create_settings_ui(settings, get_latents_display_top, model_type_choices): "reset_system_prompt_btn": reset_system_prompt_btn, "system_prompt_template": system_prompt_template, "output_dir": output_dir, "metadata_dir": metadata_dir, "lora_dir": lora_dir, "gradio_temp_dir": gradio_temp_dir, "auto_save": auto_save, "theme_dropdown": theme_dropdown, - "save_btn": save_btn, "cleanup_btn": cleanup_btn, "status": status, "cleanup_output": cleanup_output + "save_btn": save_btn, "cleanup_btn": cleanup_btn, "status": status, "cleanup_output": cleanup_output, + "lora_loader": lora_loader, "reuse_model_instance": reuse_model_instance, } +# we can avoid passing around settings object into here, as the Settings class is now a singleton. Remove this comment if done. def connect_settings_events(s, g, settings, create_latents_layout_update, tb_processor): def save_settings_func(*args): keys = list(s.keys()) From 6d1dd7105e9c5eb7d1969000283bd477765e1c86 Mon Sep 17 00:00:00 2001 From: arledesma Date: Thu, 17 Jul 2025 22:25:28 -0500 Subject: [PATCH 07/17] Map DynamicSwap_HunyuanVideoTransformer3DModelPacked to HunyuanVideoTransformer3DModel This was existing behavior that was found to not be mapped. --- diffusers_helper/lora_utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/diffusers_helper/lora_utils.py b/diffusers_helper/lora_utils.py index d726d53b..94b6f494 100644 --- a/diffusers_helper/lora_utils.py +++ b/diffusers_helper/lora_utils.py @@ -8,6 +8,7 @@ FALLBACK_CLASS_ALIASES = { "HunyuanVideoTransformer3DModelPacked": "HunyuanVideoTransformer3DModel", + "DynamicSwap_HunyuanVideoTransformer3DModelPacked": "HunyuanVideoTransformer3DModel", } def load_lora(transformer: torch.nn.Module, lora_path: Path, weight_name: str) -> Tuple[torch.nn.Module, str]: From 5b57b02c2a01a20ac8cb99db31acd596586e4985 Mon Sep 17 00:00:00 2001 From: arledesma Date: Thu, 17 Jul 2025 22:30:22 -0500 Subject: [PATCH 08/17] Add ModelConfiguration to track of model settings and customizations --- modules/generators/model_configuration.py | 286 ++++++++++++++++++++++ 1 file changed, 286 insertions(+) create mode 100644 modules/generators/model_configuration.py diff --git a/modules/generators/model_configuration.py b/modules/generators/model_configuration.py new file mode 100644 index 00000000..2e76b510 --- /dev/null +++ b/modules/generators/model_configuration.py @@ -0,0 +1,286 @@ +import hashlib +import logging +import json +from typing import Optional, cast +from dataclasses import dataclass, field, asdict + +from diffusers_helper.lora_utils_kohya_ss.enums import LoraLoader + +DEFAULT_WEIGHT: float = 0.8 + +__logger = logging.getLogger(__name__) + +@dataclass(frozen=True, eq=True) +class ModelLoraSetting: + """ Represents a LoRA (Low-Rank Adaptation) setting for a model. + Attributes: + name: The name of the LoRA. + weight: The weight of the LoRA. Typically between 0.0 and 1.0, but may be used up to 2.0 (or potentially higher). + Default is 0.8 (DEFAULT_WEIGHT). + sequence: The sequence order of the LoRA when applied to the model. + Lower numbers indicate higher priority, loaded first. + The sequence is automatically assigned if not provided or if duplicates exist. + The order of application can affect the final model output depending on block interactions. + exclude_blocks: The blocks to exclude from the LoRA represented as a single regex string. + include_blocks: The blocks to include in the LoRA represented as a single regex string. + """ + + name: str + weight: float = field(default=DEFAULT_WEIGHT) + sequence: int = field(default=0) + exclude_blocks: Optional[str] = field(kw_only=True, default=None) + include_blocks: Optional[str] = field(kw_only=True, default=None) + + def __post_init__(self): + if not self.name: + raise ValueError("ModelLoraSetting requires a 'name' attribute.") + if not isinstance(self.weight, float): + raise ValueError( + "ModelLoraSetting requires a 'weight' attribute with a type of float but got {0} of type {1}".format(self.weight, type(self.weight))) + if not isinstance(self.sequence, int): + raise ValueError("ModelLoraSetting requires a 'sequence' attribute with a type of int but got {0} of type {1}".format( + self.sequence, type(self.sequence))) + + @staticmethod + def parse_settings(settings: list["ModelLoraSetting"] | str | list[str] | dict[str, dict] | dict[str, float | int] | None = None, reverse_sequence: bool = False) -> list["ModelLoraSetting"]: + """Parses LoRA settings from various input formats into a list of ModelLoraSetting instances. + Args: + lora_settings: The LoRA settings to parse, which can be a list of ModelLoraSetting instances, + a list of strings (names), or a dictionary mapping names to settings defining at least weight and sequence. + reverse_sequence: Whether to sort the settings in reverse order based on their sequence. Default is False. + Returns: + A list of ModelLoraSetting instances. + """ + if settings is None or not settings: + return [] + + parsed_settings: list[ModelLoraSetting] = [] + if (isinstance(settings, str)): + parsed_settings = [ModelLoraSetting(name=settings, weight=DEFAULT_WEIGHT, sequence=0)] + elif (isinstance(settings, ModelLoraSetting)): + parsed_settings = [settings] + elif isinstance(settings, list) and all(isinstance(setting, ModelLoraSetting) for setting in settings): + parsed_settings = list(set(lora for lora in settings if isinstance( + lora, ModelLoraSetting)) if settings else []) + elif isinstance(settings, str): + parsed_settings = [ModelLoraSetting(name=settings, weight=DEFAULT_WEIGHT, sequence=0)] + elif isinstance(settings, list) and all(isinstance(setting, str) for setting in settings): + parsed_settings = [ModelLoraSetting(name=name, weight=DEFAULT_WEIGHT, sequence=sequence) + for sequence, name in enumerate(settings) if isinstance(name, str)] + elif isinstance(settings, dict): + if all(isinstance(k, str) and isinstance(v, float | int) for k, v in settings.items()): + parsed_settings = [ + ModelLoraSetting(name=name, weight=float(cast(float, weight)), sequence=sequence) + for sequence, (name, weight) in enumerate(settings.items()) + ] + if all(isinstance(k, str) and isinstance(v, dict) for k, v in settings.items()): + parsed_settings = [ + ModelLoraSetting(name=name, weight=float(cast(dict, details).get("weight", DEFAULT_WEIGHT)), + sequence=int(cast(dict, details).get("sequence", sequence))) + for sequence, (name, details) in enumerate(settings.items()) + ] + elif all(isinstance(v, str) for v in settings.values()): + parsed_settings = [ + ModelLoraSetting(name=name, weight=DEFAULT_WEIGHT, sequence=sequence) + for sequence, (name, _) in enumerate(settings.items()) + ] + else: + raise ValueError("Invalid lora_settings format") + + if not parsed_settings: + return [] + + # assign sequences to settings without valid sequence + sequences = [setting.sequence for setting in parsed_settings if setting.sequence is not None and setting.sequence >= 0] + unique_sequences: set[int] = set(sequences) + unique_sequences_len = len(unique_sequences) + magic_number = 1000 if unique_sequences_len != len(parsed_settings) else 0 + for setting_index, setting in enumerate(parsed_settings): + if unique_sequences_len == 0: + # no sequence set on any setting, assign based on index + setting = ModelLoraSetting(name=setting.name, weight=float(setting.weight), sequence=setting_index, + exclude_blocks=setting.exclude_blocks, include_blocks=setting.include_blocks) + elif setting.sequence is None or setting.sequence < 0: + # sequence invalid or not set, assign a new unique sequence based on index + magic_number + setting = ModelLoraSetting(name=setting.name, weight=float(setting.weight), sequence=setting_index + + magic_number, exclude_blocks=setting.exclude_blocks, include_blocks=setting.include_blocks) + + # update duplicate sequences + sequences = [setting.sequence for setting in parsed_settings if setting.sequence is not None and setting.sequence >= 0] + unique_sequences: set[int] = set(sequences) + unique_sequences_len = len(unique_sequences) + for setting_index, setting in enumerate(parsed_settings): + if sequences.count(setting.sequence) > 1: + # duplicate sequence, assign a new unique sequence based on max existing + 1 + max_sequence = max(unique_sequences, default=1000 + setting_index) + setting = ModelLoraSetting(name=setting.name, weight=float(setting.weight), sequence=max_sequence + 1, + exclude_blocks=setting.exclude_blocks, include_blocks=setting.include_blocks) + + del unique_sequences, unique_sequences_len + settings_set: set[ModelLoraSetting] = set() + for setting_index, setting in enumerate(parsed_settings): + if setting.sequence is None or setting.sequence <= 0: + new_setting = setting if not any(setting.sequence == s.sequence for s in settings_set) else None + if new_setting: + settings_set.add(new_setting) + else: + max_sequence = max((s.sequence for s in settings_set if getattr(s, 'sequence', setting_index) is not None), + default=1000 + len(settings_set)) + settings_set.add(ModelLoraSetting(name=setting.name, weight=float(setting.weight), sequence=max_sequence + 1, + exclude_blocks=setting.exclude_blocks, include_blocks=setting.include_blocks)) + + if not any(setting.sequence == setting_index for setting in parsed_settings): + new_sequence = setting_index + setting = ModelLoraSetting(name=setting.name, weight=float(setting.weight), sequence=new_sequence, + exclude_blocks=setting.exclude_blocks, include_blocks=setting.include_blocks) + + else: + max_sequence = max((s.sequence for s in parsed_settings if hasattr(s, 'sequence')), + default=1000 + setting_index) + setting = ModelLoraSetting(name=setting.name, weight=float(setting.weight), sequence=max_sequence + 1, + exclude_blocks=setting.exclude_blocks, include_blocks=setting.include_blocks) + + settings = list(settings_set) + del parsed_settings, settings_set + + # always sort by sequence ascending or descending before returning + settings.sort(key=lambda x: x.sequence, reverse=reverse_sequence) + return settings + + @staticmethod + def from_names_and_weights(lora_names: list[str], lora_weights: Optional[list[float | int]] = None) -> list["ModelLoraSetting"]: + """Creates a list of ModelLoraSetting instances from lists of names and weights. + Args: + lora_names (list[str]): The list of LoRA names. + lora_weights (Optional[list[float | int]]): The list of LoRA weights. If None, defaults to an empty list. + Returns: + list[ModelLoraSetting]: A list of ModelLoraSetting instances. + """ + + if lora_weights is None: + lora_weights = [] + if len(lora_names) != len(lora_weights): + __logger.warning( + f"Warning: Mismatch in lengths of lora_names ({len(lora_names)}) and lora_weights ({len(lora_weights)}).") + additional_weights = len(lora_names) - len(lora_weights) + if additional_weights > 0: + __logger.info(f"Filling missing weights with default value {DEFAULT_WEIGHT}.") + lora_weights = (lora_weights) + [DEFAULT_WEIGHT] * additional_weights + else: + lora_weights = (lora_weights)[:len(lora_names)] + + lora_settings: list[ModelLoraSetting] = [] + for sequence, (name, weight) in enumerate(zip(lora_names, lora_weights or [DEFAULT_WEIGHT] * len(lora_names))): + lora_settings.append(ModelLoraSetting(name=name, weight=float(weight), sequence=sequence)) + return lora_settings + + +@dataclass +class ModelSettings(): + lora_settings: list[ModelLoraSetting] = field(default_factory=list) + lora_loader: str = field(default=LoraLoader.DEFAULT.value) + + def add_lora_setting(self, setting: ModelLoraSetting) -> None: + max_sequence = max((s.sequence for s in self.lora_settings if hasattr(s, 'sequence')), default=-1) + new_sequence = max_sequence + 1 + self.lora_settings.append(ModelLoraSetting(name=setting.name, weight=setting.weight, sequence=new_sequence, + include_blocks=setting.include_blocks, exclude_blocks=setting.exclude_blocks)) + + +@dataclass +class ModelConfiguration: + model_name: str + settings: ModelSettings = field(default_factory=ModelSettings) + + @property + def _hash(self) -> str: + return hashlib.md5(json.dumps(asdict(self), sort_keys=True).encode()).hexdigest() + + def add_lora_setting(self, setting: ModelLoraSetting) -> None: + self.settings.add_lora_setting(setting) + + def add_lora(self, name: str, weight: float = DEFAULT_WEIGHT) -> None: + self.add_lora_setting(ModelLoraSetting(name=name, weight=weight)) + + def validate(self) -> bool: + total_weights = sum([setting.weight for setting in self.settings.lora_settings]) + valid = 2 > total_weights > 0 + if not valid: + __logger.warning("Warning: total weight for all LoRA may not perform well with the model ({0}). Total weight: {1}".format( + self.model_name, total_weights)) + return valid + + @staticmethod + def from_settings(model_name: str, settings: ModelSettings | dict | None): + model_settings: ModelSettings | None = None + if settings is None: + model_settings = ModelSettings() + elif isinstance(settings, ModelSettings): + model_settings = settings + elif isinstance(settings, dict): + model_settings = ModelSettings(lora_settings=ModelLoraSetting.parse_settings(settings)) + + if model_settings is None: + raise ValueError("Invalid config type for ModelConfiguration") + + return ModelConfiguration(model_name=model_name, settings=model_settings) + + @staticmethod + def from_lora_names_and_weights( + model_name: str, + lora_names: list[str], + lora_weights: list[float | int], + lora_loader: str | LoraLoader, + ) -> "ModelConfiguration": + assert isinstance(model_name, str) and model_name, "model_name must be a non-empty string" + assert isinstance(lora_names, list) and all(isinstance(name, str) for name in lora_names), "lora_names must be a list of strings" + assert isinstance(lora_weights, list) and all(isinstance(weight, (float, int)) for weight in lora_weights), "lora_weights must be a list of floats or ints" + assert isinstance(lora_loader, (str, LoraLoader)), "lora_loader must be a string or LoraLoader enum" + + weights: list[float] = [float(weight) for weight in (lora_weights or [])] + lora_settings = ModelLoraSetting.from_names_and_weights(lora_names, lora_weights=weights) + model_settings = ModelSettings(lora_settings=lora_settings, lora_loader=str(lora_loader)) + return ModelConfiguration.from_settings(model_name=model_name, settings=model_settings) + + def set_model_name(self, model_name: str) -> "ModelConfiguration": + self.model_name = model_name + return self + + def set_settings(self, settings: ModelSettings) -> "ModelConfiguration": + self.settings = settings + return self + + def update_lora_setting(self, lora_settings: list[ModelLoraSetting] | str | list[str] | dict[str, dict]) -> "ModelConfiguration": + self.settings.lora_settings = ModelLoraSetting.parse_settings(lora_settings) + return self + + +if __name__ == '__main__': + import logging + logging.basicConfig(level=logging.DEBUG) + logger = logging.getLogger(__name__) + config: ModelConfiguration = ModelConfiguration.from_lora_names_and_weights( + model_name="test-model", + lora_names=["lora1", "lora2", "lora3"], + lora_weights=[0.5, 0.5, 0.5], + lora_loader=LoraLoader.DIFFUSERS + ) + logger.info(f"Model Name: {config.model_name}") + logger.info(f"LoRA Settings: {config.settings.lora_settings}") + logger.debug(json.dumps(asdict(config), indent=4)) + logger.debug("hash: {0}".format(config._hash)) + config.model_name = "changed" + logger.debug("hash: {0}".format(config._hash)) + config.settings.lora_settings = ModelLoraSetting.from_names_and_weights( + lora_names=["lora_A", "lora_B", "lora_C"], + lora_weights=[1, 1.5, 2.5] + ) + logger.debug("hash: {0}".format(config._hash)) + config.settings.lora_settings.append(ModelLoraSetting(name="lora_D", weight=0.75)) + logger.debug(json.dumps(asdict(config), indent=4)) + logger.debug("hash: {0}".format(config._hash)) + config.add_lora("lora_E") + logger.debug(json.dumps(asdict(config), indent=4)) + logger.debug("hash: {0}".format(config._hash)) + valid = config.validate() + logger.debug("Config validation result: {0}".format(valid)) From 17ac0309e9242f70b7b6492e5f6d48665096f8f3 Mon Sep 17 00:00:00 2001 From: arledesma Date: Thu, 17 Jul 2025 22:33:46 -0500 Subject: [PATCH 09/17] Wire up kohya_ss LoRAReady loader in BaseModelGenerator --- modules/generators/base_generator.py | 193 +++++++++++++++++++++++++-- 1 file changed, 184 insertions(+), 9 deletions(-) diff --git a/modules/generators/base_generator.py b/modules/generators/base_generator.py index 175dd08f..73153b71 100644 --- a/modules/generators/base_generator.py +++ b/modules/generators/base_generator.py @@ -1,10 +1,19 @@ import torch import os # required for os.path from abc import ABC, abstractmethod +from dataclasses import asdict from diffusers_helper import lora_utils -from typing import List, Optional +from typing import List, Optional, cast from pathlib import Path +from diffusers_helper.lora_utils_kohya_ss.enums import LoraLoader +from diffusers_helper.models.hunyuan_video_packed import HunyuanVideoTransformer3DModelPacked + +from ..settings import Settings +from .model_configuration import ModelConfiguration + +# cSpell: ignore loras + class BaseModelGenerator(ABC): """ Base class for model generators. @@ -21,7 +30,7 @@ def __init__(self, feature_extractor, high_vram=False, prompt_embedding_cache=None, - settings=None, + settings: Settings | None = None, offline=False): # NEW: offline flag """ Initialize the base model generator. @@ -39,6 +48,10 @@ def __init__(self, settings: Application settings offline: Whether to run in offline mode for model loading """ + self.model_name: str + self.model_path: str + self.model_repo_id_for_cache: str + self.text_encoder = text_encoder self.text_encoder_2 = text_encoder_2 self.tokenizer = tokenizer @@ -48,29 +61,68 @@ def __init__(self, self.feature_extractor = feature_extractor self.high_vram = high_vram self.prompt_embedding_cache = prompt_embedding_cache or {} - self.settings = settings + self.settings: Settings = settings if settings is not None else Settings() self.offline = offline self.transformer = None self.gpu = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.cpu = torch.device("cpu") + self.previous_model_hash: str = "" + self.previous_model_configuration: ModelConfiguration | None = None @abstractmethod - def load_model(self): + def load_model(self) -> HunyuanVideoTransformer3DModelPacked: """ Load the transformer model. - This method should be implemented by each specific model generator. """ + # this load_model function has the same implementation in all subclasses + # candidate for consolidation to directly implement here in the base class pass @abstractmethod - def get_model_name(self): + def get_model_name(self) -> str: """ Get the name of the model. This method should be implemented by each specific model generator. """ + # this get_model_name function has the same implementation in all subclasses + # candidate for consolidation to directly implement here in the base class pass + @abstractmethod + def get_latent_paddings(self, total_latent_sections) -> list[int]: + raise NotImplementedError( + "get_latent_paddings must be implemented by the specific model generator subclass.") + + @abstractmethod + def format_position_description(self, total_generated_latent_frames, current_pos, original_pos, current_prompt) -> str: + raise NotImplementedError( + "format_position_description must be implemented by the specific model generator subclass.") + + @abstractmethod + def get_real_history_latents(self, history_latents: torch.Tensor, total_generated_latent_frames: int) -> torch.Tensor: + """ + Get the real history latents by slicing the history latents tensor. + """ + raise NotImplementedError( + "get_real_history_latents must be implemented by the specific model generator subclass.") + + @abstractmethod + def update_history_latents(self, history_latents: torch.Tensor, generated_latents: torch.Tensor) -> torch.Tensor: + """ + Update the history latents with the generated latents. + This method should be implemented by each specific model generator. + + Args: + history_latents: The history latents + generated_latents: The generated latents + + Returns: + The updated history latents + """ + raise NotImplementedError( + "update_history_latents must be implemented by the specific model generator subclass.") + @staticmethod def _get_snapshot_hash_from_refs(model_repo_id_for_cache: str) -> str | None: """ @@ -109,10 +161,12 @@ def _get_offline_load_path(self) -> str: if not hasattr(self, 'model_repo_id_for_cache') or not self.model_repo_id_for_cache: print(f"Warning: model_repo_id_for_cache not set in {self.__class__.__name__}. Cannot determine offline path.") # Fallback to model_path if it exists, otherwise None - return getattr(self, 'model_path', None) + return str(getattr(self, 'model_path', None)) if not hasattr(self, 'model_path') or not self.model_path: print(f"Warning: model_path not set in {self.__class__.__name__}. Cannot determine fallback for offline path.") + # raise error instead of returning None? + # raise ValueError(f"{self.__class__.__name__} must set model_path for offline loading.") return None snapshot_hash = self._get_snapshot_hash_from_refs(self.model_repo_id_for_cache) @@ -222,6 +276,25 @@ def move_lora_adapters_to_device(self, target_device): print(f"Moved all LoRA adapters to {target_device}") + def __compute_lora_state_hash(self, lora_config: ModelConfiguration) -> str: + """ + Compute a simple hash representing the current state of LoRA adapters in the transformer. + This can be used to detect changes in loaded LoRAs. + """ + import hashlib + # md5 should be sufficient for this purpose + m = hashlib.md5() + + if self.transformer is None: + # Should not happen - return a unique value + print("Warning: Transformer is None when computing LoRA state hash.") + from time import time + m.update(str(time() * 1000).encode('utf-8')) + + import json + m.update(json.dumps(asdict(lora_config), sort_keys=True).encode('utf-8')) + return m.hexdigest() + def load_loras(self, selected_loras: List[str], lora_folder: str, lora_loaded_names: List[str], lora_values: Optional[List[float]] = None): """ Load LoRAs into the transformer model and applies their weights. @@ -232,14 +305,116 @@ def load_loras(self, selected_loras: List[str], lora_folder: str, lora_loaded_na lora_loaded_names: The master list of ALL available LoRA names, used for correct weight indexing. lora_values: A list of strength values corresponding to lora_loaded_names. """ - self.unload_loras() - if not selected_loras: + # Only unload at this point if no LoRAs are selected + self.unload_loras() print("No LoRAs selected, skipping loading.") return + if self.transformer is None: + print("Transformer model is None, cannot load LoRAs.") + return + + if lora_values is None: + lora_values = [] + + selected_lora_values = [lora_values[lora_loaded_names.index(name)] for name in selected_loras if name in lora_loaded_names] + print(f"Loading LoRAs: {selected_loras} with values: {selected_lora_values}") + + active_model_configuration = ModelConfiguration.from_lora_names_and_weights( + self.get_model_name(), + selected_loras, + selected_lora_values, + self.settings.lora_loader + ) + + active_model_hash = self.__compute_lora_state_hash(active_model_configuration) + if active_model_hash == self.previous_model_hash: + # This can only happen if the model is not changed + # When the model is loaded we will always have the default previous_model_hash value + # The only time that this can happen is when settings.reuse_model_instance is True + # and the model is not changed, and the LoRAs are not changed. + print("Model configuration unchanged, skipping reload.") + return + + print(f"Previous LoRA config: {self.previous_model_configuration}, Current LoRA config: {active_model_configuration}") + print(f"Previous LoRA hash: {self.previous_model_hash}, Current LoRA hash: {active_model_hash}") + + self.previous_model_hash = active_model_hash + self.previous_model_configuration = active_model_configuration + lora_dir = Path(lora_folder) + if self.settings.lora_loader == LoraLoader.LORA_READY: + from diffusers_helper.lora_utils_kohya_ss.lora_loader import load_and_apply_lora + from diffusers_helper.lora_utils_kohya_ss.lora_check_helper import print_lora_status + print(f"Loading LoRAs using kohya_ss LoRAReady loader from {lora_dir}") + + def _find_model_files(model_path): + """Get state dictionary file from specified model path + This is undesirable as it depends on Diffusers implementation.""" + import glob + model_root = os.environ['HF_HOME'] # './hf_download'? + subdir = os.path.join(model_root, 'hub', 'models--' + model_path.replace('/', '--')) + model_files = glob.glob(os.path.join(subdir, '**', '*.safetensors'), recursive=True) + glob.glob(os.path.join(subdir, '**', '*.pt'), recursive=True) + model_files.sort() + return model_files + try: + model_files = _find_model_files(self.model_path) + print(f"LoRA -> Found model files: {model_files}") + lora_paths = [ + # not sure why the full path is not passed around and potentially trimmed for the interface display + str(lora_dir / f"{lora_setting.name}.safetensors") + if Path(lora_dir / f"{lora_setting.name}.safetensors").exists() + else str(lora_dir / f"{lora_setting.name}.pt") # hopefully .pt is the correct extension. + for lora_setting in active_model_configuration.settings.lora_settings + ] + lora_scales: list[float] = [lora_setting.weight for lora_setting in active_model_configuration.settings.lora_settings] + print(f'Lora paths: {lora_paths}') + if not lora_paths: + raise ValueError("No valid LoRA paths found for the selected LoRAs.") + + state_dict = load_and_apply_lora( + model_files=model_files, + lora_paths=lora_paths, + lora_scales=lora_scales, + fp8_enabled=cast(bool, self.settings.get("fp8", False)), + device=self.gpu if torch.cuda.is_available() else self.cpu + ) + print("Loading state dict into transformer...") + missing_keys, unexpected_keys = self.transformer.load_state_dict(state_dict, assign=True, strict=True) + + if missing_keys: + print(f"Warning: Missing keys when loading LoRA state dict: {missing_keys}") + if unexpected_keys: + print(f"Warning: Unexpected keys when loading LoRA state dict: {unexpected_keys}") + + state_dict_size: int = 0 + try: + state_dict_size = sum(param.numel() * param.element_size() + for param in state_dict.values() if hasattr(param, 'numel')) + print(f"State dictionary size: {state_dict_size / (1024**3):.2f} GB") + except: + pass + + try: + del state_dict + import gc + gc.collect() + print(f"Freed state dictionary size: {state_dict_size / (1024**3):.2f} GB") + except: + print("Could not free state dictionary from memory.") + + except Exception as e: + import traceback + print(f"Error loading LoRAs with kohya_ss LoRAReady loader: {e}") + traceback.print_exc() + return + + if self.settings.lora_loader != LoraLoader.DIFFUSERS: + raise NotImplementedError("Unsupported LoRA loader: {}".format(self.settings.lora_loader)) + + self.unload_loras() adapter_names = [] strengths = [] From d320f0fd69cdcb7770401359e0719b783e90315e Mon Sep 17 00:00:00 2001 From: arledesma Date: Thu, 17 Jul 2025 22:36:45 -0500 Subject: [PATCH 10/17] Return type from create_model_generator() in generators module --- modules/generators/__init__.py | 24 ++++++++++++++++++++---- 1 file changed, 20 insertions(+), 4 deletions(-) diff --git a/modules/generators/__init__.py b/modules/generators/__init__.py index 20801fcd..390ffc6c 100644 --- a/modules/generators/__init__.py +++ b/modules/generators/__init__.py @@ -1,10 +1,13 @@ -from .original_generator import OriginalModelGenerator +from .base_generator import BaseModelGenerator from .f1_generator import F1ModelGenerator -from .video_generator import VideoModelGenerator -from .video_f1_generator import VideoF1ModelGenerator +from .original_generator import OriginalModelGenerator from .original_with_endframe_generator import OriginalWithEndframeModelGenerator +from .video_base_generator import VideoBaseModelGenerator +from .video_f1_generator import VideoF1ModelGenerator +from .video_generator import VideoModelGenerator +from .model_configuration import ModelConfiguration -def create_model_generator(model_type, **kwargs): +def create_model_generator(model_type, **kwargs) -> BaseModelGenerator | VideoBaseModelGenerator: """ Create a model generator based on the model type. @@ -30,3 +33,16 @@ def create_model_generator(model_type, **kwargs): return VideoF1ModelGenerator(**kwargs) else: raise ValueError(f"Unsupported model type: {model_type}") + + +__all__ = [ + "BaseModelGenerator", + "create_model_generator", + "F1ModelGenerator", + "OriginalModelGenerator", + "OriginalWithEndframeModelGenerator", + "VideoBaseModelGenerator", + "VideoF1ModelGenerator", + "VideoModelGenerator", + "ModelConfiguration", +] From 797973f3143b524d6d6f1e5896c009b32c5087e3 Mon Sep 17 00:00:00 2001 From: arledesma Date: Thu, 17 Jul 2025 22:38:04 -0500 Subject: [PATCH 11/17] Use StudioManager in worker with model reuse settings --- modules/pipelines/worker.py | 103 ++++++++++++++++++------------------ 1 file changed, 52 insertions(+), 51 deletions(-) diff --git a/modules/pipelines/worker.py b/modules/pipelines/worker.py index 3ac25ba6..a0cf514b 100644 --- a/modules/pipelines/worker.py +++ b/modules/pipelines/worker.py @@ -1,5 +1,4 @@ import os -import json import time import traceback import einops @@ -7,14 +6,11 @@ import torch import datetime from PIL import Image -from PIL.PngImagePlugin import PngInfo from diffusers_helper.models.mag_cache import MagCache from diffusers_helper.utils import save_bcthw_as_mp4, generate_timestamp, resize_and_center_crop from diffusers_helper.memory import cpu, gpu, move_model_to_device_with_memory_preservation, offload_model_from_device_for_memory_preservation, fake_diffusers_current_device, unload_complete_models, load_model_as_complete -from diffusers_helper.thread_utils import AsyncStream from diffusers_helper.gradio.progress_bar import make_progress_bar_html from diffusers_helper.hunyuan import vae_decode -from modules.video_queue import JobStatus from modules.prompt_handler import parse_timestamped_prompt from modules.generators import create_model_generator from modules.pipelines.video_tools import combine_videos_sequentially_from_tensors @@ -22,8 +18,9 @@ from modules.llm_captioner import unload_captioning_model from modules.llm_enhancer import unload_enhancing_model from . import create_pipeline +from modules.studio_manager import StudioManager -import __main__ as studio_module # Get a reference to the __main__ module object +# cSpell: disable hunyan, loras @torch.no_grad() def get_cached_or_encode_prompt(prompt, text_encoder, text_encoder_2, tokenizer, tokenizer_2, target_device, prompt_embedding_cache): @@ -79,7 +76,7 @@ def worker( latent_type, selected_loras, has_input_image, - lora_values=None, + lora_values: list[float] = [], job_stream=None, output_dir=None, metadata_dir=None, @@ -118,8 +115,12 @@ def worker( print(f"Worker: Selected LoRAs for this worker: {selected_loras}") # Import globals from the main module - from __main__ import high_vram, args, text_encoder, text_encoder_2, tokenizer, tokenizer_2, vae, image_encoder, feature_extractor, prompt_embedding_cache, settings, stream - + from __main__ import high_vram, args, text_encoder, text_encoder_2, tokenizer, tokenizer_2, vae, image_encoder, feature_extractor, prompt_embedding_cache, stream + + studio_module = StudioManager() + settings = studio_module.settings + job_queue = studio_module.job_queue + # Ensure any existing LoRAs are unloaded from the current generator if studio_module.current_generator is not None: print("Worker: Unloading LoRAs from studio_module.current_generator") @@ -155,7 +156,6 @@ def worker( # Store initial progress data in the job object if using a job stream if job_stream is not None: try: - from __main__ import job_queue job = job_queue.get_job(job_id) if job: job.progress_data = initial_progress_data @@ -197,6 +197,7 @@ def worker( pipeline = create_pipeline(model_type, pipeline_settings) # Create job parameters dictionary + # job_params should be defined outside of the try/catch scope job_params = { 'model_type': model_type, 'input_image': input_image, @@ -251,7 +252,13 @@ def worker( # --- Model Loading / Switching --- print(f"Worker starting for model type: {model_type}") print(f"Worker: Before model assignment, studio_module.current_generator is {type(studio_module.current_generator)}, id: {id(studio_module.current_generator)}") - + + # force_new_generator will be True in any of these conditions: + # - settings.reuse_model_instance is False + # - the model type has changed + # - the LoRAs have changed (i.e., selected_loras or lora_values are different) + force_new_generator = studio_module.is_reload_required(model_name=model_type, selected_loras=selected_loras, lora_values=lora_values, lora_loaded_names=lora_loaded_names) + # Create the appropriate model generator new_generator = create_model_generator( model_type, @@ -266,21 +273,27 @@ def worker( prompt_embedding_cache=prompt_embedding_cache, offline=args.offline, settings=settings - ) - - # Update the global generator - # This modifies the 'current_generator' attribute OF THE '__main__' MODULE OBJECT - studio_module.current_generator = new_generator - print(f"Worker: AFTER model assignment, studio_module.current_generator is {type(studio_module.current_generator)}, id: {id(studio_module.current_generator)}") - if studio_module.current_generator: - print(f"Worker: studio_module.current_generator.transformer is {type(studio_module.current_generator.transformer)}") - - # Load the transformer model - studio_module.current_generator.load_model() - - # Ensure the model has no LoRAs loaded - print(f"Ensuring {model_type} model has no LoRAs loaded") - studio_module.current_generator.unload_loras() + ) if force_new_generator else None # settings is now a singleton and the Base Model is setup to get the instance if this is not provided - candidate for removal + + if new_generator is not None: + studio_module.current_generator = new_generator + # Load the transformer model + # load_model() should be called in the setter for current_generator + studio_module.current_generator.load_model() # type ignore + + # Ensure the generator is loaded + assert (studio_module.current_generator is not None), "current_generator should not be None after model assignment" + + # Ensure the transformer is loaded + assert (studio_module.current_generator.transformer is not None), "current_generator.transformer should not be None after model assignment. load_model() must be called once." + + # Update the model state with the generator and current model configuration, which is used for determining if a reload is needed in future calls + studio_module.update_model_state(selected_loras=selected_loras, lora_values=lora_values, lora_loaded_names=lora_loaded_names) + + if force_new_generator or new_generator is not None: + # Ensure the model has no LoRAs loaded + print(f"Ensuring {model_type} model has no LoRAs loaded") + studio_module.current_generator.unload_loras() # Preprocess inputs stream_to_use.output_queue.push(('progress', (None, '', make_progress_bar_html(0, 'Preprocessing inputs...')))) @@ -296,6 +309,7 @@ def worker( # Import the save_job_start_image function from metadata_utils from modules.pipelines.metadata_utils import save_job_start_image, create_metadata + # metadata_dict is not used - candidate for removal as save_job_start_image calls the same function internally # Create comprehensive metadata for the job metadata_dict = create_metadata(job_params, job_id, settings) @@ -557,7 +571,7 @@ def worker( elif model_type == "Original" or model_type == "Original with Endframe": total_generated_latent_frames = 0 - history_pixels = None + history_pixels: torch.Tensor | None = None # Get latent paddings from the generator latent_paddings = studio_module.current_generator.get_latent_paddings(total_latent_sections) @@ -567,13 +581,19 @@ def worker( # Load LoRAs if selected if selected_loras: + stream_to_use.output_queue.push(('progress', (None, '', make_progress_bar_html(0, f'Loading LoRA{"s" if len(selected_loras) > 1 else ""} ...')))) lora_folder_from_settings = settings.get("lora_dir") studio_module.current_generator.load_loras(selected_loras, lora_folder_from_settings, lora_loaded_names, lora_values) # --- Callback for progress --- def callback(d): - nonlocal last_step_time, step_durations - + nonlocal last_step_time, step_durations, history_pixels, job_params + studio_module = StudioManager() + + if studio_module.current_generator is None: + print("Worker callback: current_generator is None, cannot process progress update.") + return + # Check for cancellation signal if stream_to_use.input_queue.top() == 'end': print("Cancellation signal detected in callback") @@ -652,7 +672,7 @@ def fmt_eta(sec): # Store progress data in the job object if using a job stream if job_stream is not None: try: - from __main__ import job_queue + job_queue = studio_module.job_queue job = job_queue.get_job(job_id) if job: job.progress_data = progress_data @@ -838,6 +858,7 @@ def fmt_eta(sec): if not high_vram: # Unload VAE etc. before loading transformer unload_complete_models(vae, text_encoder, text_encoder_2, image_encoder) + stream_to_use.output_queue.push(('progress', (None, '', make_progress_bar_html(0, f'Moving model to GPU ...')))) move_model_to_device_with_memory_preservation(studio_module.current_generator.transformer, target_device=gpu, preserved_memory_gb=settings.get("gpu_memory_preservation")) if selected_loras: studio_module.current_generator.move_lora_adapters_to_device(gpu) @@ -937,6 +958,7 @@ def fmt_eta(sec): magcache = None # Handle the results + # handle_results does not seem to do anything and we do not do anything with the result result = pipeline.handle_results(job_params, output_filename) # Unload all LoRAs after generation completed @@ -1124,30 +1146,9 @@ def get_frame_count_for_combine(filename): # Renamed to avoid conflict except Exception as e: print(f"Error creating combined video ({job_id}_combined.mp4): {e}") traceback.print_exc() - # Final verification of LoRA state if studio_module.current_generator and studio_module.current_generator.transformer: - # Verify LoRA state - has_loras = False - if hasattr(studio_module.current_generator.transformer, 'peft_config'): - adapter_names = list(studio_module.current_generator.transformer.peft_config.keys()) if studio_module.current_generator.transformer.peft_config else [] - if adapter_names: - has_loras = True - print(f"Transformer has LoRAs: {', '.join(adapter_names)}") - else: - print(f"Transformer has no LoRAs in peft_config") - else: - print(f"Transformer has no peft_config attribute") - - # Check for any LoRA modules - for name, module in studio_module.current_generator.transformer.named_modules(): - if hasattr(module, 'lora_A') and module.lora_A: - has_loras = True - if hasattr(module, 'lora_B') and module.lora_B: - has_loras = True - - if not has_loras: - print(f"No LoRA components found in transformer") + studio_module.current_generator.verify_lora_state() stream_to_use.output_queue.push(('end', None)) return From 74d6a6ab92baa96e1aabfbeb339624a313ae8503 Mon Sep 17 00:00:00 2001 From: arledesma Date: Thu, 17 Jul 2025 22:41:02 -0500 Subject: [PATCH 12/17] Reduce complexity of studio with Studio Manager --- studio.py | 256 ++++++++++++++---------------------------------------- 1 file changed, 66 insertions(+), 190 deletions(-) diff --git a/studio.py b/studio.py index 14955e15..a43f8db0 100644 --- a/studio.py +++ b/studio.py @@ -1,58 +1,37 @@ -from diffusers_helper.hf_login import login - -import json +# -*- coding: utf-8 -*- +import argparse import os import shutil -from pathlib import PurePath, Path import time -import argparse -import traceback -import einops -import numpy as np -import torch -import datetime - -# Version information -from modules.version import APP_VERSION - -# Set environment variables -os.environ['HF_HOME'] = os.path.abspath(os.path.realpath(os.path.join(os.path.dirname(__file__), './hf_download'))) -os.environ['TOKENIZERS_PARALLELISM'] = 'false' # Prevent tokenizers parallelism warning +from pathlib import PurePath - - -import gradio as gr -from PIL import Image -from PIL.PngImagePlugin import PngInfo +# Site packages from diffusers import AutoencoderKLHunyuanVideo from transformers import LlamaModel, CLIPTextModel, LlamaTokenizerFast, CLIPTokenizer -from diffusers_helper.hunyuan import encode_prompt_conds, vae_decode, vae_encode, vae_decode_fake -from diffusers_helper.utils import save_bcthw_as_mp4, crop_or_pad_yield_mask, soft_append_bcthw, resize_and_center_crop, generate_timestamp -from diffusers_helper.models.hunyuan_video_packed import HunyuanVideoTransformer3DModelPacked -from diffusers_helper.pipelines.k_diffusion_hunyuan import sample_hunyuan -from diffusers_helper.memory import cpu, gpu, get_cuda_free_memory_gb, move_model_to_device_with_memory_preservation, offload_model_from_device_for_memory_preservation, fake_diffusers_current_device, DynamicSwapInstaller, unload_complete_models, load_model_as_complete -from diffusers_helper.thread_utils import AsyncStream -from diffusers_helper.gradio.progress_bar import make_progress_bar_html from transformers import SiglipImageProcessor, SiglipVisionModel -from diffusers_helper.clip_vision import hf_clip_vision_encode -from diffusers_helper.bucket_tools import find_nearest_bucket -from diffusers_helper import lora_utils -from diffusers_helper.lora_utils import load_lora, unload_all_loras +import gradio as gr +import numpy as np +import torch -# Import model generators -from modules.generators import create_model_generator +# Studio Module imports +# Import from diffusers_helper +from diffusers_helper.gradio.progress_bar import make_progress_bar_html +from diffusers_helper.memory import gpu, get_cuda_free_memory_gb, DynamicSwapInstaller +from diffusers_helper.thread_utils import AsyncStream +from diffusers_helper.utils import generate_timestamp -# Global cache for prompt embeddings -prompt_embedding_cache = {} # Import from modules -from modules.video_queue import VideoJobQueue, JobStatus -from modules.prompt_handler import parse_timestamped_prompt -from modules.ui.queue import format_queue_status -from modules.interface import create_interface -from modules.settings import Settings from modules import DUMMY_LORA_NAME # Import the constant -from modules.pipelines.metadata_utils import create_metadata +from modules.interface import create_interface from modules.pipelines.worker import worker +from modules.studio_manager import StudioManager +from modules.ui.queue import format_queue_status +from modules.video_queue import JobStatus + +# Set environment variables +if not os.getenv('HF_HOME'): + os.environ['HF_HOME'] = os.path.abspath(os.path.realpath(os.path.join(os.path.dirname(__file__), './hf_download'))) +os.environ['TOKENIZERS_PARALLELISM'] = 'false' # Prevent tokenizers parallelism warning # Try to suppress annoyingly persistent Windows asyncio proactor errors if os.name == 'nt': # Windows only @@ -77,37 +56,6 @@ def wrapper(self, *args, **kwargs): if hasattr(asyncio.proactor_events._ProactorBasePipeTransport, '_call_connection_lost'): asyncio.proactor_events._ProactorBasePipeTransport._call_connection_lost = silence_event_loop_closed( asyncio.proactor_events._ProactorBasePipeTransport._call_connection_lost) - -# ADDED: Debug function to verify LoRA state -def verify_lora_state(transformer, label=""): - """Debug function to verify the state of LoRAs in a transformer model""" - if transformer is None: - print(f"[{label}] Transformer is None, cannot verify LoRA state") - return - - has_loras = False - if hasattr(transformer, 'peft_config'): - adapter_names = list(transformer.peft_config.keys()) if transformer.peft_config else [] - if adapter_names: - has_loras = True - print(f"[{label}] Transformer has LoRAs: {', '.join(adapter_names)}") - else: - print(f"[{label}] Transformer has no LoRAs in peft_config") - else: - print(f"[{label}] Transformer has no peft_config attribute") - - # Check for any LoRA modules - for name, module in transformer.named_modules(): - if hasattr(module, 'lora_A') and module.lora_A: - has_loras = True - # print(f"[{label}] Found lora_A in module {name}") - if hasattr(module, 'lora_B') and module.lora_B: - has_loras = True - # print(f"[{label}] Found lora_B in module {name}") - - if not has_loras: - print(f"[{label}] No LoRA components found in transformer") - parser = argparse.ArgumentParser() parser.add_argument('--share', action='store_true') @@ -116,7 +64,7 @@ def verify_lora_state(transformer, label=""): parser.add_argument("--inbrowser", action='store_true') parser.add_argument("--lora", type=str, default=None, help="Lora path (comma separated for multiple)") parser.add_argument("--offline", action='store_true', help="Run in offline mode") -args = parser.parse_args() +args, unknown = parser.parse_known_args() print(args) @@ -143,9 +91,6 @@ def verify_lora_state(transformer, label=""): feature_extractor = SiglipImageProcessor.from_pretrained("lllyasviel/flux_redux_bfl", subfolder='feature_extractor') image_encoder = SiglipVisionModel.from_pretrained("lllyasviel/flux_redux_bfl", subfolder='image_encoder', torch_dtype=torch.float16).cpu() -# Initialize model generator placeholder -current_generator = None # Will hold the currently active model generator - # Load models based on VRAM availability later # Configure models @@ -173,10 +118,6 @@ def verify_lora_state(transformer, label=""): lora_dir = os.path.join(os.path.dirname(__file__), 'loras') os.makedirs(lora_dir, exist_ok=True) -# Initialize LoRA support - moved scanning after settings load -lora_names = [] -lora_values = [] # This seems unused for population, might be related to weights later - script_dir = os.path.dirname(os.path.abspath(__file__)) # Define default LoRA folder path relative to the script directory (used if setting is missing) @@ -197,8 +138,16 @@ def verify_lora_state(transformer, label=""): outputs_folder = './outputs/' os.makedirs(outputs_folder, exist_ok=True) -# Initialize settings -settings = Settings() +# Initialize the StudioManager instance - this is a singleton class, accessible globally without importing from __main__ +studio_manager = StudioManager() +settings = studio_manager.settings + +# Set the worker function for the job queue - using the imported worker from modules/pipelines/worker.py +studio_manager.job_queue.set_worker_function(worker) +job_queue = studio_manager.job_queue + +# Global cache for prompt embeddings +prompt_embedding_cache = {} # NEW: auto-cleanup on start-up option in Settings if settings.get("auto_cleanup_on_startup", False): @@ -214,105 +163,31 @@ def verify_lora_state(transformer, label=""): print("--- Startup Cleanup Complete ---") # --- Populate LoRA names AFTER settings are loaded --- -lora_folder_from_settings: str = settings.get("lora_dir", default_lora_folder) # Use setting, fallback to default -print(f"Scanning for LoRAs in: {lora_folder_from_settings}") -if os.path.isdir(lora_folder_from_settings): - try: - for root, _, files in os.walk(lora_folder_from_settings): - for file in files: - if file.endswith('.safetensors') or file.endswith('.pt'): - lora_relative_path = os.path.relpath(os.path.join(root, file), lora_folder_from_settings) - lora_name = str(PurePath(lora_relative_path).with_suffix('')) - lora_names.append(lora_name) - print(f"Found LoRAs: {lora_names}") - # Temp solution for only 1 lora - if len(lora_names) == 1: - lora_names.append(DUMMY_LORA_NAME) - except Exception as e: - print(f"Error scanning LoRA directory '{lora_folder_from_settings}': {e}") -else: - print(f"LoRA directory not found: {lora_folder_from_settings}") -# --- End LoRA population --- - - -# Create job queue -job_queue = VideoJobQueue() - - - -# Function to load a LoRA file -def load_lora_file(lora_file: str | PurePath): - if not lora_file: - return None, "No file selected" - - try: - # Get the filename from the path - lora_path = PurePath(lora_file) - lora_name = lora_path.name - - # Copy the file to the lora directory - lora_dest = PurePath(lora_dir, lora_path) - import shutil - shutil.copy(lora_file, lora_dest) - - # Load the LoRA - global current_generator, lora_names - if current_generator is None: - return None, "Error: No model loaded to apply LoRA to. Generate something first." - - # Unload any existing LoRAs first - current_generator.unload_loras() - - # Load the single LoRA - selected_loras = [lora_path.stem] - current_generator.load_loras(selected_loras, lora_dir, selected_loras) - - # Add to lora_names if not already there - lora_base_name = lora_path.stem - if lora_base_name not in lora_names: - lora_names.append(lora_base_name) - - # Get the current device of the transformer - device = next(current_generator.transformer.parameters()).device - - # Move all LoRA adapters to the same device as the base model - current_generator.move_lora_adapters_to_device(device) - - print(f"Loaded LoRA: {lora_name} to {current_generator.get_model_name()} model") - - return gr.update(choices=lora_names), f"Successfully loaded LoRA: {lora_name}" - except Exception as e: - print(f"Error loading LoRA: {e}") - return None, f"Error loading LoRA: {e}" - -@torch.no_grad() -def get_cached_or_encode_prompt(prompt, text_encoder, text_encoder_2, tokenizer, tokenizer_2, target_device): - """ - Retrieves prompt embeddings from cache or encodes them if not found. - Stores encoded embeddings (on CPU) in the cache. - Returns embeddings moved to the target_device. - """ - if prompt in prompt_embedding_cache: - print(f"Cache hit for prompt: {prompt[:60]}...") - llama_vec_cpu, llama_mask_cpu, clip_l_pooler_cpu = prompt_embedding_cache[prompt] - # Move cached embeddings (from CPU) to the target device - llama_vec = llama_vec_cpu.to(target_device) - llama_attention_mask = llama_mask_cpu.to(target_device) if llama_mask_cpu is not None else None - clip_l_pooler = clip_l_pooler_cpu.to(target_device) - return llama_vec, llama_attention_mask, clip_l_pooler +def enumerate_lora_dir() -> list[str]: + lora_folder_from_settings: str = settings.get("lora_dir", default_lora_folder) # Use setting, fallback to default + print(f"Scanning for LoRAs in: {lora_folder_from_settings}") + found_files: list[str] = [] + if os.path.isdir(lora_folder_from_settings): + try: + for root, _, files in os.walk(lora_folder_from_settings): + for file in files: + if file.endswith('.safetensors') or file.endswith('.pt'): + lora_relative_path = os.path.relpath(os.path.join(root, file), lora_folder_from_settings) + lora_name = str(PurePath(lora_relative_path).with_suffix('')) + found_files.append(lora_name) + print(f"Found LoRAs: {len(found_files)}") + # Temp solution for only 1 lora + if len(found_files) == 1: + found_files.append(DUMMY_LORA_NAME) + except Exception as e: + print(f"Error scanning LoRA directory '{lora_folder_from_settings}': {e}") else: - print(f"Cache miss for prompt: {prompt[:60]}...") - llama_vec, clip_l_pooler = encode_prompt_conds( - prompt, text_encoder, text_encoder_2, tokenizer, tokenizer_2 - ) - llama_vec, llama_attention_mask = crop_or_pad_yield_mask(llama_vec, length=512) - # Store CPU copies in cache - prompt_embedding_cache[prompt] = (llama_vec.cpu(), llama_attention_mask.cpu() if llama_attention_mask is not None else None, clip_l_pooler.cpu()) - # Return embeddings already on the target device (as encode_prompt_conds uses the model's device) - return llama_vec, llama_attention_mask, clip_l_pooler + print(f"LoRA directory not found: {lora_folder_from_settings}") + # --- End LoRA population --- + return found_files -# Set the worker function for the job queue - using the imported worker from modules/pipelines/worker.py -job_queue.set_worker_function(worker) + +lora_names = enumerate_lora_dir() def process( @@ -690,17 +565,18 @@ def get_preview_updates(preview_value): monitor_fn=monitor_job, end_process_fn=end_process, update_queue_status_fn=update_queue_status, - load_lora_file_fn=load_lora_file, + load_lora_file_fn=None, job_queue=job_queue, settings=settings, lora_names=lora_names # Explicitly pass the found LoRA names ) -# Launch the interface -interface.launch( - server_name=args.server, - server_port=args.port, - share=args.share, - inbrowser=args.inbrowser, - allowed_paths=[settings.get("output_dir"), settings.get("metadata_dir")], -) +if __name__ == "__main__": + # Launch the interface + interface.launch( + server_name=args.server, + server_port=args.port, + share=args.share, + inbrowser=args.inbrowser, + allowed_paths=[settings.get("output_dir"), settings.get("metadata_dir")], + ) From fe122d5c48ce52d9029cf23594da237ae4bbc69d Mon Sep 17 00:00:00 2001 From: arledesma Date: Thu, 17 Jul 2025 22:42:16 -0500 Subject: [PATCH 13/17] Add StudioManager --- modules/studio_manager.py | 222 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 222 insertions(+) create mode 100644 modules/studio_manager.py diff --git a/modules/studio_manager.py b/modules/studio_manager.py new file mode 100644 index 00000000..24220203 --- /dev/null +++ b/modules/studio_manager.py @@ -0,0 +1,222 @@ +import logging +from typing import ( + Optional, + Union, +) + +from diffusers_helper.lora_utils_kohya_ss.enums import LoraLoader +from modules.generators.model_configuration import ModelConfiguration +from .generators import BaseModelGenerator, VideoBaseModelGenerator +from .settings import Settings +from .video_queue import VideoJobQueue + +# cSpell: ignore loras + +class ModelState: + """ + Class to track the state of model configurations. + This class keeps track of the previous model configuration and its hash. + """ + + previous_model_configuration: Optional[ModelConfiguration] + active_model_configuration: Optional[ModelConfiguration] + __logger: logging.Logger = logging.getLogger(__name__) + + def __init__(self): + self.previous_model_configuration: Optional[ModelConfiguration] = None + self.active_model_configuration: Optional[ModelConfiguration] = None + + @property + def active_model_hash(self) -> str: + """ + Returns the hash of the active model configuration. + If no active model configuration is set, returns an empty string. + """ + return self.active_model_configuration._hash if self.active_model_configuration else "active_model_hash" + + @property + def previous_model_hash(self) -> str: + """ + Returns the hash of the previous model configuration. + If no previous model configuration is set, returns an empty string. + """ + return self.previous_model_configuration._hash if self.previous_model_configuration else "previous_model_hash" + + def is_reload_required( + self, + model_name: str, + selected_loras: list[str], + lora_values: list[float], + lora_loaded_names: list[str], + lora_loader: str | LoraLoader, + ) -> bool: + """ + Check if a reload is required based on the current model state. + This method checks if the current model configuration is different from the previous one. + Args: + model_name: The name of the model to check. + selected_loras: List of selected LoRA names. + lora_values: List of LoRA values corresponding to the selected LoRAs. + lora_loaded_names: List of names of LoRAs that are currently loaded. + Returns: + True if a reload is required, otherwise False. + """ + selected_lora_values = [ + lora_values[lora_loaded_names.index(name)] + for name in selected_loras + if name in lora_loaded_names + ] + + active_model_configuration: ModelConfiguration = ( + ModelConfiguration.from_lora_names_and_weights( + model_name=model_name, + lora_names=selected_loras, + lora_weights=selected_lora_values, + lora_loader=lora_loader, + ) + ) + + return active_model_configuration._hash != self.active_model_hash + + def update_model_state( + self, + current_generator: Optional[BaseModelGenerator], + selected_loras: list[str], + lora_values: list[float], + lora_loaded_names: list[str], + ) -> None: + """Update the model state with the current configuration. + This method checks if the current model configuration is different from the previous one. + If it is, it updates the model state and returns True. + If the configuration is unchanged, it returns False. + """ + + assert current_generator is not None, "current_generator must be set when updating model state" + self.previous_model_configuration = self.active_model_configuration + + if not self.is_reload_required( + current_generator.model_name, + selected_loras, + lora_values, + lora_loaded_names, + current_generator.settings.lora_loader, + ): + self.__logger.debug("Model configuration unchanged, skipping reload.") + return + + selected_lora_values = [ + lora_values[lora_loaded_names.index(name)] + for name in selected_loras + if name in lora_loaded_names + ] + active_model_configuration: ModelConfiguration = ( + ModelConfiguration.from_lora_names_and_weights( + model_name=current_generator.model_name, + lora_names=selected_loras, + lora_weights=selected_lora_values, + lora_loader=current_generator.settings.lora_loader, + ) + ) + + self.active_model_configuration = active_model_configuration + + +class StudioManager: + """ + Singleton class to manage the current model instance and its state. + """ + + _instance: Optional["StudioManager"] = None + __current_generator: Optional[ + Union[BaseModelGenerator, VideoBaseModelGenerator] + ] = None + job_queue: VideoJobQueue = VideoJobQueue() + settings: Settings = Settings() + model_state: ModelState = ModelState() + __logger: logging.Logger = logging.getLogger(__name__) + + def __new__(cls): + if cls._instance is None: + cls.__logger.debug("Creating the StudioManager instance") + cls._instance = super(StudioManager, cls).__new__(cls) + return cls._instance + + @property + def current_generator( + self, + ) -> Optional[Union[BaseModelGenerator, VideoBaseModelGenerator]]: + """ + Property to get the current model generator instance. + Returns None if no generator is set. + """ + return self.__current_generator + + @current_generator.setter + def current_generator( + self, generator: BaseModelGenerator + ) -> None: + """ + Property to set the current model generator instance. + Raises TypeError if the generator is not an instance of BaseModelGenerator or VideoBaseModelGenerator. + """ + assert isinstance(generator, BaseModelGenerator), "Expected generator to be an instance of BaseModelGenerator" + + self.__current_generator = generator + + def unset_current_generator(self) -> None: + """ + Delete the current model generator instance. + This will set the current generator to None. + """ + self.__current_generator = None # Reset the current generator + self.model_state = ModelState() # Reset the model state + + def is_reload_required( + self, + model_name: str, + selected_loras: list[str], + lora_values: list[float], + lora_loaded_names: list[str], + ) -> bool: + """ + Check if a reload is required based on the current model state. + This method checks if the current model generator is None or if the settings require a reload. + It also checks if the model state has changed based on the provided parameters. + Currently it does not check against the base model, so it will always return True if the model type has changed at all. + + Args: + model_name: The name of the model to check. + selected_loras: List of selected LoRA names. + lora_values: List of LoRA values corresponding to the selected LoRAs. + lora_loaded_names: List of names of LoRAs that are currently loaded. + Returns: + True if a reload is required, otherwise False. + """ + if self.current_generator is None: + self.__logger.debug("No current generator set, reload is required.") + return True + if not self.settings.reuse_model_instance: + self.__logger.debug("Model instance reuse is disabled, reload is required.") + return True + + return self.model_state.is_reload_required( + model_name=model_name, + selected_loras=selected_loras, + lora_values=lora_values, + lora_loaded_names=lora_loaded_names, + lora_loader=self.settings.lora_loader, + ) + + def update_model_state( + self, + selected_loras: list[str], + lora_values: list[float], + lora_loaded_names: list[str], + ) -> None: + assert self.current_generator is not None, "current_generator must be set when updating model state" + self.model_state.update_model_state( + current_generator=self.__current_generator, + selected_loras=selected_loras, + lora_values=lora_values, + lora_loaded_names=lora_loaded_names, + ) From 2041ed389347f6723e752bde2a911feeed437987 Mon Sep 17 00:00:00 2001 From: arledesma Date: Thu, 17 Jul 2025 22:47:55 -0500 Subject: [PATCH 14/17] Replace dynamic studio_module attributes with StudioManager in toolbox_app --- modules/toolbox_app.py | 76 +++++++++++++++++++++++++++++++++--------- 1 file changed, 60 insertions(+), 16 deletions(-) diff --git a/modules/toolbox_app.py b/modules/toolbox_app.py index 2977470c..b86ceb72 100644 --- a/modules/toolbox_app.py +++ b/modules/toolbox_app.py @@ -46,7 +46,6 @@ def wrapper(self, *args, **kwargs): asyncio.proactor_events._ProactorBasePipeTransport._call_connection_lost) # --- Third-Party Library Imports --- -import devicetorch import gradio as gr import imageio # Added for reading frame dimensions import torch @@ -54,7 +53,7 @@ def wrapper(self, *args, **kwargs): # --- Patch for basicsr (must run after torchvision import) --- functional_tensor_mod = types.ModuleType('functional_tensor') -functional_tensor_mod.rgb_to_grayscale = rgb_to_grayscale +setattr(functional_tensor_mod, 'rgb_to_grayscale', rgb_to_grayscale) sys.modules.setdefault('torchvision.transforms.functional_tensor', functional_tensor_mod) # --- Local Application Imports --- @@ -65,6 +64,9 @@ def wrapper(self, *args, **kwargs): from modules.toolbox.setup_ffmpeg import setup_ffmpeg from modules.toolbox.system_monitor import SystemMonitor from modules.toolbox.toolbox_processor import VideoProcessor +from modules.generators.base_generator import BaseModelGenerator +from modules.studio_manager import StudioManager +from modules.video_queue import JobStatus # Attempt to import helper, with a fallback if it's missing. try: @@ -86,7 +88,7 @@ def wrapper(self, *args, **kwargs): print(f"Bundled FFmpeg not found in '{bin_dir}'. Running one-time setup...") setup_ffmpeg() - +studio_manager = StudioManager() # Get the StudioManager instance tb_message_mgr = MessageManager() settings_instance = Settings() tb_processor = VideoProcessor(tb_message_mgr, settings_instance) # Pass settings to VideoProcessor @@ -857,37 +859,50 @@ def tb_handle_delete_studio_transformer(): log_messages_from_action = [] studio_module_instance = None - if '__main__' in sys.modules and hasattr(sys.modules['__main__'], 'current_generator'): + if '__main__' in sys.modules: + # __main__ will always exist studio_module_instance = sys.modules['__main__'] print("Found studio context in __main__.") - elif 'studio' in sys.modules and hasattr(sys.modules['studio'], 'current_generator'): + elif 'studio' in sys.modules: studio_module_instance = sys.modules['studio'] print("Found studio context in sys.modules['studio'].") + # Ensure any existing LoRAs are unloaded from the current generator + if studio_manager.current_generator is None or studio_manager.current_generator.transformer is None: + print("ERROR: Current transformer is None.") + tb_message_mgr.add_message("ERROR: Current transformer is None. Nothing to delete.") + tb_message_mgr.add_warning("Deletion Failed: Current transformer is None.") + return tb_update_messages() + if studio_module_instance is None: print("ERROR: Could not find the 'studio' module's active context.") tb_message_mgr.add_message("ERROR: Could not find the 'studio' module's active context in sys.modules.") tb_message_mgr.add_error("Deletion Failed: Studio module context not found.") return tb_update_messages() - job_queue_instance = getattr(studio_module_instance, 'job_queue', None) - JobStatus_enum = getattr(studio_module_instance, 'JobStatus', None) + job_queue_instance = studio_manager.job_queue or getattr(studio_module_instance, 'job_queue', None) - if job_queue_instance and JobStatus_enum: - current_job_in_queue = getattr(job_queue_instance, 'current_job', None) - if current_job_in_queue and hasattr(current_job_in_queue, 'status') and current_job_in_queue.status == JobStatus_enum.RUNNING: + if job_queue_instance: + current_job_in_queue = job_queue_instance.current_job or getattr(job_queue_instance, 'current_job', None) + if current_job_in_queue and hasattr(current_job_in_queue, 'status') and current_job_in_queue.status == JobStatus.RUNNING: tb_message_mgr.add_warning("Cannot unload model: A video generation job is currently running.") tb_message_mgr.add_message("Please wait for the current job to complete or cancel it first using the main interface.") print("Cannot unload model: A job is currently running in the queue.") return tb_update_messages() - generator_object_to_delete = getattr(studio_module_instance, 'current_generator', None) + # should be able to remove the getattr() call here, since we are directly accessing the studio_manager + generator_object_to_delete = studio_manager.current_generator or getattr(studio_module_instance, 'current_generator', None) print(f"Direct access: generator_object_to_delete is {type(generator_object_to_delete)}, id: {id(generator_object_to_delete)}") if generator_object_to_delete is not None: model_name_str = "Unknown Model" + + try: - if hasattr(generator_object_to_delete, 'get_model_name') and callable(generator_object_to_delete.get_model_name): + if isinstance(generator_object_to_delete, BaseModelGenerator): + model_name_str = generator_object_to_delete.get_model_name() + # No need for attribute acrobatics here + elif hasattr(generator_object_to_delete, 'get_model_name') and callable(generator_object_to_delete.get_model_name): model_name_str = generator_object_to_delete.get_model_name() elif hasattr(generator_object_to_delete, 'transformer') and generator_object_to_delete.transformer is not None: model_name_str = generator_object_to_delete.transformer.__class__.__name__ @@ -901,12 +916,36 @@ def tb_handle_delete_studio_transformer(): print(f"Found active generator: {model_name_str}. Preparing for deletion.") try: - if hasattr(generator_object_to_delete, 'unload_loras') and callable(generator_object_to_delete.unload_loras): + if isinstance(generator_object_to_delete, BaseModelGenerator): + print(" - LoRAs: Unloading from transformer...") + generator_object_to_delete.unload_loras() + # No need for attribute acrobatics here + elif hasattr(generator_object_to_delete, 'unload_loras') and callable(generator_object_to_delete.unload_loras): print(" - LoRAs: Unloading from transformer...") generator_object_to_delete.unload_loras() else: log_messages_from_action.append(" - LoRAs: No unload method found or not applicable.") + if isinstance(generator_object_to_delete, BaseModelGenerator) and generator_object_to_delete.transformer is not None: + transformer_object_ref = generator_object_to_delete.transformer + transformer_name_for_log = transformer_object_ref.__class__.__name__ + print(f" - Transformer ({transformer_name_for_log}): Preparing for memory operations.") + if transformer_object_ref.device != cpu: + print(f" - Transformer ({transformer_name_for_log}): Moving to CPU...") + transformer_object_ref.to(cpu) # type: ignore + log_messages_from_action.append(" - Transformer moved to CPU.") + print(f" - Transformer ({transformer_name_for_log}): Moved to CPU.") + else: + log_messages_from_action.append(" - Transformer already on CPU.") + print(f" - Transformer ({transformer_name_for_log}): Already on CPU.") + + print(f" - Transformer ({transformer_name_for_log}): Removing attribute from generator...") + generator_object_to_delete.transformer = None # type: ignore + del transformer_object_ref + log_messages_from_action.append(" - Transformer reference deleted.") + print(f" - Transformer ({transformer_name_for_log}): Reference deleted.") + + # No need for attribute acrobatics here if hasattr(generator_object_to_delete, 'transformer') and generator_object_to_delete.transformer is not None: transformer_object_ref = generator_object_to_delete.transformer transformer_name_for_log = transformer_object_ref.__class__.__name__ @@ -945,9 +984,14 @@ def tb_handle_delete_studio_transformer(): generator_class_name_for_log = generator_object_to_delete.__class__.__name__ print(f" - Model Generator ({generator_class_name_for_log}): Setting global reference to None...") - setattr(studio_module_instance, 'current_generator', None) - log_messages_from_action.append(" - 'current_generator' in studio module set to None.") - print(" - Global 'current_generator' in studio module successfully set to None.") + if studio_manager.current_generator is not None: + studio_manager.unset_current_generator() + log_messages_from_action.append(" - 'current_generator' in studio manager set to None.") + print(" - Global 'current_generator' in studio manager successfully set to None.") + else: + setattr(studio_module_instance, 'current_generator', None) + log_messages_from_action.append(" - 'current_generator' in studio module set to None.") + print(" - Global 'current_generator' in studio module successfully set to None.") print(f" - Model Generator ({generator_class_name_for_log}): Deleting local Python reference...") del generator_object_to_delete From 6c0d41f04173e133d246b370c4c56d0c0a44bec8 Mon Sep 17 00:00:00 2001 From: arledesma Date: Fri, 18 Jul 2025 18:25:26 -0500 Subject: [PATCH 15/17] Revert to overwriting HF_HOME environment This was leading to the model being downloaded again for users that did have HF_HOME set to a value. We will need to document the migration path for existing users to avoid redownloading the entire models again. --- studio.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/studio.py b/studio.py index a43f8db0..f9775797 100644 --- a/studio.py +++ b/studio.py @@ -29,8 +29,11 @@ from modules.video_queue import JobStatus # Set environment variables -if not os.getenv('HF_HOME'): - os.environ['HF_HOME'] = os.path.abspath(os.path.realpath(os.path.join(os.path.dirname(__file__), './hf_download'))) +STUDIO_HF_HOME = os.path.abspath(os.path.realpath(os.path.join(os.path.dirname(__file__), './hf_download'))) +# maybe only set HF_HOME if the directory exists, providing an opt-in migration path for users +# make sure to document this behavior if the HF_HOME changes in the future +# Set the HF_HOME to the studio's hf_download directory +os.environ['HF_HOME'] = STUDIO_HF_HOME os.environ['TOKENIZERS_PARALLELISM'] = 'false' # Prevent tokenizers parallelism warning # Try to suppress annoyingly persistent Windows asyncio proactor errors From f8601c28d766c1440837d66df1e557e014f940d9 Mon Sep 17 00:00:00 2001 From: arledesma Date: Fri, 18 Jul 2025 21:15:02 -0500 Subject: [PATCH 16/17] Handle LoRA loading from queue import When loaded from a queue import the selected_loras and lora_values are equal in length. When loaded from the Generate interface the lora_values are the same length as the lora_loaded_names and must be reduced. This change now makes the assumption that when the two lists are the same length that they are in the correct order. --- modules/generators/base_generator.py | 4 +++- modules/studio_manager.py | 8 ++++++-- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/modules/generators/base_generator.py b/modules/generators/base_generator.py index 73153b71..b3e11d51 100644 --- a/modules/generators/base_generator.py +++ b/modules/generators/base_generator.py @@ -318,7 +318,9 @@ def load_loras(self, selected_loras: List[str], lora_folder: str, lora_loaded_na if lora_values is None: lora_values = [] - selected_lora_values = [lora_values[lora_loaded_names.index(name)] for name in selected_loras if name in lora_loaded_names] + selected_lora_values = lora_values if len(selected_loras) == len(lora_values) else [ + lora_values[lora_loaded_names.index(name)] for name in selected_loras if name in lora_loaded_names + ] print(f"Loading LoRAs: {selected_loras} with values: {selected_lora_values}") active_model_configuration = ModelConfiguration.from_lora_names_and_weights( diff --git a/modules/studio_manager.py b/modules/studio_manager.py index 24220203..1bac11aa 100644 --- a/modules/studio_manager.py +++ b/modules/studio_manager.py @@ -61,7 +61,9 @@ def is_reload_required( Returns: True if a reload is required, otherwise False. """ - selected_lora_values = [ + # queue load only sends the exact selected_loras and lora_values while other functions may send lora_values for all lora_loaded_names + # remove the condition if we update to always send the matching selected_loras and lora_values + selected_lora_values = lora_values if len(selected_loras) == len(lora_values) else [ lora_values[lora_loaded_names.index(name)] for name in selected_loras if name in lora_loaded_names @@ -104,7 +106,9 @@ def update_model_state( self.__logger.debug("Model configuration unchanged, skipping reload.") return - selected_lora_values = [ + # queue load only sends the exact selected_loras and lora_values while other functions may send lora_values for all lora_loaded_names + # remove the condition if we update to always send the matching selected_loras and lora_values + selected_lora_values = lora_values if len(selected_loras) == len(lora_values) else [ lora_values[lora_loaded_names.index(name)] for name in selected_loras if name in lora_loaded_names From a8cba730d71e1712f15ae6d886708ffbd9b57e34 Mon Sep 17 00:00:00 2001 From: arledesma Date: Fri, 18 Jul 2025 21:20:45 -0500 Subject: [PATCH 17/17] Set HF_HOME environment before importing from diffusers diffusers uses this environment variable to automatically downloads files on import. weird side effect to do that amount of actual work on import. --- studio.py | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/studio.py b/studio.py index f9775797..1765edec 100644 --- a/studio.py +++ b/studio.py @@ -5,6 +5,16 @@ import time from pathlib import PurePath + +# Set environment variables +STUDIO_HF_HOME = os.path.abspath(os.path.realpath(os.path.join(os.path.dirname(__file__), './hf_download'))) +# maybe only set HF_HOME if the directory exists, providing an opt-in migration path for users +# make sure to document this behavior if the HF_HOME changes in the future +# Set the HF_HOME to the studio's hf_download directory +# HF_HOME Must be set to its expected value prior to importing diffusers and transformers +os.environ['HF_HOME'] = STUDIO_HF_HOME +os.environ['TOKENIZERS_PARALLELISM'] = 'false' # Prevent tokenizers parallelism warning + # Site packages from diffusers import AutoencoderKLHunyuanVideo from transformers import LlamaModel, CLIPTextModel, LlamaTokenizerFast, CLIPTokenizer @@ -28,13 +38,7 @@ from modules.ui.queue import format_queue_status from modules.video_queue import JobStatus -# Set environment variables -STUDIO_HF_HOME = os.path.abspath(os.path.realpath(os.path.join(os.path.dirname(__file__), './hf_download'))) -# maybe only set HF_HOME if the directory exists, providing an opt-in migration path for users -# make sure to document this behavior if the HF_HOME changes in the future -# Set the HF_HOME to the studio's hf_download directory -os.environ['HF_HOME'] = STUDIO_HF_HOME -os.environ['TOKENIZERS_PARALLELISM'] = 'false' # Prevent tokenizers parallelism warning + # Try to suppress annoyingly persistent Windows asyncio proactor errors if os.name == 'nt': # Windows only