Skip to content

Commit fe17fad

Browse files
Update KVCache maximum sequence length configuration in PPO recipe (#2412)
1 parent 7b654ea commit fe17fad

File tree

2 files changed

+8
-7
lines changed

2 files changed

+8
-7
lines changed

recipes/configs/mistral/7B_full_ppo_low_memory.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ output_dir: /tmp/torchtune/mistral_7B/full_ppo_low_memory # /tmp may be deleted
3030
tokenizer:
3131
_component_: torchtune.models.mistral.mistral_tokenizer
3232
path: /tmp/Mistral-7B-Instruct-v0.2/tokenizer.model
33-
max_seq_len: null
33+
max_seq_len: 512
3434

3535
# Dataset
3636
dataset:

recipes/ppo_full_finetune_single_device.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -234,14 +234,13 @@ def setup(self, cfg: DictConfig) -> None:
234234

235235
# setup a context manager for enabling KV-cacheing during
236236
# trajectory generation if enabled in the config
237-
self.cache_ctx_manager = lambda enable_kv_cache: (
237+
self.cache_ctx_manager = lambda enable_kv_cache, decoder_max_seq_len: (
238238
local_kv_cache(
239239
self._policy_model,
240240
batch_size=self._forward_batch_size,
241241
dtype=self._dtype,
242-
decoder_max_seq_len=self._tokenizer.max_seq_len
243-
+ self._max_generated_tokens,
244242
device=self._device,
243+
decoder_max_seq_len=decoder_max_seq_len,
245244
)
246245
if enable_kv_cache
247246
else contextlib.nullcontext()
@@ -770,9 +769,12 @@ def generate_trajectory(self, input_ids: torch.Tensor) -> Trajectory:
770769
Trajectory: An instance of :class:`~torchtune.rlhf.Trajectory` comprising
771770
the current trajectory.
772771
"""
773-
772+
_, context_length = input_ids.shape
774773
# step 1: generate responses, and logits corresponding to the responses using the current policy
775-
with self.cache_ctx_manager(self.enable_kv_cache):
774+
with self.cache_ctx_manager(
775+
self.enable_kv_cache,
776+
decoder_max_seq_len=context_length + self._max_generated_tokens,
777+
):
776778
query_responses, logits = generation.generate(
777779
model=self._policy_model,
778780
prompt=input_ids,
@@ -782,7 +784,6 @@ def generate_trajectory(self, input_ids: torch.Tensor) -> Trajectory:
782784
pad_id=self._tokenizer.pad_id,
783785
rng=self._rng,
784786
)
785-
_, context_length = input_ids.shape
786787
responses = query_responses[:, context_length:].clone()
787788
query_response_padding_masks = query_responses != self._tokenizer.pad_id
788789

0 commit comments

Comments
 (0)