-
Notifications
You must be signed in to change notification settings - Fork 667
[RFC] on-the-fly packing #2819
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: impl-step-based-ckpt
Are you sure you want to change the base?
[RFC] on-the-fly packing #2819
Conversation
This reverts commit 901723f.
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/2819
Note: Links to docs will display an error until the docs builds have been completed. ❌ 2 New Failures, 1 Cancelled Job, 3 Unrelated FailuresAs of commit 4c505e0 with merge base 3d73591 ( NEW FAILURES - The following jobs have failed:
CANCELLED JOB - The following job was cancelled. Please retry:
FLAKY - The following jobs failed but were likely due to flakiness present on trunk:
BROKEN TRUNK - The following job failed but were present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
Creates an empty pack. | ||
|
||
Returns: | ||
dict[str, list[Any]]: An empty dictionary with lists as values. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
there is no real requirement for it to be a dictionary of lists. We only need finalize_pack to be a dictionary of tensors.
@@ -611,6 +612,7 @@ def forward( | |||
This parameter is required during inference if caches have been setup. Default is None. | |||
input_embeds (Optional[torch.Tensor]): Pass these instead of tokens to short-circuit token embeddings | |||
and skip straight to the transformer layers. Shape ``[b x s x d]``. Default: None | |||
**kwargs (dict): Keyword arguments to pass to the transformer layers. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
needed to add because we do model(**batch), and batch may contain keys that the model doesnt use.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a very good PR!
I have a nit, and something that I think we should change a bit.
First, the nit: I wonder if we shouldn't just make a Pack class (or, just use NestedTensors?). It's nice and self-contained so it would make sense to me. Then we can have a .to_padded()
to convert.
The thing I think we should revisit is the packing strategy abstraction. I like the idea a lot, but it currently is not implementing a way to swap actual strategies and instead deals with formats (SFT vs DPO).
Left detailed comments inline
torchtune/data/_collate.py
Outdated
if isinstance(batch[0][key], torch.Tensor): | ||
collated[key] = torch.stack([sample[key] for sample in batch], dim=0) | ||
else: | ||
# TODO: Remove? i dont see a situation where it would not be a tensor. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree
self._resuming = True | ||
|
||
|
||
class TextPackingStrategy(PackingStrategy[dict[str, list[int]]]): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I like the abstraction of exposing a packing strategy, this is super cool and a good abstraction for others to customize as indeed there are different strategies that one can follow depending on what they care about.
My main feedback is that this is not currently really implementing any strategy, we are instead simply implementing a format -- you should probably call this SFTPackingFormat
and the next class DPOPackingFormat
.
Let us be more concrete here. I think that the two main strategies that are viable are greedy packing vs binned packing. We should just implemented greedy packing for now, but let others take on the mantle of doing binned packing (file an issue with help wanted) as it is genuinely useful for long context training.
The way that I would layer this would be the following:
- The IterablePackedDataset owns having a big buffer that's shuffled and owns keeping it full (what we do is reasonable, we can be fancier by having a background process fill it continuously and using reservoir sampling to shuffle in-place, but that's a minor thing). It provides any strategy with the buffer, and leaves how to go from buffer to batches to the strategy. Note that it's also possible that not every sample in the buffer ends up being put into batches, the strategy should have the option to also discard data it doesn't want.
- The strategy is what is responsible for the actual batches. So, a greedy strategy would just do a while loop where it greedily takes samples until it crosses the token budget. When that happens, it puts the last consumed sample "back" and pads to max length, emitting the batch. A binned strategy would instead bin the whole buffer into K bins based on length, and sample amongst them according to the weights it's provided.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Now, having both strategies and formats seems too much 🤔
On the other hand, not having the format abstraction means that if we ever wanted to add more strategies, we would risk having to implement multiple classes (e.g. SFTBinningPackingStrategy
and DPOBinningPackingStrategy
) which suuuuuucks.
In the general case, I believe that this cannot fully be avoided, but for the cases that people actually care about, I believe we can avoid it by being clever on how we write the strategies.
For example, both the greedy and "simple" binning strategy would simply ignore your format and look at the input_ids
column to make its decisions, and simply pack everything as it's given it. We can also exposes a input_ids_key_name
arg to be more general.
We cannot exclude that someone may want to look at multiple columns to come up with a crazy packing strategy, but in that case they can always write a super custom class.
And then, rather than have a Format
ABC, I believe I would just push this to the Dataset. We can make a base class that implements the buffer prefetching and the basics, and that requires that you implement a couple of methods to tell it what format you want (literally the ones you have in the current Strategy classes). But if you feel like Formats work better I'm ok with it, this is just a minor preference
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wonder if we shouldn't just make a Pack class (or, just use NestedTensors?) [...] Then we can have a .to_padded() to convert.
- Do you mean Pack class as a dataclass for the output?
- The thing about .to_padded() is that each [key,value] may have a different logic, e.g. pad_id, ignore_idx, some image related info. Also, not sure how NestedTensors interact with distributed/compile/etc.
but it currently is not implementing a way to swap actual strategies
- I see. Thats fair. In this case, to swap to binned we could make the change in
IterablePackedDataset
. More specifically, in_find_next_fitting_sample
. The current strategy is semi-greedy. It is greedy until the very next example doesnt fit. Then, it iterates over the buffer. My intuition is that binning would NOT improve much upon this. But maybe i am wrong. I also worry about biasing the sample selection by size.
something like:
IterablePackedDataset(format = DPOPackingFormat, packing_strategy='binning')
All that it needs is the strategy.get_sample_size
. The buffer already saves the info as self._buffer.append((sample, sample_size))
.
The IterablePackedDataset owns having a big buffer that's shuffled
- I was not planning on having shuffle in 'IterablePackedDataset'
In my design, it would be like this:
for dataset in datasets:
dataset.to_iterable(num_shards)
dataset.shuffle()
dataset.distribute()
dataset = interleave_datasets(datasets, weights)
PackedDataset(dataset)
That way, PackedDataset is only responsible for packing, and it seems safer to shuffle before distribute and interleave. Let me know what you think.
rather than have a Format ABC, I believe I would just push this to the Dataset.
I thought about merging the classes and just having class PackedSFTTextDataset(PackedDataset)
(or whatever we name it). The reason i left them as 2 classes was to give a clear message that one does NOT need to know about the other. But i am flexible here if others people have the same preference as you.
|
||
# NOTE: For demonstration purposes only. | ||
|
||
class DPOPackingStrategy(PackingStrategy[dict[str, list[int]]]): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See comment above
Curious, how is packed batch represented? With a NJT? At least, NJT seems to have been introduced precisely for this - and as tensor subclass to be a foundation for plugging with torch.compile / distributed). Although, it's not sure if existing NJT kernels are suitable for CUDA graphs... (for this kernels need not to specialize on the actual total numel or on |
hey, thanks for the comment! We could use nested tensors, but i am not sure how well it works with all of pytorch's ecosystem. Using the vanilla tensors we already have enough issues with optimizers, parallelism, compile, sdpa/flash attention, different devices (amd, cuda, etc), quantization, etc, hence why i didnt go down that path. So a packed batch in this PR is just a batch of concatenated tensors. The example in the PR should help seeing how the concatenated tensors look like. |
Are the representations of concatenated tensors / offsets compatible to NJT? I thought the promise of NJT was that pytorch's regular ops can now dispatch to NJT-specific kernels... And that there are some SDPA support for NJT directly... @jbschlosser or am I missing something? Or if NJT is not a good fit for this usecase (varlen sequences packed together) and does not compose sufficiently well with other aspects, why NJT exists?.. I wonder if at least NJT raw kernels consuming concatenated tensors/offsets can be useful for packed tensors in this PR |
…iterable_dataset_final
What:
Packing is the process of putting together samples until a certain target size is reached. This is done to reduce the number of padding tokens in a batch. To avoid contamination between samples, we use a document-level causal mask. To make it faster, we use flex attention to handle the special mask.
Example:
Goal:
Context:
Solution:
i) a
PackingStrategy
that defines how a) to pack and b) the _mask_mod used for flex attention;ii) a
IterablePackedDataset
that takes any a)PackingStrategy
, b) iterable dataset as input and yields packed samples;iii) a
packed_collate_fn
that takes the batch of packed samples and a mask_fn (e.g.strategy.create_block_mask
) to generate the attention mask on the fly.To define a new packing strategy, the user only needs to implement the
PackingStrategy
class.Implementation:
full_finetune_distributed.py
to useIterablePackedDataset
when packing is enabled. There are challenges related to iterable datasets and this will be tackled in a separate iterable dataset PR. Changes made were to enable it to run for this RFC.Not in this PR: