-
Notifications
You must be signed in to change notification settings - Fork 461
Supports block causal mask #1001
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
465b650
8188bb2
b324186
bfd4472
5f3f809
4d63939
1c03ddd
d67e196
47753cb
5cd8f6c
6605914
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,124 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
# | ||
# Copyright (c) Meta Platforms, Inc. All Rights Reserved. | ||
|
||
from typing import Callable, ClassVar, Optional | ||
|
||
import torch | ||
import torch.nn.functional as F | ||
from torch.nn.attention.flex_attention import ( | ||
BlockMask, | ||
create_block_mask, | ||
flex_attention, | ||
) | ||
|
||
|
||
class FlexAttention(torch.nn.Module): | ||
# We registered flex_attention related attributes as class variables as we | ||
# need to amortize the cost of compilation. | ||
flex_attn: ClassVar[Callable] = torch.compile( | ||
flex_attention, mode="max-autotune-no-cudagraphs" | ||
) | ||
compiled_create_block_mask: ClassVar[Callable] = torch.compile(create_block_mask) | ||
used_attn_mask_types: ClassVar[set[str]] = set() | ||
# Attention mask type to the created BlockMask. | ||
# This allows us to keep track the created block masks for each | ||
# new batch. We will use this to update the block mask when a | ||
# new batch is created. This also allows user to create different | ||
# block masks for different layers. | ||
block_masks: ClassVar[dict[str, BlockMask]] = {} | ||
|
||
# Instance variables. | ||
attn_mask_type: str | ||
|
||
def __init__(self, attn_mask_type: str) -> None: | ||
super().__init__() | ||
if attn_mask_type not in ["causal", "block_causal"]: | ||
raise ValueError(f"Unrecognized attn_mask_type {attn_mask_type}.") | ||
self.attn_mask_type = attn_mask_type | ||
FlexAttention.used_attn_mask_types.add(attn_mask_type) | ||
|
||
def forward( | ||
self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor | ||
) -> torch.Tensor: | ||
block_mask = FlexAttention.block_masks[self.attn_mask_type] | ||
return FlexAttention.flex_attn(q, k, v, block_mask=block_mask) | ||
|
||
@staticmethod | ||
def _get_causal_mask_fn() -> Callable: | ||
def causal_mask(b, h, q_idx, kv_idx): | ||
return q_idx >= kv_idx | ||
|
||
return causal_mask | ||
|
||
@staticmethod | ||
def _get_block_causal_mask_fn(batch: torch.Tensor, eos_id: int) -> Callable: | ||
# batch is [b, s, h, d] shape | ||
mask = batch == eos_id | ||
mask[:, -1] = True | ||
acc_mask = torch.cumsum(torch.where(mask, 1, 0), dim=1) | ||
seq_idx = torch.zeros_like(acc_mask, dtype=torch.int32) | ||
seq_idx[:, 1:] = acc_mask[:, :-1] | ||
|
||
def block_causal_mask(b, h, q_idx, kv_idx): | ||
return (seq_idx[b, q_idx] == seq_idx[b, kv_idx]) & (q_idx >= kv_idx) | ||
|
||
return block_causal_mask | ||
|
||
@staticmethod | ||
@torch.no_grad() | ||
def init_attention_mask(batch: torch.Tensor, eos_id: Optional[int] = None) -> None: | ||
# batch is [b, s, h, d] shape | ||
for attn_mask_type in FlexAttention.used_attn_mask_types: | ||
match attn_mask_type: | ||
case "causal": | ||
if FlexAttention.block_masks.get(attn_mask_type, None) is not None: | ||
continue | ||
# We don't care about batch dimension -- | ||
# all samples have the same lower triangle mask. | ||
batch_dimension = 1 | ||
mask_fn = FlexAttention._get_causal_mask_fn() | ||
case "block_causal": | ||
if eos_id is None: | ||
raise RuntimeError( | ||
"eos_id must be provided for block_causal mask." | ||
) | ||
batch_dimension = batch.shape[0] | ||
mask_fn = FlexAttention._get_block_causal_mask_fn(batch, eos_id) | ||
case _: | ||
raise RuntimeError(f"Shouldn't reach here. {attn_mask_type}") | ||
|
||
seq_len = batch.shape[1] | ||
block_mask = FlexAttention.compiled_create_block_mask( | ||
mask_fn, batch_dimension, None, seq_len, seq_len | ||
) | ||
FlexAttention.block_masks[attn_mask_type] = block_mask | ||
|
||
|
||
class ScaledDotProductAttention(torch.nn.Module): | ||
def __init__(self, attn_mask_type: str) -> None: | ||
super().__init__() | ||
if attn_mask_type != "causal": | ||
raise ValueError( | ||
"TorchTitan with SDPA currently only supports causal mask." | ||
) | ||
|
||
def forward( | ||
self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor | ||
) -> torch.Tensor: | ||
return F.scaled_dot_product_attention(q, k, v, is_causal=True) | ||
|
||
|
||
def build_attention(use_flex_attn: bool, attn_mask_type: str): | ||
if use_flex_attn: | ||
return FlexAttention(attn_mask_type) | ||
else: | ||
return ScaledDotProductAttention(attn_mask_type) | ||
|
||
|
||
def init_attention_mask(batch: torch.Tensor, eos_id: Optional[int] = None) -> None: | ||
FlexAttention.init_attention_mask(batch, eos_id) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -8,20 +8,15 @@ | |
|
||
|
||
from dataclasses import dataclass | ||
from typing import Callable, ClassVar, Optional, Tuple | ||
from typing import Optional | ||
|
||
import torch | ||
import torch.nn.functional as F | ||
from torch import nn | ||
from torch.nn.attention.flex_attention import ( | ||
BlockMask, | ||
create_block_mask, | ||
flex_attention, | ||
) | ||
|
||
from torchtitan.components.tokenizer import Tokenizer | ||
from torchtitan.config_manager import JobConfig | ||
|
||
from torchtitan.models.attention import build_attention, init_attention_mask | ||
from torchtitan.models.norms import build_norm | ||
from torchtitan.protocols.train_spec import BaseModelArgs, ModelProtocol | ||
|
||
|
@@ -45,12 +40,15 @@ class TransformerModelArgs(BaseModelArgs): | |
norm_type: str = "rmsnorm" | ||
|
||
use_flex_attn: bool = False | ||
attn_mask_type: str = "causal" | ||
eos_id: int = 0 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this necessary? I think the model should not have to know about properties of the tokenizer like this There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why is this an issue? It is configurable and is given by tokenizer, not defined by the model. Otherwise, how will the attention module separate the tokens? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @ebsmothers There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't have a strong opinion of which approach is better. Both are viable approaches. I just want to emphasize that the model has to know There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @fegin maybe I am misunderstanding something here then. It seems to me that most of the logic for constructing the block mask should be handled by the dataloader. You are already iterating over a bunch of samples that you want to pack, right? Why not just use this opportunity to construct a list of sequence lengths in the pack? Then this can be used to construct a BlockMask for flex without the model needing to know anything about eos_id. @tianyu-l iiuc this is kinda similar to what you're proposing, but passing There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @fegin I also didn't get this
An alternative seems to pass the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
This is debatable. If you are building the block causal mask inside dataloader, you are polluting the dataloader. There are options to use either Flex or Flash and there can be more than just simple block causal. Why does dataloader need to know whether the model is using Flex or Flash or what kind of masking does the model use? It is just that which component you choose to know the information of the other component. We discussed this internally about whether to couple building mask with dataloader, there was an opinion internally to decouple dataloader from the attention implementation as researchers can do different attention masks without changing dataloader but just knowing EOS_ID is enough. I would prefer to keep that decision even though that was not the decision specifically for TorchTitan but the discussion was how to do CP + Flex for PTD. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Just one more quick point, this is our generic Document mask re-writer: https://github.com/pytorch-labs/attention-gym/blob/001b36d625aceae8a03f59241113e4797122db1d/attn_gym/masks/document_mask.py#L33. It takes in a two things, another mask_mod and a tensor specifying the boundaries. I find this de-coupiling pretty attractive where you decouple the generation of the extra metadata and which mask mod you want to perform, in this concrete case it is "causal". In this case the choice is "causal" but someone has to own setting up your inpt tensor in the packed format and the generation of the extra metadata likely should be colocated IMO |
||
|
||
def update_from_config(self, job_config: JobConfig, tokenizer: Tokenizer) -> None: | ||
self.norm_type = job_config.model.norm_type | ||
self.vocab_size = tokenizer.n_words | ||
self.max_seq_len = job_config.training.seq_len | ||
self.use_flex_attn = job_config.model.use_flex_attn | ||
self.attn_mask_type = job_config.model.attn_mask_type | ||
|
||
def get_num_flop_per_token(self, num_params: int, seq_len: int) -> int: | ||
l, h, q, t = ( | ||
|
@@ -123,7 +121,7 @@ def apply_rotary_emb( | |
xq: torch.Tensor, | ||
xk: torch.Tensor, | ||
freqs_cis: torch.Tensor, | ||
) -> Tuple[torch.Tensor, torch.Tensor]: | ||
) -> tuple[torch.Tensor, torch.Tensor]: | ||
""" | ||
Apply rotary embeddings to input tensors using the given frequency tensor. | ||
|
||
|
@@ -138,7 +136,7 @@ def apply_rotary_emb( | |
freqs_cis (torch.Tensor): Precomputed frequency tensor for complex exponentials. | ||
|
||
Returns: | ||
Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings. | ||
tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings. | ||
""" | ||
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) | ||
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) | ||
|
@@ -179,13 +177,6 @@ class Attention(nn.Module): | |
|
||
""" | ||
|
||
# We registered flex_attention related attributes as class variables as we | ||
# need to amortize the cost of compilation. Enabling per-instance flex_attention | ||
# is not supported. | ||
block_mask: ClassVar[Optional[BlockMask]] = None | ||
use_flex_attn: ClassVar[bool] = False | ||
flex_attn: ClassVar[Optional[Callable]] = None | ||
|
||
def __init__(self, model_args: TransformerModelArgs): | ||
super().__init__() | ||
self.n_heads = model_args.n_heads | ||
|
@@ -205,7 +196,7 @@ def __init__(self, model_args: TransformerModelArgs): | |
self.wo = nn.Linear( | ||
model_args.n_heads * self.head_dim, model_args.dim, bias=False | ||
) | ||
self.use_flex_attn = model_args.use_flex_attn | ||
self.sdpa = build_attention(model_args.use_flex_attn, model_args.attn_mask_type) | ||
|
||
def init_weights(self, init_std: float): | ||
for linear in (self.wq, self.wk, self.wv): | ||
|
@@ -249,35 +240,14 @@ def forward( | |
xk = keys.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) | ||
xv = values.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) | ||
|
||
# we use casual mask for training | ||
if self.use_flex_attn: | ||
# assert False, (type(xq), type(xk), type(xv)) | ||
self._init_flex_attn(seqlen=seqlen) | ||
output = self.flex_attn(xq, xk, xv, block_mask=self.block_mask) | ||
else: | ||
output = F.scaled_dot_product_attention(xq, xk, xv, is_causal=True) | ||
output = self.sdpa(xq, xk, xv) | ||
|
||
output = output.transpose( | ||
1, 2 | ||
).contiguous() # (bs, seqlen, n_local_heads, head_dim) | ||
output = output.view(bs, seqlen, -1) | ||
return self.wo(output) | ||
|
||
@torch.no_grad() | ||
def _init_flex_attn(self, seqlen: int) -> None: | ||
if self.block_mask is not None: | ||
return | ||
|
||
def causal_mask(b, h, q_idx, kv_idx): | ||
return q_idx >= kv_idx | ||
|
||
compiled_create_block_mask = torch.compile(create_block_mask) | ||
self.block_mask = compiled_create_block_mask( | ||
causal_mask, None, None, seqlen, seqlen | ||
) | ||
self.flex_attn = torch.compile( | ||
flex_attention, mode="max-autotune-no-cudagraphs" | ||
) | ||
|
||
|
||
class FeedForward(nn.Module): | ||
""" | ||
|
@@ -420,6 +390,7 @@ def __init__(self, model_args: TransformerModelArgs): | |
self.model_args = model_args | ||
self.vocab_size = model_args.vocab_size | ||
self.n_layers = model_args.n_layers | ||
self.eos_id = model_args.eos_id | ||
|
||
self.tok_embeddings = nn.Embedding(model_args.vocab_size, model_args.dim) | ||
|
||
|
@@ -500,6 +471,11 @@ def forward(self, tokens: torch.Tensor): | |
torch.Tensor: Output logits after applying the Transformer model. | ||
|
||
""" | ||
# TODO: We will to change forward() signature to allow tokens to | ||
# be always passed in. | ||
if self.model_args.use_flex_attn: | ||
init_attention_mask(tokens, eos_id=self.eos_id) | ||
|
||
# passthrough for nonexistent layers, allows easy configuration of pipeline parallel stages | ||
h = self.tok_embeddings(tokens) if self.tok_embeddings else tokens | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
NOTE: this is a 2D mask, enforcing block causal attention within each row of a batch. There is no cross-row attention.