Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 0 additions & 92 deletions primus/configs/modules/megatron/trainer_base.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,21 +11,14 @@ trainable: true
# training, optimzer, checkpoint, loss, distributed, recompute, data, profile, logging

# training
yaml_cfg: null # not support
spec: null
micro_batch_size: 2
batch_size: null # deprecated
global_batch_size: 128
rampup_batch_size: null
decrease_batch_size_if_needed: false
check_for_nan_in_loss_and_grad: true
check_for_spiky_loss: false
check_for_large_grads: false
make_vocab_size_divisible_by: 128
exit_signal_handler: false
exit_duration_in_mins: null
exit_interval: null
onnx_safe: null
bert_binary_head: true

use_flash_attn: false
Expand All @@ -39,7 +32,6 @@ fp16: false
bf16: true
grad_reduce_in_bf16: false
calculate_per_token_loss: false
loss_scale: null
initial_loss_scale: 4294967296
min_loss_scale: 1.0
loss_scale_window: 1000
Expand All @@ -48,7 +40,6 @@ accumulate_allreduce_grads_in_fp32: false
fp16_lm_cross_entropy: false

# fp8
fp8: null # e4m3, hybrid
fp8_margin: 0
fp8_recipe: delayed
fp8_interval: 1 # deprecated
Expand All @@ -60,7 +51,6 @@ te_rng_tracker: false
inference_rng_tracker: false

# fp4
fp4: null
fp4_recipe: nvfp4
fp4_param: false

Expand All @@ -72,20 +62,13 @@ num_layers_at_end_in_bf16: 1
optimizer: adam
lr: 2.5e-4
lr_decay_style: cosine
lr_decay_iters: null
lr_decay_samples: null
lr_warmup_fraction: null
lr_warmup_iters: 0
lr_warmup_samples: 0
lr_warmup_init: 0.0
min_lr: 2.5e-5
lr_wsd_decay_style: exponential
lr_wsd_decay_samples: null
lr_wsd_decay_iters: null
head_lr_mult: 1.0
weight_decay: 0.01
start_weight_decay: null
end_weight_decay: null
weight_decay_incr_style: constant
clip_grad: 1.0
adam_beta1: 0.9
Expand All @@ -94,9 +77,6 @@ adam_eps: 1.0e-08
sgd_momentum: 0.9
override_opt_param_scheduler: false
use_checkpoint_opt_param_scheduler: false
warmup: null
decoupled_lr: null
decoupled_min_lr: null
# muon
muon_extra_scale_factor: 1.0
muon_scale_mode: "spectral"
Expand All @@ -117,41 +97,22 @@ pin_cpu_grads: true
pin_cpu_params: true

# checkpointing arguments
save: null
save_interval: 20000
save_retain_interval: null
no_save_optim: null
no_save_rng: null
load: null
load_main_params_from_ckpt: false
no_load_optim: null
no_load_rng: null
finetune: false
use_checkpoint_args: false
use_mp_args_from_checkpoint_args: false
use_tokenizer_model_from_checkpoint_args: true
exit_on_missing_checkpoint: true
non_persistent_save_interval: null # int
non_persistent_ckpt_type: null # 'global', 'local', 'in_memory', null
non_persistent_global_ckpt_dir: null # str
non_persistent_local_ckpt_dir: null # str
non_persistent_local_ckpt_algo: "fully_parallel" # 'fully_parallel', 'atomic'
dist_ckpt_save_pre_mcore_014: null
dist_ckpt_optim_fully_reshardable: null

pretrained_checkpoint: null
ckpt_step: null
use_dist_ckpt_deprecated: false
use_persistent_ckpt_worker: false
auto_detect_ckpt_format: false
dist_ckpt_format_deprecated: null
ckpt_format: torch_dist # 'torch', 'torch_dist', 'zarr'
ckpt_convert_format: null # 'torch', 'torch_dist', 'zarr'
ckpt_convert_save: null
ckpt_convert_update_legacy_dist_opt_format: false
ckpt_fully_parallel_save_deprecated: false
ckpt_fully_parallel_save: true
async_save: null
ckpt_fully_parallel_load: false
ckpt_assume_constant_structure: false
dist_ckpt_strictness: assume_ok_unexpected
Expand All @@ -163,8 +124,6 @@ distributed_timeout_minutes: 10
defer_embedding_wgrad_compute: false
wgrad_deferral_limit: 0 # int
align_grad_reduce: true
ddp_num_buckets: null # int
ddp_bucket_size: null # int
ddp_pad_buckets_for_high_nccl_busbw: false
ddp_average_in_collective: false
overlap_grad_reduce: false
Expand All @@ -173,15 +132,12 @@ overlap_param_gather_with_optimizer_step: false
align_param_gather: true
scatter_gather_tensors_in_pipeline: true
use_ring_exchange_p2p: false
local_rank: null
lazy_mpu_init: null
account_for_embedding_in_pipeline_split: false
account_for_loss_in_pipeline_split: false
empty_unused_memory_level: 0
standalone_embedding_stage: false
use_distributed_optimizer: false
use_sharp: false
sharp_enabled_group: null # options: [dp, dp_replica]
use_custom_fsdp: false
use_megatron_fsdp: false
init_model_with_meta_device: false
Expand All @@ -191,31 +147,22 @@ suggested_communication_unit_size: 400000000 # int
keep_fp8_transpose_cache_when_using_custom_fsdp: false
num_distributed_optimizer_instances: 1 # int
use_torch_fsdp2: false
nccl_communicator_config_path: null
use_tp_pp_dp_mapping: false
replication: false
replication_jump: null # int
replication_factor: null # int
deterministic_mode: false
check_weight_hash_across_dp_replicas_interval: null
overlap_moe_expert_parallel_comm: false

train_iters: null
eval_iters: 32
full_validation: false
multiple_validation_sets: false
eval_interval: 2000
skip_train: false
train_sync_interval: null # int

adlr_autoresume: false
adlr_autoresume_interval: 1000

# activation recomputation
recompute_activations: false
recompute_granularity: null # full, selective
recompute_method: null # uniform, block
recompute_num_layers: null # int
distribute_saved_activations: false
checkpoint_activations: false # deprecated

Expand All @@ -225,20 +172,10 @@ manual_gc_interval: 1 # int, default 0
manual_gc_eval: false

#data
data_path: null
data_sharding: true
split: "99,1,0"
train_data_path: null
valid_data_path: null
test_data_path: null
data_args_path: null # str
per_split_data_args_path: null # str
data_cache_path: null
mock_data: false
merge_file: null
seq_length: 4096
encoder_seq_length: null
decoder_seq_length: null
retriever_seq_length: 256
sample_rate: 1.0
mask_prob: 0.15
Expand All @@ -247,8 +184,6 @@ num_workers: 8
reset_position_ids: false
reset_attention_mask: false
eod_mask_loss: false
train_samples: null
dataloader_type: null
mmap_bin_files: true

#profile:
Expand All @@ -257,8 +192,6 @@ use_pytorch_profiler: false
profile_ranks: [0]
profile_step_end: 12
profile_step_start: 10
iterations_to_skip: null
result_rejected_tracker_filename: null
enable_gloo_process_groups: true
record_memory_history: false
memory_snapshot_path: snapshot.pickle # str
Expand All @@ -281,20 +214,12 @@ log_validation_ppl_to_tensorboard: false
log_memory_to_tensorboard: false
log_world_size_to_tensorboard: false
log_loss_scale_to_tensorboard: true
wandb_project: null
wandb_exp_name: null
wandb_save_dir: null
wandb_entity: null
enable_one_logger: true
one_logger_project: megatron-lm
one_logger_run_name: null
log_interval: 100
tensorboard_dir: null
logging_level: null # int
config_logger_dir: ""

one_logger_async: false
app_tag_run_name: null
app_tag_run_version: 0.0.0

# rerun
Expand Down Expand Up @@ -338,9 +263,7 @@ classes_fraction: 1.0
data_per_class_fraction: 1.0

# others
retro_project_dir: null
retro_add_retriever: false
retro_cyclic_train_iters: null
retro_encoder_layers: 2
retro_encoder_hidden_dropout: 0.1
retro_encoder_attention_dropout: 0.1
Expand Down Expand Up @@ -370,9 +293,6 @@ inference_batch_times_seqlen_threshold: -1
inference_dynamic_batching: false
inference_dynamic_batching_buffer_size_gb: 40.0 # float
inference_dynamic_batching_buffer_guaranteed_fraction: 0.2 # float
inference_dynamic_batching_buffer_overflow_factor: null # float
inference_dynamic_batching_max_requests_override: null # int
inference_dynamic_batching_max_tokens_override: null # int
max_tokens_to_oom: 12000
output_bert_embeddings: false
bert_embedder_type: megatron # "megatron", "huggingface"
Expand All @@ -386,28 +306,19 @@ inference_max_seq_length: 2560 # int, (prefill + decode)

create_attention_mask_in_dataloader: true
num_dataset_builder_threads: 1
ict_head_size: null
biencoder_projection_dim: 0
biencoder_shared_query_context_model: false
ict_load: null
bert_load: null
titles_data_path: null
query_in_block_prob: 0.1
use_one_sent_docs: false
evidence_data_path: null
retriever_report_topk_accuracies: []
retriever_score_scaling: false
block_data_path: null
embedding_path: null
indexer_batch_size: 128
indexer_log_interval: 1000

enable_ft_package: false
calc_ft_timeouts: false
run_workload_inspector_server: false

heterogeneous_layers_config_path: null
heterogeneous_layers_config_encoded_json: null
inprocess_restart: false

# rl_args
Expand All @@ -424,13 +335,10 @@ grpo_filter_groups_with_same_reward: false
grpo_default_temperature: 1.0
grpo_default_top_p: 0
langrl_inference_server_type: inplace_megatron
langrl_inference_server_conversation_template: null
langrl_env_config: null
rl_offload_optimizer_during_inference: false
rl_offload_kv_cache_during_training: false
rl_remove_kv_cache_during_training: false
rl_reset_cuda_graphs: false
rl_partial_rollouts: false
rl_inference_logprobs_is_correction: false
rl_importance_sampling_truncation_coef: null
rl_calculate_intra_group_similarity: false
21 changes: 18 additions & 3 deletions primus/core/launcher/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,15 @@ def _split_known_unknown(ns: SimpleNamespace, overrides: dict) -> Tuple[dict, di
return known, unknown


def _allow_backend_override_passthrough(pre_trainer_cfg: SimpleNamespace, framework_name: str) -> bool:
"""
Decide whether CLI override keys should be merged directly into module config
even when those keys are not declared in the module YAML preset.
"""
framework = yaml_utils.get_value_by_key(pre_trainer_cfg, "framework")
return framework == framework_name


def parse_args(extra_args_provider=None, ignore_unknown_args=False):
args, unknown_args = _parse_args(extra_args_provider, ignore_unknown_args=True)

Expand All @@ -135,7 +144,9 @@ def parse_args(extra_args_provider=None, ignore_unknown_args=False):

overrides = parse_cli_overrides(unknown_args, type_mode="legacy")
pre_trainer_cfg = primus_config.get_module_config("pre_trainer")
_check_keys_exist(pre_trainer_cfg, overrides)
# For megatron, allow passthrough so YAML key cleanup does not break CLI overrides.
if not _allow_backend_override_passthrough(pre_trainer_cfg, "megatron"):
_check_keys_exist(pre_trainer_cfg, overrides)
_deep_merge_namespace(pre_trainer_cfg, overrides)

return primus_config
Expand Down Expand Up @@ -163,9 +174,13 @@ def _load_legacy_primus_config(args: argparse.Namespace, overrides: List[str]) -
# 3 Apply overrides to pre_trainer module config
pre_trainer_cfg = primus_config.get_module_config("pre_trainer")
# _check_keys_exist(pre_trainer_cfg, override_ns)
# _deep_merge_namespace(pre_trainer_cfg, override_ns)

# return primus_config
if _allow_backend_override_passthrough(pre_trainer_cfg, "megatron"):
# Megatron legacy flow should not depend on YAML key presence. This keeps
# overrides stable after removing redundant/null keys from module presets.
_deep_merge_namespace(pre_trainer_cfg, override_ns)
return primus_config, {}

known_overrides, unknown_overrides = _split_known_unknown(pre_trainer_cfg, override_ns)

if known_overrides:
Expand Down
Loading