@@ -234,14 +234,13 @@ def setup(self, cfg: DictConfig) -> None:
234234
235235 # setup a context manager for enabling KV-cacheing during
236236 # trajectory generation if enabled in the config
237- self .cache_ctx_manager = lambda enable_kv_cache : (
237+ self .cache_ctx_manager = lambda enable_kv_cache , decoder_max_seq_len : (
238238 local_kv_cache (
239239 self ._policy_model ,
240240 batch_size = self ._forward_batch_size ,
241241 dtype = self ._dtype ,
242- decoder_max_seq_len = self ._tokenizer .max_seq_len
243- + self ._max_generated_tokens ,
244242 device = self ._device ,
243+ decoder_max_seq_len = decoder_max_seq_len ,
245244 )
246245 if enable_kv_cache
247246 else contextlib .nullcontext ()
@@ -770,9 +769,12 @@ def generate_trajectory(self, input_ids: torch.Tensor) -> Trajectory:
770769 Trajectory: An instance of :class:`~torchtune.rlhf.Trajectory` comprising
771770 the current trajectory.
772771 """
773-
772+ _ , context_length = input_ids . shape
774773 # step 1: generate responses, and logits corresponding to the responses using the current policy
775- with self .cache_ctx_manager (self .enable_kv_cache ):
774+ with self .cache_ctx_manager (
775+ self .enable_kv_cache ,
776+ decoder_max_seq_len = context_length + self ._max_generated_tokens ,
777+ ):
776778 query_responses , logits = generation .generate (
777779 model = self ._policy_model ,
778780 prompt = input_ids ,
@@ -782,7 +784,6 @@ def generate_trajectory(self, input_ids: torch.Tensor) -> Trajectory:
782784 pad_id = self ._tokenizer .pad_id ,
783785 rng = self ._rng ,
784786 )
785- _ , context_length = input_ids .shape
786787 responses = query_responses [:, context_length :].clone ()
787788 query_response_padding_masks = query_responses != self ._tokenizer .pad_id
788789
0 commit comments