Skip to content

Commit 073be7c

Browse files
committed
refactor(config): prune megatron trainer defaults and preserve CLI override compatibility
- Remove redundant/null entries from megatron/trainer_base.yaml to simplify the preset and reduce maintenance noise. - Update launcher parsing to allow Megatron override passthrough, keeping CLI-driven configuration backward-compatible after preset key cleanup.
1 parent 3098213 commit 073be7c

File tree

2 files changed

+18
-95
lines changed

2 files changed

+18
-95
lines changed

primus/configs/modules/megatron/trainer_base.yaml

Lines changed: 0 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -11,21 +11,14 @@ trainable: true
1111
# training, optimzer, checkpoint, loss, distributed, recompute, data, profile, logging
1212

1313
# training
14-
yaml_cfg: null # not support
15-
spec: null
1614
micro_batch_size: 2
17-
batch_size: null # deprecated
1815
global_batch_size: 128
19-
rampup_batch_size: null
2016
decrease_batch_size_if_needed: false
2117
check_for_nan_in_loss_and_grad: true
2218
check_for_spiky_loss: false
2319
check_for_large_grads: false
2420
make_vocab_size_divisible_by: 128
2521
exit_signal_handler: false
26-
exit_duration_in_mins: null
27-
exit_interval: null
28-
onnx_safe: null
2922
bert_binary_head: true
3023

3124
use_flash_attn: false
@@ -39,7 +32,6 @@ fp16: false
3932
bf16: true
4033
grad_reduce_in_bf16: false
4134
calculate_per_token_loss: false
42-
loss_scale: null
4335
initial_loss_scale: 4294967296
4436
min_loss_scale: 1.0
4537
loss_scale_window: 1000
@@ -48,7 +40,6 @@ accumulate_allreduce_grads_in_fp32: false
4840
fp16_lm_cross_entropy: false
4941

5042
# fp8
51-
fp8: null # e4m3, hybrid
5243
fp8_margin: 0
5344
fp8_recipe: delayed
5445
fp8_interval: 1 # deprecated
@@ -60,7 +51,6 @@ te_rng_tracker: false
6051
inference_rng_tracker: false
6152

6253
# fp4
63-
fp4: null
6454
fp4_recipe: nvfp4
6555
fp4_param: false
6656

@@ -72,20 +62,13 @@ num_layers_at_end_in_bf16: 1
7262
optimizer: adam
7363
lr: 2.5e-4
7464
lr_decay_style: cosine
75-
lr_decay_iters: null
76-
lr_decay_samples: null
77-
lr_warmup_fraction: null
7865
lr_warmup_iters: 0
7966
lr_warmup_samples: 0
8067
lr_warmup_init: 0.0
8168
min_lr: 2.5e-5
8269
lr_wsd_decay_style: exponential
83-
lr_wsd_decay_samples: null
84-
lr_wsd_decay_iters: null
8570
head_lr_mult: 1.0
8671
weight_decay: 0.01
87-
start_weight_decay: null
88-
end_weight_decay: null
8972
weight_decay_incr_style: constant
9073
clip_grad: 1.0
9174
adam_beta1: 0.9
@@ -94,9 +77,6 @@ adam_eps: 1.0e-08
9477
sgd_momentum: 0.9
9578
override_opt_param_scheduler: false
9679
use_checkpoint_opt_param_scheduler: false
97-
warmup: null
98-
decoupled_lr: null
99-
decoupled_min_lr: null
10080
# muon
10181
muon_extra_scale_factor: 1.0
10282
muon_scale_mode: "spectral"
@@ -117,41 +97,22 @@ pin_cpu_grads: true
11797
pin_cpu_params: true
11898

11999
# checkpointing arguments
120-
save: null
121100
save_interval: 20000
122-
save_retain_interval: null
123-
no_save_optim: null
124-
no_save_rng: null
125-
load: null
126101
load_main_params_from_ckpt: false
127-
no_load_optim: null
128-
no_load_rng: null
129102
finetune: false
130103
use_checkpoint_args: false
131104
use_mp_args_from_checkpoint_args: false
132105
use_tokenizer_model_from_checkpoint_args: true
133106
exit_on_missing_checkpoint: true
134-
non_persistent_save_interval: null # int
135-
non_persistent_ckpt_type: null # 'global', 'local', 'in_memory', null
136-
non_persistent_global_ckpt_dir: null # str
137-
non_persistent_local_ckpt_dir: null # str
138107
non_persistent_local_ckpt_algo: "fully_parallel" # 'fully_parallel', 'atomic'
139-
dist_ckpt_save_pre_mcore_014: null
140-
dist_ckpt_optim_fully_reshardable: null
141108

142-
pretrained_checkpoint: null
143-
ckpt_step: null
144109
use_dist_ckpt_deprecated: false
145110
use_persistent_ckpt_worker: false
146111
auto_detect_ckpt_format: false
147-
dist_ckpt_format_deprecated: null
148112
ckpt_format: torch_dist # 'torch', 'torch_dist', 'zarr'
149-
ckpt_convert_format: null # 'torch', 'torch_dist', 'zarr'
150-
ckpt_convert_save: null
151113
ckpt_convert_update_legacy_dist_opt_format: false
152114
ckpt_fully_parallel_save_deprecated: false
153115
ckpt_fully_parallel_save: true
154-
async_save: null
155116
ckpt_fully_parallel_load: false
156117
ckpt_assume_constant_structure: false
157118
dist_ckpt_strictness: assume_ok_unexpected
@@ -163,8 +124,6 @@ distributed_timeout_minutes: 10
163124
defer_embedding_wgrad_compute: false
164125
wgrad_deferral_limit: 0 # int
165126
align_grad_reduce: true
166-
ddp_num_buckets: null # int
167-
ddp_bucket_size: null # int
168127
ddp_pad_buckets_for_high_nccl_busbw: false
169128
ddp_average_in_collective: false
170129
overlap_grad_reduce: false
@@ -173,15 +132,12 @@ overlap_param_gather_with_optimizer_step: false
173132
align_param_gather: true
174133
scatter_gather_tensors_in_pipeline: true
175134
use_ring_exchange_p2p: false
176-
local_rank: null
177-
lazy_mpu_init: null
178135
account_for_embedding_in_pipeline_split: false
179136
account_for_loss_in_pipeline_split: false
180137
empty_unused_memory_level: 0
181138
standalone_embedding_stage: false
182139
use_distributed_optimizer: false
183140
use_sharp: false
184-
sharp_enabled_group: null # options: [dp, dp_replica]
185141
use_custom_fsdp: false
186142
use_megatron_fsdp: false
187143
init_model_with_meta_device: false
@@ -191,31 +147,22 @@ suggested_communication_unit_size: 400000000 # int
191147
keep_fp8_transpose_cache_when_using_custom_fsdp: false
192148
num_distributed_optimizer_instances: 1 # int
193149
use_torch_fsdp2: false
194-
nccl_communicator_config_path: null
195150
use_tp_pp_dp_mapping: false
196151
replication: false
197-
replication_jump: null # int
198-
replication_factor: null # int
199152
deterministic_mode: false
200-
check_weight_hash_across_dp_replicas_interval: null
201153
overlap_moe_expert_parallel_comm: false
202154

203-
train_iters: null
204155
eval_iters: 32
205156
full_validation: false
206157
multiple_validation_sets: false
207158
eval_interval: 2000
208159
skip_train: false
209-
train_sync_interval: null # int
210160

211161
adlr_autoresume: false
212162
adlr_autoresume_interval: 1000
213163

214164
# activation recomputation
215165
recompute_activations: false
216-
recompute_granularity: null # full, selective
217-
recompute_method: null # uniform, block
218-
recompute_num_layers: null # int
219166
distribute_saved_activations: false
220167
checkpoint_activations: false # deprecated
221168

@@ -225,20 +172,10 @@ manual_gc_interval: 1 # int, default 0
225172
manual_gc_eval: false
226173

227174
#data
228-
data_path: null
229175
data_sharding: true
230176
split: "99,1,0"
231-
train_data_path: null
232-
valid_data_path: null
233-
test_data_path: null
234-
data_args_path: null # str
235-
per_split_data_args_path: null # str
236-
data_cache_path: null
237177
mock_data: false
238-
merge_file: null
239178
seq_length: 4096
240-
encoder_seq_length: null
241-
decoder_seq_length: null
242179
retriever_seq_length: 256
243180
sample_rate: 1.0
244181
mask_prob: 0.15
@@ -247,8 +184,6 @@ num_workers: 8
247184
reset_position_ids: false
248185
reset_attention_mask: false
249186
eod_mask_loss: false
250-
train_samples: null
251-
dataloader_type: null
252187
mmap_bin_files: true
253188

254189
#profile:
@@ -257,8 +192,6 @@ use_pytorch_profiler: false
257192
profile_ranks: [0]
258193
profile_step_end: 12
259194
profile_step_start: 10
260-
iterations_to_skip: null
261-
result_rejected_tracker_filename: null
262195
enable_gloo_process_groups: true
263196
record_memory_history: false
264197
memory_snapshot_path: snapshot.pickle # str
@@ -281,20 +214,12 @@ log_validation_ppl_to_tensorboard: false
281214
log_memory_to_tensorboard: false
282215
log_world_size_to_tensorboard: false
283216
log_loss_scale_to_tensorboard: true
284-
wandb_project: null
285-
wandb_exp_name: null
286-
wandb_save_dir: null
287-
wandb_entity: null
288217
enable_one_logger: true
289218
one_logger_project: megatron-lm
290-
one_logger_run_name: null
291219
log_interval: 100
292-
tensorboard_dir: null
293-
logging_level: null # int
294220
config_logger_dir: ""
295221

296222
one_logger_async: false
297-
app_tag_run_name: null
298223
app_tag_run_version: 0.0.0
299224

300225
# rerun
@@ -338,9 +263,7 @@ classes_fraction: 1.0
338263
data_per_class_fraction: 1.0
339264

340265
# others
341-
retro_project_dir: null
342266
retro_add_retriever: false
343-
retro_cyclic_train_iters: null
344267
retro_encoder_layers: 2
345268
retro_encoder_hidden_dropout: 0.1
346269
retro_encoder_attention_dropout: 0.1
@@ -370,9 +293,6 @@ inference_batch_times_seqlen_threshold: -1
370293
inference_dynamic_batching: false
371294
inference_dynamic_batching_buffer_size_gb: 40.0 # float
372295
inference_dynamic_batching_buffer_guaranteed_fraction: 0.2 # float
373-
inference_dynamic_batching_buffer_overflow_factor: null # float
374-
inference_dynamic_batching_max_requests_override: null # int
375-
inference_dynamic_batching_max_tokens_override: null # int
376296
max_tokens_to_oom: 12000
377297
output_bert_embeddings: false
378298
bert_embedder_type: megatron # "megatron", "huggingface"
@@ -386,28 +306,19 @@ inference_max_seq_length: 2560 # int, (prefill + decode)
386306

387307
create_attention_mask_in_dataloader: true
388308
num_dataset_builder_threads: 1
389-
ict_head_size: null
390309
biencoder_projection_dim: 0
391310
biencoder_shared_query_context_model: false
392-
ict_load: null
393-
bert_load: null
394-
titles_data_path: null
395311
query_in_block_prob: 0.1
396312
use_one_sent_docs: false
397-
evidence_data_path: null
398313
retriever_report_topk_accuracies: []
399314
retriever_score_scaling: false
400-
block_data_path: null
401-
embedding_path: null
402315
indexer_batch_size: 128
403316
indexer_log_interval: 1000
404317

405318
enable_ft_package: false
406319
calc_ft_timeouts: false
407320
run_workload_inspector_server: false
408321

409-
heterogeneous_layers_config_path: null
410-
heterogeneous_layers_config_encoded_json: null
411322
inprocess_restart: false
412323

413324
# rl_args
@@ -424,13 +335,10 @@ grpo_filter_groups_with_same_reward: false
424335
grpo_default_temperature: 1.0
425336
grpo_default_top_p: 0
426337
langrl_inference_server_type: inplace_megatron
427-
langrl_inference_server_conversation_template: null
428-
langrl_env_config: null
429338
rl_offload_optimizer_during_inference: false
430339
rl_offload_kv_cache_during_training: false
431340
rl_remove_kv_cache_during_training: false
432341
rl_reset_cuda_graphs: false
433342
rl_partial_rollouts: false
434343
rl_inference_logprobs_is_correction: false
435-
rl_importance_sampling_truncation_coef: null
436344
rl_calculate_intra_group_similarity: false

primus/core/launcher/parser.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,15 @@ def _split_known_unknown(ns: SimpleNamespace, overrides: dict) -> Tuple[dict, di
127127
return known, unknown
128128

129129

130+
def _allow_backend_override_passthrough(pre_trainer_cfg: SimpleNamespace, framework_name: str) -> bool:
131+
"""
132+
Decide whether CLI override keys should be merged directly into module config
133+
even when those keys are not declared in the module YAML preset.
134+
"""
135+
framework = yaml_utils.get_value_by_key(pre_trainer_cfg, "framework")
136+
return framework == framework_name
137+
138+
130139
def parse_args(extra_args_provider=None, ignore_unknown_args=False):
131140
args, unknown_args = _parse_args(extra_args_provider, ignore_unknown_args=True)
132141

@@ -135,7 +144,9 @@ def parse_args(extra_args_provider=None, ignore_unknown_args=False):
135144

136145
overrides = parse_cli_overrides(unknown_args, type_mode="legacy")
137146
pre_trainer_cfg = primus_config.get_module_config("pre_trainer")
138-
_check_keys_exist(pre_trainer_cfg, overrides)
147+
# For megatron, allow passthrough so YAML key cleanup does not break CLI overrides.
148+
if not _allow_backend_override_passthrough(pre_trainer_cfg, "megatron"):
149+
_check_keys_exist(pre_trainer_cfg, overrides)
139150
_deep_merge_namespace(pre_trainer_cfg, overrides)
140151

141152
return primus_config
@@ -163,9 +174,13 @@ def _load_legacy_primus_config(args: argparse.Namespace, overrides: List[str]) -
163174
# 3 Apply overrides to pre_trainer module config
164175
pre_trainer_cfg = primus_config.get_module_config("pre_trainer")
165176
# _check_keys_exist(pre_trainer_cfg, override_ns)
166-
# _deep_merge_namespace(pre_trainer_cfg, override_ns)
167177

168-
# return primus_config
178+
if _allow_backend_override_passthrough(pre_trainer_cfg, "megatron"):
179+
# Megatron legacy flow should not depend on YAML key presence. This keeps
180+
# overrides stable after removing redundant/null keys from module presets.
181+
_deep_merge_namespace(pre_trainer_cfg, override_ns)
182+
return primus_config, {}
183+
169184
known_overrides, unknown_overrides = _split_known_unknown(pre_trainer_cfg, override_ns)
170185

171186
if known_overrides:

0 commit comments

Comments
 (0)