File tree Expand file tree Collapse file tree 2 files changed +7
-8
lines changed
bionemo-recipes/recipes/esm2_native_te Expand file tree Collapse file tree 2 files changed +7
-8
lines changed Original file line number Diff line number Diff line change @@ -109,24 +109,22 @@ def main(args: DictConfig) -> float | None:
109109 output_device = dist_config .local_rank ,
110110 process_group = group_fsdp_cp ,
111111 )
112- cp_group = device_mesh ["cp" ].get_group ()
113- cp_rank = device_mesh .get_local_rank ("cp" )
114112
115113 if args .cp_size > 1 :
116114 for i , transformer_layer in enumerate (model .module .esm .encoder .layers ):
117115 logger .debug (f"Rank { dist_config .rank } : Setting CP group for layer { i } " )
118116 transformer_layer .set_context_parallel_group (
119- cp_group , torch .distributed .get_process_group_ranks (device_mesh ["cp" ].get_group ()), torch .cuda .Stream ()
117+ device_mesh ["cp" ].get_group (),
118+ torch .distributed .get_process_group_ranks (device_mesh ["cp" ].get_group ()),
119+ torch .cuda .Stream (),
120120 )
121121
122122 # Context Parallelism requires THD Sequence Packing.
123123 assert args .use_sequence_packing , "Context Parallelism requires THD Sequence Packing."
124124
125125 train_dataloader , dataset_or_sampler = create_cp_dataloader (
126126 dist_config ,
127- cp_world_size = torch .distributed .get_world_size (group = cp_group ),
128- cp_group = cp_group ,
129- cp_rank = cp_rank ,
127+ cp_mesh = device_mesh ["cp" ],
130128 ** args .dataset ,
131129 )
132130
Original file line number Diff line number Diff line change @@ -114,14 +114,15 @@ def main(args: DictConfig) -> float | None:
114114 transformer_stack = model .esm .encoder .layers if hasattr (model .esm .encoder , "layers" ) else model .esm .encoder .layer
115115 # Fully shard takes in a DeviceMesh object, which is a 2D mesh of dimensions (CP_dimension, DP_dimension).
116116 # FSDP2 will shard the model across the DP (dim=1) dimension and then duplicate across the CP (dim=0) dimension.
117- cp_group = device_mesh ["cp" ].get_group ()
118117 for layer in transformer_stack :
119118 fully_shard (layer , mesh = cp_dp_mesh )
120119 # Set CP group for layer if CP is enabled.
121120 if args .cp_size > 1 :
122121 logger .debug (f"Rank { dist_config .rank } : Setting CP group for layer { layer } " )
123122 layer .set_context_parallel_group (
124- cp_group , torch .distributed .get_process_group_ranks (cp_group ), torch .cuda .Stream ()
123+ device_mesh ["cp" ].get_group (),
124+ torch .distributed .get_process_group_ranks (device_mesh ["cp" ].get_group ()),
125+ torch .cuda .Stream (),
125126 )
126127 fully_shard (model , mesh = cp_dp_mesh )
127128
You can’t perform that action at this time.
0 commit comments