From 81a4c801751d839491422fd3240a3443d32f980e Mon Sep 17 00:00:00 2001 From: Rodrigo Antonio de Araujo Date: Sat, 12 Oct 2024 23:04:38 -0300 Subject: [PATCH 1/2] Add support to sequential cpu offload 8GB VRAM maybe --- .../pyramid_dit_for_video_gen_pipeline.py | 53 +++++++++++++------ 1 file changed, 37 insertions(+), 16 deletions(-) diff --git a/pyramid_dit/pyramid_dit_for_video_gen_pipeline.py b/pyramid_dit/pyramid_dit_for_video_gen_pipeline.py index 5ca79bd..fa7ca9d 100644 --- a/pyramid_dit/pyramid_dit_for_video_gen_pipeline.py +++ b/pyramid_dit/pyramid_dit_for_video_gen_pipeline.py @@ -135,6 +135,19 @@ def __init__(self, model_path, model_dtype='bf16', use_gradient_checkpointing=Fa self.cfg_rate = 0.1 self.return_log = return_log self.use_flash_attn = use_flash_attn + self.sequential_offload_enabled = False + + def _enable_sequential_cpu_offload(self, model): + self.sequential_offload_enabled = True + torch_device = torch.device("cuda") + device_type = torch_device.type + device = torch.device(f"{device_type}:0") + offload_buffers = len(model._parameters) > 0 + cpu_offload(model, device, offload_buffers=offload_buffers) + + def enable_sequential_cpu_offload(self): + self._enable_sequential_cpu_offload(self.text_encoder) + self._enable_sequential_cpu_offload(self.dit) def load_checkpoint(self, checkpoint_path, model_key='model', **kwargs): checkpoint = torch.load(checkpoint_path, map_location='cpu') @@ -322,10 +335,11 @@ def generate_i2v( dtype = self.dtype if cpu_offloading: # skip caring about the text encoder here as its about to be used anyways. - if str(self.dit.device) != "cpu": - print("(dit) Warning: Do not preload pipeline components (i.e. to cuda) with cpu offloading enabled! Otherwise, a second transfer will occur needlessly taking up time.") - self.dit.to("cpu") - torch.cuda.empty_cache() + if not self.sequential_offload_enabled: + if str(self.dit.device) != "cpu": + print("(dit) Warning: Do not preload pipeline components (i.e. to cuda) with cpu offloading enabled! Otherwise, a second transfer will occur needlessly taking up time.") + self.dit.to("cpu") + torch.cuda.empty_cache() if str(self.vae.device) != "cpu": print("(vae) Warning: Do not preload pipeline components (i.e. to cuda) with cpu offloading enabled! Otherwise, a second transfer will occur needlessly taking up time.") self.vae.to("cpu") @@ -350,12 +364,14 @@ def generate_i2v( negative_prompt = negative_prompt or "" # Get the text embeddings - if cpu_offloading: - self.text_encoder.to("cuda") + if cpu_offloading and not self.sequential_offload_enabled: + self.text_encoder.to("cuda") prompt_embeds, prompt_attention_mask, pooled_prompt_embeds = self.text_encoder(prompt, device) negative_prompt_embeds, negative_prompt_attention_mask, negative_pooled_prompt_embeds = self.text_encoder(negative_prompt, device) + if cpu_offloading: - self.text_encoder.to("cpu") + if not self.sequential_offload_enabled: + self.text_encoder.to("cpu") self.vae.to("cuda") torch.cuda.empty_cache() @@ -425,7 +441,8 @@ def generate_i2v( if cpu_offloading: self.vae.to("cpu") - self.dit.to("cuda") + if not self.sequential_offload_enabled: + self.dit.to("cuda") torch.cuda.empty_cache() for unit_index in tqdm(range(1, num_units)): @@ -524,15 +541,17 @@ def generate( dtype = self.dtype if cpu_offloading: # skip caring about the text encoder here as its about to be used anyways. - if str(self.dit.device) != "cpu": - print("(dit) Warning: Do not preload pipeline components (i.e. to cuda) with cpu offloading enabled! Otherwise, a second transfer will occur needlessly taking up time.") - self.dit.to("cpu") - torch.cuda.empty_cache() + if not self.sequential_offload_enabled: + if str(self.dit.device) != "cpu": + print("(dit) Warning: Do not preload pipeline components (i.e. to cuda) with cpu offloading enabled! Otherwise, a second transfer will occur needlessly taking up time.") + self.dit.to("cpu") + torch.cuda.empty_cache() if str(self.vae.device) != "cpu": print("(vae) Warning: Do not preload pipeline components (i.e. to cuda) with cpu offloading enabled! Otherwise, a second transfer will occur needlessly taking up time.") self.vae.to("cpu") torch.cuda.empty_cache() + assert (temp - 1) % self.frame_per_unit == 0, "The frames should be divided by frame_per unit" if isinstance(prompt, str): @@ -552,13 +571,14 @@ def generate( negative_prompt = negative_prompt or "" # Get the text embeddings - if cpu_offloading: + if cpu_offloading and not self.sequential_offload_enabled: self.text_encoder.to("cuda") prompt_embeds, prompt_attention_mask, pooled_prompt_embeds = self.text_encoder(prompt, device) negative_prompt_embeds, negative_prompt_attention_mask, negative_pooled_prompt_embeds = self.text_encoder(negative_prompt, device) if cpu_offloading: - self.text_encoder.to("cpu") - self.dit.to("cuda") + if not self.sequential_offload_enabled: + self.text_encoder.to("cpu") + self.dit.to("cuda") torch.cuda.empty_cache() if use_linear_guidance: @@ -689,7 +709,8 @@ def generate( image = generated_latents else: if cpu_offloading: - self.dit.to("cpu") + if not self.sequential_offload_enabled: + self.dit.to("cpu") self.vae.to("cuda") torch.cuda.empty_cache() image = self.decode_latent(generated_latents, save_memory=save_memory, inference_multigpu=inference_multigpu) From 6d0f7e60e03406deb927afe1a5aefee324ea875d Mon Sep 17 00:00:00 2001 From: Rodrigo Antonio de Araujo Date: Sun, 13 Oct 2024 07:22:01 -0300 Subject: [PATCH 2/2] Add missing import --- README.md | 10 +++++++--- pyramid_dit/pyramid_dit_for_video_gen_pipeline.py | 2 +- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index a451c93..39c9143 100644 --- a/README.md +++ b/README.md @@ -92,10 +92,14 @@ model = PyramidDiTForVideoGeneration( model_variant='diffusion_transformer_768p', # 'diffusion_transformer_384p' ) -model.vae.to("cuda") -model.dit.to("cuda") -model.text_encoder.to("cuda") + model.vae.enable_tiling() +# model.vae.to("cuda") +# model.dit.to("cuda") +# model.text_encoder.to("cuda") + +# if you're not using sequential offloading bellow uncomment the lines above ^ +model.enable_sequential_cpu_offload() ``` Then, you can try text-to-video generation on your own prompts: diff --git a/pyramid_dit/pyramid_dit_for_video_gen_pipeline.py b/pyramid_dit/pyramid_dit_for_video_gen_pipeline.py index fa7ca9d..3ab5254 100644 --- a/pyramid_dit/pyramid_dit_for_video_gen_pipeline.py +++ b/pyramid_dit/pyramid_dit_for_video_gen_pipeline.py @@ -16,7 +16,7 @@ from torchvision import transforms from copy import deepcopy from typing import Any, Callable, Dict, List, Optional, Union -from accelerate import Accelerator +from accelerate import Accelerator, cpu_offload from diffusion_schedulers import PyramidFlowMatchEulerDiscreteScheduler from video_vae.modeling_causal_vae import CausalVideoVAE