diff --git a/onediff_comfy_nodes/__init__.py b/onediff_comfy_nodes/__init__.py index d7697500a..2db12b7fa 100644 --- a/onediff_comfy_nodes/__init__.py +++ b/onediff_comfy_nodes/__init__.py @@ -12,6 +12,7 @@ ControlNetGraphSaver, SVDSpeedup, ModuleDeepCacheSpeedup, + OneDiffCheckpointLoaderSimple, ) from ._compare_node import CompareModel, ShowImageDiff @@ -30,6 +31,7 @@ "ControlNetGraphSaver": ControlNetGraphSaver, "SVDSpeedup": SVDSpeedup, "ModuleDeepCacheSpeedup": ModuleDeepCacheSpeedup, + "OneDiffCheckpointLoaderSimple": OneDiffCheckpointLoaderSimple, } NODE_DISPLAY_NAME_MAPPINGS = { @@ -46,6 +48,7 @@ "ControlNetGraphSaver": "ControlNet Graph Saver", "SVDSpeedup": "SVD Speedup", "ModuleDeepCacheSpeedup": "Model DeepCache Speedup", + "OneDiffCheckpointLoaderSimple": "Load Checkpoint - OneDiff", } if _USE_UNET_INT8: diff --git a/onediff_comfy_nodes/_nodes.py b/onediff_comfy_nodes/_nodes.py index 67bca5ce7..a1113f066 100644 --- a/onediff_comfy_nodes/_nodes.py +++ b/onediff_comfy_nodes/_nodes.py @@ -14,6 +14,7 @@ import folder_paths from comfy import model_management from comfy.cli_args import args +from folder_paths import get_input_directory from .utils import ( OneFlowSpeedUpModelPatcher, @@ -577,3 +578,57 @@ def apply_model(model_function, kwargs): oneflow_model.set_model_unet_function_wrapper(apply_model) return (oneflow_model,) + + +def generate_graph_path(ckpt_name, model): + input_dir = get_input_directory() + input_dir = Path(input_dir) + graph_dir = input_dir / "graphs" / ckpt_name + graph_file_path = graph_dir / (type(model).__name__ + ".graph") + return graph_file_path + + +from nodes import CheckpointLoaderSimple + + +class OneDiffCheckpointLoaderSimple(CheckpointLoaderSimple): + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "ckpt_name": (folder_paths.get_filename_list("checkpoints"),), + "vae_speedup": (["disable", "enable"],), + } + } + + + CATEGORY = "OneDiff" + + def load_checkpoint( + self, ckpt_name, output_vae=True, output_clip=True, vae_speedup="disable" + ): + model, clip, vae = super().load_checkpoint(ckpt_name, output_vae, output_clip) + offload_device = model_management.unet_offload_device() + + diffusion_model = model.model.diffusion_model + file_path = generate_graph_path(ckpt_name, diffusion_model) + print(f" OneDiffCheckpointLoaderSimple load_checkpoint file_path {file_path}") + + oneflow_model = OneFlowSpeedUpModelPatcher( + model.model, + load_device=model_management.get_torch_device(), + offload_device=offload_device, + use_graph=True, + graph_path=file_path, + graph_device=model_management.get_torch_device(), + ) + + if vae_speedup == "enable": + file_path = generate_graph_path(ckpt_name, vae.first_stage_model) + vae.first_stage_model = oneflow_compile( + vae.first_stage_model, + use_graph=True, + graph_path=file_path, + graph_device=model_management.get_torch_device(), + ) + return oneflow_model, clip, vae diff --git a/onediff_comfy_nodes/utils/model_patcher.py b/onediff_comfy_nodes/utils/model_patcher.py index 8575272e5..4c64be3c8 100644 --- a/onediff_comfy_nodes/utils/model_patcher.py +++ b/onediff_comfy_nodes/utils/model_patcher.py @@ -30,6 +30,8 @@ def __init__( weight_inplace_update=False, *, use_graph=None, + graph_path=None, + graph_device=None, ): from onediff.infer_compiler import oneflow_compile from onediff.infer_compiler.with_oneflow_compile import DeployableModule @@ -46,7 +48,10 @@ def __init__( ] = self.model.diffusion_model else: self.model.__dict__["_modules"]["diffusion_model"] = oneflow_compile( - self.model.diffusion_model, use_graph=use_graph + self.model.diffusion_model, + use_graph=use_graph, + graph_path=graph_path, + graph_device=graph_device, ) self.model._register_state_dict_hook(state_dict_hook) self.patches = {} @@ -495,7 +500,6 @@ def __init__( ): from onediff.infer_compiler import oneflow_compile from onediff.infer_compiler.with_oneflow_compile import DeployableModule - self.weight_inplace_update = weight_inplace_update self.object_patches = {} @@ -504,10 +508,14 @@ def __init__( self.model = copy.copy(model) self.model.__dict__["_modules"] = copy.copy(model.__dict__["_modules"]) self.deep_cache_unet = oneflow_compile( - DeepCacheUNet(self.model.diffusion_model, cache_layer_id, cache_block_id), use_graph=use_graph + DeepCacheUNet(self.model.diffusion_model, cache_layer_id, cache_block_id), + use_graph=use_graph, ) - self.fast_deep_cache_unet =oneflow_compile( - FastDeepCacheUNet(self.model.diffusion_model, cache_layer_id, cache_block_id), use_graph=use_graph + self.fast_deep_cache_unet = oneflow_compile( + FastDeepCacheUNet( + self.model.diffusion_model, cache_layer_id, cache_block_id + ), + use_graph=use_graph, ) self.model._register_state_dict_hook(state_dict_hook) self.patches = {} diff --git a/src/onediff/infer_compiler/with_oneflow_compile.py b/src/onediff/infer_compiler/with_oneflow_compile.py index 859cdc40f..40ee1771d 100644 --- a/src/onediff/infer_compiler/with_oneflow_compile.py +++ b/src/onediff/infer_compiler/with_oneflow_compile.py @@ -56,10 +56,12 @@ def _align_tensor(torch_module, oneflow_module): [x for x, _ in oneflow_module.named_parameters()] + [x for x, _ in oneflow_module.named_buffers()] ) - for name, tensor in chain.from_iterable([ - torch_module.named_parameters(), - torch_module.named_buffers(), - ]): + for name, tensor in chain.from_iterable( + [ + torch_module.named_parameters(), + torch_module.named_buffers(), + ] + ): if name not in oneflow_tensor_list: tensor.data = tensor.to(*args, **kwargs) else: @@ -76,7 +78,6 @@ def _align_tensor(torch_module, oneflow_module): else: _align_tensor(module, self._oneflow_module.get_submodule(name)) - def __getattr__(self, name): if name == "_torch_module": return self._modules[name] @@ -124,7 +125,11 @@ def __init__(self, torch_modules, oneflow_modules): for torch_module, oneflow_module in zip( self._torch_modules, self._oneflow_modules ): - dual_modules.append(get_mixed_dual_module(torch_module.__class__)(torch_module, oneflow_module)) + dual_modules.append( + get_mixed_dual_module(torch_module.__class__)( + torch_module, oneflow_module + ) + ) # clear self._modules since `self._torch_modules = torch_modules` will append a module to self._modules self._modules.clear() self += dual_modules @@ -147,6 +152,7 @@ def __setattr__(self, key, value): setattr(self._oneflow_modules, key, value) return object.__setattr__(self, key, value) + def get_mixed_dual_module(module_cls): class MixedDualModule(DualModule, module_cls): def __init__(self, torch_module, oneflow_module): @@ -155,32 +161,81 @@ def __init__(self, torch_module, oneflow_module): return MixedDualModule +def load_graph_from_config(self: "DeployableModule"): + try: + if self._graph_config is not None: + graph_path = self._graph_config[0] + if not os.path.exists(graph_path): + logger.warning( + f"Graph file {graph_path} not exists, will generate graph." + ) + return + graph_device = torch2oflow(self._graph_config[1]) + self.load_graph(graph_path, graph_device) + self._graph_config = None + except Exception as e: + logger.error(f"Exception in load_graph_from_config: {e=}") + + +def save_graph_to_config(self: "DeployableModule"): + try: + if self._graph_config is not None: + graph_file = self._graph_config[0] + os.makedirs(os.path.dirname(graph_file), exist_ok=True) + self.save_graph(self._graph_config[0]) + logger.info(f"Save graph to {self._graph_config[0]} done!") + except Exception as e: + logger.error(f"Exception in save_graph_to_config: {e=}") + finally: + self._graph_config = None + + def handle_deployable_exception(func): @wraps(func) def wrapper(self, *args, **kwargs): + def _run_func(): + load_graph_from_config(self) + result = func(self, *args, **kwargs) + save_graph_to_config(self) + return result + if transform_mgr.debug_mode: - return func(self, *args, **kwargs) + return _run_func() else: try: - return func(self, *args, **kwargs) + return _run_func() except Exception as e: logger.error(f"Exception in {func.__name__}: {e=}") logger.warning("Recompile oneflow module ...") del self._deployable_module_model.oneflow_module self._deployable_module_dpl_graph = None - return func(self, *args, **kwargs) + return _run_func() return wrapper class DeployableModule(torch.nn.Module): - def __init__(self, torch_module, oneflow_module, use_graph=True, options={}): + def __init__( + self, + torch_module, + oneflow_module, + use_graph=True, + options={}, + graph_path=None, + graph_device=None, + ): torch.nn.Module.__init__(self) - self._deployable_module_model = get_mixed_dual_module(torch_module.__class__)(torch_module, oneflow_module) + self._deployable_module_model = get_mixed_dual_module(torch_module.__class__)( + torch_module, oneflow_module + ) self._deployable_module_use_graph = use_graph self._deployable_module_options = options self._deployable_module_dpl_graph = None self._is_raw_deployable_module = True + if graph_path is not None: + self._graph_config = (graph_path, graph_device) + else: + self._graph_config = None @classmethod def from_existing(cls, existing_module, use_graph=None, options=None): @@ -190,6 +245,7 @@ def from_existing(cls, existing_module, use_graph=None, options=None): instance._deployable_module_dpl_graph = ( existing_module._deployable_module_dpl_graph if use_graph else None ) + instance._graph_config = existing_module._graph_config return instance def get_graph(self): @@ -245,7 +301,10 @@ def to(self, *args, **kwargs): # assert the target device is same as graph device target_device = parse_device(args, kwargs) - if target_device is not None and len(self._deployable_module_dpl_graph._blocks) > 0: + if ( + target_device is not None + and len(self._deployable_module_dpl_graph._blocks) > 0 + ): current_device = next(self._deployable_module_dpl_graph._state()).device if not check_device(current_device, target_device): raise RuntimeError( @@ -357,11 +416,25 @@ def state_dict_hook(module, state_dict, prefix, local_metadata): # Return a DeployableModule that using module_cls as it's parent class. -def get_mixed_deployable_module(module_cls): +def get_mixed_deployable_module(module_cls, graph_path=None, graph_device=None): class MixedDeployableModule(DeployableModule, module_cls): - def __init__(self, torch_module, oneflow_module, use_graph=True, options={}): + def __init__( + self, + torch_module, + oneflow_module, + use_graph=True, + options={}, + graph_path=None, + graph_device=None, + ): DeployableModule.__init__( - self, torch_module, oneflow_module, use_graph, options + self, + torch_module, + oneflow_module, + use_graph, + options, + graph_path, + graph_device, ) self._is_raw_deployable_module = False @@ -378,7 +451,14 @@ def from_existing(cls, existing_module, use_graph=None, options=None): return MixedDeployableModule -def oneflow_compile(torch_module: torch.nn.Module, *, use_graph=True, options={}): +def oneflow_compile( + torch_module: torch.nn.Module, + *, + use_graph=True, + options={}, + graph_path=None, + graph_device=None, +): set_default_registry() def wrap_module(module): @@ -387,7 +467,7 @@ def wrap_module(module): return module.__class__.from_existing(module, use_graph, options) else: return get_mixed_deployable_module(module.__class__)( - module, None, use_graph, options + module, None, use_graph, options, graph_path, graph_device ) model = wrap_module(torch_module)