Hotfix: Set float32 as default dtype for testing tiny models#4770
Conversation
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
|
After investigating, here are a few elements that can help understand what's happening here: Transformers dtype default behavior changedWith from transformers import AutoModelForCausalLM # v4.57.2
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen3-0.6B")
model.dtype # torch.float32
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen3-0.6B", dtype="auto")
model.dtype # torch.bfloat16Starting with transformers v5, from transformers import AutoModelForCausalLM # v5.0.0
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen3-0.6B")
model.dtype # torch.bfloat16
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen3-0.6B", dtype="auto")
model.dtype # torch.bfloat16This explains why some tests that previously ran in float32 now run in float16/bfloat16, and as you pointed out, this can lead to situations where some parameters are not updated. Longer-term: TRL should provide training-oriented defaultsMore broadly, I think TRL should aim to provide safe and stable defaults for training. In particular, we should distinguish between:
From a training stability perspective, the most robust default is usually:
The second point is already aligned with TRL defaults (e.g. enabling mixed precision in configs): Lines 124 to 131 in 2337cc9 Line 277 in 2337cc9 However, it looks like the load dtype often follows the model dtype, which can implicitly put users/tests into fp16/bf16 without intent: Line 1145 in 2337cc9 ProposalA longer-term solution could be:
The key idea is: we should not end up training in the model dtype unless it’s intentional, especially in tests that are not meant to validate this specific (and likely unstable) case. |
|
Thanks for your review, @qgallouedec: I totally agree. In this PR I was preliminary testing that setting As an alignment with your long-term proposal, I agree we should set float32 as the default precision at loading time. |
…uggingface#4770)" This reverts commit ca16441.
Set float32 as default dtype for testing tiny models, after the merge in
transformersof this PR:Fix #4748.