Skip to content

[RFC] on-the-fly packing #2819

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 57 commits into
base: impl-step-based-ckpt
Choose a base branch
from

Conversation

felipemello1
Copy link
Contributor

@felipemello1 felipemello1 commented Jun 12, 2025

What:

Packing is the process of putting together samples until a certain target size is reached. This is done to reduce the number of padding tokens in a batch. To avoid contamination between samples, we use a document-level causal mask. To make it faster, we use flex attention to handle the special mask.

Example:

# The current pack with one sample
pack = {"tokens": [1, 2], "labels": [3, 4], "document_ids": [0, 0], "input_pos": [0, 1]}

# The next sample to be added
sample = {"tokens": [5, 6], "labels": [7, 8]}

# After adding the sample
added_docs = add_sample_to_pack(pack, sample, next_doc_id=1)
print(pack)
>>> {"tokens": [1, 2, 5, 6],
    "labels": [3, 4, 7, 8],
    "document_ids": [0, 0, 1, 1],
    "input_pos": [0, 1, 0, 1]}

create_block_causal_mask(document_ids)
>>> [
     [1, 0, 0, 0],
     [1, 1, 0, 0],
     [0, 0, 1, 0],
     [0, 0, 1, 1],
    ]

Goal:

  1. Make packing a first-class citizen in TorchTune, available for all sorts of models and recipes.

Context:

  1. We currently have map-style packing. We pre-process the entire dataset before training starts, which is not scalable.
  2. Packing is only present for SFT + text data. There is no contract for how to extend it to multimodal, DPO, etc.
  3. Collate function has to be aware of packing logic. This is currently hardcoded in the recipe with if/else.

Solution:

  1. Implement a new on-the-fly packing that takes any iterable dataset as input;
  2. Packing contract consists of
    i) a PackingStrategy that defines how a) to pack and b) the _mask_mod used for flex attention;
    ii) a IterablePackedDataset that takes any a) PackingStrategy, b) iterable dataset as input and yields packed samples;
    iii) a packed_collate_fn that takes the batch of packed samples and a mask_fn (e.g. strategy.create_block_mask) to generate the attention mask on the fly.
    To define a new packing strategy, the user only needs to implement the PackingStrategy class.

Implementation:

  1. Updated full_finetune_distributed.py to use IterablePackedDataset when packing is enabled. There are challenges related to iterable datasets and this will be tackled in a separate iterable dataset PR. Changes made were to enable it to run for this RFC.

Not in this PR:

  1. Logging: Since we cannot do len(iterable_dataset), we need to add proper logging/metadata to assist users in understanding how far along they are on each dataset and metrics regarding the samples (avg num tokens, avg num samples / pack, etc.)
  2. Packing-aware Loss: For SFT, the same loss works for map-style and packing. This is not the case for DPO/GRPO, which would need different masking. Future work will have to handle how to associate packing with a loss that supports it.
  3. Packing-aware metrics: Advanced metrics, such as logprob per sample, would require to be aware of packing;
  4. tokenization: For advanced packing, e.g. shared prompts in GRPO/DPO, we will need extra metadata from upstream datasets, e.g. prompt len.
PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True tune run --nproc_per_node 2 full_finetune_distributed --config llama3_2/3B_full tokenizer.max_seq_len=16384 metric_logger=torchtune.training.metric_logging.WandBLogger batch_size=2 max_steps_per_epoch=100 compile=True output_dir="/data/users/felipemello/experiments" dataset.packed=True

image

Copy link

pytorch-bot bot commented Jun 12, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/2819

Note: Links to docs will display an error until the docs builds have been completed.

❌ 2 New Failures, 1 Cancelled Job, 3 Unrelated Failures

As of commit 4c505e0 with merge base 3d73591 (image):

NEW FAILURES - The following jobs have failed:

CANCELLED JOB - The following job was cancelled. Please retry:

FLAKY - The following jobs failed but were likely due to flakiness present on trunk:

BROKEN TRUNK - The following job failed but were present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jun 12, 2025
Creates an empty pack.

Returns:
dict[str, list[Any]]: An empty dictionary with lists as values.
Copy link
Contributor Author

@felipemello1 felipemello1 Jun 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there is no real requirement for it to be a dictionary of lists. We only need finalize_pack to be a dictionary of tensors.

@@ -611,6 +612,7 @@ def forward(
This parameter is required during inference if caches have been setup. Default is None.
input_embeds (Optional[torch.Tensor]): Pass these instead of tokens to short-circuit token embeddings
and skip straight to the transformer layers. Shape ``[b x s x d]``. Default: None
**kwargs (dict): Keyword arguments to pass to the transformer layers.
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

needed to add because we do model(**batch), and batch may contain keys that the model doesnt use.

Copy link
Contributor

@Darktex Darktex left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a very good PR!

I have a nit, and something that I think we should change a bit.

First, the nit: I wonder if we shouldn't just make a Pack class (or, just use NestedTensors?). It's nice and self-contained so it would make sense to me. Then we can have a .to_padded() to convert.

The thing I think we should revisit is the packing strategy abstraction. I like the idea a lot, but it currently is not implementing a way to swap actual strategies and instead deals with formats (SFT vs DPO).

Left detailed comments inline

if isinstance(batch[0][key], torch.Tensor):
collated[key] = torch.stack([sample[key] for sample in batch], dim=0)
else:
# TODO: Remove? i dont see a situation where it would not be a tensor.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree

self._resuming = True


class TextPackingStrategy(PackingStrategy[dict[str, list[int]]]):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like the abstraction of exposing a packing strategy, this is super cool and a good abstraction for others to customize as indeed there are different strategies that one can follow depending on what they care about.

My main feedback is that this is not currently really implementing any strategy, we are instead simply implementing a format -- you should probably call this SFTPackingFormat and the next class DPOPackingFormat.

Let us be more concrete here. I think that the two main strategies that are viable are greedy packing vs binned packing. We should just implemented greedy packing for now, but let others take on the mantle of doing binned packing (file an issue with help wanted) as it is genuinely useful for long context training.

The way that I would layer this would be the following:

  • The IterablePackedDataset owns having a big buffer that's shuffled and owns keeping it full (what we do is reasonable, we can be fancier by having a background process fill it continuously and using reservoir sampling to shuffle in-place, but that's a minor thing). It provides any strategy with the buffer, and leaves how to go from buffer to batches to the strategy. Note that it's also possible that not every sample in the buffer ends up being put into batches, the strategy should have the option to also discard data it doesn't want.
  • The strategy is what is responsible for the actual batches. So, a greedy strategy would just do a while loop where it greedily takes samples until it crosses the token budget. When that happens, it puts the last consumed sample "back" and pads to max length, emitting the batch. A binned strategy would instead bin the whole buffer into K bins based on length, and sample amongst them according to the weights it's provided.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now, having both strategies and formats seems too much 🤔

On the other hand, not having the format abstraction means that if we ever wanted to add more strategies, we would risk having to implement multiple classes (e.g. SFTBinningPackingStrategy and DPOBinningPackingStrategy) which suuuuuucks.

In the general case, I believe that this cannot fully be avoided, but for the cases that people actually care about, I believe we can avoid it by being clever on how we write the strategies.

For example, both the greedy and "simple" binning strategy would simply ignore your format and look at the input_ids column to make its decisions, and simply pack everything as it's given it. We can also exposes a input_ids_key_name arg to be more general.

We cannot exclude that someone may want to look at multiple columns to come up with a crazy packing strategy, but in that case they can always write a super custom class.

And then, rather than have a Format ABC, I believe I would just push this to the Dataset. We can make a base class that implements the buffer prefetching and the basics, and that requires that you implement a couple of methods to tell it what format you want (literally the ones you have in the current Strategy classes). But if you feel like Formats work better I'm ok with it, this is just a minor preference

Copy link
Contributor Author

@felipemello1 felipemello1 Jun 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if we shouldn't just make a Pack class (or, just use NestedTensors?) [...] Then we can have a .to_padded() to convert.

  1. Do you mean Pack class as a dataclass for the output?
  2. The thing about .to_padded() is that each [key,value] may have a different logic, e.g. pad_id, ignore_idx, some image related info. Also, not sure how NestedTensors interact with distributed/compile/etc.

but it currently is not implementing a way to swap actual strategies

  1. I see. Thats fair. In this case, to swap to binned we could make the change in IterablePackedDataset. More specifically, in _find_next_fitting_sample. The current strategy is semi-greedy. It is greedy until the very next example doesnt fit. Then, it iterates over the buffer. My intuition is that binning would NOT improve much upon this. But maybe i am wrong. I also worry about biasing the sample selection by size.

something like:

IterablePackedDataset(format = DPOPackingFormat, packing_strategy='binning')

All that it needs is the strategy.get_sample_size. The buffer already saves the info as self._buffer.append((sample, sample_size)).

The IterablePackedDataset owns having a big buffer that's shuffled

  1. I was not planning on having shuffle in 'IterablePackedDataset'
    In my design, it would be like this:
for dataset in datasets:
	dataset.to_iterable(num_shards)
	dataset.shuffle()
	dataset.distribute()

dataset = interleave_datasets(datasets, weights)

PackedDataset(dataset)

That way, PackedDataset is only responsible for packing, and it seems safer to shuffle before distribute and interleave. Let me know what you think.

rather than have a Format ABC, I believe I would just push this to the Dataset.

I thought about merging the classes and just having class PackedSFTTextDataset(PackedDataset) (or whatever we name it). The reason i left them as 2 classes was to give a clear message that one does NOT need to know about the other. But i am flexible here if others people have the same preference as you.


# NOTE: For demonstration purposes only.

class DPOPackingStrategy(PackingStrategy[dict[str, list[int]]]):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See comment above

@vadimkantorov
Copy link

vadimkantorov commented Jun 16, 2025

Curious, how is packed batch represented? With a NJT? At least, NJT seems to have been introduced precisely for this - and as tensor subclass to be a foundation for plugging with torch.compile / distributed). Although, it's not sure if existing NJT kernels are suitable for CUDA graphs... (for this kernels need not to specialize on the actual total numel or on .shape[0] and read them from GPU memory)

@felipemello1
Copy link
Contributor Author

felipemello1 commented Jun 16, 2025

Curious, how is packed batch represented? With a NJT? At least, NJT seems to have been introduced precisely for this - and as tensor subclass to be a foundation for plugging with torch.compile / distributed). Although, it's not sure if existing NJT kernels are suitable for CUDA graphs... (for this kernels need not to specialize on the actual total numel or on .shape[0] and read them from GPU memory)

hey, thanks for the comment!

We could use nested tensors, but i am not sure how well it works with all of pytorch's ecosystem. Using the vanilla tensors we already have enough issues with optimizers, parallelism, compile, sdpa/flash attention, different devices (amd, cuda, etc), quantization, etc, hence why i didnt go down that path.

So a packed batch in this PR is just a batch of concatenated tensors. The example in the PR should help seeing how the concatenated tensors look like.

@vadimkantorov
Copy link

vadimkantorov commented Jun 16, 2025

Are the representations of concatenated tensors / offsets compatible to NJT?

I thought the promise of NJT was that pytorch's regular ops can now dispatch to NJT-specific kernels... And that there are some SDPA support for NJT directly...

@jbschlosser or am I missing something?

Or if NJT is not a good fit for this usecase (varlen sequences packed together) and does not compose sufficiently well with other aspects, why NJT exists?.. I wonder if at least NJT raw kernels consuming concatenated tensors/offsets can be useful for packed tensors in this PR

@felipemello1 felipemello1 changed the base branch from main to dcp_ckpt_test July 2, 2025 20:39
@felipemello1 felipemello1 changed the base branch from dcp_ckpt_test to impl-step-based-ckpt July 2, 2025 20:39
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants