Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -92,10 +92,14 @@ model = PyramidDiTForVideoGeneration(
model_variant='diffusion_transformer_768p', # 'diffusion_transformer_384p'
)

model.vae.to("cuda")
model.dit.to("cuda")
model.text_encoder.to("cuda")

model.vae.enable_tiling()
# model.vae.to("cuda")
# model.dit.to("cuda")
# model.text_encoder.to("cuda")

# if you're not using sequential offloading bellow uncomment the lines above ^
model.enable_sequential_cpu_offload()
```

Then, you can try text-to-video generation on your own prompts:
Expand Down
55 changes: 38 additions & 17 deletions pyramid_dit/pyramid_dit_for_video_gen_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from torchvision import transforms
from copy import deepcopy
from typing import Any, Callable, Dict, List, Optional, Union
from accelerate import Accelerator
from accelerate import Accelerator, cpu_offload
from diffusion_schedulers import PyramidFlowMatchEulerDiscreteScheduler
from video_vae.modeling_causal_vae import CausalVideoVAE

Expand Down Expand Up @@ -135,6 +135,19 @@ def __init__(self, model_path, model_dtype='bf16', use_gradient_checkpointing=Fa
self.cfg_rate = 0.1
self.return_log = return_log
self.use_flash_attn = use_flash_attn
self.sequential_offload_enabled = False

def _enable_sequential_cpu_offload(self, model):
self.sequential_offload_enabled = True
torch_device = torch.device("cuda")
device_type = torch_device.type
device = torch.device(f"{device_type}:0")
offload_buffers = len(model._parameters) > 0
cpu_offload(model, device, offload_buffers=offload_buffers)

def enable_sequential_cpu_offload(self):
self._enable_sequential_cpu_offload(self.text_encoder)
self._enable_sequential_cpu_offload(self.dit)

def load_checkpoint(self, checkpoint_path, model_key='model', **kwargs):
checkpoint = torch.load(checkpoint_path, map_location='cpu')
Expand Down Expand Up @@ -322,10 +335,11 @@ def generate_i2v(
dtype = self.dtype
if cpu_offloading:
# skip caring about the text encoder here as its about to be used anyways.
if str(self.dit.device) != "cpu":
print("(dit) Warning: Do not preload pipeline components (i.e. to cuda) with cpu offloading enabled! Otherwise, a second transfer will occur needlessly taking up time.")
self.dit.to("cpu")
torch.cuda.empty_cache()
if not self.sequential_offload_enabled:
if str(self.dit.device) != "cpu":
print("(dit) Warning: Do not preload pipeline components (i.e. to cuda) with cpu offloading enabled! Otherwise, a second transfer will occur needlessly taking up time.")
self.dit.to("cpu")
torch.cuda.empty_cache()
if str(self.vae.device) != "cpu":
print("(vae) Warning: Do not preload pipeline components (i.e. to cuda) with cpu offloading enabled! Otherwise, a second transfer will occur needlessly taking up time.")
self.vae.to("cpu")
Expand All @@ -350,12 +364,14 @@ def generate_i2v(
negative_prompt = negative_prompt or ""

# Get the text embeddings
if cpu_offloading:
self.text_encoder.to("cuda")
if cpu_offloading and not self.sequential_offload_enabled:
self.text_encoder.to("cuda")
prompt_embeds, prompt_attention_mask, pooled_prompt_embeds = self.text_encoder(prompt, device)
negative_prompt_embeds, negative_prompt_attention_mask, negative_pooled_prompt_embeds = self.text_encoder(negative_prompt, device)

if cpu_offloading:
self.text_encoder.to("cpu")
if not self.sequential_offload_enabled:
self.text_encoder.to("cpu")
self.vae.to("cuda")
torch.cuda.empty_cache()

Expand Down Expand Up @@ -425,7 +441,8 @@ def generate_i2v(

if cpu_offloading:
self.vae.to("cpu")
self.dit.to("cuda")
if not self.sequential_offload_enabled:
self.dit.to("cuda")
torch.cuda.empty_cache()

for unit_index in tqdm(range(1, num_units)):
Expand Down Expand Up @@ -524,15 +541,17 @@ def generate(
dtype = self.dtype
if cpu_offloading:
# skip caring about the text encoder here as its about to be used anyways.
if str(self.dit.device) != "cpu":
print("(dit) Warning: Do not preload pipeline components (i.e. to cuda) with cpu offloading enabled! Otherwise, a second transfer will occur needlessly taking up time.")
self.dit.to("cpu")
torch.cuda.empty_cache()
if not self.sequential_offload_enabled:
if str(self.dit.device) != "cpu":
print("(dit) Warning: Do not preload pipeline components (i.e. to cuda) with cpu offloading enabled! Otherwise, a second transfer will occur needlessly taking up time.")
self.dit.to("cpu")
torch.cuda.empty_cache()
if str(self.vae.device) != "cpu":
print("(vae) Warning: Do not preload pipeline components (i.e. to cuda) with cpu offloading enabled! Otherwise, a second transfer will occur needlessly taking up time.")
self.vae.to("cpu")
torch.cuda.empty_cache()


assert (temp - 1) % self.frame_per_unit == 0, "The frames should be divided by frame_per unit"

if isinstance(prompt, str):
Expand All @@ -552,13 +571,14 @@ def generate(
negative_prompt = negative_prompt or ""

# Get the text embeddings
if cpu_offloading:
if cpu_offloading and not self.sequential_offload_enabled:
self.text_encoder.to("cuda")
prompt_embeds, prompt_attention_mask, pooled_prompt_embeds = self.text_encoder(prompt, device)
negative_prompt_embeds, negative_prompt_attention_mask, negative_pooled_prompt_embeds = self.text_encoder(negative_prompt, device)
if cpu_offloading:
self.text_encoder.to("cpu")
self.dit.to("cuda")
if not self.sequential_offload_enabled:
self.text_encoder.to("cpu")
self.dit.to("cuda")
torch.cuda.empty_cache()

if use_linear_guidance:
Expand Down Expand Up @@ -689,7 +709,8 @@ def generate(
image = generated_latents
else:
if cpu_offloading:
self.dit.to("cpu")
if not self.sequential_offload_enabled:
self.dit.to("cpu")
self.vae.to("cuda")
torch.cuda.empty_cache()
image = self.decode_latent(generated_latents, save_memory=save_memory, inference_multigpu=inference_multigpu)
Expand Down