Skip to content

Conversation

@SalmanMohammadi
Copy link
Contributor

@SalmanMohammadi SalmanMohammadi commented Feb 19, 2025

Bug was raised in #2064. When setting up KV caches, we use tokenizer.max_seq_len to determine the shape of the KV cache. As this was not configured in the config, we would error out.

The direct fix which I'm introducing in this PR is to dynamically construct the KV cache shape based on the context length of the current batch, rather than the fixed tokenizer.max_seq_len.

I also introduce the default tokenizer.max_seq_len=512 as this was the config I used in #2066. This is a more sensible value to start out with than null.

Please see training logs below for this branch:

tune run ppo_full_finetune_single_device \
--config mistral/7B_full_ppo_low_memory \
metric_logger=torchtune.training.metric_logging.WandBLogger \
metric_logger.project=ppo_cfg_updates \
metric_logger.name={branch} \
num_steps=4992 \
tokenizer.path=/workspace/Mistral-7B-Instruct-v0.2/tokenizer.model \
checkpointer.checkpoint_dir=/workspace/Mistral-7B-Instruct-v0.2/ \
ref_policy_checkpointer.checkpoint_dir=/workspace/Mistral-7B-Instruct-v0.2/ \
value_checkpointer.checkpoint_dir=/workspace/RM-Mistral-7B/ \
reward_checkpointer.checkpoint_dir=/workspace/RM-Mistral-7B/ \
log_peak_memory_stats=True \
tokenizer.max_seq_len=512 \
compile={compile} \
forward_batch_size=128 \
batch_size=128 \
ppo_batch_size=128 \
gradient_accumulation_steps=1 \
enable_kv_cache=True \
ppo_epochs=1 \
enable_activation_checkpointing=True 
image

@pytorch-bot
Copy link

pytorch-bot bot commented Feb 19, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/2412

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit a07588a with merge base 504cbea (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Feb 19, 2025
@codecov-commenter
Copy link

Codecov Report

Attention: Patch coverage is 0% with 2 lines in your changes missing coverage. Please review.

Project coverage is 23.47%. Comparing base (504cbea) to head (8f529c8).
Report is 1 commits behind head on main.

Files with missing lines Patch % Lines
recipes/ppo_full_finetune_single_device.py 0.00% 2 Missing ⚠️

❗ There is a different number of reports uploaded between BASE (504cbea) and HEAD (8f529c8). Click for more details.

HEAD has 2 uploads less than BASE
Flag BASE (504cbea) HEAD (8f529c8)
3 1
Additional details and impacted files
@@             Coverage Diff             @@
##             main    #2412       +/-   ##
===========================================
- Coverage   63.87%   23.47%   -40.41%     
===========================================
  Files         368      373        +5     
  Lines       21873    22403      +530     
===========================================
- Hits        13971     5258     -8713     
- Misses       7902    17145     +9243     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@SalmanMohammadi SalmanMohammadi changed the title Ensuring max_seq_len is set in PPO recipe when using kv-cacheing Update KVCache maximum sequence length configuration in PPO recipe Feb 19, 2025
@SalmanMohammadi SalmanMohammadi merged commit fe17fad into meta-pytorch:main Feb 25, 2025
17 checks passed
@SalmanMohammadi SalmanMohammadi deleted the mistral_ppo_config branch February 25, 2025 13:29
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants