Skip to content

Commit d1fc224

Browse files
committed
fix train_ddp_cp
Signed-off-by: Peter St. John <[email protected]>
1 parent 88d735c commit d1fc224

File tree

2 files changed

+7
-8
lines changed

2 files changed

+7
-8
lines changed

bionemo-recipes/recipes/esm2_native_te/train_ddp_cp.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -109,24 +109,22 @@ def main(args: DictConfig) -> float | None:
109109
output_device=dist_config.local_rank,
110110
process_group=group_fsdp_cp,
111111
)
112-
cp_group = device_mesh["cp"].get_group()
113-
cp_rank = device_mesh.get_local_rank("cp")
114112

115113
if args.cp_size > 1:
116114
for i, transformer_layer in enumerate(model.module.esm.encoder.layers):
117115
logger.debug(f"Rank {dist_config.rank}: Setting CP group for layer {i}")
118116
transformer_layer.set_context_parallel_group(
119-
cp_group, torch.distributed.get_process_group_ranks(device_mesh["cp"].get_group()), torch.cuda.Stream()
117+
device_mesh["cp"].get_group(),
118+
torch.distributed.get_process_group_ranks(device_mesh["cp"].get_group()),
119+
torch.cuda.Stream(),
120120
)
121121

122122
# Context Parallelism requires THD Sequence Packing.
123123
assert args.use_sequence_packing, "Context Parallelism requires THD Sequence Packing."
124124

125125
train_dataloader, dataset_or_sampler = create_cp_dataloader(
126126
dist_config,
127-
cp_world_size=torch.distributed.get_world_size(group=cp_group),
128-
cp_group=cp_group,
129-
cp_rank=cp_rank,
127+
cp_mesh=device_mesh["cp"],
130128
**args.dataset,
131129
)
132130

bionemo-recipes/recipes/esm2_native_te/train_fsdp2_cp.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -114,14 +114,15 @@ def main(args: DictConfig) -> float | None:
114114
transformer_stack = model.esm.encoder.layers if hasattr(model.esm.encoder, "layers") else model.esm.encoder.layer
115115
# Fully shard takes in a DeviceMesh object, which is a 2D mesh of dimensions (CP_dimension, DP_dimension).
116116
# FSDP2 will shard the model across the DP (dim=1) dimension and then duplicate across the CP (dim=0) dimension.
117-
cp_group = device_mesh["cp"].get_group()
118117
for layer in transformer_stack:
119118
fully_shard(layer, mesh=cp_dp_mesh)
120119
# Set CP group for layer if CP is enabled.
121120
if args.cp_size > 1:
122121
logger.debug(f"Rank {dist_config.rank}: Setting CP group for layer {layer}")
123122
layer.set_context_parallel_group(
124-
cp_group, torch.distributed.get_process_group_ranks(cp_group), torch.cuda.Stream()
123+
device_mesh["cp"].get_group(),
124+
torch.distributed.get_process_group_ranks(device_mesh["cp"].get_group()),
125+
torch.cuda.Stream(),
125126
)
126127
fully_shard(model, mesh=cp_dp_mesh)
127128

0 commit comments

Comments
 (0)