-
Notifications
You must be signed in to change notification settings - Fork 30.5k
Closed
Description
transformers
implement LLaMA model's Rotary Positional Embedding (RoPE) as follows:
transformers/src/transformers/models/llama/modeling_llama.py
Lines 173 to 188 in e42587f
def rotate_half(x): | |
"""Rotates half the hidden dims of the input.""" | |
x1 = x[..., : x.shape[-1] // 2] | |
x2 = x[..., x.shape[-1] // 2 :] | |
return torch.cat((-x2, x1), dim=-1) | |
def apply_rotary_pos_emb(q, k, cos, sin, position_ids): | |
# The first two dimensions of cos and sin are always 1, so we can `squeeze` them. | |
cos = cos.squeeze(1).squeeze(0) # [seq_len, dim] | |
sin = sin.squeeze(1).squeeze(0) # [seq_len, dim] | |
cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] | |
sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] | |
q_embed = (q * cos) + (rotate_half(q) * sin) | |
k_embed = (k * cos) + (rotate_half(k) * sin) | |
return q_embed, k_embed |
This is GPT-NeoX style RoPE. But in Meta's official model implementation, the model adopts GPT-J style RoPE, which processes query and key vectors in an interleaved way instead of split into two half (as in rotate_half
method).
Meta's official repo implements RoPE as (full code link):
def apply_rotary_emb(
xq: torch.Tensor,
xk: torch.Tensor,
freqs_cis: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
return xq_out.type_as(xq), xk_out.type_as(xk)
I'm confused with this difference, since transformers.LlamaModel
can directly load weights converted from the officially released checkpoint, won't this lead to inconsistency in inference results? Is this difference expected?
constroy, zhuzilin, yezhengmao1, leezythu, caixd-220529 and 16 more
Metadata
Metadata
Assignees
Labels
No labels