Skip to content
5 changes: 2 additions & 3 deletions examples/configs/dpo.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -138,9 +138,8 @@ policy:
start_weight_decay: ${policy.megatron_cfg.optimizer.weight_decay}
end_weight_decay: ${policy.megatron_cfg.optimizer.weight_decay}
weight_decay_incr_style: "constant"
lr_decay_style: "linear"
lr_decay_iters: 1000000000
lr_warmup_iters: 2
lr_decay_style: "constant"
lr_warmup_iters: 1
lr_warmup_init: 0.00000001

distributed_data_parallel_config:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,9 +92,8 @@ policy:
start_weight_decay: ${policy.megatron_cfg.optimizer.weight_decay}
end_weight_decay: ${policy.megatron_cfg.optimizer.weight_decay}
weight_decay_incr_style: "constant"
lr_decay_style: "linear"
lr_decay_iters: 1000000000
lr_warmup_iters: 2
lr_decay_style: "constant"
lr_warmup_iters: 1
lr_warmup_init: 0.00000001

distributed_data_parallel_config:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,9 +92,8 @@ policy:
start_weight_decay: ${policy.megatron_cfg.optimizer.weight_decay}
end_weight_decay: ${policy.megatron_cfg.optimizer.weight_decay}
weight_decay_incr_style: "constant"
lr_decay_style: "linear"
lr_decay_iters: 1000000000
lr_warmup_iters: 2
lr_decay_style: "constant"
lr_warmup_iters: 1
lr_warmup_init: 0.00000001

distributed_data_parallel_config:
Expand Down
12 changes: 12 additions & 0 deletions nemo_rl/algorithms/dpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,18 @@ def setup(
# Training
# ==========================
print("\n▶ Setting up model...")
if policy_config.get("megatron_cfg", {}).get("enabled", False):
total_train_iters = min(
dpo_config["max_num_steps"],
dpo_config["max_num_epochs"] * len(train_dataloader),
)
## NOTE: we double the train_iters because effective batch size is doubled
## for (chosen, rejected) pairs
policy_config["megatron_cfg"]["train_iters"] = total_train_iters * 2
if "scheduler" in policy_config["megatron_cfg"]:
for k in policy_config["megatron_cfg"]["scheduler"]:
if "iters" in k:
policy_config["megatron_cfg"]["scheduler"][k] *= 2
policy = Policy(
cluster=cluster,
config=policy_config,
Expand Down
5 changes: 5 additions & 0 deletions nemo_rl/algorithms/grpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,6 +341,11 @@ def setup(
weights_path = None
optimizer_path = None

if policy_config.get("megatron_cfg", {}).get("enabled", False):
## NOTE: this is equal to the total number of scheduler steps
total_train_iters = min(grpo_config["max_num_steps"], len(dataloader))
policy_config["megatron_cfg"]["train_iters"] = total_train_iters

policy = Policy(
cluster=train_cluster,
config=policy_config,
Expand Down
12 changes: 12 additions & 0 deletions nemo_rl/algorithms/rm.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,18 @@ def setup(
# Training
# ==========================
print("\n▶ Setting up model...")
if policy_config.get("megatron_cfg", {}).get("enabled", False):
total_train_iters = min(
rm_config["max_num_steps"],
rm_config["max_num_epochs"] * len(train_dataloader),
)
## NOTE: we double the train_iters because effective batch size is doubled
## for (chosen, rejected) pairs
policy_config["megatron_cfg"]["train_iters"] = total_train_iters * 2
if "scheduler" in policy_config["megatron_cfg"]:
for k in policy_config["megatron_cfg"]["scheduler"]:
if "iters" in k:
policy_config["megatron_cfg"]["scheduler"][k] *= 2
policy = Policy(
cluster=cluster,
config=policy_config,
Expand Down
6 changes: 6 additions & 0 deletions nemo_rl/algorithms/sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,12 @@ def setup(
# Training
# ==========================
print("\n▶ Setting up model...")
if policy_config.get("megatron_cfg", {}).get("enabled", False):
total_train_iters = min(
sft_config["max_num_steps"],
sft_config["max_num_epochs"] * len(train_dataloader),
)
policy_config["megatron_cfg"]["train_iters"] = total_train_iters
# check if tokenizer is a processor (e.g. for VLMs)
processor = None
if not isinstance(tokenizer, PreTrainedTokenizerBase):
Expand Down
10 changes: 9 additions & 1 deletion nemo_rl/models/policy/megatron_policy_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -568,14 +568,22 @@ def __init__(
fully_parallel_load=True, # Enable fully parallel load
load_rng=False,
)

assert "train_iters" in self.cfg["megatron_cfg"], (
"train_iters must be set in megatron_cfg. For an example, see "
"https://github.com/NVIDIA-NeMo/RL/blob/bccbc377705a81a1f4b3c31ad9767bcc15f735a8/nemo_rl/algorithms/sft.py#L175-L179."
)

self.megatron_cfg = ConfigContainer(
model_config=model_cfg,
checkpoint_config=checkpoint_config,
logger_config=LoggerConfig(logging_level=0),
train_config=TrainingConfig(
micro_batch_size=1, # ignored
global_batch_size=self.cfg["train_global_batch_size"], # ignored
train_iters=1000, # Default value for inference
train_iters=self.cfg["megatron_cfg"][
"train_iters"
], # Set by algorithm setup
),
optimizer_config=OptimizerConfig(
**self.cfg["megatron_cfg"]["optimizer"],
Expand Down