@@ -341,6 +341,9 @@ def forward(
341341 value_states = self .v_norm (value_states )
342342 value_states = value_states .transpose (1 , 2 )
343343
344+ # CHANGED: the local KV states are also returned, so that under gradient checkpointing (where the layer
345+ # runs cache-less) the calling model can write the cache outside the checkpointed call.
346+ local_key_states , local_value_states = key_states , value_states
344347 if past_key_values is not None :
345348 key_states , value_states = past_key_values .update (key_states , value_states , self .layer_idx )
346349 # CHANGED: removed the `if self.store_full_length_kv` branch
@@ -364,7 +367,7 @@ def forward(
364367
365368 attn_output = attn_output .reshape (* input_shape , - 1 ).contiguous ()
366369 attn_output = self .o_proj (attn_output )
367- return attn_output , attn_weights
370+ return attn_output , attn_weights , local_key_states , local_value_states
368371
369372
370373class DiffusionGemmaDecoderTextAttention (nn .Module ):
@@ -418,7 +421,8 @@ def forward(
418421 hidden_states : torch .Tensor ,
419422 position_embeddings : torch .Tensor ,
420423 attention_mask : torch .Tensor | None ,
421- past_key_values : Cache | None = None ,
424+ encoder_key_states : torch .Tensor | None = None ,
425+ encoder_value_states : torch .Tensor | None = None ,
422426 ** kwargs : Unpack [FlashAttentionKwargs ],
423427 ) -> tuple [torch .Tensor , torch .Tensor | None , tuple [torch .Tensor ] | None ]:
424428 # The code in this function is adapted from Gemma4TextAttention. ** The modified parts are clearly indicated **
@@ -443,11 +447,9 @@ def forward(
443447 value_states = self .v_norm (value_states )
444448 value_states = value_states .transpose (1 , 2 )
445449
446- if past_key_values is not None :
447- # CHANGED: instead of calling `past_key_values.update()` which updates the KV cache in-place and returns
448- # the full KV states, we first obtain the encoder cache contents, and then append the current KV states.
449- encoder_key_states = past_key_values .layers [self .layer_idx ].keys
450- encoder_value_states = past_key_values .layers [self .layer_idx ].values
450+ if encoder_key_states is not None :
451+ # CHANGED: the encoder KV states are passed as plain tensors (extracted from the encoder cache by the
452+ # calling model) so they survive gradient checkpointing; the canvas KV states are appended to them.
451453 key_states = torch .cat ([encoder_key_states , key_states ], dim = 2 )
452454 value_states = torch .cat ([encoder_value_states , value_states ], dim = 2 )
453455 # CHANGED: removed the `if self.store_full_length_kv` branch
@@ -602,11 +604,11 @@ def forward(
602604 position_ids : torch .LongTensor | None = None ,
603605 past_key_values : Cache | None = None ,
604606 ** kwargs ,
605- ) -> torch .Tensor :
607+ ) -> tuple [ torch .Tensor , torch . Tensor , torch . Tensor ] :
606608 residual = hidden_states
607609
608610 hidden_states = self .input_layernorm (hidden_states )
609- hidden_states , _ = self .self_attn (
611+ hidden_states , _ , key_states , value_states = self .self_attn (
610612 hidden_states = hidden_states ,
611613 position_embeddings = position_embeddings ,
612614 attention_mask = attention_mask ,
@@ -638,7 +640,7 @@ def forward(
638640 hidden_states = residual + hidden_states
639641
640642 hidden_states *= self .layer_scalar
641- return hidden_states
643+ return hidden_states , key_states , value_states
642644
643645
644646class DiffusionGemmaDecoderTextLayer (GradientCheckpointingLayer ):
@@ -675,7 +677,8 @@ def forward(
675677 position_embeddings : torch .Tensor = None ,
676678 attention_mask : torch .Tensor | None = None ,
677679 position_ids : torch .LongTensor | None = None ,
678- past_key_values : Cache | None = None ,
680+ encoder_key_states : torch .Tensor | None = None ,
681+ encoder_value_states : torch .Tensor | None = None ,
679682 ** kwargs ,
680683 ) -> torch .Tensor :
681684 residual = hidden_states
@@ -686,7 +689,8 @@ def forward(
686689 position_embeddings = position_embeddings ,
687690 attention_mask = attention_mask ,
688691 position_ids = position_ids ,
689- past_key_values = past_key_values ,
692+ encoder_key_states = encoder_key_states ,
693+ encoder_value_states = encoder_value_states ,
690694 ** kwargs ,
691695 )
692696 hidden_states = self .post_attention_layernorm (hidden_states )
@@ -798,7 +802,7 @@ def forward(self, inputs_embeds, self_conditioning_signal: torch.Tensor) -> torc
798802class DiffusionGemmaPreTrainedModel (PreTrainedModel ):
799803 config : DiffusionGemmaConfig
800804 base_model_prefix = "model"
801- supports_gradient_checkpointing = False
805+ supports_gradient_checkpointing = True
802806 _no_split_modules = [
803807 "DiffusionGemmaDecoderTextLayer" ,
804808 "DiffusionGemmaEncoderTextLayer" ,
@@ -940,14 +944,20 @@ def forward(
940944
941945 # decoder layers
942946 for i , encoder_layer in enumerate (self .layers [: self .config .num_hidden_layers ]):
943- hidden_states = encoder_layer (
947+ # Under gradient checkpointing the layer runs cache-less and the cache write happens outside the
948+ # checkpointed call instead: the returned KV states, as checkpoint outputs, keep the gradient path
949+ # from the decoder open (an in-layer write would be lost and double-applied on recomputation).
950+ checkpointing = encoder_layer .gradient_checkpointing and self .training
951+ hidden_states , key_states , value_states = encoder_layer (
944952 hidden_states ,
945953 position_embeddings = position_embeddings [self .config .layer_types [i ]],
946954 attention_mask = causal_mask_mapping [self .config .layer_types [i ]],
947955 position_ids = position_ids ,
948- past_key_values = past_key_values ,
956+ past_key_values = None if checkpointing else past_key_values ,
949957 ** kwargs ,
950958 )
959+ if checkpointing :
960+ past_key_values .update (key_states , value_states , i )
951961
952962 hidden_states = self .norm (hidden_states )
953963
@@ -1289,7 +1299,8 @@ def forward(
12891299 position_embeddings = position_embeddings [self .text_config .layer_types [i ]],
12901300 attention_mask = mask_mapping [self .text_config .layer_types [i ]],
12911301 position_ids = decoder_position_ids ,
1292- past_key_values = past_key_values ,
1302+ encoder_key_states = past_key_values .layers [i ].keys if past_key_values is not None else None ,
1303+ encoder_value_states = past_key_values .layers [i ].values if past_key_values is not None else None ,
12931304 ** kwargs ,
12941305 )
12951306
0 commit comments