-
Notifications
You must be signed in to change notification settings - Fork 461
Supports block causal mask #1001
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks nice in general. I left some comments on UX.
torchtitan/config_manager.py
Outdated
@@ -198,6 +198,17 @@ def __init__(self): | |||
action="store_true", | |||
help="Whether to use Flex Attention.", | |||
) | |||
self.parser.add_argument( | |||
"--model.attn_bias_type", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's consistently use mask
or bias
but not both.
mask
sounds better to me. bias
doesn't sound accurate.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
shall we change the CI on FlexAttention to use block_causal
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hmmm The implementation choices (flex or flash) is fine to be used as a knob, but I don't think this is sth we should add as a "config". Attention variants are very specific model details, which should be defined as part of the Model (via ModelArgs), instead of making a knob on how to do attention. .
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Isn't attn_bias_type
similar to norm_type
, which is also exposed for configuration? layer_norm and rmsnorm are not mathematically the same. But the main issue is being that block_causal
is not supported by SDPA (without using NJT). So it is quite confusing if we purely treat the attn_bias_type
as the ModelArgs. Then we won't be able to use SDPA for llama3.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@wanchaol
I think the tricky part is -- some model configs are determined from training configs, specifically the items here https://github.com/pytorch/torchtitan/blob/main/torchtitan/models/llama/model.py#L49
norm_type
and vocab_size
are good examples -- they are model config, but need to come from job config or tokenizer.
I feel this attn_mask_type
being causal
or block_causal
is similar. WDYT?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@wanchaol Are suggesting that we should allow user to customize ModelArgs
through toml? Right now, users can only customize ModelArgs
by creating a new model configuration.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One way to do it is to remove all model-specific configs from the base JobConfig, and uses https://github.com/pytorch/torchtitan/blob/main/docs/extension.md#extending-jobconfig inside every model folder, to define model specific args in toml, e.g. --model.attn_mask_type
in LLMs.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@tianyu-l That's reasonable, my question is mainly about how are we going to let users customize their model arguments. Creating different model configurations can cause combinatorial configurations.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was actually thinking that user should just customize the model args by creating a new model configuration instead of being able to configure via TOML. Is there any reason we want to add a new TOML config instead of adding a new model config, I guess mainly because it's simpler for us to experiment on llama models?
In general I feel such a fine grained control on model details can't be easily achieved by sth like TOML (maybe sth like fiddle can).
@tianyu-l suggestion might also work. I think given that we already have norm_type
here, adding another option is ok, but I do want to let us think more about how to handle model arg customizations
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@wanchaol
Not all people fork torchtitan. A lot of people would use torchtitan as submodule (to use latest features!). Their central request is how to easily config things while not touching the core part. E.g. #790
Currently for model config, the lightest way to add a new one is to create a new TrainSpec with a new config dictionary https://github.com/pytorch/torchtitan/blob/main/torchtitan/protocols/train_spec.py#L69
It sounds slightly too heavy for me if it's just switching things like sequence length (for RoPE), and causal vs. block-causal masking, but I'm OK with that.
I personally think with per-model customized argparser and protocols like this https://github.com/pytorch/torchtitan/blob/main/torchtitan/models/llama/model.py#L49, it doesn't look too bad.
torchtitan/models/attention.py
Outdated
return causal_mask | ||
|
||
@classmethod | ||
def _get_block_causal_mask_fn(cls, batch: torch.Tensor, eos_id: int) -> Callable: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
eos_id
never changes throughout the training. Shall we make it a variable for this class?
@@ -45,12 +40,15 @@ class TransformerModelArgs(BaseModelArgs): | |||
norm_type: str = "rmsnorm" | |||
|
|||
use_flex_attn: bool = False | |||
attn_bias_type: str = "causal" | |||
eos_id: int = 0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this necessary? I think the model should not have to know about properties of the tokenizer like this
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is this an issue? It is configurable and is given by tokenizer, not defined by the model. Otherwise, how will the attention module separate the tokens?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ebsmothers
I feel it would be one way or another -- we either put eos_id
here, or each iteration would take a mask_mod
function as model input. It is not clear to me which is cleaner -- I feel current one is not bad. WDYT?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't have a strong opinion of which approach is better. Both are viable approaches. I just want to emphasize that the model has to know eos_id
, whether it is saved as an instance variable or is passed as an argument through forward.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@fegin maybe I am misunderstanding something here then. It seems to me that most of the logic for constructing the block mask should be handled by the dataloader. You are already iterating over a bunch of samples that you want to pack, right? Why not just use this opportunity to construct a list of sequence lengths in the pack? Then this can be used to construct a BlockMask for flex without the model needing to know anything about eos_id. @tianyu-l iiuc this is kinda similar to what you're proposing, but passing BlockMask
instead of the mask_mod
function
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@fegin I also didn't get this
I just want to emphasize that the model has to know eos_id, whether it is saved as an instance variable or is passed as an argument through forward.
An alternative seems to pass the block_mask_fn
(say from data loader, or an util function) to model forward. Technically, in this case, the model doesn't know eos_id
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe I am misunderstanding something here then. It seems to me that most of the logic for constructing the block mask should be handled by the dataloader.
This is debatable. If you are building the block causal mask inside dataloader, you are polluting the dataloader. There are options to use either Flex or Flash and there can be more than just simple block causal. Why does dataloader need to know whether the model is using Flex or Flash or what kind of masking does the model use? It is just that which component you choose to know the information of the other component.
We discussed this internally about whether to couple building mask with dataloader, there was an opinion internally to decouple dataloader from the attention implementation as researchers can do different attention masks without changing dataloader but just knowing EOS_ID is enough. I would prefer to keep that decision even though that was not the decision specifically for TorchTitan but the discussion was how to do CP + Flex for PTD.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just one more quick point, this is our generic Document mask re-writer: https://github.com/pytorch-labs/attention-gym/blob/001b36d625aceae8a03f59241113e4797122db1d/attn_gym/masks/document_mask.py#L33. It takes in a two things, another mask_mod and a tensor specifying the boundaries. I find this de-coupiling pretty attractive where you decouple the generation of the extra metadata and which mask mod you want to perform, in this concrete case it is "causal". In this case the choice is "causal" but someone has to own setting up your inpt tensor in the packed format and the generation of the extra metadata likely should be colocated IMO
There are discussions regarding 1) where to initialize the mask, 2) how to let users customize the attention type. There is no conclusion yet. So I will keep the current implementations but will have some discussions after we have more models using block causal and other advanced attention type. |
New changes:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As discussed offline, let's fix PP microbatch related logic.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Had one question about block causal mask fn.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. I feel that after we modify the mask, the benchmark also need to be redone. But it can be a followup.
seq_idx[:, 1:] = acc_mask[:, :-1] | ||
|
||
def block_causal_mask(b, h, q_idx, kv_idx): | ||
return (seq_idx[b, q_idx] == seq_idx[b, kv_idx]) & (q_idx >= kv_idx) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
NOTE: this is a 2D mask, enforcing block causal attention within each row of a batch. There is no cross-row attention.
This PR adds support of block causal/document mask. We currently only add this feature for FlexAttention.