Skip to content

Supports block causal mask #1001

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Apr 3, 2025
Merged

Supports block causal mask #1001

merged 11 commits into from
Apr 3, 2025

Conversation

fegin
Copy link
Contributor

@fegin fegin commented Mar 21, 2025

This PR adds support of block causal/document mask. We currently only add this feature for FlexAttention.

Screenshot 2025-03-26 at 9 27 11 AM Screenshot 2025-03-26 at 9 31 59 AM

@fegin fegin requested review from XilunWu and tianyu-l March 21, 2025 16:59
@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Meta Open Source bot. label Mar 21, 2025
@fegin fegin changed the title Move SDPA logic to a separate attention module Supports block causal mask Mar 26, 2025
@fegin fegin requested a review from drisspg March 26, 2025 16:32
@fegin fegin force-pushed the attention_module branch from 1b11d89 to bd4d192 Compare March 26, 2025 17:20
Copy link
Contributor

@tianyu-l tianyu-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks nice in general. I left some comments on UX.

@@ -198,6 +198,17 @@ def __init__(self):
action="store_true",
help="Whether to use Flex Attention.",
)
self.parser.add_argument(
"--model.attn_bias_type",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's consistently use mask or bias but not both.
mask sounds better to me. bias doesn't sound accurate.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shall we change the CI on FlexAttention to use block_causal?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmmm The implementation choices (flex or flash) is fine to be used as a knob, but I don't think this is sth we should add as a "config". Attention variants are very specific model details, which should be defined as part of the Model (via ModelArgs), instead of making a knob on how to do attention. .

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't attn_bias_type similar to norm_type, which is also exposed for configuration? layer_norm and rmsnorm are not mathematically the same. But the main issue is being that block_causal is not supported by SDPA (without using NJT). So it is quite confusing if we purely treat the attn_bias_type as the ModelArgs. Then we won't be able to use SDPA for llama3.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@wanchaol
I think the tricky part is -- some model configs are determined from training configs, specifically the items here https://github.com/pytorch/torchtitan/blob/main/torchtitan/models/llama/model.py#L49
norm_type and vocab_size are good examples -- they are model config, but need to come from job config or tokenizer.
I feel this attn_mask_type being causal or block_causal is similar. WDYT?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@wanchaol Are suggesting that we should allow user to customize ModelArgs through toml? Right now, users can only customize ModelArgs by creating a new model configuration.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One way to do it is to remove all model-specific configs from the base JobConfig, and uses https://github.com/pytorch/torchtitan/blob/main/docs/extension.md#extending-jobconfig inside every model folder, to define model specific args in toml, e.g. --model.attn_mask_type in LLMs.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tianyu-l That's reasonable, my question is mainly about how are we going to let users customize their model arguments. Creating different model configurations can cause combinatorial configurations.

Copy link
Collaborator

@wanchaol wanchaol Apr 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was actually thinking that user should just customize the model args by creating a new model configuration instead of being able to configure via TOML. Is there any reason we want to add a new TOML config instead of adding a new model config, I guess mainly because it's simpler for us to experiment on llama models?

In general I feel such a fine grained control on model details can't be easily achieved by sth like TOML (maybe sth like fiddle can).

@tianyu-l suggestion might also work. I think given that we already have norm_type here, adding another option is ok, but I do want to let us think more about how to handle model arg customizations

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@wanchaol
Not all people fork torchtitan. A lot of people would use torchtitan as submodule (to use latest features!). Their central request is how to easily config things while not touching the core part. E.g. #790

Currently for model config, the lightest way to add a new one is to create a new TrainSpec with a new config dictionary https://github.com/pytorch/torchtitan/blob/main/torchtitan/protocols/train_spec.py#L69

It sounds slightly too heavy for me if it's just switching things like sequence length (for RoPE), and causal vs. block-causal masking, but I'm OK with that.

I personally think with per-model customized argparser and protocols like this https://github.com/pytorch/torchtitan/blob/main/torchtitan/models/llama/model.py#L49, it doesn't look too bad.

return causal_mask

@classmethod
def _get_block_causal_mask_fn(cls, batch: torch.Tensor, eos_id: int) -> Callable:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

eos_id never changes throughout the training. Shall we make it a variable for this class?

@@ -45,12 +40,15 @@ class TransformerModelArgs(BaseModelArgs):
norm_type: str = "rmsnorm"

use_flex_attn: bool = False
attn_bias_type: str = "causal"
eos_id: int = 0
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this necessary? I think the model should not have to know about properties of the tokenizer like this

Copy link
Contributor Author

@fegin fegin Mar 27, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this an issue? It is configurable and is given by tokenizer, not defined by the model. Otherwise, how will the attention module separate the tokens?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ebsmothers
I feel it would be one way or another -- we either put eos_id here, or each iteration would take a mask_mod function as model input. It is not clear to me which is cleaner -- I feel current one is not bad. WDYT?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't have a strong opinion of which approach is better. Both are viable approaches. I just want to emphasize that the model has to know eos_id, whether it is saved as an instance variable or is passed as an argument through forward.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@fegin maybe I am misunderstanding something here then. It seems to me that most of the logic for constructing the block mask should be handled by the dataloader. You are already iterating over a bunch of samples that you want to pack, right? Why not just use this opportunity to construct a list of sequence lengths in the pack? Then this can be used to construct a BlockMask for flex without the model needing to know anything about eos_id. @tianyu-l iiuc this is kinda similar to what you're proposing, but passing BlockMask instead of the mask_mod function

Copy link
Contributor

@tianyu-l tianyu-l Mar 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@fegin I also didn't get this

I just want to emphasize that the model has to know eos_id, whether it is saved as an instance variable or is passed as an argument through forward.

An alternative seems to pass the block_mask_fn (say from data loader, or an util function) to model forward. Technically, in this case, the model doesn't know eos_id.

Copy link
Contributor Author

@fegin fegin Mar 31, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ebsmothers, @tianyu-l

maybe I am misunderstanding something here then. It seems to me that most of the logic for constructing the block mask should be handled by the dataloader.

This is debatable. If you are building the block causal mask inside dataloader, you are polluting the dataloader. There are options to use either Flex or Flash and there can be more than just simple block causal. Why does dataloader need to know whether the model is using Flex or Flash or what kind of masking does the model use? It is just that which component you choose to know the information of the other component.

We discussed this internally about whether to couple building mask with dataloader, there was an opinion internally to decouple dataloader from the attention implementation as researchers can do different attention masks without changing dataloader but just knowing EOS_ID is enough. I would prefer to keep that decision even though that was not the decision specifically for TorchTitan but the discussion was how to do CP + Flex for PTD.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just one more quick point, this is our generic Document mask re-writer: https://github.com/pytorch-labs/attention-gym/blob/001b36d625aceae8a03f59241113e4797122db1d/attn_gym/masks/document_mask.py#L33. It takes in a two things, another mask_mod and a tensor specifying the boundaries. I find this de-coupiling pretty attractive where you decouple the generation of the extra metadata and which mask mod you want to perform, in this concrete case it is "causal". In this case the choice is "causal" but someone has to own setting up your inpt tensor in the packed format and the generation of the extra metadata likely should be colocated IMO

@fegin
Copy link
Contributor Author

fegin commented Apr 2, 2025

There are discussions regarding 1) where to initialize the mask, 2) how to let users customize the attention type. There is no conclusion yet. So I will keep the current implementations but will have some discussions after we have more models using block causal and other advanced attention type.

@fegin fegin force-pushed the attention_module branch from bd4d192 to 4d63939 Compare April 2, 2025 20:26
@fegin
Copy link
Contributor Author

fegin commented Apr 2, 2025

New changes:

  1. Address the comments above. Still keep the eos_id and attn_mask_type to ModelArgs.
  2. Allow use different block masks for different layers. id(batch) and attn_mask_type are used to identify the existing block mask.

@fegin fegin requested a review from tianyu-l April 2, 2025 20:27
Copy link
Contributor

@tianyu-l tianyu-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As discussed offline, let's fix PP microbatch related logic.

@fegin fegin requested a review from tianyu-l April 3, 2025 00:14
Copy link
Contributor

@tianyu-l tianyu-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Had one question about block causal mask fn.

Copy link
Contributor

@tianyu-l tianyu-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. I feel that after we modify the mask, the benchmark also need to be redone. But it can be a followup.

seq_idx[:, 1:] = acc_mask[:, :-1]

def block_causal_mask(b, h, q_idx, kv_idx):
return (seq_idx[b, q_idx] == seq_idx[b, kv_idx]) & (q_idx >= kv_idx)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NOTE: this is a 2D mask, enforcing block causal attention within each row of a batch. There is no cross-row attention.

@tianyu-l tianyu-l merged commit 7090a7b into main Apr 3, 2025
6 checks passed
@tianyu-l tianyu-l deleted the attention_module branch April 3, 2025 20:27
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Meta Open Source bot.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants