-
Notifications
You must be signed in to change notification settings - Fork 225
Expand file tree
/
Copy pathtrain.py
More file actions
1284 lines (1067 loc) · 52.8 KB
/
train.py
File metadata and controls
1284 lines (1067 loc) · 52.8 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Licensed under the TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://github.com/Tencent-Hunyuan/HunyuanVideo-1.5/blob/main/LICENSE
#
# Unless and only to the extent required by applicable law, the Tencent Hunyuan works and any
# output and results therefrom are provided "AS IS" without any express or implied warranties of
# any kind including any warranties of title, merchantability, noninfringement, course of dealing,
# usage of trade, or fitness for a particular purpose. You are solely responsible for determining the
# appropriateness of using, reproducing, modifying, performing, displaying or distributing any of
# the Tencent Hunyuan works or outputs and assume any and all risks associated with your or a
# third party's use or distribution of any of the Tencent Hunyuan works or outputs and your exercise
# of rights and permissions under this agreement.
# See the License for the specific language governing permissions and limitations under the License.
"""
HunyuanVideo-1.5 Training Script
This script provides a complete training pipeline for HunyuanVideo-1.5 model.
Quick Start:
1. Implement your own dataloader:
- Replace the `create_dummy_dataloader()` function with your own implementation
- Your dataset's __getitem__ method should return a single sample:
* "pixel_values": torch.Tensor - Video: [C, F, H, W] or Image: [C, H, W]
Pixel values must be in range [-1, 1]
Note: For video data, temporal dimension F must be 4n+1 (e.g., 1, 5, 9, 13, 17, ...)
* "text": str - Text prompt for this sample
* "data_type": str - "video" or "image"
* Optional: "latents" - Pre-encoded VAE latents for faster training
* Optional: "byt5_text_ids" and "byt5_text_mask" - Pre-tokenized byT5 inputs
- See `create_dummy_dataloader()` function for detailed format documentation
2. Configure training parameters:
- Set `--pretrained_model_root` to your pretrained model path
- Adjust training hyperparameters (learning_rate, batch_size, etc.)
- Configure distributed training settings (sp_size, enable_fsdp, etc.)
3. Run training:
- Single GPU: python train.py --pretrained_model_root <path> [other args]
- Multi-GPU: torchrun --nproc_per_node=N train.py --pretrained_model_root <path> [other args]
4. Monitor training:
- Checkpoints are saved to `output_dir` at intervals specified by `--save_interval`
- Validation videos are generated at intervals specified by `--validation_interval`
- Training logs are printed to console at intervals specified by `--log_interval`
5. Resume training:
- Use `--resume_from_checkpoint <checkpoint_dir>` to resume from a saved checkpoint
For detailed format requirements, see the docstring of `create_dummy_dataloader()` function.
"""
import os
import random
import math
import argparse
from dataclasses import dataclass
from typing import Dict, Any, List, Optional
from enum import Enum
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.distributed.checkpoint as dcp
from torch.distributed.checkpoint.state_dict import (
get_model_state_dict,
get_optimizer_state_dict,
)
from diffusers.optimization import get_scheduler
from loguru import logger
import einops
import imageio
from hyvideo.pipelines.hunyuan_video_pipeline import HunyuanVideo_1_5_Pipeline
from hyvideo.commons.parallel_states import get_parallel_state, initialize_parallel_state
from hyvideo.optim.muon import get_muon_optimizer
from torch.distributed._composable.fsdp import (
MixedPrecisionPolicy,
fully_shard,
)
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
CheckpointImpl,
apply_activation_checkpointing,
checkpoint_wrapper,
)
class SNRType(str, Enum):
UNIFORM = "uniform"
LOGNORM = "lognorm"
MIX = "mix"
MODE = "mode"
def str_to_bool(value):
"""Convert string to boolean, supporting true/false, 1/0, yes/no.
If value is None (when flag is provided without value), returns True."""
if value is None:
return True
if isinstance(value, bool):
return value
if isinstance(value, str):
value = value.lower().strip()
if value in ('true', '1', 'yes', 'on'):
return True
elif value in ('false', '0', 'no', 'off'):
return False
raise argparse.ArgumentTypeError(f"Boolean value expected, got: {value}")
def save_video(video: torch.Tensor, path: str):
if video.ndim == 5:
assert video.shape[0] == 1, f"Expected batch size 1, got {video.shape[0]}"
video = video[0]
vid = (video * 255).clamp(0, 255).to(torch.uint8)
vid = einops.rearrange(vid, 'c f h w -> f h w c')
imageio.mimwrite(path, vid.cpu().numpy(), fps=24)
@dataclass
class TrainingConfig:
# Model paths
pretrained_model_root: str
pretrained_transformer_version: str = "720p_t2v"
# Training parameters
learning_rate: float = 5e-5
weight_decay: float = 0.01
max_steps: int = 10000
warmup_steps: int = 500
gradient_accumulation_steps: int = 1
max_grad_norm: float = 1.0
use_muon: bool = True
# Diffusion parameters
num_train_timesteps: int = 1000
train_timestep_shift: float = 3.0
validation_timestep_shift: float = 5.0
snr_type: SNRType = SNRType.LOGNORM # Timestep sampling strategy: uniform, lognorm, mix, or mode
# Task configuration
task_type: str = "t2v" # "t2v" or "i2v"
i2v_prob: float = 0.3 # Probability of using i2v task when data_type is video (default: 0.3 for video training)
# FSDP configuration
enable_fsdp: bool = True # Enable FSDP for distributed training
enable_gradient_checkpointing: bool = True # Enable gradient checkpointing
sp_size: int = 8 # Sequence parallelism size (must divide world_size evenly)
dp_replicate: int = 1 # Data parallelism replicate size (must divide world_size evenly)
# Data configuration
batch_size: int = 1
num_workers: int = 4
# Output configuration
output_dir: str = "./outputs"
save_interval: int = 1000
log_interval: int = 10
# Device configuration
dtype: str = "bf16" # "bf16" or "fp32"
# Seed
seed: int = 42
# Validation configuration
validation_interval: int = 100 # Run validation every N steps
validation_prompts: Optional[List[str]] = None # Prompts for validation (default: single prompt)
validate_video_length: int = 121 # Video length (number of frames) for validation
# Resume training configuration
resume_from_checkpoint: Optional[str] = None # Path to checkpoint directory to resume from
# LoRA configuration
use_lora: bool = False
lora_r: int = 8
lora_alpha: int = 16
lora_dropout: float = 0.0
lora_target_modules: Optional[List[str]] = None # Target modules for LoRA (default: all Linear layers)
pretrained_lora_path: Optional[str] = None
class LinearInterpolationSchedule:
"""Simple linear interpolation schedule for flow matching"""
def __init__(self, T: int = 1000):
self.T = T
def forward(self, x0: torch.Tensor, x1: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
"""
Linear interpolation: x_t = (1 - t/T) * x0 + (t/T) * x1
Args:
x0: starting point (clean latents)
x1: ending point (noise)
t: timesteps
"""
t_normalized = t / self.T
t_normalized = t_normalized.view(-1, *([1] * (x0.ndim - 1)))
return (1 - t_normalized) * x0 + t_normalized * x1
class TimestepSampler:
TRAIN_EPS = 1e-5
SAMPLE_EPS = 1e-3
def __init__(
self,
T: int = 1000,
device: torch.device = None,
snr_type: SNRType = SNRType.LOGNORM,
):
self.T = T
self.device = device
self.snr_type = SNRType(snr_type) if isinstance(snr_type, str) else snr_type
def _check_interval(self, eval: bool = False):
# For ICPlan-like path with velocity model, use [eps, 1-eps]
eps = self.SAMPLE_EPS if eval else self.TRAIN_EPS
t0 = eps
t1 = 1.0 - eps
return t0, t1
def sample(self, batch_size: int, device: torch.device = None) -> torch.Tensor:
if device is None:
device = self.device if self.device is not None else torch.device("cuda")
t0, t1 = self._check_interval(eval=False)
if self.snr_type == SNRType.UNIFORM:
# Uniform sampling: t = rand() * (t1 - t0) + t0
t = torch.rand((batch_size,), device=device) * (t1 - t0) + t0
elif self.snr_type == SNRType.LOGNORM:
# Log-normal sampling: t = 1 / (1 + exp(-u)) * (t1 - t0) + t0
u = torch.normal(mean=0.0, std=1.0, size=(batch_size,), device=device)
t = 1.0 / (1.0 + torch.exp(-u)) * (t1 - t0) + t0
elif self.snr_type == SNRType.MIX:
# Mix sampling: 30% lognorm + 70% clipped uniform
u = torch.normal(mean=0.0, std=1.0, size=(batch_size,), device=device)
t_lognorm = 1.0 / (1.0 + torch.exp(-u)) * (t1 - t0) + t0
# Clipped uniform: delta = 0.0 (0.0~0.01 clip)
delta = 0.0
t0_clip = t0 + delta
t1_clip = t1 - delta
t_clip_uniform = torch.rand((batch_size,), device=device) * (t1_clip - t0_clip) + t0_clip
# Mix with 30% lognorm, 70% uniform
mask = (torch.rand((batch_size,), device=device) > 0.3).float()
t = mask * t_lognorm + (1 - mask) * t_clip_uniform
elif self.snr_type == SNRType.MODE:
# Mode sampling: t = 1 - u - mode_scale * (cos(pi * u / 2)^2 - 1 + u)
mode_scale = 1.29
u = torch.rand(size=(batch_size,), device=device)
t = 1.0 - u - mode_scale * (torch.cos(math.pi * u / 2.0) ** 2 - 1.0 + u)
# Scale to [t0, t1] range
t = t * (t1 - t0) + t0
else:
raise ValueError(f"Unknown SNR type: {self.snr_type}")
# Scale to [0, T] range
timesteps = t * self.T
return timesteps
def timestep_transform(timesteps: torch.Tensor, T: int, shift: float = 1.0) -> torch.Tensor:
"""Transform timesteps with shift"""
if shift == 1.0:
return timesteps
timesteps_normalized = timesteps / T
timesteps_transformed = shift * timesteps_normalized / (1 + (shift - 1) * timesteps_normalized)
return timesteps_transformed * T
def is_src(src, group_src, group):
assert src is not None or group_src is not None
assert src is None or group_src is None
if src is not None:
return dist.get_rank() == src
if group_src is not None:
return dist.get_rank() == dist.get_global_rank(group, group_src)
raise RuntimeError("src and group_src cannot be both None")
def broadcast_object(
obj,
src = None,
group = None,
device = None,
group_src = None,
):
kwargs = dict(
src=src,
group_src=group_src,
group=group,
device=device,
)
buffer = [obj] if is_src(src, group_src, group) else [None]
dist.broadcast_object_list(buffer, **kwargs)
return buffer[0]
def broadcast_tensor(
tensor,
src = None,
group = None,
async_op: bool = False,
group_src = None,
):
"""shape and dtype safe broadcast of tensor"""
kwargs = dict(
src=src,
group_src=group_src,
group=group,
async_op=async_op,
)
if is_src(src, group_src, group):
tensor = tensor.cuda().contiguous()
if is_src(src, group_src, group):
shape, dtype = tensor.shape, tensor.dtype
else:
shape, dtype = None, None
shape = broadcast_object(shape, src=src, group_src=group_src, group=group)
dtype = broadcast_object(dtype, src=src, group_src=group_src, group=group)
buffer = tensor if is_src(src, group_src, group) else torch.empty(shape, device='cuda', dtype=dtype)
dist.broadcast(buffer, **kwargs)
return buffer
def sync_tensor_for_sp(tensor: torch.Tensor, sp_group) -> torch.Tensor:
"""
Sync tensor within sequence parallel group.
Ensures all ranks in the SP group have the same tensor values.
"""
if sp_group is None:
return tensor
if not isinstance(tensor, torch.Tensor):
obj_list = [tensor]
dist.broadcast_object_list(obj_list, group_src=0, group=sp_group)
return obj_list[0]
return broadcast_tensor(tensor, group_src=0, group=sp_group)
class HunyuanVideoTrainer:
def __init__(self, config: TrainingConfig):
self.config = config
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if "RANK" in os.environ:
self.rank = int(os.environ["RANK"])
self.world_size = int(os.environ.get("WORLD_SIZE", "1"))
self.local_rank = int(os.environ.get("LOCAL_RANK", "0"))
self.device = torch.device(f"cuda:{self.local_rank}")
self.is_main_process = self.rank == 0
else:
self.rank = 0
self.world_size = 1
self.local_rank = 0
self.is_main_process = True
if config.sp_size > self.world_size:
raise ValueError(
f"sp_size ({config.sp_size}) cannot be greater than world_size ({self.world_size})"
)
if self.world_size % config.sp_size != 0:
raise ValueError(
f"sp_size ({config.sp_size}) must evenly divide world_size ({self.world_size}). "
f"world_size % sp_size = {self.world_size % config.sp_size}"
)
initialize_parallel_state(sp=config.sp_size, dp_replicate=config.dp_replicate)
torch.cuda.set_device(self.local_rank)
self.parallel_state = get_parallel_state()
self.dp_rank = self.parallel_state.world_mesh['dp'].get_local_rank()
self.dp_size = self.parallel_state.world_mesh['dp'].size()
self.sp_enabled = self.parallel_state.sp_enabled
self.sp_group = self.parallel_state.sp_group if self.sp_enabled else None
self._set_seed(config.seed + self.dp_rank)
self._build_models()
self._build_optimizer()
self.noise_schedule = LinearInterpolationSchedule(T=config.num_train_timesteps)
self.timestep_sampler = TimestepSampler(
T=config.num_train_timesteps,
device=self.device,
snr_type=config.snr_type,
)
self.global_step = 0
self.current_epoch = 0
if self.is_main_process:
os.makedirs(config.output_dir, exist_ok=True)
self.validation_output_dir = os.path.join(config.output_dir, "samples")
if self.is_main_process:
os.makedirs(self.validation_output_dir, exist_ok=True)
if config.validation_prompts is None:
config.validation_prompts = ["A beautiful sunset over the ocean with waves gently crashing on the shore"]
def _set_seed(self, seed: int):
random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
def _build_models(self):
if self.config.dtype == "bf16":
transformer_dtype = torch.bfloat16
elif self.config.dtype == "fp32":
transformer_dtype = torch.float32
else:
raise ValueError(f"Unsupported dtype: {self.config.dtype}")
# Don't create SR pipeline for training (validation uses enable_sr=False)
self.pipeline = HunyuanVideo_1_5_Pipeline.create_pipeline(
pretrained_model_name_or_path=self.config.pretrained_model_root,
transformer_version=self.config.pretrained_transformer_version,
transformer_dtype=transformer_dtype,
enable_offloading=False,
enable_group_offloading=False,
overlap_group_offloading=False,
create_sr_pipeline=False,
flow_shift=self.config.validation_timestep_shift,
device=self.device,
)
self.transformer = self.pipeline.transformer
self.vae = self.pipeline.vae
self.text_encoder = self.pipeline.text_encoder
self.text_encoder_2 = self.pipeline.text_encoder_2
self.vision_encoder = self.pipeline.vision_encoder
self.byt5_kwargs = {
"byt5_model": self.pipeline.byt5_model,
"byt5_tokenizer": self.pipeline.byt5_tokenizer,
}
self.transformer.train()
if self.config.use_lora:
self._apply_lora()
if self.config.enable_gradient_checkpointing:
self._apply_gradient_checkpointing()
if self.config.enable_fsdp and self.world_size > 1:
self._apply_fsdp()
if self.is_main_process:
logger.info(f"Models loaded. Transformer dtype: {transformer_dtype}")
total_params = sum(p.numel() for p in self.transformer.parameters())
trainable_params = sum(p.numel() for p in self.transformer.parameters() if p.requires_grad)
logger.info(f"Transformer parameters: {total_params:,} (trainable: {trainable_params:,})")
logger.info(f"LoRA enabled: {self.config.use_lora}")
logger.info(f"FSDP enabled: {self.config.enable_fsdp and self.world_size > 1}")
logger.info(f"Gradient checkpointing enabled: {self.config.enable_gradient_checkpointing}")
logger.info(f"Timestep sampling strategy: {self.config.snr_type.value}")
def _apply_lora(self):
if self.is_main_process:
logger.info("Applying LoRA to transformer using PeftAdapterMixin...")
if self.config.pretrained_lora_path is not None:
if self.is_main_process:
logger.info(f"Loading pretrained LoRA from {self.config.pretrained_lora_path}")
self.load_pretrained_lora(self.config.pretrained_lora_path)
else:
from peft import LoraConfig
if self.config.lora_target_modules is None:
target_modules = "all-linear"
else:
target_modules = self.config.lora_target_modules
lora_config = LoraConfig(
r=self.config.lora_r,
lora_alpha=self.config.lora_alpha,
target_modules=target_modules,
lora_dropout=self.config.lora_dropout,
bias="none",
task_type="FEATURE_EXTRACTION",
)
self.transformer.add_adapter(lora_config, adapter_name="default")
if self.is_main_process:
trainable_params = sum(p.numel() for p in self.transformer.parameters() if p.requires_grad)
total_params = sum(p.numel() for p in self.transformer.parameters())
logger.info(f"LoRA applied successfully. Trainable parameters: {trainable_params:,} / {total_params:,} "
f"({100 * trainable_params / total_params:.2f}%)")
def _apply_fsdp(self):
if self.is_main_process:
logger.info("Applying FSDP2 to transformer...")
param_dtype = torch.bfloat16
reduce_dtype = torch.float32 # Reduce in float32 for stability
self.transformer = self.transformer.to(dtype=param_dtype)
mp_policy = MixedPrecisionPolicy(
param_dtype=param_dtype,
reduce_dtype=reduce_dtype,
)
fsdp_config = {"mp_policy": mp_policy}
if self.world_size > 1:
try:
fsdp_config["mesh"] = get_parallel_state().fsdp_mesh
except Exception as e:
if self.is_main_process:
logger.warning(f"Could not create DeviceMesh: {e}. FSDP will use process group instead.")
for block in list(self.transformer.double_blocks) + list(self.transformer.single_blocks):
if block is not None:
fully_shard(block, **fsdp_config)
fully_shard(self.transformer, **fsdp_config)
if self.is_main_process:
logger.info("FSDP2 applied successfully")
def _apply_gradient_checkpointing(self):
if self.is_main_process:
logger.info("Applying gradient checkpointing to transformer blocks...")
no_split_module_type = None
for block in self.transformer.double_blocks:
if block is not None:
no_split_module_type = type(block)
break
if no_split_module_type is None:
for block in self.transformer.single_blocks:
if block is not None:
no_split_module_type = type(block)
break
if no_split_module_type is None:
logger.warning("Could not find block type for gradient checkpointing. Using fallback.")
if hasattr(self.transformer, "gradient_checkpointing_enable"):
self.transformer.gradient_checkpointing_enable()
return
def non_reentrant_wrapper(module):
return checkpoint_wrapper(
module,
checkpoint_impl=CheckpointImpl.NO_REENTRANT,
)
def selective_checkpointing(submodule):
return isinstance(submodule, no_split_module_type)
apply_activation_checkpointing(
self.transformer,
checkpoint_wrapper_fn=non_reentrant_wrapper,
check_fn=selective_checkpointing,
)
if self.is_main_process:
logger.info("Gradient checkpointing applied successfully")
def _build_optimizer(self):
if self.config.use_muon:
self.optimizer = get_muon_optimizer(
model=self.transformer,
lr=self.config.learning_rate,
weight_decay=self.config.weight_decay,
)
else:
trainable_params = list(self.transformer.parameters())
self.optimizer = torch.optim.AdamW(
trainable_params,
lr=self.config.learning_rate,
betas=(0.9, 0.999),
eps=1e-8,
weight_decay=self.config.weight_decay,
)
self.lr_scheduler = get_scheduler(
"constant",
optimizer=self.optimizer,
num_warmup_steps=self.config.warmup_steps * self.world_size,
num_training_steps=self.config.max_steps * self.world_size,
)
if self.is_main_process:
logger.info(f"Optimizer and scheduler initialized")
def encode_text(self, prompts, data_type: str = "image"):
text_inputs = self.text_encoder.text2tokens(prompts, data_type=data_type)
text_outputs = self.text_encoder.encode(text_inputs, data_type=data_type, device=self.device)
text_emb = text_outputs.hidden_state
text_mask = text_outputs.attention_mask
text_emb_2 = None
text_mask_2 = None
if self.text_encoder_2 is not None:
text_inputs_2 = self.text_encoder_2.text2tokens(prompts)
text_outputs_2 = self.text_encoder_2.encode(text_inputs_2, device=self.device)
text_emb_2 = text_outputs_2.hidden_state
text_mask_2 = text_outputs_2.attention_mask
return text_emb, text_mask, text_emb_2, text_mask_2
def encode_byt5(self, text_ids: torch.Tensor, attention_mask: torch.Tensor):
if self.byt5_kwargs["byt5_model"] is None:
return None, None
byt5_outputs = self.byt5_kwargs["byt5_model"](text_ids, attention_mask=attention_mask.float())
byt5_emb = byt5_outputs[0]
return byt5_emb, attention_mask
def encode_images(self, images):
"""Encode images to vision states (for i2v)"""
if self.vision_encoder is None:
return None
assert images.max() <= 1.0 and images.min() >= -1.0, f"Images must be in the range [-1, 1], but got {images.min()} {images.max()}"
images = (images + 1) / 2 # [-1, 1] -> [0, 1]
images_np = (images.cpu().permute(0, 2, 3, 1).numpy() * 255).clip(0, 255).astype("uint8")
vision_states = self.vision_encoder.encode_images(images_np)
return vision_states.last_hidden_state.to(device=self.device, dtype=self.transformer.dtype)
def encode_vae(self, images: torch.Tensor) -> torch.Tensor:
if images.max() > 1.0 or images.min() < -1.0:
raise ValueError(f"Images must be in the range [-1, 1], but got {images.min()} {images.max()}")
if images.ndim == 4:
images = images.unsqueeze(2)
with torch.no_grad(), torch.autocast(device_type="cuda", dtype=torch.float16), self.vae.memory_efficient_context():
latents = self.vae.encode(images).latent_dist.sample()
if hasattr(self.vae.config, "shift_factor") and self.vae.config.shift_factor:
latents = (latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor
else:
latents = latents * self.vae.config.scaling_factor
return latents
def get_condition(self, latents: torch.Tensor, task_type: str) -> torch.Tensor:
b, c, f, h, w = latents.shape
cond = torch.zeros([b, c + 1, f, h, w], device=latents.device, dtype=latents.dtype)
if task_type == "t2v":
return cond
elif task_type == "i2v":
cond[:, :-1, :1] = latents[:, :, :1]
cond[:, -1, 0] = 1
return cond
else:
raise ValueError(f"Unsupported task type: {task_type}")
def sample_task(self, data_type: str) -> str:
"""
Sample task type based on data type and configuration.
For video data: samples between t2v and i2v based on i2v_prob
For image data: always returns t2v (image-to-video generation)
"""
if data_type == "image":
return "t2v"
elif data_type == "video":
if random.random() < self.config.i2v_prob:
return "i2v"
else:
return "t2v"
else:
return "t2v"
def prepare_batch(self, batch: Dict[str, Any]) -> Dict[str, Any]:
"""
Prepare batch for training.
Expected batch format:
{
"pixel_values": torch.Tensor, # [B, C, F, H, W] for video or [B, C, H, W] for image
# Pixel values must be in range [-1, 1]
"text": List[str],
"data_type": str, # "image" or "video"
"byt5_text_ids": Optional[torch.Tensor],
"byt5_text_mask": Optional[torch.Tensor],
}
Note: For video data, the temporal dimension F must be 4n+1 (e.g., 1, 5, 9, 13, 17, ...)
to satisfy VAE requirements. The dataset should ensure this before returning data.
"""
pixel_values = batch.get("pixel_values", None)
if pixel_values is not None:
pixel_values = pixel_values.to(self.device)
if 'latents' in batch:
latents = batch['latents'].to(self.device)
else:
latents = self.encode_vae(pixel_values)
if self.sp_enabled:
latents = sync_tensor_for_sp(latents, self.sp_group)
if pixel_values is not None:
pixel_values = sync_tensor_for_sp(pixel_values, self.sp_group)
data_type_raw = batch.get("data_type", "image")
if isinstance(data_type_raw, list):
data_type = data_type_raw[0]
elif isinstance(data_type_raw, str):
data_type = data_type_raw
else:
data_type = str(data_type_raw) if data_type_raw is not None else "image"
task_type = self.sample_task(data_type)
if self.sp_enabled:
task_type = sync_tensor_for_sp(task_type, self.sp_group)
cond_latents = self.get_condition(latents, task_type)
prompts = batch["text"]
if self.sp_enabled:
prompts = sync_tensor_for_sp(prompts, self.sp_group)
text_emb, text_mask, text_emb_2, text_mask_2 = self.encode_text(prompts, data_type=data_type)
byt5_text_states = None
byt5_text_mask = None
if self.byt5_kwargs["byt5_model"] is not None:
if "byt5_text_ids" in batch and batch["byt5_text_ids"] is not None:
byt5_text_ids = batch["byt5_text_ids"].to(self.device)
byt5_text_mask = batch["byt5_text_mask"].to(self.device)
if self.sp_enabled:
byt5_text_ids = sync_tensor_for_sp(byt5_text_ids, self.sp_group)
byt5_text_mask = sync_tensor_for_sp(byt5_text_mask, self.sp_group)
byt5_text_states, byt5_text_mask = self.encode_byt5(byt5_text_ids, byt5_text_mask)
else:
byt5_embeddings_list = []
byt5_mask_list = []
for prompt in prompts:
emb, mask = self.pipeline._process_single_byt5_prompt(prompt, self.device)
byt5_embeddings_list.append(emb)
byt5_mask_list.append(mask)
byt5_text_states = torch.cat(byt5_embeddings_list, dim=0)
byt5_text_mask = torch.cat(byt5_mask_list, dim=0)
vision_states = None
if task_type == "i2v":
assert pixel_values is not None, '`pixel_values` must be provided for i2v task'
if pixel_values.ndim == 5:
first_frame = pixel_values[:, :, 0, :, :]
else:
first_frame = pixel_values
vision_states = self.encode_images(first_frame)
noise = torch.randn_like(latents)
timesteps = self.timestep_sampler.sample(latents.shape[0], device=self.device)
timesteps = timestep_transform(timesteps, self.config.num_train_timesteps, self.config.train_timestep_shift)
latents_noised = self.noise_schedule.forward(latents, noise, timesteps)
target = noise - latents
if self.sp_enabled:
target = sync_tensor_for_sp(target, self.sp_group)
return {
"latents_noised": latents_noised,
"cond_latents": cond_latents,
"timesteps": timesteps,
"target": target,
"text_emb": text_emb,
"text_emb_2": text_emb_2,
"text_mask": text_mask,
"text_mask_2": text_mask_2,
"byt5_text_states": byt5_text_states,
"byt5_text_mask": byt5_text_mask,
"vision_states": vision_states,
"task_type": task_type,
"data_type": data_type,
}
def train_step(self, batch: Dict[str, Any]) -> Dict[str, float]:
inputs = self.prepare_batch(batch)
latents_input = torch.cat([inputs["latents_noised"], inputs["cond_latents"]], dim=1)
model_dtype = torch.bfloat16 if self.config.dtype == "bf16" else torch.float32
extra_kwargs = {}
if inputs["byt5_text_states"] is not None:
extra_kwargs["byt5_text_states"] = inputs["byt5_text_states"].to(dtype=model_dtype)
extra_kwargs["byt5_text_mask"] = inputs["byt5_text_mask"]
with torch.autocast(device_type="cuda", dtype=model_dtype, enabled=(model_dtype == torch.bfloat16)):
model_pred = self.transformer(
latents_input.to(dtype=model_dtype),
inputs["timesteps"],
text_states=inputs["text_emb"].to(dtype=model_dtype),
text_states_2=inputs["text_emb_2"].to(dtype=model_dtype) if inputs["text_emb_2"] is not None else None,
encoder_attention_mask=inputs["text_mask"].to(dtype=model_dtype),
vision_states=inputs["vision_states"].to(dtype=model_dtype) if inputs["vision_states"] is not None else None,
mask_type=inputs["task_type"],
extra_kwargs=extra_kwargs if extra_kwargs else None,
return_dict=False,
)[0]
target = inputs["target"].to(dtype=model_pred.dtype)
loss = nn.functional.mse_loss(model_pred, target)
loss = loss / self.config.gradient_accumulation_steps
loss.backward()
if (self.global_step + 1) % self.config.gradient_accumulation_steps == 0:
if self.config.max_grad_norm > 0:
grad_norm = torch.nn.utils.clip_grad_norm_(
self.transformer.parameters(),
self.config.max_grad_norm
)
else:
grad_norm = torch.tensor(0.0)
self.optimizer.step()
self.lr_scheduler.step()
self.optimizer.zero_grad()
else:
grad_norm = torch.tensor(0.0)
metrics = {
"loss": loss.item() * self.config.gradient_accumulation_steps,
"grad_norm": grad_norm.item() if isinstance(grad_norm, torch.Tensor) else grad_norm,
"lr": self.lr_scheduler.get_last_lr()[0] if hasattr(self.lr_scheduler, "get_last_lr") else self.config.learning_rate,
}
return metrics
def save_checkpoint(self, step: int):
checkpoint_dir = os.path.join(self.config.output_dir, f"checkpoint-{step}")
transformer_dir = os.path.join(checkpoint_dir, "transformer")
if self.is_main_process:
os.makedirs(checkpoint_dir, exist_ok=True)
if self.world_size > 1:
dist.barrier()
if self.config.use_lora and hasattr(self.transformer, "save_lora_adapter"):
lora_dir = os.path.join(checkpoint_dir, "lora")
os.makedirs(lora_dir, exist_ok=True)
if hasattr(self.transformer, "peft_config") and self.transformer.peft_config:
adapter_names = list(self.transformer.peft_config.keys())
if self.is_main_process:
logger.info(f"Saving {len(adapter_names)} LoRA adapter(s): {adapter_names}")
for adapter_name in adapter_names:
adapter_dir = os.path.join(lora_dir, adapter_name)
os.makedirs(adapter_dir, exist_ok=True)
self.transformer.save_lora_adapter(
save_directory=adapter_dir,
adapter_name=adapter_name,
safe_serialization=True,
)
if self.is_main_process:
logger.info(f"LoRA adapter '{adapter_name}' saved to {adapter_dir}")
else:
raise RuntimeError("No LoRA adapter found in the model")
if self.world_size > 1:
dist.barrier()
# Save full model state dict
model_state_dict = get_model_state_dict(self.transformer)
dcp.save(
state_dict={"model": model_state_dict},
checkpoint_id=transformer_dir,
)
optimizer_state_dict = get_optimizer_state_dict(
self.transformer,
self.optimizer,
)
optimizer_dir = os.path.join(checkpoint_dir, "optimizer")
dcp.save(
state_dict={"optimizer": optimizer_state_dict},
checkpoint_id=optimizer_dir,
)
if self.is_main_process:
training_state_path = os.path.join(checkpoint_dir, "training_state.pt")
torch.save({
"lr_scheduler": self.lr_scheduler.state_dict(),
"global_step": step,
}, training_state_path)
if self.world_size > 1:
dist.barrier()
if self.is_main_process:
logger.info(f"Checkpoint saved at step {step} to {checkpoint_dir}")
def load_pretrained_lora(self, lora_dir: str):
self.transformer.load_lora_adapter(
pretrained_model_name_or_path_or_dict=lora_dir,
prefix=None,
adapter_name="default",
use_safetensors=True,
hotswap=False,
)
def load_checkpoint(self, checkpoint_path: str):
if not os.path.exists(checkpoint_path):
raise ValueError(f"Checkpoint path does not exist: {checkpoint_path}")
if self.is_main_process:
logger.info(f"Loading checkpoint from {checkpoint_path}")
if self.world_size > 1:
dist.barrier()
transformer_dir = os.path.join(checkpoint_path, "transformer")
if os.path.exists(transformer_dir):
model_state_dict = get_model_state_dict(self.transformer)
dcp.load(
state_dict={"model": model_state_dict},
checkpoint_id=transformer_dir,
)
if self.is_main_process:
logger.info("Transformer model state loaded")
else:
logger.warning(f"Transformer dcp checkpoint not found from {checkpoint_path}")
optimizer_dir = os.path.join(checkpoint_path, "optimizer")
if os.path.exists(optimizer_dir):
optimizer_state_dict = get_optimizer_state_dict(
self.transformer,
self.optimizer,
)
dcp.load(
state_dict={"optimizer": optimizer_state_dict},
checkpoint_id=optimizer_dir,
)
if self.is_main_process:
logger.info("Optimizer state loaded")
training_state_path = os.path.join(checkpoint_path, "training_state.pt")
if os.path.exists(training_state_path):
if self.is_main_process:
training_state = torch.load(training_state_path, map_location=self.device)
self.lr_scheduler.load_state_dict(training_state["lr_scheduler"])
self.global_step = training_state.get("global_step", 0)
logger.info(f"Training state loaded: global_step={self.global_step}")
else:
# Non-main processes will get global_step via broadcast
self.global_step = 0
if self.world_size > 1:
global_step_tensor = torch.tensor(self.global_step, device=self.device)
dist.broadcast(global_step_tensor, src=0)
self.global_step = global_step_tensor.item()
if self.world_size > 1:
dist.barrier()
if self.is_main_process:
logger.info(f"Checkpoint loaded successfully. Resuming from step {self.global_step}")
def train(self, dataloader):
if self.is_main_process:
logger.info("Starting training...")
logger.info(f"Max steps: {self.config.max_steps}")
logger.info(f"Batch size: {self.config.batch_size}")
logger.info(f"Learning rate: {self.config.learning_rate}")
if self.config.resume_from_checkpoint is not None:
self.load_checkpoint(self.config.resume_from_checkpoint)
self.transformer.train()
while self.global_step < self.config.max_steps:
for batch in dataloader:
if self.global_step >= self.config.max_steps:
break
metrics = self.train_step(batch)
if self.global_step % self.config.log_interval == 0 and self.is_main_process:
logger.info(
f"Step {self.global_step}/{self.config.max_steps} | "
f"Loss: {metrics['loss']:.6f} | "
f"Grad Norm: {metrics['grad_norm']:.4f} | "
f"LR: {metrics['lr']:.2e}"
)
if self.global_step >= 0 and self.global_step % self.config.validation_interval == 0:
self.validate(self.global_step)
if (self.global_step + 1) % self.config.save_interval == 0:
self.save_checkpoint(self.global_step + 1)
if self.world_size > 1:
dist.barrier()
self.global_step += 1