fix(rt-detr): use float32 on MPS in build_2d_sinusoidal_position_embedding#46513
fix(rt-detr): use float32 on MPS in build_2d_sinusoidal_position_embedding#46513Shylin26 wants to merge 1 commit into
Conversation
…soidal_position_embedding for MPS compatibility MPS (Apple Silicon) does not support float64. The hardcoded dtype=torch.float64 in build_2d_sinusoidal_position_embedding causes a TypeError when running RT-DETR and RT-DETRv2 on MPS devices. Replace with float32 which is supported on all backends (CPU, CUDA, MPS). Fixes huggingface#46159 Signed-off-by: Shylin26 <parishachauhan26@gmail.com>
|
[For maintainers] Suggested jobs to run (before merge) run-slow: rt_detr, rt_detr_v2 |
|
Submitted PR #46513 to fix this — uses |
|
Thanks for picking this up — we hit this exact MPS crash with RT-DETRv2 and have been patching it locally, so glad to see a fix in flight. The device-conditional approach looks right — keeps float64 precision on CUDA/CPU where it's supported, and gracefully falls back to float32 on MPS where it's not. Two suggestions:
On the CUDA perf concern raised on #46174: this is a one-time positional embedding computation (a few hundred positions), not a hot path. The device-conditional approach avoids the question entirely since CUDA keeps float64 on-device as before — no CPU↔device transfer, no perf regression. Happy to help test on MPS if needed — we use RT-DETRv2 via docling for layout detection on Apple Silicon daily. |
|
Duplicate of #46174 |
What does this PR do?
Fixes #46159
build_2d_sinusoidal_position_embeddingin both RT-DETR and RT-DETRv2hardcodes
dtype=torch.float64for internal sinusoidal arithmetic.MPS (Apple Silicon) does not support float64, causing a TypeError when
running inference on Apple Silicon devices.
This PR uses
float64on CUDA/CPU for precision, and falls back tofloat32on MPS devices only.Code Agent Policy
Before submitting
Who can review?
@yonigozlan @molbap