diff --git a/tests/integration_tests.py b/tests/integration_tests.py index 59a73efbf..1e005706e 100755 --- a/tests/integration_tests.py +++ b/tests/integration_tests.py @@ -302,6 +302,7 @@ def build_test_list(): "--parallelism.data_parallel_shard_degree=4", "--activation_checkpoint.mode='full'", "--model.use_flex_attn", + "--model.attn_mask_type='block_causal'", ] ], "FSDP+FLEX_ATTN", diff --git a/torchtitan/config_manager.py b/torchtitan/config_manager.py index b2134ecc6..06452d261 100644 --- a/torchtitan/config_manager.py +++ b/torchtitan/config_manager.py @@ -196,7 +196,21 @@ def __init__(self): self.parser.add_argument( "--model.use_flex_attn", action="store_true", - help="Whether to use Flex Attention.", + help=""" + Whether to use Flex Attention. + Mixed usage of SDPA and FlexAttention is not upported yet. + """, + ) + self.parser.add_argument( + "--model.attn_mask_type", + type=str, + default="causal", + choices=["causal", "block_causal"], + help=""" + Specifies the type of bias/mask used for attention. If SDPA is used, + only the causal mask is supported by default. If FlexAttention is used, + both causal and block_causal masks are supported. + """, ) self.parser.add_argument( "--model.tokenizer_path", diff --git a/torchtitan/models/attention.py b/torchtitan/models/attention.py new file mode 100644 index 000000000..94e447cde --- /dev/null +++ b/torchtitan/models/attention.py @@ -0,0 +1,124 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +# +# Copyright (c) Meta Platforms, Inc. All Rights Reserved. + +from typing import Callable, ClassVar, Optional + +import torch +import torch.nn.functional as F +from torch.nn.attention.flex_attention import ( + BlockMask, + create_block_mask, + flex_attention, +) + + +class FlexAttention(torch.nn.Module): + # We registered flex_attention related attributes as class variables as we + # need to amortize the cost of compilation. + flex_attn: ClassVar[Callable] = torch.compile( + flex_attention, mode="max-autotune-no-cudagraphs" + ) + compiled_create_block_mask: ClassVar[Callable] = torch.compile(create_block_mask) + used_attn_mask_types: ClassVar[set[str]] = set() + # Attention mask type to the created BlockMask. + # This allows us to keep track the created block masks for each + # new batch. We will use this to update the block mask when a + # new batch is created. This also allows user to create different + # block masks for different layers. + block_masks: ClassVar[dict[str, BlockMask]] = {} + + # Instance variables. + attn_mask_type: str + + def __init__(self, attn_mask_type: str) -> None: + super().__init__() + if attn_mask_type not in ["causal", "block_causal"]: + raise ValueError(f"Unrecognized attn_mask_type {attn_mask_type}.") + self.attn_mask_type = attn_mask_type + FlexAttention.used_attn_mask_types.add(attn_mask_type) + + def forward( + self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor + ) -> torch.Tensor: + block_mask = FlexAttention.block_masks[self.attn_mask_type] + return FlexAttention.flex_attn(q, k, v, block_mask=block_mask) + + @staticmethod + def _get_causal_mask_fn() -> Callable: + def causal_mask(b, h, q_idx, kv_idx): + return q_idx >= kv_idx + + return causal_mask + + @staticmethod + def _get_block_causal_mask_fn(batch: torch.Tensor, eos_id: int) -> Callable: + # batch is [b, s, h, d] shape + mask = batch == eos_id + mask[:, -1] = True + acc_mask = torch.cumsum(torch.where(mask, 1, 0), dim=1) + seq_idx = torch.zeros_like(acc_mask, dtype=torch.int32) + seq_idx[:, 1:] = acc_mask[:, :-1] + + def block_causal_mask(b, h, q_idx, kv_idx): + return (seq_idx[b, q_idx] == seq_idx[b, kv_idx]) & (q_idx >= kv_idx) + + return block_causal_mask + + @staticmethod + @torch.no_grad() + def init_attention_mask(batch: torch.Tensor, eos_id: Optional[int] = None) -> None: + # batch is [b, s, h, d] shape + for attn_mask_type in FlexAttention.used_attn_mask_types: + match attn_mask_type: + case "causal": + if FlexAttention.block_masks.get(attn_mask_type, None) is not None: + continue + # We don't care about batch dimension -- + # all samples have the same lower triangle mask. + batch_dimension = 1 + mask_fn = FlexAttention._get_causal_mask_fn() + case "block_causal": + if eos_id is None: + raise RuntimeError( + "eos_id must be provided for block_causal mask." + ) + batch_dimension = batch.shape[0] + mask_fn = FlexAttention._get_block_causal_mask_fn(batch, eos_id) + case _: + raise RuntimeError(f"Shouldn't reach here. {attn_mask_type}") + + seq_len = batch.shape[1] + block_mask = FlexAttention.compiled_create_block_mask( + mask_fn, batch_dimension, None, seq_len, seq_len + ) + FlexAttention.block_masks[attn_mask_type] = block_mask + + +class ScaledDotProductAttention(torch.nn.Module): + def __init__(self, attn_mask_type: str) -> None: + super().__init__() + if attn_mask_type != "causal": + raise ValueError( + "TorchTitan with SDPA currently only supports causal mask." + ) + + def forward( + self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor + ) -> torch.Tensor: + return F.scaled_dot_product_attention(q, k, v, is_causal=True) + + +def build_attention(use_flex_attn: bool, attn_mask_type: str): + if use_flex_attn: + return FlexAttention(attn_mask_type) + else: + return ScaledDotProductAttention(attn_mask_type) + + +def init_attention_mask(batch: torch.Tensor, eos_id: Optional[int] = None) -> None: + FlexAttention.init_attention_mask(batch, eos_id) diff --git a/torchtitan/models/llama/model.py b/torchtitan/models/llama/model.py index 4ab6da41e..d96827a68 100644 --- a/torchtitan/models/llama/model.py +++ b/torchtitan/models/llama/model.py @@ -8,20 +8,15 @@ from dataclasses import dataclass -from typing import Callable, ClassVar, Optional, Tuple +from typing import Optional import torch import torch.nn.functional as F from torch import nn -from torch.nn.attention.flex_attention import ( - BlockMask, - create_block_mask, - flex_attention, -) from torchtitan.components.tokenizer import Tokenizer from torchtitan.config_manager import JobConfig - +from torchtitan.models.attention import build_attention, init_attention_mask from torchtitan.models.norms import build_norm from torchtitan.protocols.train_spec import BaseModelArgs, ModelProtocol @@ -45,12 +40,15 @@ class TransformerModelArgs(BaseModelArgs): norm_type: str = "rmsnorm" use_flex_attn: bool = False + attn_mask_type: str = "causal" + eos_id: int = 0 def update_from_config(self, job_config: JobConfig, tokenizer: Tokenizer) -> None: self.norm_type = job_config.model.norm_type self.vocab_size = tokenizer.n_words self.max_seq_len = job_config.training.seq_len self.use_flex_attn = job_config.model.use_flex_attn + self.attn_mask_type = job_config.model.attn_mask_type def get_num_flop_per_token(self, num_params: int, seq_len: int) -> int: l, h, q, t = ( @@ -123,7 +121,7 @@ def apply_rotary_emb( xq: torch.Tensor, xk: torch.Tensor, freqs_cis: torch.Tensor, -) -> Tuple[torch.Tensor, torch.Tensor]: +) -> tuple[torch.Tensor, torch.Tensor]: """ Apply rotary embeddings to input tensors using the given frequency tensor. @@ -138,7 +136,7 @@ def apply_rotary_emb( freqs_cis (torch.Tensor): Precomputed frequency tensor for complex exponentials. Returns: - Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings. + tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings. """ xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) @@ -179,13 +177,6 @@ class Attention(nn.Module): """ - # We registered flex_attention related attributes as class variables as we - # need to amortize the cost of compilation. Enabling per-instance flex_attention - # is not supported. - block_mask: ClassVar[Optional[BlockMask]] = None - use_flex_attn: ClassVar[bool] = False - flex_attn: ClassVar[Optional[Callable]] = None - def __init__(self, model_args: TransformerModelArgs): super().__init__() self.n_heads = model_args.n_heads @@ -205,7 +196,7 @@ def __init__(self, model_args: TransformerModelArgs): self.wo = nn.Linear( model_args.n_heads * self.head_dim, model_args.dim, bias=False ) - self.use_flex_attn = model_args.use_flex_attn + self.sdpa = build_attention(model_args.use_flex_attn, model_args.attn_mask_type) def init_weights(self, init_std: float): for linear in (self.wq, self.wk, self.wv): @@ -249,35 +240,14 @@ def forward( xk = keys.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) xv = values.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) - # we use casual mask for training - if self.use_flex_attn: - # assert False, (type(xq), type(xk), type(xv)) - self._init_flex_attn(seqlen=seqlen) - output = self.flex_attn(xq, xk, xv, block_mask=self.block_mask) - else: - output = F.scaled_dot_product_attention(xq, xk, xv, is_causal=True) + output = self.sdpa(xq, xk, xv) + output = output.transpose( 1, 2 ).contiguous() # (bs, seqlen, n_local_heads, head_dim) output = output.view(bs, seqlen, -1) return self.wo(output) - @torch.no_grad() - def _init_flex_attn(self, seqlen: int) -> None: - if self.block_mask is not None: - return - - def causal_mask(b, h, q_idx, kv_idx): - return q_idx >= kv_idx - - compiled_create_block_mask = torch.compile(create_block_mask) - self.block_mask = compiled_create_block_mask( - causal_mask, None, None, seqlen, seqlen - ) - self.flex_attn = torch.compile( - flex_attention, mode="max-autotune-no-cudagraphs" - ) - class FeedForward(nn.Module): """ @@ -420,6 +390,7 @@ def __init__(self, model_args: TransformerModelArgs): self.model_args = model_args self.vocab_size = model_args.vocab_size self.n_layers = model_args.n_layers + self.eos_id = model_args.eos_id self.tok_embeddings = nn.Embedding(model_args.vocab_size, model_args.dim) @@ -500,6 +471,11 @@ def forward(self, tokens: torch.Tensor): torch.Tensor: Output logits after applying the Transformer model. """ + # TODO: We will to change forward() signature to allow tokens to + # be always passed in. + if self.model_args.use_flex_attn: + init_attention_mask(tokens, eos_id=self.eos_id) + # passthrough for nonexistent layers, allows easy configuration of pipeline parallel stages h = self.tok_embeddings(tokens) if self.tok_embeddings else tokens diff --git a/torchtitan/models/llama/parallelize_llama.py b/torchtitan/models/llama/parallelize_llama.py index 7649d0dfb..ed2e6f0c7 100644 --- a/torchtitan/models/llama/parallelize_llama.py +++ b/torchtitan/models/llama/parallelize_llama.py @@ -72,15 +72,20 @@ def parallelize_llama( enable_async_tp=job_config.parallelism.enable_async_tensor_parallel, ) - if job_config.activation_checkpoint.mode != "none": - if ( - job_config.activation_checkpoint.mode == "selective" - and job_config.model.use_flex_attn - ): + if job_config.model.use_flex_attn: + if job_config.activation_checkpoint.mode == "selective": raise ValueError( "FlexAttention is not compatible with selective AC yet. " "See https://github.com/pytorch/pytorch/issues/147879" ) + + if parallel_dims.cp_enabled: + raise ValueError( + "FlexAttention is not compatible with CP yet. " + "We are still working on this." + ) + + if job_config.activation_checkpoint.mode != "none": apply_ac(model, job_config.activation_checkpoint) # turn on per-TransformerBlock compile after AC wrapping and before FSDP