Skip to content

refactor the total norm computation in grad clipping in APS #3243

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
181 changes: 95 additions & 86 deletions torchrec/optim/clipping.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def __init__(
super().__init__(optimizer)
self._clipping = clipping
self._max_gradient = max_gradient
self._norm_type = norm_type
self._norm_type = float(norm_type)
self._check_meta: bool = True
self._enable_global_grad_clip = enable_global_grad_clip
self._step_num = 0
Expand Down Expand Up @@ -122,121 +122,130 @@ def step(self, closure: Any = None) -> None:
for p in self._replicate_params
]
torch.nn.utils.clip_grad_norm_(
replicate_params,
self._max_gradient,
norm_type=float(self._norm_type),
parameters=replicate_params,
max_norm=self._max_gradient,
norm_type=self._norm_type,
)
else:
self.clip_grad_norm_()

elif self._clipping == GradientClipping.VALUE:
torch.nn.utils.clip_grad_value_(self._replicate_params, self._max_gradient)
torch.nn.utils.clip_grad_value_(
parameters=self._replicate_params, clip_value=self._max_gradient
)

super().step(closure)
self._step_num += 1

@torch.no_grad()
def clip_grad_norm_(self) -> Optional[Union[float, torch.Tensor]]:
"""Clip the gradient norm of all parameters."""
max_norm = self._max_gradient
norm_type = float(self._norm_type)

# converts self._norm_type to a float if it's a string. Used in the case where self._norm_type is 'inf'.
all_grads = []
total_grad_norm = None
sharded_params = self._sharded_params
replicate_params = self._replicate_params

# Process distributed parameters and gradients
for pgs, dist_params in self._sharded_params.items():
sharded_grads = [
p.grad._local_tensor if isinstance(p.grad, DTensor) else p.grad
for p in dist_params
if p.grad is not None and p.grad.numel() > 0
]
if len(sharded_grads) == 0:
continue
all_grads.extend(sharded_grads)

sharded_grad_norm = _batch_cal_norm(
sharded_grads,
max_norm,
norm_type,
pgs,
)
total_grad_norm = (
sharded_grad_norm
if total_grad_norm is None
else (
torch.maximum(total_grad_norm, sharded_grad_norm)
if norm_type == torch.inf
else total_grad_norm + sharded_grad_norm
)
)

square_sharded_grad_norm = total_grad_norm if total_grad_norm is not None else 0
sharded_grads = {
pgs: _get_grads(dist_params) for pgs, dist_params in sharded_params.items()
}
all_grads.extend(*sharded_grads.values())

# Process replicated parameters and gradients
if self._replicate_params:
replicated_grads = [
p.grad._local_tensor if isinstance(p.grad, DTensor) else p.grad
for p in self._replicate_params
if p.grad is not None and p.grad.numel() > 0
]
all_grads.extend(replicated_grads)

replicated_grad_norm = _batch_cal_norm(
replicated_grads,
max_norm,
norm_type,
None,
)
total_grad_norm = (
replicated_grad_norm
if total_grad_norm is None
else (
torch.maximum(total_grad_norm, replicated_grad_norm)
if norm_type == torch.inf
else total_grad_norm + replicated_grad_norm
)
)
square_replicated_grad_norm = replicated_grad_norm
else:
square_replicated_grad_norm = 0

global log_grad_norm
if log_grad_norm:
if total_grad_norm is not None and norm_type != torch.inf:
# pyre-ignore[58]
grad_norm = total_grad_norm ** (1.0 / norm_type)
else:
grad_norm = total_grad_norm

rank = dist.get_rank()
logger.info(
f"Clipping [rank={rank}, step={self._step_num}]: square_sharded_grad_norm = {square_sharded_grad_norm}, square_replicated_grad_norm = {square_replicated_grad_norm}, total_grad_norm = {grad_norm}"
)

# Aggregation
if total_grad_norm is None:
return
replicate_grads = _get_grads(replicate_params)
all_grads.extend(replicate_grads)

total_grad_norm = _compute_total_norm(
replicate_grads=replicate_grads,
sharded_grads=sharded_grads,
norm_type=self._norm_type,
max_grad_norm=self._max_gradient,
)

if norm_type != torch.inf:
# pyre-ignore [58]: ** is not supported for operand types torch._tensor.Tensor and float.
total_grad_norm = total_grad_norm ** (1.0 / norm_type)
# pyre-ignore [58]: / is not supported for operand types float and Union[float, torch._tensor.Tensor].
clip_coef = cast(torch.Tensor, max_norm / (total_grad_norm + 1e-6))
clip_coef = cast(torch.Tensor, self._max_gradient / (total_grad_norm + 1e-6))
clip_coef_clamped = torch.clamp(clip_coef, max=1.0)
torch._foreach_mul_(all_grads, clip_coef_clamped)
return total_grad_norm


def _get_grads(
param_list: List[torch.Tensor],
) -> List[torch.Tensor]:
"""Get the gradients of a list of parameters. Converts DTensors to local tensors if needed."""
grads = [
p.grad._local_tensor if isinstance(p.grad, DTensor) else p.grad
for p in param_list
if p.grad is not None and p.grad.numel() > 0
]
return grads


def _compute_total_norm(
replicate_grads: List[torch.Tensor],
sharded_grads: Dict[Tuple[dist.ProcessGroup], List[torch.Tensor]],
norm_type: float = 2.0, # can be a normal float, or torch.inf
max_grad_norm: float = 1.0,
) -> torch.Tensor:
"""
Given both replicate grads and sharded grads, compute the total norm of the gradients of the full replicate params and the
full sharded param (parameters with a process group).

Args:
replicate_grads (List[torch.Tensor]): list of gradients for replicate params
sharded_grads (Dict[Tuple[dist.ProcessGroup], List[torch.Tensor]]): dict that maps each process group to a list of gradients for sharded params
norm_type (float): type of the used p-norm. Can be torch.inf for infinity norm.
max_grad_norm (float): max gradient norm.
"""

## compute the norm |W|^p corresponding to all sharded params W
sharded_grad_norm: torch.Tensor = torch.tensor(0.0)
combine_norm_operator = torch.maximum if norm_type == torch.inf else torch.add

# We need to move sharded_grad_norm to the same device as the first shard so that we can do addition (or take max)
# this is specifically for the case where sharded_grad_norm is 0, and replicate_grad_norm is not,
# because by default torch.tensor(0.0) is on cpu, and replicate_grad_norm is on GPU. For MTIA
# specifically, adding a tensor on cpu and a tensor on GPU will result in an error.
for pgs, dist_params in sharded_grads.items():
current_shard_norm = _batch_cal_norm(
grad_list=dist_params,
max_norm=max_grad_norm,
norm_type=norm_type,
process_groups=pgs,
)
sharded_grad_norm = combine_norm_operator(
sharded_grad_norm.to(current_shard_norm.device), current_shard_norm
)
# compute |W|^p corresponding to all replicate params W
# Similar to the case above, we move replicate_grad_norm to the same device as sharded_grad_norm so that we can do addition.
replicate_grad_norm: torch.Tensor = (
_batch_cal_norm(
grad_list=replicate_grads, max_norm=max_grad_norm, norm_type=norm_type
)
if replicate_grads
else torch.tensor(0.0)
).to(sharded_grad_norm.device)

# In the p-norm case, we are given norms |W_sharded|^p and |W_replicate|^p. To compute the total norm, we need to
# sum them and take the p-th root. In the inf-norm case, we are given max(|W_sharded|) and max(|W_replicate|).
# To compute the total norm, we need to take max(max(|W_sharded|), max(|W_replicate|).
combined_norm = combine_norm_operator(replicate_grad_norm, sharded_grad_norm)
total_grad_norm = (
combined_norm.pow(1.0 / norm_type) if norm_type != torch.inf else combined_norm
)

return total_grad_norm


def _batch_cal_norm(
grad_list: List[torch.Tensor],
max_norm: float,
norm_type: float = 2.0,
process_groups: Optional[Tuple[dist.ProcessGroup]] = None,
) -> torch.Tensor:
"""Helper function that calculates the norm of a list of gradients in batches. If process_groups
are passed in, the norm will be aggregated across all ranks in the process group.
"""Helper function that calculates the p-th power of the norm of a list of gradients in batches.
If process_groups are passed in, the norm will be aggregated across all ranks in the process group.
"""

global use_64bit_grad_norm
if use_64bit_grad_norm:
grad_norms = torch.linalg.vector_norm(
Expand Down
Loading