Skip to content

Commit e522d6a

Browse files
committed
refactor
1 parent e3b4b64 commit e522d6a

File tree

2 files changed

+5
-5
lines changed

2 files changed

+5
-5
lines changed

megatron/core/tensor_parallel/random.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
# repo: https://github.com/pytorch/pytorch
55

66
import contextlib
7-
import functools
87
import logging
98
from typing import Optional, Union
109

@@ -543,11 +542,9 @@ def backward(ctx, *args):
543542
return (None, None) + grads
544543

545544

546-
def checkpoint(function, distribute_saved_activations, *args, **kwargs):
545+
def checkpoint(function, distribute_saved_activations, *args):
547546
"""Checkpoint a model or part of the model.
548547
This has been directly copied from torch.utils.checkpoint."""
549-
if kwargs:
550-
function = functools.partial(function, **kwargs)
551548
return CheckpointFunction.apply(function, distribute_saved_activations, *args)
552549

553550

megatron/core/transformer/transformer_layer.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22

3+
import functools
34
import logging
45
import warnings
56
from abc import ABC
@@ -672,7 +673,9 @@ def _forward_mlp(self, hidden_states, inference_context=None, padding_mask=None)
672673
)
673674
else:
674675
mlp_output_with_bias = tensor_parallel.checkpoint(
675-
self.mlp, False, pre_mlp_layernorm_output, padding_mask=padding_mask
676+
functools.partial(self.mlp, padding_mask=padding_mask),
677+
False,
678+
pre_mlp_layernorm_output,
676679
)
677680
elif should_chunk_mlp_for_prefill:
678681
# Chunk input along sequence dimension

0 commit comments

Comments
 (0)