-
Notifications
You must be signed in to change notification settings - Fork 3.4k
Add RL support for MOEs #2742
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
jon-barker
wants to merge
24
commits into
NVIDIA:main
Choose a base branch
from
jon-barker:jbarker/rl_nanov3_support
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Add RL support for MOEs #2742
Changes from all commits
Commits
Show all changes
24 commits
Select commit
Hold shift + click to select a range
0432284
wip: changes to get nanov3 running
jon-barker 64a3bcf
upgrade math verifier eos token logic
jon-barker dd68d9a
add grpo top-k arg
jon-barker a730ff0
wip: reinitialize zero norm embeddings
jon-barker fba7563
wip: debugging nanov3 support
jon-barker dbbfc50
update moe configs
jon-barker 7c031a2
remove core inference change
jon-barker f446e87
remove core inference change
jon-barker 9234506
remove core inference change
jon-barker 1f1d774
remove debugging code
jon-barker d5de98b
remove debug change
jon-barker 68d2bcb
remove debug code
jon-barker bbd49a4
revert changes
jon-barker b8e2bda
update nemotron6_3b_moe config
jon-barker 9f9a7d0
update nemotron6 config
jon-barker b3077d1
remove debug comments
jon-barker 4f008bd
moe support updates
jon-barker aa8fa15
cleanup
jon-barker 5734288
linting
jon-barker 75762a8
linting
jon-barker c746b87
revert move of positional arg
jon-barker 2c31264
fix failing unit test
jon-barker 02b3634
update rl arg
jon-barker 773b983
fix rl unit test
jon-barker File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,5 +1,6 @@ | ||
| - agent_type: examples.rl.environments.math.gsm8k_agent.GSM8KAgent | ||
| agent_args: | ||
| answer_format: "boxed" | ||
| format_reward: 0.5 | ||
| weight: 1.0 | ||
| evaluation_only: false |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,115 @@ | ||
| #!/bin/bash | ||
| TP=${TP:-2} | ||
| PP=${PP:-1} | ||
| EP=${EP:-1} | ||
| NODES_REQUIRED=${NODES_REQUIRED:-1} | ||
| LLM="dsv2_lite" | ||
|
|
||
| echo "Using Deepseek-v2-lite model checkpoint (not the exact model weights..)" | ||
| SCRIPT_PATH="${BASH_SOURCE[0]}" | ||
| source $(dirname $SCRIPT_PATH)/common.sh | ||
|
|
||
| # In all cases, one can override those values. | ||
| # However, running without envs will give you some | ||
| # good perf out of the box for established envs. | ||
| if [ "$(basename "$ENV_CONFIG")" = "dapo.yaml" ]; then | ||
| echo "Using DAPO environment config" | ||
| GRPO_CLAMP_EPS_LOWER=${GRPO_CLAMP_EPS_LOWER:-0.2} | ||
| GRPO_CLAMP_EPS_UPPER=${GRPO_CLAMP_EPS_UPPER:-0.28} | ||
| MAX_INFERENCE_BS=${MAX_INFERENCE_BS:-32} | ||
| GRPO_GROUP_SIZE=${GRPO_GROUP_SIZE:-16} | ||
| GRPO_PROMPTS_PER_STEP=${GRPO_PROMPTS_PER_STEP:-64} | ||
| GRPO_ITERATIONS=${GRPO_ITERATIONS:-1} | ||
| GRPO_KL_BETA=${GRPO_KL_BETA:-"0.0"} | ||
| TRAINING_BATCH_SIZE=${TRAINING_BATCH_SIZE:-1024} | ||
| MICRO_BATCH_SIZE=${MICRO_BATCH_SIZE:-1} | ||
| MAX_SEQ_LENGTH=${MAX_SEQ_LENGTH:-11999} | ||
| EXIT_INTERVAL=${EXIT_INTERVAL:-16} | ||
| CHKPT_SAVE_INTERVAL=${CHKPT_SAVE_INTERVAL:-16} | ||
| else | ||
| # Some default values if config is unsupported. | ||
| echo "Undected environment config, using default values" | ||
| GRPO_CLAMP_EPS_LOWER=${GRPO_CLAMP_EPS_LOWER:-0.2} | ||
| GRPO_CLAMP_EPS_UPPER=${GRPO_CLAMP_EPS_UPPER:-0.4} | ||
| MAX_INFERENCE_BS=${MAX_INFERENCE_BS:-64} | ||
| GRPO_GROUP_SIZE=${GRPO_GROUP_SIZE:-8} | ||
| GRPO_PROMPTS_PER_STEP=${GRPO_PROMPTS_PER_STEP:-8} | ||
| GRPO_ITERATIONS=${GRPO_ITERATIONS:-1} | ||
| GRPO_KL_BETA=${GRPO_KL_BETA:-"0.0"} | ||
| TRAINING_BATCH_SIZE=${TRAINING_BATCH_SIZE:-64} | ||
| MICRO_BATCH_SIZE=${MICRO_BATCH_SIZE:-1} | ||
| MAX_SEQ_LENGTH=${MAX_SEQ_LENGTH:-8192} | ||
| EXIT_INTERVAL=${EXIT_INTERVAL:-20} | ||
| CHKPT_SAVE_INTERVAL=${CHKPT_SAVE_INTERVAL:-20} | ||
| fi | ||
|
|
||
| ENV_DEPENDENT="\ | ||
| --micro-batch-size $MICRO_BATCH_SIZE \ | ||
| --global-batch-size $TRAINING_BATCH_SIZE \ | ||
| --grpo-group-size $GRPO_GROUP_SIZE \ | ||
| --grpo-prompts-per-step $GRPO_PROMPTS_PER_STEP \ | ||
| --grpo-iterations $GRPO_ITERATIONS \ | ||
| --grpo-clamp-eps-lower $GRPO_CLAMP_EPS_LOWER \ | ||
| --grpo-clamp-eps-upper $GRPO_CLAMP_EPS_UPPER \ | ||
| --grpo-kl-beta $GRPO_KL_BETA \ | ||
| --langrl-env-config $ENV_CONFIG " | ||
|
|
||
|
|
||
| MODEL_OPTIONS="\ | ||
| --use-checkpoint-args \ | ||
| --enable-experimental \ | ||
| --cross-entropy-loss-fusion \ | ||
| --cross-entropy-fusion-impl native \ | ||
| --moe-aux-loss-coeff 0.0 \ | ||
| --moe-router-dtype fp64 \ | ||
| --moe-router-load-balancing-type none \ | ||
| --moe-token-dispatcher-type alltoall \ | ||
| --attention-backend flash \ | ||
| --disable-gloo-process-groups \ | ||
| --grpo-default-temperature 1.2 \ | ||
| --grpo-default-top-p 0.95 \ | ||
| --disable-chunked-prefill \ | ||
| --calculate-per-token-loss \ | ||
| --seq-length $MAX_SEQ_LENGTH \ | ||
| --inference-max-seq-length $MAX_SEQ_LENGTH \ | ||
| --inference-max-batch-size $MAX_INFERENCE_BS \ | ||
| --pretrained-checkpoint $CHECKPOINT \ | ||
| --distributed-timeout-minutes 60 \ | ||
| --use-mcore-models \ | ||
| --no-mmap-bin-files \ | ||
| --disable-bias-linear \ | ||
| --norm-epsilon 1e-5 \ | ||
| --init-method-std 0.014 \ | ||
| --exit-duration-in-mins 5750 \ | ||
| --max-position-embeddings 8192 \ | ||
| --tensor-model-parallel-size $TP \ | ||
| --pipeline-model-parallel-size $PP \ | ||
| --expert-model-parallel-size $EP \ | ||
| --no-masked-softmax-fusion \ | ||
| --attention-softmax-in-fp32 \ | ||
| --weight-decay 0.01 \ | ||
| --clip-grad 0.1 \ | ||
| --tiktoken-pattern v2 \ | ||
| --tokenizer-type TikTokenizer \ | ||
| --tokenizer-model ${TOKENIZER_MODEL} \ | ||
| --no-use-tokenizer-model-from-checkpoint-args \ | ||
| --dist-ckpt-strictness log_unexpected | ||
| --ckpt-format torch_dist \ | ||
| --ckpt-fully-parallel-save \ | ||
| --ckpt-fully-parallel-load \ | ||
| --use-distributed-optimizer \ | ||
| --overlap-grad-reduce \ | ||
| --overlap-param-gather \ | ||
| --no-create-attention-mask-in-dataloader \ | ||
| --lr 1e-7 \ | ||
| --lr-warmup-samples 0 \ | ||
| --no-load-optim \ | ||
| --decode-only-cuda-graphs \ | ||
| --rl-inference-logprobs-is-correction \ | ||
| --rl-importance-sampling-truncation-coef 5.0 \ | ||
| " | ||
|
|
||
| # 1. remove importance sampling | ||
|
|
||
|
|
||
| # 2. removed any form of load balancing loss | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,128 @@ | ||
| #!/bin/bash | ||
| TP=${TP:-2} | ||
| PP=${PP:-1} | ||
| EP=${EP:-32} | ||
| NODES_REQUIRED=${NODES_REQUIRED:-4} | ||
| LLM="nemotron6_3b_moe" | ||
|
|
||
| ROOT_DIR="/lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/nemotron6" | ||
|
|
||
| CHECKPOINT="${ROOT_DIR}/3b_hybrid_moe/checkpoints/phase2_lc_reinit_emb/" | ||
|
|
||
| TOKENIZER_MODEL="${ROOT_DIR}/tokenizers/multiMixV8.gpt4o_nc_sd.500000.128k.vocab.json" | ||
|
|
||
| echo "Using Nemotron6 3B MOE model checkpoint" | ||
| SCRIPT_PATH="${BASH_SOURCE[0]}" | ||
| source $(dirname $SCRIPT_PATH)/common.sh | ||
|
|
||
| # In all cases, one can override those values. | ||
| # However, running without envs will give you some | ||
| # good perf out of the box for established envs. | ||
| if [ "$(basename "$ENV_CONFIG")" = "dapo.yaml" ]; then | ||
| echo "Using DAPO environment config" | ||
| GRPO_CLAMP_EPS_LOWER=${GRPO_CLAMP_EPS_LOWER:-0.2} | ||
| GRPO_CLAMP_EPS_UPPER=${GRPO_CLAMP_EPS_UPPER:-0.28} | ||
| MAX_INFERENCE_BS=${MAX_INFERENCE_BS:-32} | ||
| GRPO_GROUP_SIZE=${GRPO_GROUP_SIZE:-16} | ||
| GRPO_PROMPTS_PER_STEP=${GRPO_PROMPTS_PER_STEP:-64} | ||
| GRPO_ITERATIONS=${GRPO_ITERATIONS:-1} | ||
| GRPO_KL_BETA=${GRPO_KL_BETA:-"0.0"} | ||
| TRAINING_BATCH_SIZE=${TRAINING_BATCH_SIZE:-1024} | ||
| MICRO_BATCH_SIZE=${MICRO_BATCH_SIZE:-1} | ||
| MAX_SEQ_LENGTH=${MAX_SEQ_LENGTH:-11999} | ||
| EXIT_INTERVAL=${EXIT_INTERVAL:-20} | ||
| CHKPT_SAVE_INTERVAL=${CHKPT_SAVE_INTERVAL:-20} | ||
| else | ||
| # Some default values if config is unsupported. | ||
| echo "Undected environment config, using default values" | ||
| GRPO_CLAMP_EPS_LOWER=${GRPO_CLAMP_EPS_LOWER:-0.2} | ||
| GRPO_CLAMP_EPS_UPPER=${GRPO_CLAMP_EPS_UPPER:-0.28} | ||
| MAX_INFERENCE_BS=${MAX_INFERENCE_BS:-64} | ||
| GRPO_GROUP_SIZE=${GRPO_GROUP_SIZE:-2} | ||
| GRPO_PROMPTS_PER_STEP=${GRPO_PROMPTS_PER_STEP:-16} | ||
| GRPO_ITERATIONS=${GRPO_ITERATIONS:-1} | ||
| GRPO_KL_BETA=${GRPO_KL_BETA:-"0.0"} | ||
| TRAINING_BATCH_SIZE=${TRAINING_BATCH_SIZE:-32} | ||
| MICRO_BATCH_SIZE=${MICRO_BATCH_SIZE:-1} | ||
| MAX_SEQ_LENGTH=${MAX_SEQ_LENGTH:-1024} | ||
| EXIT_INTERVAL=${EXIT_INTERVAL:-20} | ||
| CHKPT_SAVE_INTERVAL=${CHKPT_SAVE_INTERVAL:-20} | ||
| fi | ||
|
|
||
| ENV_DEPENDENT="\ | ||
| --micro-batch-size $MICRO_BATCH_SIZE \ | ||
| --global-batch-size $TRAINING_BATCH_SIZE \ | ||
| --grpo-group-size $GRPO_GROUP_SIZE \ | ||
| --grpo-prompts-per-step $GRPO_PROMPTS_PER_STEP \ | ||
| --grpo-iterations $GRPO_ITERATIONS \ | ||
| --grpo-clamp-eps-lower $GRPO_CLAMP_EPS_LOWER \ | ||
| --grpo-clamp-eps-upper $GRPO_CLAMP_EPS_UPPER \ | ||
| --grpo-kl-beta $GRPO_KL_BETA \ | ||
| --langrl-env-config $ENV_CONFIG " | ||
|
|
||
| MODEL_OPTIONS="\ | ||
| --rl-skip-bos-token \ | ||
| --no-rl-use-sequence-packing \ | ||
| --rl-partial-rollouts \ | ||
| --rl-offload-optimizer-during-inference \ | ||
| --moe-pad-experts-for-cuda-graph-inference \ | ||
| --inference-dynamic-batching-max-tokens 8192 \ | ||
| --inference-dynamic-batching-max-requests 128 \ | ||
| --inference-dynamic-batching-num-cuda-graphs 2 \ | ||
| --decode-only-cuda-graphs \ | ||
| --cuda-graph-impl local \ | ||
| --cuda-graph-scope full \ | ||
| --use-checkpoint-args \ | ||
| --enable-experimental \ | ||
| --cross-entropy-loss-fusion \ | ||
| --cross-entropy-fusion-impl native \ | ||
| --moe-aux-loss-coeff 0.0 \ | ||
| --moe-router-dtype fp64 \ | ||
| --moe-router-load-balancing-type aux_loss \ | ||
| --moe-router-score-function sigmoid \ | ||
| --moe-token-dispatcher-type alltoall \ | ||
| --moe-router-enable-expert-bias \ | ||
| --moe-router-topk-scaling-factor 2.5 \ | ||
| --disable-gloo-process-groups \ | ||
| --grpo-default-top-k -1 \ | ||
| --grpo-default-temperature 1.0 \ | ||
| --grpo-default-top-p 1.0 \ | ||
| --rl-inference-logprobs-is-correction \ | ||
| --rl-importance-sampling-truncation-coef 10.0 \ | ||
| --seq-length $MAX_SEQ_LENGTH \ | ||
| --inference-max-seq-length $MAX_SEQ_LENGTH \ | ||
| --inference-max-batch-size $MAX_INFERENCE_BS \ | ||
| --pretrained-checkpoint $CHECKPOINT \ | ||
| --distributed-timeout-minutes 60 \ | ||
| --use-mcore-models \ | ||
| --no-mmap-bin-files \ | ||
| --disable-bias-linear \ | ||
| --norm-epsilon 1e-5 \ | ||
| --init-method-std 0.014 \ | ||
| --exit-duration-in-mins 5750 \ | ||
| --max-position-embeddings $MAX_SEQ_LENGTH \ | ||
| --tensor-model-parallel-size $TP \ | ||
| --pipeline-model-parallel-size $PP \ | ||
| --expert-model-parallel-size $EP \ | ||
| --expert-tensor-parallel-size 1 \ | ||
| --weight-decay 0.01 \ | ||
| --clip-grad 1.0 \ | ||
| --tiktoken-pattern v2 \ | ||
| --tokenizer-type TikTokenizer \ | ||
| --tokenizer-model ${TOKENIZER_MODEL} \ | ||
| --dist-ckpt-strictness log_unexpected | ||
| --ckpt-format torch_dist \ | ||
| --ckpt-fully-parallel-save \ | ||
| --ckpt-fully-parallel-load \ | ||
| --use-distributed-optimizer \ | ||
| --overlap-grad-reduce \ | ||
| --overlap-param-gather \ | ||
| --no-create-attention-mask-in-dataloader \ | ||
| --lr 3e-6 \ | ||
| --min-lr 3e-6 \ | ||
| --lr-decay-style constant \ | ||
| --lr-warmup-samples 640 \ | ||
| --lr-warmup-init 0.3e-7 \ | ||
| --no-load-optim \ | ||
| --no-load-rng \ | ||
| " |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It might be helpful to clarify that this is only architecturally identical to DSv2-lite. The model weights are different.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We're not actually pointing to any weights here. Is it the case that if someone did have the genuine dsv2_lite weights converted to mcore format we'd expect them to work correctly though?