Skip to content

Conversation

galopyz
Copy link

@galopyz galopyz commented Jul 22, 2025

This is a PR to fix a bug in #746.

I found out that the output of RoPE did not match torchtune's RotaryPositionalEmbeddings or llama's RoPE.

After making changes to the RoPE implementation, the output from both Hugging Face and the notebook matched.

Here are the changes. Instead of dividing head into first half and second half by indexing up to half point, I indexed by even and odd. This is because even indexes are multiplied by cos and odd with sin. Then, I apply the following formula:

(a + bi) * (cos(θ) + i*sin(θ)) = (a*cos(θ) - b*sin(θ)) + i*(a*sin(θ) + b*cos(θ))

to apply RoPE transformation.

I tried to follow your style of keeping cos and sin instead of converting them into complex numbers.

I also found out that instruction-fintuned model weights were not getting loaded, so I added it.

Please let me know if some parts are unclear or needs changes. Thank you.

Copy link

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

@casinca
Copy link
Contributor

casinca commented Jul 22, 2025

I wonder if the origin of the problem isn't coming more broadly from the difference in weights loading and the RoPE variant used, explained in this issue: huggingface/transformers#25199

Sebastian is loading the weights directly vs HF permuting Q and K when converting and then use the 2 halves variant.

It could explain why switching to your interleaved variant is the right one to use here iiuc.

@rasbt
Copy link
Owner

rasbt commented Jul 22, 2025

Thanks for the PR. I agree that there could be a bug. What what's weird is that the unit test comparing the RoPE calculations to 2 reference implementations (LitGPT and HF transformers) gave the same results. Maybe there was an edge case.

It looks like that there is no some issue with the tests after the fix:

E AssertionError: The values for attribute 'shape' do not match: torch.Size([4096, 8]) != torch.Size([4096, 16]).

Could you update the PR?

@galopyz
Copy link
Author

galopyz commented Jul 22, 2025

Sorry, I was not aware of the tests.

After going through the tests, I found out that LitGPT implementation matches with Hugging Face implementation. However, they do not match with torchtune or llama2 implementation. Here is a google colab notebook with comparisons.

I am not sure why they do not match. But using torchtune's RotaryPositionalEmbeddings gave me the same output from Hugging Face.

Would it be okay to change the test to compare the implementation with torchtune?

@rasbt
Copy link
Owner

rasbt commented Jul 22, 2025

Thanks for looking into that! Honestly, I really appreciate your time here fixing the RoPE issues. I remember spending a lot of time debugging things back then...

I haven't had a time to carefully double-check the reference implementations this morning (and the last time I checked was like a year ago when I wrote the original Llama code here), I I may be missing something or don't understand correctly, yet.

But that being said, regarding

I found out that LitGPT implementation matches with Hugging Face implementation. However, they do not match with torchtune or llama2 implementation.

I think the reason is the LitGPT implementation is a general-purpose implementation (developed before torchtune) that works with all kinds of LLMs, not just Llama. In fact, torchtune copied many aspects from LitGPT but they may have implemented the RoPE in their own way. (With copied, I mean that LitGPT was around first, and torchtune was developed 1-2 years later trying to mimic the LitGPT API; you can see it when searching for Lit-GPT in the torchtune PRs).

So maybe torchtune has a correct implementation here whereas the LitGPT project hasn't.

But using torchtune's RotaryPositionalEmbeddings gave me the same output from Hugging Face.

So when I understand correctly,

  1. torchtune's and Hugging Face's RoPE match?
  2. torchtune's and LitGPT's RoPE don't match?

That part I find a bit confusing, because how can the torchtune RoPE match the HuggingFace one but not the LitGPT one even though LitGPT and Hugging Face both match the RoPE in this repository in the tests.

Would it be okay to change the test to compare the implementation with torchtune?

That would be okay with me, but I think we always need a 2nd reference here like Hugging Face to ensure consensus.

@galopyz
Copy link
Author

galopyz commented Jul 22, 2025

Sorry about the confusion. I changed the access to the colab notebook so you can read it.

Here is the result from the notebook in a summary.

  • torchtune and llama2 match RoPE outputs.
  • LitGPT and Hugging Face match RoPE outputs.
  • However, torchtune and Hugging Face RoPE do NOT match.
  • torchtune and LitGPT RoPE do NOT match.
  • Your original implementation matches the output from LitGPT and Hugging Face RoPE.

Regarding

But using torchtune's RotaryPositionalEmbeddings gave me the same output from Hugging Face.

What I meant to say about this is that by changing the RoPE implementation to match with torchtune, the output from the model (genenrated text) match with Hugging Face's generated text. I should have been more explicit about what kind of output I was referring to.

@rasbt
Copy link
Owner

rasbt commented Jul 22, 2025

Thanks for clarifying, I think I understand now.
That's an interesting conundrum... I wonder if the current RoPE implementation here is correct (because it matches Hugging Face and LitGPT) but it's not correctly applied in grouped query attention. (Maybe it could also be related to precision)

@galopyz
Copy link
Author

galopyz commented Jul 22, 2025

As @casinca menteiond, HuggingFace uses permute for the query and key weights for the purpose of sliced rotary. However, we are using the llama weights (meta-llama/Llama-2-7b), not the hugging face transformer version of weights of meta-llama/Llama-2-7b-hf.

Llama2 7B does not use group query attention, so we can rule out that case.

I have tried loading weights from meta-llama/Llama-2-7b-hf instead of meta-llama/Llama-2-7b for the original code, but the model does not generate coherent text.

It could be possible that LitGPT uses hugging face transformer version of the model weights.

@rasbt
Copy link
Owner

rasbt commented Jul 22, 2025

Ohhh, I see now. Yes LitGPT uses the hugging face weights.

But in this case, could we not just permute the queries and keys instead of swapping the RoPE? Similar to what Hugging Face did:

n_heads = LLAMA2_CONFIG_7B["n_heads"]
dim = LLAMA2_CONFIG_7B["emb_dim"]

def permute(w, n_heads=n_heads, dim1=dim, dim2=dim):
    return w.view(n_heads, dim1 // n_heads // 2, 2, dim2).transpose(1, 2).reshape(dim1, dim2)



def assign(left, right):
    if left.shape != right.shape:
        raise ValueError(f"Shape mismatch. Left: {left.shape}, Right: {right.shape}")

    if isinstance(right, torch.Tensor):
        return torch.nn.Parameter(right.clone().detach())
    else:
        return torch.nn.Parameter(torch.tensor(right))


def load_weights_into_llama(model, param_config, params):
    model.tok_emb.weight = assign(model.tok_emb.weight, params["tok_embeddings.weight"])

    for l in range(param_config["n_layers"]):

        # Load attention weights
        model.trf_blocks[l].att.W_query.weight = assign(
            model.trf_blocks[l].att.W_query.weight,
            permute(params[f"layers.{l}.attention.wq.weight"])     # NEW
        )
        model.trf_blocks[l].att.W_key.weight = assign(
            model.trf_blocks[l].att.W_key.weight,
            permute(params[f"layers.{l}.attention.wk.weight"])     # NEW
        )
        model.trf_blocks[l].att.W_value.weight = assign(
            model.trf_blocks[l].att.W_value.weight,
            params[f"layers.{l}.attention.wv.weight"]
        )
        model.trf_blocks[l].att.out_proj.weight = assign(
            model.trf_blocks[l].att.out_proj.weight,
            params[f"layers.{l}.attention.wo.weight"]
        )
        model.trf_blocks[l].norm1.weight = assign(
            model.trf_blocks[l].norm1.weight,
            params[f"layers.{l}.attention_norm.weight"]
        )

        # Load FeedForward weights
        model.trf_blocks[l].ff.fc1.weight = assign(
            model.trf_blocks[l].ff.fc1.weight,
            params[f"layers.{l}.feed_forward.w1.weight"]
        )
        # For some reason w2 and w3 are provided in the wrong order in the weights file
        model.trf_blocks[l].ff.fc2.weight = assign(
            model.trf_blocks[l].ff.fc2.weight,
            params[f"layers.{l}.feed_forward.w3.weight"]
        )
        model.trf_blocks[l].ff.fc3.weight = assign(
            model.trf_blocks[l].ff.fc3.weight,
            params[f"layers.{l}.feed_forward.w2.weight"]
        )
        model.trf_blocks[l].norm2.weight = assign(
            model.trf_blocks[l].norm2.weight,
            params[f"layers.{l}.ffn_norm.weight"]
        )

    # Load output layer weights
    model.final_norm.weight = assign(model.final_norm.weight, params["norm.weight"])
    model.out_head.weight = assign(model.out_head.weight, params["output.weight"])


load_weights_into_llama(model, LLAMA2_CONFIG_7B, weights)
model.to(device);

@rasbt
Copy link
Owner

rasbt commented Jul 23, 2025

I just gave it a quick try and it seems to work. I added it as a separate PR in #750 so you can check out the file diffs via ReviewNB.

It looks like we are now getting almost identical results:

Base model:

  1. The original code (before the PR):

Output text:
Every effort has been made to ensure that the information contained in this website is accurate and up to date and correct at the time of publication

  1. Your new RoPE implementation

Output text:
Every effort has been made to ensure the accuracy of the information contained in this website. However, the information contained in this website is provided

  1. Before-PR RoPE with permute fix:

Output text:
Every effort has been made to ensure the accuracy of the information contained in this website. However, the information contained in this website is not

Note that the last word is different in 2 & 3.

Chat model

  1. The original code (before the PR):

Llamas and alpacas are herbivores, which means they eat grasses, leaves, grass

  1. Your new RoPE implementation:

Llamas are herbivores, which means they eat plants for their food. They feed on a variety

  1. Before-PR PR RoPE with permute fix:

Llamas are herbivores, which means they eat plants for their food. They feed on a variety

Still, there is 1 word different in the base model.

@galopyz
Copy link
Author

galopyz commented Jul 23, 2025

That worked nicely!

The one word difference might be from hugging face tokenizer adding 1 at the beginning of the sequence. Adding 1 manually at the beginning and generating text resulted the exact same text as the Hugging Face transformer. Here are the outputs with and without bos token vs. Hugging Face model output.

With 1 as a bos token:

Output text:
 Every effort has been made to ensure that the information contained in this website is accurate and up to date. However, the information is provided

WIthout bos token:

Output text:
 Every effort has been made to ensure the accuracy of the information contained in this website. However, the information contained in this website is not intended to be a substitute for professional advice.
The information contained in this

HF output:

Every effort has been made to ensure that the information contained in this website is accurate and up to date. However, the information is provided without any warranty, express or implied, as to the accuracy

@rasbt
Copy link
Owner

rasbt commented Jul 23, 2025

The one word difference might be from hugging face tokenizer adding 1 at the beginning of the sequence. Adding 1 manually at the beginning and generating text resulted the exact same text as the Hugging Face transformer.

Awesome. Glad that it all works correctly now. Thanks so much for the valuable discussion and contribution!

(I will merge the other PR then, it's a bit easier this way than changing the RoPE code so that the RoPE code can be reused for Llama 3 etc.)

@galopyz
Copy link
Author

galopyz commented Jul 23, 2025

That's great. I am glad we fixed this very subtle bug. I appreciate your feedback and discussions. I learned a lot.

@galopyz galopyz closed this Jul 23, 2025
@galopyz galopyz deleted the fix_rope branch July 23, 2025 14:44
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants