Skip to content

Commit cfd14a0

Browse files
authored
PERF: vectorise for loop using torch-native functions (#137)
* PERF: vectorise for loop using torch-native functions Found using `torchfix` * Correct function naming from benchmark * Re-add comment lost in benchmark
1 parent 234b680 commit cfd14a0

File tree

1 file changed

+12
-15
lines changed

1 file changed

+12
-15
lines changed

models/llama3/reference_impl/model.py

Lines changed: 12 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ def forward(self, x):
4242
return output * self.weight
4343

4444

45-
def apply_scaling(freqs: torch.Tensor):
45+
def apply_scaling(freqs: torch.Tensor) -> torch.Tensor:
4646
# Values obtained from grid search
4747
scale_factor = 8
4848
low_freq_factor = 1
@@ -51,20 +51,17 @@ def apply_scaling(freqs: torch.Tensor):
5151

5252
low_freq_wavelen = old_context_len / low_freq_factor
5353
high_freq_wavelen = old_context_len / high_freq_factor
54-
new_freqs = []
55-
for freq in freqs:
56-
wavelen = 2 * math.pi / freq
57-
if wavelen < high_freq_wavelen:
58-
new_freqs.append(freq)
59-
elif wavelen > low_freq_wavelen:
60-
new_freqs.append(freq / scale_factor)
61-
else:
62-
assert low_freq_wavelen != high_freq_wavelen
63-
smooth = (old_context_len / wavelen - low_freq_factor) / (
64-
high_freq_factor - low_freq_factor
65-
)
66-
new_freqs.append((1 - smooth) * freq / scale_factor + smooth * freq)
67-
return torch.tensor(new_freqs, dtype=freqs.dtype, device=freqs.device)
54+
55+
wavelen = 2 * torch.pi / freqs
56+
new_freqs = torch.where(wavelen > low_freq_wavelen, freqs / scale_factor, freqs)
57+
smooth = (old_context_len / wavelen - low_freq_factor) / (
58+
high_freq_factor - low_freq_factor
59+
)
60+
return torch.where(
61+
(wavelen >= high_freq_wavelen) & (wavelen <= low_freq_wavelen),
62+
(1 - smooth) * new_freqs / scale_factor + smooth * new_freqs,
63+
new_freqs,
64+
)
6865

6966

7067
def precompute_freqs_cis(

0 commit comments

Comments
 (0)