Skip to content

Supports block causal mask #1001

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Apr 3, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions tests/integration_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,7 @@ def build_test_list():
"--parallelism.data_parallel_shard_degree=4",
"--activation_checkpoint.mode='full'",
"--model.use_flex_attn",
"--model.attn_mask_type='block_causal'",
]
],
"FSDP+FLEX_ATTN",
Expand Down
16 changes: 15 additions & 1 deletion torchtitan/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,21 @@ def __init__(self):
self.parser.add_argument(
"--model.use_flex_attn",
action="store_true",
help="Whether to use Flex Attention.",
help="""
Whether to use Flex Attention.
Mixed usage of SDPA and FlexAttention is not upported yet.
""",
)
self.parser.add_argument(
"--model.attn_mask_type",
type=str,
default="causal",
choices=["causal", "block_causal"],
help="""
Specifies the type of bias/mask used for attention. If SDPA is used,
only the causal mask is supported by default. If FlexAttention is used,
both causal and block_causal masks are supported.
""",
)
self.parser.add_argument(
"--model.tokenizer_path",
Expand Down
124 changes: 124 additions & 0 deletions torchtitan/models/attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
#
# Copyright (c) Meta Platforms, Inc. All Rights Reserved.

from typing import Callable, ClassVar, Optional

import torch
import torch.nn.functional as F
from torch.nn.attention.flex_attention import (
BlockMask,
create_block_mask,
flex_attention,
)


class FlexAttention(torch.nn.Module):
# We registered flex_attention related attributes as class variables as we
# need to amortize the cost of compilation.
flex_attn: ClassVar[Callable] = torch.compile(
flex_attention, mode="max-autotune-no-cudagraphs"
)
compiled_create_block_mask: ClassVar[Callable] = torch.compile(create_block_mask)
used_attn_mask_types: ClassVar[set[str]] = set()
# Attention mask type to the created BlockMask.
# This allows us to keep track the created block masks for each
# new batch. We will use this to update the block mask when a
# new batch is created. This also allows user to create different
# block masks for different layers.
block_masks: ClassVar[dict[str, BlockMask]] = {}

# Instance variables.
attn_mask_type: str

def __init__(self, attn_mask_type: str) -> None:
super().__init__()
if attn_mask_type not in ["causal", "block_causal"]:
raise ValueError(f"Unrecognized attn_mask_type {attn_mask_type}.")
self.attn_mask_type = attn_mask_type
FlexAttention.used_attn_mask_types.add(attn_mask_type)

def forward(
self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor
) -> torch.Tensor:
block_mask = FlexAttention.block_masks[self.attn_mask_type]
return FlexAttention.flex_attn(q, k, v, block_mask=block_mask)

@staticmethod
def _get_causal_mask_fn() -> Callable:
def causal_mask(b, h, q_idx, kv_idx):
return q_idx >= kv_idx

return causal_mask

@staticmethod
def _get_block_causal_mask_fn(batch: torch.Tensor, eos_id: int) -> Callable:
# batch is [b, s, h, d] shape
mask = batch == eos_id
mask[:, -1] = True
acc_mask = torch.cumsum(torch.where(mask, 1, 0), dim=1)
seq_idx = torch.zeros_like(acc_mask, dtype=torch.int32)
seq_idx[:, 1:] = acc_mask[:, :-1]

def block_causal_mask(b, h, q_idx, kv_idx):
return (seq_idx[b, q_idx] == seq_idx[b, kv_idx]) & (q_idx >= kv_idx)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NOTE: this is a 2D mask, enforcing block causal attention within each row of a batch. There is no cross-row attention.


return block_causal_mask

@staticmethod
@torch.no_grad()
def init_attention_mask(batch: torch.Tensor, eos_id: Optional[int] = None) -> None:
# batch is [b, s, h, d] shape
for attn_mask_type in FlexAttention.used_attn_mask_types:
match attn_mask_type:
case "causal":
if FlexAttention.block_masks.get(attn_mask_type, None) is not None:
continue
# We don't care about batch dimension --
# all samples have the same lower triangle mask.
batch_dimension = 1
mask_fn = FlexAttention._get_causal_mask_fn()
case "block_causal":
if eos_id is None:
raise RuntimeError(
"eos_id must be provided for block_causal mask."
)
batch_dimension = batch.shape[0]
mask_fn = FlexAttention._get_block_causal_mask_fn(batch, eos_id)
case _:
raise RuntimeError(f"Shouldn't reach here. {attn_mask_type}")

seq_len = batch.shape[1]
block_mask = FlexAttention.compiled_create_block_mask(
mask_fn, batch_dimension, None, seq_len, seq_len
)
FlexAttention.block_masks[attn_mask_type] = block_mask


class ScaledDotProductAttention(torch.nn.Module):
def __init__(self, attn_mask_type: str) -> None:
super().__init__()
if attn_mask_type != "causal":
raise ValueError(
"TorchTitan with SDPA currently only supports causal mask."
)

def forward(
self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor
) -> torch.Tensor:
return F.scaled_dot_product_attention(q, k, v, is_causal=True)


def build_attention(use_flex_attn: bool, attn_mask_type: str):
if use_flex_attn:
return FlexAttention(attn_mask_type)
else:
return ScaledDotProductAttention(attn_mask_type)


def init_attention_mask(batch: torch.Tensor, eos_id: Optional[int] = None) -> None:
FlexAttention.init_attention_mask(batch, eos_id)
56 changes: 16 additions & 40 deletions torchtitan/models/llama/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,20 +8,15 @@


from dataclasses import dataclass
from typing import Callable, ClassVar, Optional, Tuple
from typing import Optional

import torch
import torch.nn.functional as F
from torch import nn
from torch.nn.attention.flex_attention import (
BlockMask,
create_block_mask,
flex_attention,
)

from torchtitan.components.tokenizer import Tokenizer
from torchtitan.config_manager import JobConfig

from torchtitan.models.attention import build_attention, init_attention_mask
from torchtitan.models.norms import build_norm
from torchtitan.protocols.train_spec import BaseModelArgs, ModelProtocol

Expand All @@ -45,12 +40,15 @@ class TransformerModelArgs(BaseModelArgs):
norm_type: str = "rmsnorm"

use_flex_attn: bool = False
attn_mask_type: str = "causal"
eos_id: int = 0
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this necessary? I think the model should not have to know about properties of the tokenizer like this

Copy link
Contributor Author

@fegin fegin Mar 27, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this an issue? It is configurable and is given by tokenizer, not defined by the model. Otherwise, how will the attention module separate the tokens?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ebsmothers
I feel it would be one way or another -- we either put eos_id here, or each iteration would take a mask_mod function as model input. It is not clear to me which is cleaner -- I feel current one is not bad. WDYT?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't have a strong opinion of which approach is better. Both are viable approaches. I just want to emphasize that the model has to know eos_id, whether it is saved as an instance variable or is passed as an argument through forward.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@fegin maybe I am misunderstanding something here then. It seems to me that most of the logic for constructing the block mask should be handled by the dataloader. You are already iterating over a bunch of samples that you want to pack, right? Why not just use this opportunity to construct a list of sequence lengths in the pack? Then this can be used to construct a BlockMask for flex without the model needing to know anything about eos_id. @tianyu-l iiuc this is kinda similar to what you're proposing, but passing BlockMask instead of the mask_mod function

Copy link
Contributor

@tianyu-l tianyu-l Mar 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@fegin I also didn't get this

I just want to emphasize that the model has to know eos_id, whether it is saved as an instance variable or is passed as an argument through forward.

An alternative seems to pass the block_mask_fn (say from data loader, or an util function) to model forward. Technically, in this case, the model doesn't know eos_id.

Copy link
Contributor Author

@fegin fegin Mar 31, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ebsmothers, @tianyu-l

maybe I am misunderstanding something here then. It seems to me that most of the logic for constructing the block mask should be handled by the dataloader.

This is debatable. If you are building the block causal mask inside dataloader, you are polluting the dataloader. There are options to use either Flex or Flash and there can be more than just simple block causal. Why does dataloader need to know whether the model is using Flex or Flash or what kind of masking does the model use? It is just that which component you choose to know the information of the other component.

We discussed this internally about whether to couple building mask with dataloader, there was an opinion internally to decouple dataloader from the attention implementation as researchers can do different attention masks without changing dataloader but just knowing EOS_ID is enough. I would prefer to keep that decision even though that was not the decision specifically for TorchTitan but the discussion was how to do CP + Flex for PTD.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just one more quick point, this is our generic Document mask re-writer: https://github.com/pytorch-labs/attention-gym/blob/001b36d625aceae8a03f59241113e4797122db1d/attn_gym/masks/document_mask.py#L33. It takes in a two things, another mask_mod and a tensor specifying the boundaries. I find this de-coupiling pretty attractive where you decouple the generation of the extra metadata and which mask mod you want to perform, in this concrete case it is "causal". In this case the choice is "causal" but someone has to own setting up your inpt tensor in the packed format and the generation of the extra metadata likely should be colocated IMO


def update_from_config(self, job_config: JobConfig, tokenizer: Tokenizer) -> None:
self.norm_type = job_config.model.norm_type
self.vocab_size = tokenizer.n_words
self.max_seq_len = job_config.training.seq_len
self.use_flex_attn = job_config.model.use_flex_attn
self.attn_mask_type = job_config.model.attn_mask_type

def get_num_flop_per_token(self, num_params: int, seq_len: int) -> int:
l, h, q, t = (
Expand Down Expand Up @@ -123,7 +121,7 @@ def apply_rotary_emb(
xq: torch.Tensor,
xk: torch.Tensor,
freqs_cis: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Apply rotary embeddings to input tensors using the given frequency tensor.

Expand All @@ -138,7 +136,7 @@ def apply_rotary_emb(
freqs_cis (torch.Tensor): Precomputed frequency tensor for complex exponentials.

Returns:
Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
"""
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
Expand Down Expand Up @@ -179,13 +177,6 @@ class Attention(nn.Module):

"""

# We registered flex_attention related attributes as class variables as we
# need to amortize the cost of compilation. Enabling per-instance flex_attention
# is not supported.
block_mask: ClassVar[Optional[BlockMask]] = None
use_flex_attn: ClassVar[bool] = False
flex_attn: ClassVar[Optional[Callable]] = None

def __init__(self, model_args: TransformerModelArgs):
super().__init__()
self.n_heads = model_args.n_heads
Expand All @@ -205,7 +196,7 @@ def __init__(self, model_args: TransformerModelArgs):
self.wo = nn.Linear(
model_args.n_heads * self.head_dim, model_args.dim, bias=False
)
self.use_flex_attn = model_args.use_flex_attn
self.sdpa = build_attention(model_args.use_flex_attn, model_args.attn_mask_type)

def init_weights(self, init_std: float):
for linear in (self.wq, self.wk, self.wv):
Expand Down Expand Up @@ -249,35 +240,14 @@ def forward(
xk = keys.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)
xv = values.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)

# we use casual mask for training
if self.use_flex_attn:
# assert False, (type(xq), type(xk), type(xv))
self._init_flex_attn(seqlen=seqlen)
output = self.flex_attn(xq, xk, xv, block_mask=self.block_mask)
else:
output = F.scaled_dot_product_attention(xq, xk, xv, is_causal=True)
output = self.sdpa(xq, xk, xv)

output = output.transpose(
1, 2
).contiguous() # (bs, seqlen, n_local_heads, head_dim)
output = output.view(bs, seqlen, -1)
return self.wo(output)

@torch.no_grad()
def _init_flex_attn(self, seqlen: int) -> None:
if self.block_mask is not None:
return

def causal_mask(b, h, q_idx, kv_idx):
return q_idx >= kv_idx

compiled_create_block_mask = torch.compile(create_block_mask)
self.block_mask = compiled_create_block_mask(
causal_mask, None, None, seqlen, seqlen
)
self.flex_attn = torch.compile(
flex_attention, mode="max-autotune-no-cudagraphs"
)


class FeedForward(nn.Module):
"""
Expand Down Expand Up @@ -420,6 +390,7 @@ def __init__(self, model_args: TransformerModelArgs):
self.model_args = model_args
self.vocab_size = model_args.vocab_size
self.n_layers = model_args.n_layers
self.eos_id = model_args.eos_id

self.tok_embeddings = nn.Embedding(model_args.vocab_size, model_args.dim)

Expand Down Expand Up @@ -500,6 +471,11 @@ def forward(self, tokens: torch.Tensor):
torch.Tensor: Output logits after applying the Transformer model.

"""
# TODO: We will to change forward() signature to allow tokens to
# be always passed in.
if self.model_args.use_flex_attn:
init_attention_mask(tokens, eos_id=self.eos_id)

# passthrough for nonexistent layers, allows easy configuration of pipeline parallel stages
h = self.tok_embeddings(tokens) if self.tok_embeddings else tokens

Expand Down
15 changes: 10 additions & 5 deletions torchtitan/models/llama/parallelize_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,15 +72,20 @@ def parallelize_llama(
enable_async_tp=job_config.parallelism.enable_async_tensor_parallel,
)

if job_config.activation_checkpoint.mode != "none":
if (
job_config.activation_checkpoint.mode == "selective"
and job_config.model.use_flex_attn
):
if job_config.model.use_flex_attn:
if job_config.activation_checkpoint.mode == "selective":
raise ValueError(
"FlexAttention is not compatible with selective AC yet. "
"See https://github.com/pytorch/pytorch/issues/147879"
)

if parallel_dims.cp_enabled:
raise ValueError(
"FlexAttention is not compatible with CP yet. "
"We are still working on this."
)

if job_config.activation_checkpoint.mode != "none":
apply_ac(model, job_config.activation_checkpoint)

# turn on per-TransformerBlock compile after AC wrapping and before FSDP
Expand Down