|
| 1 | +import argparse |
| 2 | +import json |
| 3 | +import time |
| 4 | +from typing import Union, List |
| 5 | + |
| 6 | +import PIL |
| 7 | +import imageio |
| 8 | +import numpy as np |
| 9 | +import torch |
| 10 | + |
| 11 | +from diffusers import CogVideoXPipeline |
| 12 | +from onediffx import compile_pipe, quantize_pipe |
| 13 | + |
| 14 | + |
| 15 | +def export_to_video_imageio( |
| 16 | + video_frames: Union[List[np.ndarray], List[PIL.Image.Image]], output_video_path: str = None, fps: int = 8 |
| 17 | +) -> str: |
| 18 | + """ |
| 19 | + Export the video frames to a video file using imageio lib to Avoid "green screen" issue (for example CogVideoX) |
| 20 | + """ |
| 21 | + if output_video_path is None: |
| 22 | + output_video_path = tempfile.NamedTemporaryFile(suffix=".mp4").name |
| 23 | + if isinstance(video_frames[0], PIL.Image.Image): |
| 24 | + video_frames = [np.array(frame) for frame in video_frames] |
| 25 | + with imageio.get_writer(output_video_path, fps=fps) as writer: |
| 26 | + for frame in video_frames: |
| 27 | + writer.append_data(frame) |
| 28 | + return output_video_path |
| 29 | + |
| 30 | +def parse_args(): |
| 31 | + parser = argparse.ArgumentParser( |
| 32 | + description="Use onediif to accelerate image generation with CogVideoX" |
| 33 | + ) |
| 34 | + parser.add_argument( |
| 35 | + "--model", |
| 36 | + type=str, |
| 37 | + default="THUDM/CogVideoX-2b", |
| 38 | + help="Model path or identifier.", |
| 39 | + ) |
| 40 | + parser.add_argument( |
| 41 | + "--compiler", |
| 42 | + type=str, |
| 43 | + default="none", |
| 44 | + help="Compiler backend to use. Options: 'none', 'nexfort'", |
| 45 | + ) |
| 46 | + parser.add_argument( |
| 47 | + "--compiler-config", type=str, help="JSON string for compiler config." |
| 48 | + ) |
| 49 | + parser.add_argument( |
| 50 | + "--quantize-config", type=str, help="JSON string for quantization config." |
| 51 | + ) |
| 52 | + parser.add_argument( |
| 53 | + "--prompt", |
| 54 | + type=str, |
| 55 | + default='In the haunting backdrop of a war-torn city, where ruins and crumbled walls tell a story of devastation, a poignant close-up frames a young girl. Her face is smudged with ash, a silent testament to the chaos around her. Her eyes glistening with a mix of sorrow and resilience, capturing the raw emotion of a world that has lost its innocence to the ravages of conflict.', |
| 56 | + help="Prompt for the image generation.", |
| 57 | + ) |
| 58 | + parser.add_argument( |
| 59 | + "--guidance_scale", |
| 60 | + type=float, |
| 61 | + default=6.5, |
| 62 | + help="The scale factor for the guidance.", |
| 63 | + ) |
| 64 | + parser.add_argument( |
| 65 | + "--num-inference-steps", type=int, default=50, help="Number of inference steps." |
| 66 | + ) |
| 67 | + parser.add_argument("--num_videos_per_prompt", type=int, default=1, help="Number of videos to generate per prompt") |
| 68 | + parser.add_argument( |
| 69 | + "--output_path", type=str, default="./output.mp4", help="The path where the generated video will be saved" |
| 70 | + ) |
| 71 | + parser.add_argument( |
| 72 | + "--seed", type=int, default=66, help="Seed for random number generation." |
| 73 | + ) |
| 74 | + parser.add_argument( |
| 75 | + "--warmup-iterations", |
| 76 | + type=int, |
| 77 | + default=1, |
| 78 | + help="Number of warm-up iterations before actual inference.", |
| 79 | + ) |
| 80 | + return parser.parse_args() |
| 81 | + |
| 82 | + |
| 83 | +args = parse_args() |
| 84 | + |
| 85 | +device = torch.device("cuda") |
| 86 | + |
| 87 | + |
| 88 | +class CogVideoGenerator: |
| 89 | + def __init__( |
| 90 | + self, model, compiler_config=None, quantize_config=None, compiler="none" |
| 91 | + ): |
| 92 | + self.pipe = CogVideoXPipeline.from_pretrained( |
| 93 | + model, torch_dtype=torch.float16, variant="fp16" |
| 94 | + ).to(device) |
| 95 | + |
| 96 | + self.prompt_embeds = None |
| 97 | + |
| 98 | + if compiler == "nexfort": |
| 99 | + if compiler_config: |
| 100 | + print("nexfort backend compile...") |
| 101 | + self.pipe = self.compile_pipe(self.pipe, compiler_config) |
| 102 | + |
| 103 | + if quantize_config: |
| 104 | + print("nexfort backend quant...") |
| 105 | + self.pipe = self.quantize_pipe(self.pipe, quantize_config) |
| 106 | + |
| 107 | + def encode_prompt(self, prompt, num_videos_per_prompt): |
| 108 | + self.prompt_embeds, _ = self.pipe.encode_prompt( |
| 109 | + prompt=prompt, |
| 110 | + negative_prompt=None, |
| 111 | + do_classifier_free_guidance=True, |
| 112 | + num_videos_per_prompt=num_videos_per_prompt, |
| 113 | + max_sequence_length=226, |
| 114 | + device=device, |
| 115 | + dtype=torch.float16, |
| 116 | + ) |
| 117 | + |
| 118 | + def warmup(self, gen_args, warmup_iterations): |
| 119 | + warmup_args = gen_args.copy() |
| 120 | + |
| 121 | + warmup_args["generator"] = torch.Generator(device=device).manual_seed(0) |
| 122 | + |
| 123 | + print("Starting warmup...") |
| 124 | + start_time = time.time() |
| 125 | + |
| 126 | + for _ in range(warmup_iterations): |
| 127 | + self.pipe(**warmup_args) |
| 128 | + |
| 129 | + end_time = time.time() |
| 130 | + print("Warmup complete.") |
| 131 | + print(f"Warmup time: {end_time - start_time:.2f} seconds") |
| 132 | + |
| 133 | + def generate(self, gen_args): |
| 134 | + gen_args["generator"] = torch.Generator(device=device).manual_seed(args.seed) |
| 135 | + |
| 136 | + # Run the model |
| 137 | + start_time = time.time() |
| 138 | + video = self.pipe(**gen_args).frames[0] |
| 139 | + end_time = time.time() |
| 140 | + |
| 141 | + export_to_video_imageio(video, args.output_path, fps=8) |
| 142 | + |
| 143 | + return video, end_time - start_time |
| 144 | + |
| 145 | + def compile_pipe(self, pipe, compiler_config): |
| 146 | + options = compiler_config |
| 147 | + pipe = compile_pipe( |
| 148 | + pipe, backend="nexfort", options=options, fuse_qkv_projections=True |
| 149 | + ) |
| 150 | + return pipe |
| 151 | + |
| 152 | + def quantize_pipe(self, pipe, quantize_config): |
| 153 | + pipe = quantize_pipe(pipe, ignores=[], **quantize_config) |
| 154 | + return pipe |
| 155 | + |
| 156 | + |
| 157 | +def main(): |
| 158 | + nexfort_compiler_config = ( |
| 159 | + json.loads(args.compiler_config) if args.compiler_config else None |
| 160 | + ) |
| 161 | + nexfort_quantize_config = ( |
| 162 | + json.loads(args.quantize_config) if args.quantize_config else None |
| 163 | + ) |
| 164 | + |
| 165 | + CogVideo = CogVideoGenerator( |
| 166 | + args.model, |
| 167 | + nexfort_compiler_config, |
| 168 | + nexfort_quantize_config, |
| 169 | + compiler=args.compiler, |
| 170 | + ) |
| 171 | + |
| 172 | + CogVideo.encode_prompt(args.prompt, args.num_videos_per_prompt) |
| 173 | + |
| 174 | + gen_args = { |
| 175 | + "prompt_embeds": CogVideo.prompt_embeds, |
| 176 | + "num_inference_steps": args.num_inference_steps, |
| 177 | + "guidance_scale": args.guidance_scale, |
| 178 | + "negative_prompt_embeds": torch.zeros_like(CogVideo.prompt_embeds), # Not Supported negative prompt |
| 179 | + "num_frames": 8, |
| 180 | + } |
| 181 | + |
| 182 | + CogVideo.warmup(gen_args, args.warmup_iterations) |
| 183 | + |
| 184 | + _, inference_time = CogVideo.generate(gen_args) |
| 185 | + print( |
| 186 | + f"Generated video saved to {args.output_path} in {inference_time:.2f} seconds." |
| 187 | + ) |
| 188 | + cuda_mem_after_used = torch.cuda.max_memory_allocated() / (1024**3) |
| 189 | + print(f"Max used CUDA memory : {cuda_mem_after_used:.3f}GiB") |
| 190 | + |
| 191 | + |
| 192 | +if __name__ == "__main__": |
| 193 | + main() |
0 commit comments