@@ -78,32 +78,19 @@ def __init__(self, job_config: JobConfig):
7878 self .device = torch .device (f"{ device_type } :{ int (os .environ ['LOCAL_RANK' ])} " )
7979 # Device has to be set before creating TorchFT manager.
8080 device_module .set_device (self .device )
81- ft_manager = ft .init_ft_manager (job_config )
8281
8382 # init distributed
8483 world_size = int (os .environ ["WORLD_SIZE" ])
8584 parallelism_config = job_config .parallelism
86- if not ft_manager .enabled :
87- self .parallel_dims = parallel_dims = ParallelDims (
88- dp_shard = parallelism_config .data_parallel_shard_degree ,
89- dp_replicate = parallelism_config .data_parallel_replicate_degree ,
90- cp = parallelism_config .context_parallel_degree ,
91- tp = parallelism_config .tensor_parallel_degree ,
92- pp = parallelism_config .pipeline_parallel_degree ,
93- world_size = world_size ,
94- enable_loss_parallel = not parallelism_config .disable_loss_parallel ,
95- )
96- else :
97- self .parallel_dims = parallel_dims = ft .FTParallelDims (
98- dp_shard = parallelism_config .data_parallel_shard_degree ,
99- dp_replicate = parallelism_config .data_parallel_replicate_degree ,
100- cp = parallelism_config .context_parallel_degree ,
101- tp = parallelism_config .tensor_parallel_degree ,
102- pp = parallelism_config .pipeline_parallel_degree ,
103- world_size = world_size ,
104- enable_loss_parallel = not parallelism_config .disable_loss_parallel ,
105- ft_manager = ft_manager ,
106- )
85+ self .parallel_dims = parallel_dims = ParallelDims (
86+ dp_shard = parallelism_config .data_parallel_shard_degree ,
87+ dp_replicate = parallelism_config .data_parallel_replicate_degree ,
88+ cp = parallelism_config .context_parallel_degree ,
89+ tp = parallelism_config .tensor_parallel_degree ,
90+ pp = parallelism_config .pipeline_parallel_degree ,
91+ world_size = world_size ,
92+ enable_loss_parallel = not parallelism_config .disable_loss_parallel ,
93+ )
10794 dist_utils .init_distributed (job_config )
10895
10996 # build meshes
@@ -114,6 +101,12 @@ def __init__(self, job_config: JobConfig):
114101 else :
115102 dp_degree , dp_rank = 1 , 0
116103
104+ self .ft_manager = ft .init_ft_manager (job_config )
105+ # If TorchFT is enabled, the dp_rank and dp_degree, which are used for
106+ # dataloader must be changed.
107+ if self .ft_manager .enabled :
108+ dp_degree , dp_rank = self .ft_manager .get_dp_info (dp_degree , dp_rank )
109+
117110 # Set random seed, and maybe enable deterministic mode
118111 # (mainly for debugging, expect perf loss).
119112 dist_utils .set_determinism (
@@ -131,11 +124,6 @@ def __init__(self, job_config: JobConfig):
131124 else None
132125 )
133126
134- # If TorchFT is enabled, the dp_rank and dp_degree, which are used for
135- # dataloader must be changed.
136- if ft_manager .enabled :
137- dp_degree , dp_rank = ft_manager .get_dp_info (dp_degree , dp_rank )
138-
139127 self .dataloader = self .train_spec .build_dataloader_fn (
140128 dp_world_size = dp_degree ,
141129 dp_rank = dp_rank ,
@@ -241,6 +229,9 @@ def __init__(self, job_config: JobConfig):
241229
242230 self .model_parts = [model ]
243231
232+ if self .ft_manager .enabled :
233+ self .ft_manager .set_all_reduce_hook (self .model_parts )
234+
244235 # initialize device memory monitor and get peak flops for MFU calculation
245236 device_memory_monitor = self .metrics_processor .device_memory_monitor
246237 gpu_peak_flops = utils .get_peak_flops (device_memory_monitor .device_name )
@@ -254,7 +245,7 @@ def __init__(self, job_config: JobConfig):
254245
255246 # build optimizer after applying parallelisms to the model
256247 self .optimizers = self .train_spec .build_optimizers_fn (
257- self .model_parts , job_config , ft_manager
248+ self .model_parts , job_config , self . ft_manager
258249 )
259250 self .lr_schedulers = self .train_spec .build_lr_schedulers_fn (
260251 self .optimizers , job_config
@@ -280,7 +271,7 @@ def __init__(self, job_config: JobConfig):
280271 lr_schedulers = self .lr_schedulers ,
281272 states = {"train_state" : self },
282273 job_config = job_config ,
283- ft_manager = ft_manager ,
274+ ft_manager = self . ft_manager ,
284275 )
285276
286277 self .train_context = dist_utils .get_train_context (
@@ -384,11 +375,13 @@ def train_step(self, input_dict: dict[str, torch.Tensor], labels: torch.Tensor):
384375 parallel_dims .dp_replicate_enabled
385376 or parallel_dims .dp_shard_enabled
386377 or parallel_dims .cp_enabled
378+ or self .ft_manager .enabled
387379 ):
388380 loss = loss .detach ()
381+ ft_pg = self .ft_manager .replicate_pg if self .ft_manager .enabled else None
389382 global_avg_loss , global_max_loss = (
390- dist_utils .dist_mean (loss , world_mesh ["dp_cp" ]),
391- dist_utils .dist_max (loss , world_mesh ["dp_cp" ]),
383+ dist_utils .dist_mean (loss , world_mesh ["dp_cp" ], ft_pg ),
384+ dist_utils .dist_max (loss , world_mesh ["dp_cp" ], ft_pg ),
392385 )
393386 else :
394387 global_avg_loss = global_max_loss = loss .detach ().item ()
0 commit comments