Skip to content

Commit 6ac3d9d

Browse files
committed
Support gradient checkpointing in DiffusionGemma
1 parent 8014139 commit 6ac3d9d

2 files changed

Lines changed: 54 additions & 32 deletions

File tree

src/transformers/models/diffusion_gemma/modeling_diffusion_gemma.py

Lines changed: 27 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -341,6 +341,9 @@ def forward(
341341
value_states = self.v_norm(value_states)
342342
value_states = value_states.transpose(1, 2)
343343

344+
# CHANGED: the local KV states are also returned, so that under gradient checkpointing (where the layer
345+
# runs cache-less) the calling model can write the cache outside the checkpointed call.
346+
local_key_states, local_value_states = key_states, value_states
344347
if past_key_values is not None:
345348
key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx)
346349
# CHANGED: removed the `if self.store_full_length_kv` branch
@@ -364,7 +367,7 @@ def forward(
364367

365368
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
366369
attn_output = self.o_proj(attn_output)
367-
return attn_output, attn_weights
370+
return attn_output, attn_weights, local_key_states, local_value_states
368371

369372

370373
class DiffusionGemmaDecoderTextAttention(nn.Module):
@@ -418,7 +421,8 @@ def forward(
418421
hidden_states: torch.Tensor,
419422
position_embeddings: torch.Tensor,
420423
attention_mask: torch.Tensor | None,
421-
past_key_values: Cache | None = None,
424+
encoder_key_states: torch.Tensor | None = None,
425+
encoder_value_states: torch.Tensor | None = None,
422426
**kwargs: Unpack[FlashAttentionKwargs],
423427
) -> tuple[torch.Tensor, torch.Tensor | None, tuple[torch.Tensor] | None]:
424428
# The code in this function is adapted from Gemma4TextAttention. ** The modified parts are clearly indicated **
@@ -443,11 +447,9 @@ def forward(
443447
value_states = self.v_norm(value_states)
444448
value_states = value_states.transpose(1, 2)
445449

446-
if past_key_values is not None:
447-
# CHANGED: instead of calling `past_key_values.update()` which updates the KV cache in-place and returns
448-
# the full KV states, we first obtain the encoder cache contents, and then append the current KV states.
449-
encoder_key_states = past_key_values.layers[self.layer_idx].keys
450-
encoder_value_states = past_key_values.layers[self.layer_idx].values
450+
if encoder_key_states is not None:
451+
# CHANGED: the encoder KV states are passed as plain tensors (extracted from the encoder cache by the
452+
# calling model) so they survive gradient checkpointing; the canvas KV states are appended to them.
451453
key_states = torch.cat([encoder_key_states, key_states], dim=2)
452454
value_states = torch.cat([encoder_value_states, value_states], dim=2)
453455
# CHANGED: removed the `if self.store_full_length_kv` branch
@@ -602,11 +604,11 @@ def forward(
602604
position_ids: torch.LongTensor | None = None,
603605
past_key_values: Cache | None = None,
604606
**kwargs,
605-
) -> torch.Tensor:
607+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
606608
residual = hidden_states
607609

608610
hidden_states = self.input_layernorm(hidden_states)
609-
hidden_states, _ = self.self_attn(
611+
hidden_states, _, key_states, value_states = self.self_attn(
610612
hidden_states=hidden_states,
611613
position_embeddings=position_embeddings,
612614
attention_mask=attention_mask,
@@ -638,7 +640,7 @@ def forward(
638640
hidden_states = residual + hidden_states
639641

640642
hidden_states *= self.layer_scalar
641-
return hidden_states
643+
return hidden_states, key_states, value_states
642644

643645

644646
class DiffusionGemmaDecoderTextLayer(GradientCheckpointingLayer):
@@ -675,7 +677,8 @@ def forward(
675677
position_embeddings: torch.Tensor = None,
676678
attention_mask: torch.Tensor | None = None,
677679
position_ids: torch.LongTensor | None = None,
678-
past_key_values: Cache | None = None,
680+
encoder_key_states: torch.Tensor | None = None,
681+
encoder_value_states: torch.Tensor | None = None,
679682
**kwargs,
680683
) -> torch.Tensor:
681684
residual = hidden_states
@@ -686,7 +689,8 @@ def forward(
686689
position_embeddings=position_embeddings,
687690
attention_mask=attention_mask,
688691
position_ids=position_ids,
689-
past_key_values=past_key_values,
692+
encoder_key_states=encoder_key_states,
693+
encoder_value_states=encoder_value_states,
690694
**kwargs,
691695
)
692696
hidden_states = self.post_attention_layernorm(hidden_states)
@@ -798,7 +802,7 @@ def forward(self, inputs_embeds, self_conditioning_signal: torch.Tensor) -> torc
798802
class DiffusionGemmaPreTrainedModel(PreTrainedModel):
799803
config: DiffusionGemmaConfig
800804
base_model_prefix = "model"
801-
supports_gradient_checkpointing = False
805+
supports_gradient_checkpointing = True
802806
_no_split_modules = [
803807
"DiffusionGemmaDecoderTextLayer",
804808
"DiffusionGemmaEncoderTextLayer",
@@ -940,14 +944,20 @@ def forward(
940944

941945
# decoder layers
942946
for i, encoder_layer in enumerate(self.layers[: self.config.num_hidden_layers]):
943-
hidden_states = encoder_layer(
947+
# Under gradient checkpointing the layer runs cache-less and the cache write happens outside the
948+
# checkpointed call instead: the returned KV states, as checkpoint outputs, keep the gradient path
949+
# from the decoder open (an in-layer write would be lost and double-applied on recomputation).
950+
checkpointing = encoder_layer.gradient_checkpointing and self.training
951+
hidden_states, key_states, value_states = encoder_layer(
944952
hidden_states,
945953
position_embeddings=position_embeddings[self.config.layer_types[i]],
946954
attention_mask=causal_mask_mapping[self.config.layer_types[i]],
947955
position_ids=position_ids,
948-
past_key_values=past_key_values,
956+
past_key_values=None if checkpointing else past_key_values,
949957
**kwargs,
950958
)
959+
if checkpointing:
960+
past_key_values.update(key_states, value_states, i)
951961

952962
hidden_states = self.norm(hidden_states)
953963

@@ -1289,7 +1299,8 @@ def forward(
12891299
position_embeddings=position_embeddings[self.text_config.layer_types[i]],
12901300
attention_mask=mask_mapping[self.text_config.layer_types[i]],
12911301
position_ids=decoder_position_ids,
1292-
past_key_values=past_key_values,
1302+
encoder_key_states=past_key_values.layers[i].keys if past_key_values is not None else None,
1303+
encoder_value_states=past_key_values.layers[i].values if past_key_values is not None else None,
12931304
**kwargs,
12941305
)
12951306

src/transformers/models/diffusion_gemma/modular_diffusion_gemma.py

Lines changed: 27 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -269,6 +269,9 @@ def forward(
269269
value_states = self.v_norm(value_states)
270270
value_states = value_states.transpose(1, 2)
271271

272+
# CHANGED: the local KV states are also returned, so that under gradient checkpointing (where the layer
273+
# runs cache-less) the calling model can write the cache outside the checkpointed call.
274+
local_key_states, local_value_states = key_states, value_states
272275
if past_key_values is not None:
273276
key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx)
274277
# CHANGED: removed the `if self.store_full_length_kv` branch
@@ -292,7 +295,7 @@ def forward(
292295

293296
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
294297
attn_output = self.o_proj(attn_output)
295-
return attn_output, attn_weights
298+
return attn_output, attn_weights, local_key_states, local_value_states
296299

297300

298301
class DiffusionGemmaDecoderTextAttention(nn.Module):
@@ -346,7 +349,8 @@ def forward(
346349
hidden_states: torch.Tensor,
347350
position_embeddings: torch.Tensor,
348351
attention_mask: torch.Tensor | None,
349-
past_key_values: Cache | None = None,
352+
encoder_key_states: torch.Tensor | None = None,
353+
encoder_value_states: torch.Tensor | None = None,
350354
**kwargs: Unpack[FlashAttentionKwargs],
351355
) -> tuple[torch.Tensor, torch.Tensor | None, tuple[torch.Tensor] | None]:
352356
# The code in this function is adapted from Gemma4TextAttention. ** The modified parts are clearly indicated **
@@ -371,11 +375,9 @@ def forward(
371375
value_states = self.v_norm(value_states)
372376
value_states = value_states.transpose(1, 2)
373377

374-
if past_key_values is not None:
375-
# CHANGED: instead of calling `past_key_values.update()` which updates the KV cache in-place and returns
376-
# the full KV states, we first obtain the encoder cache contents, and then append the current KV states.
377-
encoder_key_states = past_key_values.layers[self.layer_idx].keys
378-
encoder_value_states = past_key_values.layers[self.layer_idx].values
378+
if encoder_key_states is not None:
379+
# CHANGED: the encoder KV states are passed as plain tensors (extracted from the encoder cache by the
380+
# calling model) so they survive gradient checkpointing; the canvas KV states are appended to them.
379381
key_states = torch.cat([encoder_key_states, key_states], dim=2)
380382
value_states = torch.cat([encoder_value_states, value_states], dim=2)
381383
# CHANGED: removed the `if self.store_full_length_kv` branch
@@ -478,11 +480,11 @@ def forward(
478480
position_ids: torch.LongTensor | None = None,
479481
past_key_values: Cache | None = None,
480482
**kwargs,
481-
) -> torch.Tensor:
483+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
482484
residual = hidden_states
483485

484486
hidden_states = self.input_layernorm(hidden_states)
485-
hidden_states, _ = self.self_attn(
487+
hidden_states, _, key_states, value_states = self.self_attn(
486488
hidden_states=hidden_states,
487489
position_embeddings=position_embeddings,
488490
attention_mask=attention_mask,
@@ -514,7 +516,7 @@ def forward(
514516
hidden_states = residual + hidden_states
515517

516518
hidden_states *= self.layer_scalar
517-
return hidden_states
519+
return hidden_states, key_states, value_states
518520

519521

520522
class DiffusionGemmaDecoderTextLayer(Gemma4TextDecoderLayer):
@@ -551,7 +553,8 @@ def forward(
551553
position_embeddings: torch.Tensor = None,
552554
attention_mask: torch.Tensor | None = None,
553555
position_ids: torch.LongTensor | None = None,
554-
past_key_values: Cache | None = None,
556+
encoder_key_states: torch.Tensor | None = None,
557+
encoder_value_states: torch.Tensor | None = None,
555558
**kwargs,
556559
) -> torch.Tensor:
557560
residual = hidden_states
@@ -562,7 +565,8 @@ def forward(
562565
position_embeddings=position_embeddings,
563566
attention_mask=attention_mask,
564567
position_ids=position_ids,
565-
past_key_values=past_key_values,
568+
encoder_key_states=encoder_key_states,
569+
encoder_value_states=encoder_value_states,
566570
**kwargs,
567571
)
568572
hidden_states = self.post_attention_layernorm(hidden_states)
@@ -644,7 +648,7 @@ def forward(self, inputs_embeds, self_conditioning_signal: torch.Tensor) -> torc
644648
class DiffusionGemmaPreTrainedModel(PreTrainedModel):
645649
config: DiffusionGemmaConfig
646650
base_model_prefix = "model"
647-
supports_gradient_checkpointing = False
651+
supports_gradient_checkpointing = True
648652
_no_split_modules = [
649653
"DiffusionGemmaDecoderTextLayer",
650654
"DiffusionGemmaEncoderTextLayer",
@@ -786,14 +790,20 @@ def forward(
786790

787791
# decoder layers
788792
for i, encoder_layer in enumerate(self.layers[: self.config.num_hidden_layers]):
789-
hidden_states = encoder_layer(
793+
# Under gradient checkpointing the layer runs cache-less and the cache write happens outside the
794+
# checkpointed call instead: the returned KV states, as checkpoint outputs, keep the gradient path
795+
# from the decoder open (an in-layer write would be lost and double-applied on recomputation).
796+
checkpointing = encoder_layer.gradient_checkpointing and self.training
797+
hidden_states, key_states, value_states = encoder_layer(
790798
hidden_states,
791799
position_embeddings=position_embeddings[self.config.layer_types[i]],
792800
attention_mask=causal_mask_mapping[self.config.layer_types[i]],
793801
position_ids=position_ids,
794-
past_key_values=past_key_values,
802+
past_key_values=None if checkpointing else past_key_values,
795803
**kwargs,
796804
)
805+
if checkpointing:
806+
past_key_values.update(key_states, value_states, i)
797807

798808
hidden_states = self.norm(hidden_states)
799809

@@ -1097,7 +1107,8 @@ def forward(
10971107
position_embeddings=position_embeddings[self.text_config.layer_types[i]],
10981108
attention_mask=mask_mapping[self.text_config.layer_types[i]],
10991109
position_ids=decoder_position_ids,
1100-
past_key_values=past_key_values,
1110+
encoder_key_states=past_key_values.layers[i].keys if past_key_values is not None else None,
1111+
encoder_value_states=past_key_values.layers[i].values if past_key_values is not None else None,
11011112
**kwargs,
11021113
)
11031114

0 commit comments

Comments
 (0)