From 4bb03ed1a42da702ce7204eefee3d68a649b23de Mon Sep 17 00:00:00 2001 From: IvanKobzarev Date: Thu, 24 Jul 2025 11:39:42 -0700 Subject: [PATCH] [compile] workaround for compile error FakeTensor no op for builtin.- [ghstack-poisoned] --- torchtune/models/llama4/_position_embeddings.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/torchtune/models/llama4/_position_embeddings.py b/torchtune/models/llama4/_position_embeddings.py index 3adfc32738..f031dce2ae 100644 --- a/torchtune/models/llama4/_position_embeddings.py +++ b/torchtune/models/llama4/_position_embeddings.py @@ -179,10 +179,14 @@ def forward( # tensor has shape [b, s, n_h, h_d // 2, 2] x_out = torch.stack( [ - xshaped[..., 0] * rope_cache[..., 0] - - xshaped[..., 1] * rope_cache[..., 1], - xshaped[..., 1] * rope_cache[..., 0] - + xshaped[..., 0] * rope_cache[..., 1], + torch.sub( + xshaped[..., 0] * rope_cache[..., 0], + xshaped[..., 1] * rope_cache[..., 1], + ), + torch.add( + xshaped[..., 1] * rope_cache[..., 0], + xshaped[..., 0] * rope_cache[..., 1], + ), ], -1, )