Skip to content

Commit 969fda9

Browse files
Gavin Zhangfacebook-github-bot
authored andcommitted
refactor the total norm computation in grad clipping in APS (#3243)
Summary: Refactored the previous code for applying gradient clipping across ddp and fsdp parameter. Added a new funciton _compute_total_norm() that takes in the replicated and sharded params provided in the gradientclippingOpitmizer class and computes the total gradient norm of the given parameter. Differential Revision: D79128843
1 parent 3bdf9f3 commit 969fda9

File tree

1 file changed

+82
-80
lines changed

1 file changed

+82
-80
lines changed

torchrec/optim/clipping.py

Lines changed: 82 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ def __init__(
5959
super().__init__(optimizer)
6060
self._clipping = clipping
6161
self._max_gradient = max_gradient
62-
self._norm_type = norm_type
62+
self._norm_type = float(norm_type)
6363
self._check_meta: bool = True
6464
self._enable_global_grad_clip = enable_global_grad_clip
6565
self._step_num = 0
@@ -124,7 +124,7 @@ def step(self, closure: Any = None) -> None:
124124
torch.nn.utils.clip_grad_norm_(
125125
replicate_params,
126126
self._max_gradient,
127-
norm_type=float(self._norm_type),
127+
norm_type=self._norm_type,
128128
)
129129
else:
130130
self.clip_grad_norm_()
@@ -135,98 +135,101 @@ def step(self, closure: Any = None) -> None:
135135
super().step(closure)
136136
self._step_num += 1
137137

138-
@torch.no_grad()
139138
def clip_grad_norm_(self) -> Optional[Union[float, torch.Tensor]]:
140139
"""Clip the gradient norm of all parameters."""
141-
max_norm = self._max_gradient
142-
norm_type = float(self._norm_type)
140+
141+
# converts self._norm_type to a float if it's a string. Used in the case where self._norm_type is 'inf'.
143142
all_grads = []
144143
total_grad_norm = None
145144

146-
# Process distributed parameters and gradients
147-
for pgs, dist_params in self._sharded_params.items():
148-
sharded_grads = [
149-
p.grad._local_tensor if isinstance(p.grad, DTensor) else p.grad
150-
for p in dist_params
151-
if p.grad is not None and p.grad.numel() > 0
152-
]
153-
if len(sharded_grads) == 0:
154-
continue
155-
all_grads.extend(sharded_grads)
156-
157-
sharded_grad_norm = _batch_cal_norm(
158-
sharded_grads,
159-
max_norm,
160-
norm_type,
161-
pgs,
162-
)
163-
total_grad_norm = (
164-
sharded_grad_norm
165-
if total_grad_norm is None
166-
else (
167-
torch.maximum(total_grad_norm, sharded_grad_norm)
168-
if norm_type == torch.inf
169-
else total_grad_norm + sharded_grad_norm
170-
)
171-
)
145+
sharded_params = self._sharded_params
146+
replicate_params = self._replicate_params
172147

173-
square_sharded_grad_norm = total_grad_norm if total_grad_norm is not None else 0
148+
# Process distributed parameters and gradients
149+
sharded_grads = {
150+
pgs: _get_grads(dist_params) for pgs, dist_params in sharded_params.items()
151+
}
152+
all_grads.extend(*sharded_grads.values())
174153

175154
# Process replicated parameters and gradients
176-
if self._replicate_params:
177-
replicated_grads = [
178-
p.grad._local_tensor if isinstance(p.grad, DTensor) else p.grad
179-
for p in self._replicate_params
180-
if p.grad is not None and p.grad.numel() > 0
181-
]
182-
all_grads.extend(replicated_grads)
183-
184-
replicated_grad_norm = _batch_cal_norm(
185-
replicated_grads,
186-
max_norm,
187-
norm_type,
188-
None,
189-
)
190-
total_grad_norm = (
191-
replicated_grad_norm
192-
if total_grad_norm is None
193-
else (
194-
torch.maximum(total_grad_norm, replicated_grad_norm)
195-
if norm_type == torch.inf
196-
else total_grad_norm + replicated_grad_norm
197-
)
198-
)
199-
square_replicated_grad_norm = replicated_grad_norm
200-
else:
201-
square_replicated_grad_norm = 0
202-
203-
global log_grad_norm
204-
if log_grad_norm:
205-
if total_grad_norm is not None and norm_type != torch.inf:
206-
# pyre-ignore[58]
207-
grad_norm = total_grad_norm ** (1.0 / norm_type)
208-
else:
209-
grad_norm = total_grad_norm
210-
211-
rank = dist.get_rank()
212-
logger.info(
213-
f"Clipping [rank={rank}, step={self._step_num}]: square_sharded_grad_norm = {square_sharded_grad_norm}, square_replicated_grad_norm = {square_replicated_grad_norm}, total_grad_norm = {grad_norm}"
214-
)
215-
216-
# Aggregation
217-
if total_grad_norm is None:
218-
return
155+
replicate_grads = _get_grads(replicate_params)
156+
all_grads.extend(replicate_grads)
157+
158+
total_grad_norm = _compute_total_norm(
159+
replicate_grads=replicate_grads,
160+
sharded_grads=sharded_grads,
161+
norm_type=self._norm_type,
162+
max_grad_norm=self._max_gradient,
163+
)
219164

220-
if norm_type != torch.inf:
221-
# pyre-ignore [58]: ** is not supported for operand types torch._tensor.Tensor and float.
222-
total_grad_norm = total_grad_norm ** (1.0 / norm_type)
223165
# pyre-ignore [58]: / is not supported for operand types float and Union[float, torch._tensor.Tensor].
224-
clip_coef = cast(torch.Tensor, max_norm / (total_grad_norm + 1e-6))
166+
clip_coef = cast(torch.Tensor, self._max_gradient / (total_grad_norm + 1e-6))
225167
clip_coef_clamped = torch.clamp(clip_coef, max=1.0)
226168
torch._foreach_mul_(all_grads, clip_coef_clamped)
227169
return total_grad_norm
228170

229171

172+
def _get_grads(
173+
param_list: List[torch.Tensor],
174+
) -> List[torch.Tensor]:
175+
"""Get the gradients of a list of parameters. Converts DTensors to local tensors if needed."""
176+
grads = [
177+
p.grad._local_tensor if isinstance(p.grad, DTensor) else p.grad
178+
for p in param_list
179+
if p.grad is not None and p.grad.numel() > 0
180+
]
181+
return grads
182+
183+
184+
def _compute_total_norm(
185+
replicate_grads: List[torch.Tensor],
186+
sharded_grads: Dict[Tuple[dist.ProcessGroup], List[torch.Tensor]],
187+
norm_type: float = 2.0, # can be a normal float, or torch.inf
188+
max_grad_norm: float = 1.0,
189+
) -> torch.Tensor:
190+
"""
191+
Given both replicate grads and sharded grads, compute the total norm of the gradients of the full replicate params and the
192+
full sharded param (parameters with a process group).
193+
194+
Args:
195+
replicate_grads (List[torch.Tensor]): list of gradients for replicate params
196+
sharded_grads (Dict[Tuple[dist.ProcessGroup], List[torch.Tensor]]): dict that maps each process group to a list of gradients for sharded params
197+
norm_type (float): type of the used p-norm. Can be torch.inf for infinity norm.
198+
max_grad_norm (float): max gradient norm.
199+
"""
200+
201+
## compute the norm |W|^p corresponding to all sharded params W
202+
sharded_grad_norm: torch.Tensor = torch.tensor(0.0)
203+
combine_norm_operator = torch.maximum if norm_type == torch.inf else torch.add
204+
205+
# We need to move sharded_grad_norm to the same device as the first shard so that we can do addition (or take max)
206+
# this is specifically for the case where sharded_grad_norm is 0, and replicate_grad_norm is not,
207+
# because by default torch.tensor(0.0) is on cpu, and replicate_grad_norm is on GPU. For MTIA
208+
# specifically, adding a tensor on cpu and a tensor on GPU will result in an error.
209+
for pgs, dist_params in sharded_grads.items():
210+
current_shard_norm = _batch_cal_norm(dist_params, max_grad_norm, norm_type, pgs)
211+
sharded_grad_norm = combine_norm_operator(
212+
sharded_grad_norm.to(current_shard_norm.device), current_shard_norm
213+
)
214+
# compute |W|^p corresponding to all replicate params W
215+
# Similar to the case above, we move replicate_grad_norm to the same device as sharded_grad_norm so that we can do addition.
216+
replicate_grad_norm: torch.Tensor = (
217+
_batch_cal_norm(replicate_grads, max_grad_norm, norm_type)
218+
if replicate_grads
219+
else torch.tensor(0.0)
220+
).to(sharded_grad_norm.device)
221+
222+
# In the p-norm case, we are given norms |W_sharded|^p and |W_replicate|^p. To compute the total norm, we need to
223+
# sum them and take the p-th root. In the inf-norm case, we are given max(|W_sharded|) and max(|W_replicate|).
224+
# To compute the total norm, we need to take max(max(|W_sharded|), max(|W_replicate|).
225+
combined_norm = combine_norm_operator(replicate_grad_norm, sharded_grad_norm)
226+
total_grad_norm = (
227+
combined_norm.pow(1.0 / norm_type) if norm_type != torch.inf else combined_norm
228+
)
229+
230+
return total_grad_norm
231+
232+
230233
def _batch_cal_norm(
231234
grad_list: List[torch.Tensor],
232235
max_norm: float,
@@ -236,7 +239,6 @@ def _batch_cal_norm(
236239
"""Helper function that calculates the norm of a list of gradients in batches. If process_groups
237240
are passed in, the norm will be aggregated across all ranks in the process group.
238241
"""
239-
240242
global use_64bit_grad_norm
241243
if use_64bit_grad_norm:
242244
grad_norms = torch.linalg.vector_norm(

0 commit comments

Comments
 (0)