diff --git a/torchtune/models/llama4/_position_embeddings.py b/torchtune/models/llama4/_position_embeddings.py index 3adfc32738..f031dce2ae 100644 --- a/torchtune/models/llama4/_position_embeddings.py +++ b/torchtune/models/llama4/_position_embeddings.py @@ -179,10 +179,14 @@ def forward( # tensor has shape [b, s, n_h, h_d // 2, 2] x_out = torch.stack( [ - xshaped[..., 0] * rope_cache[..., 0] - - xshaped[..., 1] * rope_cache[..., 1], - xshaped[..., 1] * rope_cache[..., 0] - + xshaped[..., 0] * rope_cache[..., 1], + torch.sub( + xshaped[..., 0] * rope_cache[..., 0], + xshaped[..., 1] * rope_cache[..., 1], + ), + torch.add( + xshaped[..., 1] * rope_cache[..., 0], + xshaped[..., 0] * rope_cache[..., 1], + ), ], -1, )