-
Notifications
You must be signed in to change notification settings - Fork 122
Open
Labels
Description
Open this issue to track the enhancement of adding on-policy knowledge distillation in NeMo RL, a distillation variation used by Qwen3, which follows the following procedure:
"On-policy Distillation: In this phase, the student model generates on-policy sequences for fine-tuning. Specifically, prompts are sampled, and the student model produces responses in either /think or /no think mode. The student model is then fine-tuned by aligning its logits with those of a teacher model (Qwen3-32B or Qwen3-235B-A22B) to minimize the KL divergence."
Training and program flow:


Core methodology (high level)
- Student generates on-policy rollouts for sampled prompts.
- Teacher(s) provide token-level targets (logits or probabilities) on the student’s generated trajectories.
- Optimize student with a distillation loss on those trajectories.
Implementation design - changes on core components
- New Distillation Algorithm Class (OnPolicyDistillation)
- Similar structure to GRPO algorithm in nemo_rl/algorithms/grpo.py
- Manages the overall training loop and orchestrates student-teacher interaction
- Leverages existing GRPO infrastructure for distributed training
- Distillation Loss Function (DistillationLossFn)
- Implements KL divergence loss between student and teacher logits
- Similar to ClippedPGLossFn in nemo_rl/algorithms/loss_functions.py
- Supports token-level masking to exclude prompt tokens from loss calculation
- Dual Policy Management
- Student Policy: Trainable, inherits from existing Policy class
- Teacher Policy: Fixed/frozen, separate instance or shared infrastructure
- Both can use the same underlying worker architecture (DTensor/Megatron)
https://nvidia.slack.com/archives/C0271E234TB/p1751960108820759
terrykong and yuki-97