Skip to content

Eagle3: add norm_before_fc for gpt-oss draft models#337

Merged
shanjiaz merged 3 commits into
mainfrom
gpt_oss_norm_fix
Mar 10, 2026
Merged

Eagle3: add norm_before_fc for gpt-oss draft models#337
shanjiaz merged 3 commits into
mainfrom
gpt_oss_norm_fix

Conversation

@shubhra

@shubhra shubhra commented Mar 9, 2026

Copy link
Copy Markdown
Collaborator

When training Eagle3 draft models for gpt-oss, we observed exploding hidden states in the draft path. This PR adds an optional RMSNorm before the fc (on the concatenated 3× aux hidden states) to stabilize training. The behavior is gated by norm_before_fc so only gpt-oss (or models that need it) use it; others are unchanged.

Description

The Eagle3 fusion path concatenates three aux hidden states and projects via fc. gpt-oss exhibits exploding states in this path; we add an optional RMSNorm (input_norm) before the fc to stabilize. The norm runs at train time (speculators) and inference (vLLM), so there is no train–serve mismatch.

  • Config: norm_before_fc: bool = False on Eagle3SpeculatorConfig (same style as norm_before_residual). When True, the draft model uses the pre-fc norm.
  • Core: Create/apply input_norm only when config.norm_before_fc; otherwise fc gets the raw concat as before.
  • Loading: "input_norm.weight" in _keys_to_ignore_on_load_missing so old checkpoints without the norm still load.

Tests

  • Training with norm_before_fc=True uses input_norm; with False (default) behavior matches pre-PR.

Related: Inference support in vLLM: vllm-project/vllm#36545

@shubhra shubhra force-pushed the gpt_oss_norm_fix branch from 582b6d8 to ed449e9 Compare March 9, 2026 19:51
@shubhra shubhra requested review from fynnsu and shanjiaz March 9, 2026 19:52
@github-actions

github-actions Bot commented Mar 9, 2026

Copy link
Copy Markdown

📦 Build Artifacts Available
The build artifacts (`.whl` and `.tar.gz`) have been successfully generated and are available for download: https://github.com/vllm-project/speculators/actions/runs/22872426973/artifacts/5837286616.
They will be retained for up to 30 days.
Commit: fd27b1b

@shubhra shubhra force-pushed the gpt_oss_norm_fix branch from 15a9fe7 to c69cbd7 Compare March 9, 2026 20:00
shubhra added 3 commits March 9, 2026 20:05
Signed-off-by: Shubhra Pandit <shubhra.pandit@gmail.com>
Signed-off-by: Shubhra Pandit <shubhra.pandit@gmail.com>
Signed-off-by: Shubhra Pandit <shubhra.pandit@gmail.com>

@shanjiaz shanjiaz left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding the fix!

@shanjiaz shanjiaz merged commit ac6db62 into main Mar 10, 2026
12 checks passed
@shanjiaz shanjiaz deleted the gpt_oss_norm_fix branch March 10, 2026 13:30
shubhra added a commit that referenced this pull request Mar 16, 2026
Follow-up to #337. Expose --norm-before-fc in train.py,
add norm_before_fc to TrainArgs in gen_and_train.py, and
set norm_before_fc=True in the gpt-oss example.

Made-with: Cursor
shubhra added a commit that referenced this pull request Mar 16, 2026
Follow-up to #337. Expose --norm-before-fc in train.py,
add norm_before_fc to TrainArgs in gen_and_train.py, and
set norm_before_fc=True in the gpt-oss example.

Made-with: Cursor
Signed-off-by: Shubhra Pandit <shubhra.pandit@gmail.com>
shubhra added a commit that referenced this pull request Mar 16, 2026
Follow-up to #337. Expose --norm-before-fc in train.py,
add norm_before_fc to TrainArgs in gen_and_train.py, and
set norm_before_fc=True in the gpt-oss example.

Made-with: Cursor
Signed-off-by: Shubhra Pandit <shubhra.pandit@gmail.com>
shanjiaz pushed a commit that referenced this pull request Mar 17, 2026
…#347)

Follow-up to
[#337](#337), which
added config + core support for `norm_before_fc` but did not expose it
in the training CLI or the combined pipeline. Without this, users
running `train.py` or `gen_and_train.py` had no way to enable the pre-FC
norm, so gpt-oss models could not be trained correctly via the standard
scripts.

## Changes

- **train.py:** Add `--norm-before-fc` flag so the training script can
pass `norm_before_fc=True` into the Eagle3 config.
- **gen_and_train.py:** Add `norm_before_fc` to `TrainArgs` so the
combined pipeline forwards the flag to `train.py`.
- **gpt_oss_20b_ultrachat_5k.py:** Set `norm_before_fc=True` in the
gpt-oss example so it trains with the stabilizing norm out of the box.

With these changes, gpt-oss models train correctly (`--norm-before-fc`),
and all other models continue to train as before (flag defaults to off).

## Tests

- `train.py --norm-before-fc` creates the draft model with `input_norm`;
omitting the flag matches pre-#337 behavior.
- gpt-oss example runs end-to-end with the norm enabled.

## Related

- Core + config:
[#337](#337)
- Inference support in vLLM:
[vllm-project/vllm#36545](vllm-project/vllm#36545)

Signed-off-by: Shubhra Pandit <shubhra.pandit@gmail.com>
YzTongNiar pushed a commit to YzTongNiar/speculators that referenced this pull request Apr 10, 2026
When training Eagle3 draft models for gpt-oss, we observed exploding
hidden states in the draft path. This PR adds an optional RMSNorm before
the fc (on the concatenated 3× aux hidden states) to stabilize training.
The behavior is gated by `norm_before_fc` so only gpt-oss (or models
that need it) use it; others are unchanged.

#### Description

The Eagle3 fusion path concatenates three aux hidden states and projects
via fc. `gpt-oss` exhibits exploding states in this path; we add an
optional RMSNorm (`input_norm`) before the fc to stabilize. The norm
runs at train time (speculators) and inference (vLLM), so there is no
train–serve mismatch.

- **Config:** `norm_before_fc: bool = False` on `Eagle3SpeculatorConfig`
(same style as `norm_before_residual`). When True, the draft model uses
the pre-fc norm.
- **Core:** Create/apply `input_norm` only when `config.norm_before_fc`;
otherwise fc gets the raw concat as before.
- **Loading:** `"input_norm.weight"` in
`_keys_to_ignore_on_load_missing` so old checkpoints without the norm
still load.

#### Tests
- Training with `norm_before_fc=True` uses `input_norm`; with `False`
(default) behavior matches pre-PR.

**Related:** Inference support in vLLM: vllm-project/vllm#36545

---------

Signed-off-by: Shubhra Pandit <shubhra.pandit@gmail.com>
YzTongNiar pushed a commit to YzTongNiar/speculators that referenced this pull request Apr 10, 2026
…vllm-project#347)

Follow-up to
[vllm-project#337](vllm-project#337), which
added config + core support for `norm_before_fc` but did not expose it
in the training CLI or the combined pipeline. Without this, users
running `train.py` or `gen_and_train.py` had no way to enable the pre-FC
norm, so gpt-oss models could not be trained correctly via the standard
scripts.

## Changes

- **train.py:** Add `--norm-before-fc` flag so the training script can
pass `norm_before_fc=True` into the Eagle3 config.
- **gen_and_train.py:** Add `norm_before_fc` to `TrainArgs` so the
combined pipeline forwards the flag to `train.py`.
- **gpt_oss_20b_ultrachat_5k.py:** Set `norm_before_fc=True` in the
gpt-oss example so it trains with the stabilizing norm out of the box.

With these changes, gpt-oss models train correctly (`--norm-before-fc`),
and all other models continue to train as before (flag defaults to off).

## Tests

- `train.py --norm-before-fc` creates the draft model with `input_norm`;
omitting the flag matches pre-vllm-project#337 behavior.
- gpt-oss example runs end-to-end with the norm enabled.

## Related

- Core + config:
[vllm-project#337](vllm-project#337)
- Inference support in vLLM:
[vllm-project/vllm#36545](vllm-project/vllm#36545)

Signed-off-by: Shubhra Pandit <shubhra.pandit@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants