Skip to content

Commit a16f066

Browse files
committed
Add cogvideo draft
1 parent 92c10ec commit a16f066

File tree

2 files changed

+196
-0
lines changed

2 files changed

+196
-0
lines changed
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
RUN:
2+
3+
python3 onediff_diffusers_extensions/examples/cog/text_to_image_cog.py --model /data0/hf_models/CogVideoX-2b --compiler nexfort --compiler-config '{"mode": "max-optimize:max-autotune:max-autotune", "memory_format": "channels_last", "options": {"inductor.optimize_linear_epilogue": false, "triton.fuse_attention_allow_fp16_reduction": false}}'
Lines changed: 193 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,193 @@
1+
import argparse
2+
import json
3+
import time
4+
from typing import Union, List
5+
6+
import PIL
7+
import imageio
8+
import numpy as np
9+
import torch
10+
11+
from diffusers import CogVideoXPipeline
12+
from onediffx import compile_pipe, quantize_pipe
13+
14+
15+
def export_to_video_imageio(
16+
video_frames: Union[List[np.ndarray], List[PIL.Image.Image]], output_video_path: str = None, fps: int = 8
17+
) -> str:
18+
"""
19+
Export the video frames to a video file using imageio lib to Avoid "green screen" issue (for example CogVideoX)
20+
"""
21+
if output_video_path is None:
22+
output_video_path = tempfile.NamedTemporaryFile(suffix=".mp4").name
23+
if isinstance(video_frames[0], PIL.Image.Image):
24+
video_frames = [np.array(frame) for frame in video_frames]
25+
with imageio.get_writer(output_video_path, fps=fps) as writer:
26+
for frame in video_frames:
27+
writer.append_data(frame)
28+
return output_video_path
29+
30+
def parse_args():
31+
parser = argparse.ArgumentParser(
32+
description="Use onediif to accelerate image generation with CogVideoX"
33+
)
34+
parser.add_argument(
35+
"--model",
36+
type=str,
37+
default="THUDM/CogVideoX-2b",
38+
help="Model path or identifier.",
39+
)
40+
parser.add_argument(
41+
"--compiler",
42+
type=str,
43+
default="none",
44+
help="Compiler backend to use. Options: 'none', 'nexfort'",
45+
)
46+
parser.add_argument(
47+
"--compiler-config", type=str, help="JSON string for compiler config."
48+
)
49+
parser.add_argument(
50+
"--quantize-config", type=str, help="JSON string for quantization config."
51+
)
52+
parser.add_argument(
53+
"--prompt",
54+
type=str,
55+
default='In the haunting backdrop of a war-torn city, where ruins and crumbled walls tell a story of devastation, a poignant close-up frames a young girl. Her face is smudged with ash, a silent testament to the chaos around her. Her eyes glistening with a mix of sorrow and resilience, capturing the raw emotion of a world that has lost its innocence to the ravages of conflict.',
56+
help="Prompt for the image generation.",
57+
)
58+
parser.add_argument(
59+
"--guidance_scale",
60+
type=float,
61+
default=6.5,
62+
help="The scale factor for the guidance.",
63+
)
64+
parser.add_argument(
65+
"--num-inference-steps", type=int, default=50, help="Number of inference steps."
66+
)
67+
parser.add_argument("--num_videos_per_prompt", type=int, default=1, help="Number of videos to generate per prompt")
68+
parser.add_argument(
69+
"--output_path", type=str, default="./output.mp4", help="The path where the generated video will be saved"
70+
)
71+
parser.add_argument(
72+
"--seed", type=int, default=66, help="Seed for random number generation."
73+
)
74+
parser.add_argument(
75+
"--warmup-iterations",
76+
type=int,
77+
default=1,
78+
help="Number of warm-up iterations before actual inference.",
79+
)
80+
return parser.parse_args()
81+
82+
83+
args = parse_args()
84+
85+
device = torch.device("cuda")
86+
87+
88+
class CogVideoGenerator:
89+
def __init__(
90+
self, model, compiler_config=None, quantize_config=None, compiler="none"
91+
):
92+
self.pipe = CogVideoXPipeline.from_pretrained(
93+
model, torch_dtype=torch.float16, variant="fp16"
94+
).to(device)
95+
96+
self.prompt_embeds = None
97+
98+
if compiler == "nexfort":
99+
if compiler_config:
100+
print("nexfort backend compile...")
101+
self.pipe = self.compile_pipe(self.pipe, compiler_config)
102+
103+
if quantize_config:
104+
print("nexfort backend quant...")
105+
self.pipe = self.quantize_pipe(self.pipe, quantize_config)
106+
107+
def encode_prompt(self, prompt, num_videos_per_prompt):
108+
self.prompt_embeds, _ = self.pipe.encode_prompt(
109+
prompt=prompt,
110+
negative_prompt=None,
111+
do_classifier_free_guidance=True,
112+
num_videos_per_prompt=num_videos_per_prompt,
113+
max_sequence_length=226,
114+
device=device,
115+
dtype=torch.float16,
116+
)
117+
118+
def warmup(self, gen_args, warmup_iterations):
119+
warmup_args = gen_args.copy()
120+
121+
warmup_args["generator"] = torch.Generator(device=device).manual_seed(0)
122+
123+
print("Starting warmup...")
124+
start_time = time.time()
125+
126+
for _ in range(warmup_iterations):
127+
self.pipe(**warmup_args)
128+
129+
end_time = time.time()
130+
print("Warmup complete.")
131+
print(f"Warmup time: {end_time - start_time:.2f} seconds")
132+
133+
def generate(self, gen_args):
134+
gen_args["generator"] = torch.Generator(device=device).manual_seed(args.seed)
135+
136+
# Run the model
137+
start_time = time.time()
138+
video = self.pipe(**gen_args).frames[0]
139+
end_time = time.time()
140+
141+
export_to_video_imageio(video, args.output_path, fps=8)
142+
143+
return video, end_time - start_time
144+
145+
def compile_pipe(self, pipe, compiler_config):
146+
options = compiler_config
147+
pipe = compile_pipe(
148+
pipe, backend="nexfort", options=options, fuse_qkv_projections=True
149+
)
150+
return pipe
151+
152+
def quantize_pipe(self, pipe, quantize_config):
153+
pipe = quantize_pipe(pipe, ignores=[], **quantize_config)
154+
return pipe
155+
156+
157+
def main():
158+
nexfort_compiler_config = (
159+
json.loads(args.compiler_config) if args.compiler_config else None
160+
)
161+
nexfort_quantize_config = (
162+
json.loads(args.quantize_config) if args.quantize_config else None
163+
)
164+
165+
CogVideo = CogVideoGenerator(
166+
args.model,
167+
nexfort_compiler_config,
168+
nexfort_quantize_config,
169+
compiler=args.compiler,
170+
)
171+
172+
CogVideo.encode_prompt(args.prompt, args.num_videos_per_prompt)
173+
174+
gen_args = {
175+
"prompt_embeds": CogVideo.prompt_embeds,
176+
"num_inference_steps": args.num_inference_steps,
177+
"guidance_scale": args.guidance_scale,
178+
"negative_prompt_embeds": torch.zeros_like(CogVideo.prompt_embeds), # Not Supported negative prompt
179+
"num_frames": 8,
180+
}
181+
182+
CogVideo.warmup(gen_args, args.warmup_iterations)
183+
184+
_, inference_time = CogVideo.generate(gen_args)
185+
print(
186+
f"Generated video saved to {args.output_path} in {inference_time:.2f} seconds."
187+
)
188+
cuda_mem_after_used = torch.cuda.max_memory_allocated() / (1024**3)
189+
print(f"Max used CUDA memory : {cuda_mem_after_used:.3f}GiB")
190+
191+
192+
if __name__ == "__main__":
193+
main()

0 commit comments

Comments
 (0)