|
| 1 | +'''This file provides jax implementation of GSAM.''' |
| 2 | + |
| 3 | +import jax |
| 4 | +import jax.numpy as jnp |
| 5 | + |
| 6 | +def dual_vector(y): |
| 7 | + """Returns the solution of max_x y^T x s.t. ||x||_2 <= 1. |
| 8 | + Args: |
| 9 | + y: A pytree of numpy ndarray, vector y in the equation above. |
| 10 | + """ |
| 11 | + gradient_norm = jnp.sqrt(sum( |
| 12 | + jnp.sum(jnp.square(e)) for e in jax.tree_util.tree_leaves(y))) |
| 13 | + normalized_gradient = jax.tree_map(lambda x: x / gradient_norm, y) |
| 14 | + return normalized_gradient, gradient_norm |
| 15 | + |
| 16 | +def gsam_gradient(loss_fn, params, inputs, targets, |
| 17 | + rho_max, rho_min, alpha, lr, lr_max, lr_min, eps=1e-12, |
| 18 | + adaptive_perturbation=False, minimize_fp=True): |
| 19 | + """ |
| 20 | + Get the GSAM gradient (https://openreview.net/pdf?id=edONMAnhLu-). |
| 21 | + Args: |
| 22 | + loss_fn: the loss function. |
| 23 | + params: the model weights. |
| 24 | + inputs: the inputs to the loss function. |
| 25 | + targets: the targets to the loss function. |
| 26 | + rho_max: the maximum rho value for perturbation of weights. |
| 27 | + rho_min: the minimum rho value for perturbation of weights. |
| 28 | + alpha: the alpha value for the rho schedule, see Algorithm 1 in the paper. |
| 29 | + lr: current learning rate. |
| 30 | + lr_max: the maximum learning rate. |
| 31 | + lr_min: the minimum learning rate. |
| 32 | + eps: the epsilon value for numerical stability. |
| 33 | + adaptive_perturbation: if False, same perturbation as SAM, |
| 34 | + treat all parameters as a single vector, |
| 35 | + perturbation norm is calculated as the norm of the whole vector; |
| 36 | + If True, perturbation norm is proportional to parameter norm, |
| 37 | + this stabilizes training when different layers have weights |
| 38 | + of different scales. |
| 39 | + Emprically, setting it to True can handle 10x larger rho than |
| 40 | + setting it to False. |
| 41 | + minimize_fp: if True, min(f_p, h), original GSAM; |
| 42 | + if False, min(f, h), where f is the clean loss. |
| 43 | + f_p is the perturbed loss, h is the surrogate gap. |
| 44 | + If True, training dynamics is closer to SAM than conventional training, |
| 45 | + you might observe several loss spikes during training. |
| 46 | + If False, the training dynamics is closer to conventional training, |
| 47 | + and is often more stable (fewer loss spikes) during training. |
| 48 | + Returns: |
| 49 | + l_clean: the loss function value. |
| 50 | + g_gsam: the GSAM gradient. g_gsam is not averaged across workers, |
| 51 | + need to call "jax.lax.pmean" to average. |
| 52 | +
|
| 53 | + Note: |
| 54 | + Setting `rho_max=rho_min` and `alpha=0` reduces GSAM to SAM. |
| 55 | + """ |
| 56 | + l_clean, g_clean = jax.value_and_grad(loss_fn)(params, inputs, targets) |
| 57 | + g_clean_normalized, g_clean_length = dual_vector(g_clean) |
| 58 | + |
| 59 | + if lr_max == lr_min: |
| 60 | + sam_rho = rho_max |
| 61 | + else: |
| 62 | + sam_rho = rho_min + (rho_max - rho_min) * (lr - lr_min) / (lr_max - lr_min) |
| 63 | + |
| 64 | + # Per-worker perturbation. |
| 65 | + if adaptive_perturbation: |
| 66 | + param_sam = jax.tree_map(lambda a, b: a + \ |
| 67 | + jnp.abs(a) * sam_rho * b / (g_clean_length + eps), params, g_clean) |
| 68 | + else: |
| 69 | + param_sam = jax.tree_map(lambda a, b: a + \ |
| 70 | + sam_rho * b / (g_clean_length + eps), params, g_clean) |
| 71 | + |
| 72 | + # Get gradients at perturbed weights. |
| 73 | + _, g_robust = jax.value_and_grad(loss_fn)(param_sam, inputs, targets) |
| 74 | + |
| 75 | + # Decompose gradients. |
| 76 | + g_clean_flatten, _ = jax.tree_util.tree_flatten(g_clean) |
| 77 | + g_robust_flatten, _ = jax.tree_util.tree_flatten(g_robust) |
| 78 | + |
| 79 | + if minimize_fp: |
| 80 | + # Decompose g_clean onto parallel and vertical to g_robust. |
| 81 | + g_robust_normalized, _ = dual_vector(g_robust) |
| 82 | + g_robust_normalized_flatten, _ = jax.tree_util.tree_flatten( |
| 83 | + g_robust_normalized) |
| 84 | + |
| 85 | + g_clean_projection_norm = sum(jnp.vdot(p, q) for (p,q) in |
| 86 | + zip(g_robust_normalized_flatten, g_clean_flatten)) |
| 87 | + g_clean_residual = jax.tree_map(lambda a, b: |
| 88 | + a - g_clean_projection_norm * b, g_clean, g_robust_normalized) |
| 89 | + |
| 90 | + # Get GSAM gradient. |
| 91 | + g_gsam = jax.tree_map(lambda a, b: a - b * alpha, |
| 92 | + g_robust, g_clean_residual) |
| 93 | + else: |
| 94 | + # Decompose g_robust onto parallel and vertical to g_clean. |
| 95 | + g_clean_normalized, g_clean_length = dual_vector(g_clean) |
| 96 | + g_clean_normalized_flatten, _ = jax.tree_util.tree_flatten( |
| 97 | + g_clean_normalized) |
| 98 | + |
| 99 | + g_robust_projection_norm = sum(jnp.vdot(p, q) for (p,q) in |
| 100 | + zip(g_clean_normalized_flatten, g_robust_flatten)) |
| 101 | + g_robust_residual = jax.tree_map(lambda a, b: |
| 102 | + a - g_robust_projection_norm * b, g_robust, g_clean_normalized) |
| 103 | + |
| 104 | + # Get GSAM gradient. |
| 105 | + g_gsam = jax.tree_map(lambda a, b: a + b * alpha, |
| 106 | + g_clean, g_robust_residual) |
| 107 | + |
| 108 | + # Always return the clean loss (rather than the perturbed loss). |
| 109 | + return l_clean, g_gsam |
0 commit comments