Skip to content

Conversation

@sbrugman
Copy link
Contributor

@sbrugman sbrugman commented Sep 9, 2024

Found using torchfix.

The apply_scaling function used a for-loop with three conditions. The same result can be vectorised using the torch.where function twice (if/else).

Benchmark / test (extremely basic):

import math
from time import time

import torch


def apply_scaling(freqs: torch.Tensor) -> torch.Tensor:
    # Values obtained from grid search
    scale_factor = 8
    low_freq_factor = 1
    high_freq_factor = 4
    old_context_len = 8192  # original llama3 length

    low_freq_wavelen = old_context_len / low_freq_factor
    high_freq_wavelen = old_context_len / high_freq_factor
    new_freqs = []
    for freq in freqs:
        wavelen = 2 * math.pi / freq
        if wavelen < high_freq_wavelen:
            new_freqs.append(freq)
        elif wavelen > low_freq_wavelen:
            new_freqs.append(freq / scale_factor)
        else:
            assert low_freq_wavelen != high_freq_wavelen
            smooth = (old_context_len / wavelen - low_freq_factor) / (
                high_freq_factor - low_freq_factor
            )
            new_freqs.append((1 - smooth) * freq / scale_factor + smooth * freq)
    return torch.tensor(new_freqs, dtype=freqs.dtype, device=freqs.device)


def apply_scaling_fast(freqs: torch.Tensor) -> torch.Tensor:
    scale_factor = 8
    low_freq_factor = 1
    high_freq_factor = 4
    old_context_len = 8192  # original llama3 length

    low_freq_wavelen = old_context_len / low_freq_factor
    high_freq_wavelen = old_context_len / high_freq_factor

    wavelen = 2 * torch.pi / freqs
    new_freqs = torch.where(wavelen > low_freq_wavelen, freqs / scale_factor, freqs)
    smooth = (old_context_len / wavelen - low_freq_factor) / (
        high_freq_factor - low_freq_factor
    )
    return torch.where(
        (wavelen >= high_freq_wavelen) & (wavelen <= low_freq_wavelen),
        (1 - smooth) * new_freqs / scale_factor + smooth * new_freqs,
        new_freqs,
    )


theta = 10000.0
dim = 100
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
start = time()
res1 = apply_scaling(freqs)
end = time()
print(res1)
print(end - start)

theta = 10000.0
dim = 100
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
start = time()
res2 = apply_scaling_fast(freqs)
end = time()
print(res2)
print(end - start)

print((res1 == res2).all())

Output:

tensor([1.0000e+00, 8.3176e-01, 6.9183e-01, 5.7544e-01, 4.7863e-01, 3.9811e-01,
        3.3113e-01, 2.7542e-01, 2.2909e-01, 1.9055e-01, 1.5849e-01, 1.3183e-01,
        1.0965e-01, 9.1201e-02, 7.5858e-02, 6.3096e-02, 5.2481e-02, 4.3652e-02,
        3.6308e-02, 3.0200e-02, 2.5119e-02, 2.0893e-02, 1.7378e-02, 1.4454e-02,
        1.2023e-02, 1.0000e-02, 8.3176e-03, 6.9183e-03, 5.7544e-03, 4.7863e-03,
        3.9811e-03, 3.3113e-03, 2.4256e-03, 1.6139e-03, 1.0631e-03, 6.9106e-04,
        4.4113e-04, 2.7444e-04, 1.6430e-04, 9.4822e-05, 7.8870e-05, 6.5601e-05,
        5.4564e-05, 4.5385e-05, 3.7749e-05, 3.1399e-05, 2.6116e-05, 2.1723e-05,
        1.8068e-05, 1.5028e-05])
0.0006110668182373047
tensor([1.0000e+00, 8.3176e-01, 6.9183e-01, 5.7544e-01, 4.7863e-01, 3.9811e-01,
        3.3113e-01, 2.7542e-01, 2.2909e-01, 1.9055e-01, 1.5849e-01, 1.3183e-01,
        1.0965e-01, 9.1201e-02, 7.5858e-02, 6.3096e-02, 5.2481e-02, 4.3652e-02,
        3.6308e-02, 3.0200e-02, 2.5119e-02, 2.0893e-02, 1.7378e-02, 1.4454e-02,
        1.2023e-02, 1.0000e-02, 8.3176e-03, 6.9183e-03, 5.7544e-03, 4.7863e-03,
        3.9811e-03, 3.3113e-03, 2.4256e-03, 1.6139e-03, 1.0631e-03, 6.9106e-04,
        4.4113e-04, 2.7444e-04, 1.6430e-04, 9.4822e-05, 7.8870e-05, 6.5601e-05,
        5.4564e-05, 4.5385e-05, 3.7749e-05, 3.1399e-05, 2.6116e-05, 2.1723e-05,
        1.8068e-05, 1.5028e-05])
7.009506225585938e-05
tensor(True)

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Meta Open Source bot. label Sep 9, 2024
Copy link
Contributor

@ashwinb ashwinb left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks

@ashwinb ashwinb merged commit cfd14a0 into meta-llama:main Jan 28, 2025
@sbrugman sbrugman deleted the patch-1 branch January 29, 2025 18:00
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants