Skip to content

Conversation

xu-song
Copy link

@xu-song xu-song commented Feb 22, 2025

Overview

Weight absorb is an import feature for memory and speed, but not implemented in current releases.

Image

Core code

    wo = wo.transpose(0,1).view(self.n_heads, self.v_head_dim, self.dim)
    wo_absorb = torch.einsum("hdc,hdi->hci", wkv_b[:, -self.v_head_dim:], wo)  # absorb w_uk into wo
    x = torch.einsum("bshc,hci->bshi", x, self.wo_absorb)
    x = torch.sum(x, dim=2)

Simple demo

import torch
from torch import nn

class AbsorbDemo:

    def __init__(self, bsz=1, q_len=1, kv_len=4, dim=7168, kv_lora_rank=512, n_heads=128, v_head_dim=128):
        
        self.n_heads = n_heads
        self.v_head_dim = v_head_dim
        self.dim = dim
        
        self.scores = torch.rand(bsz, q_len, n_heads, kv_len)
        self.kv_cache = torch.rand(bsz, kv_len, kv_lora_rank)
        self.w_uv = torch.rand(n_heads, v_head_dim, kv_lora_rank)
        self.wo = nn.Linear(self.n_heads * self.v_head_dim, self.dim, bias=False) 
        self.wo_absorb = None
        
    def run(self, absorb=False):
        x = torch.einsum("bsht,btc->bshc", self.scores, self.kv_cache)
        if absorb:
            if self.wo_absorb is None:
                wo = self.wo.weight
                wo = wo.transpose(0,1).view(self.n_heads, self.v_head_dim, self.dim)
                self.wo_absorb = torch.einsum("hdc,hdi->hci", self.w_uv, wo)
            x = torch.einsum("bshc,hci->bshi", x, self.wo_absorb)
            x = torch.sum(x, dim=2)
        else:
            x = torch.einsum("bshc,hdc->bshd", x, self.w_uv)   # it cost large memeory
            x = self.wo(x.flatten(2))
        return x


demo = AbsorbDemo()
tensor1 = demo.run(absorb=False)
tensor2 = demo.run(absorb=True)
print("w/o absorb:", tensor1.data)
print("w   absorb:", tensor2.data)
print(torch.allclose(tensor1.data, tensor2.data, atol=1e-03))

output:

w/o absorb: tensor([[[ -25.6272,   30.9115,   12.3565,  ...,  -30.2628, -120.6479, 118.7872]]])
w   absorb: tensor([[[ -25.6272,   30.9115,   12.3564,  ...,  -30.2630, -120.6479, 118.7871]]])
True

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant