Skip to content

Commit 136deda

Browse files
implement gsam in jax (#8)
1 parent 8921d51 commit 136deda

File tree

3 files changed

+613
-0
lines changed

3 files changed

+613
-0
lines changed
Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
# Copyright 2022 Big Vision Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
# pylint: disable=line-too-long
16+
r"""Pre-training ViT on ILSVRC-2012 with GSAM in https://arxiv.org/abs/2203.08065
17+
18+
Run training of a B/32 model:
19+
20+
big_vision.trainers.proj.gsam.train \
21+
--config big_vision/configs/proj/gsam/vit_i1k_gsam_no_aug.py \
22+
--workdir gs://[your_bucket]/big_vision/`date '+%m-%d_%H%M'`
23+
24+
"""
25+
26+
import big_vision.configs.common as bvcc
27+
from big_vision.configs.common_fewshot import get_fewshot_lsr
28+
import ml_collections as mlc
29+
30+
def get_config(arg=None):
31+
"""Config for training."""
32+
arg = bvcc.parse_arg(arg, variant='B/32', runlocal=False)
33+
config = mlc.ConfigDict()
34+
35+
config.dataset = 'imagenet2012'
36+
config.train_split = 'train[:99%]'
37+
config.cache_raw = not arg.runlocal # Needs up to 120GB of RAM!
38+
config.shuffle_buffer_size = 250_000 # Per host, so small-ish is ok.
39+
config.num_classes = 1000
40+
config.loss = 'sigmoid_xent'
41+
config.batch_size = 4096
42+
config.num_epochs = 300
43+
44+
pp_common = (
45+
'|value_range(-1, 1)'
46+
'|onehot(1000, key="{lbl}", key_result="labels")'
47+
'|keep("image", "labels")'
48+
)
49+
config.pp_train = (
50+
'decode_jpeg_and_inception_crop(224)|flip_lr|' +
51+
pp_common.format(lbl='label')
52+
)
53+
pp = 'decode|resize_small(256)|central_crop(224)' + pp_common
54+
55+
# Aggressive pre-fetching because our models here are small, so we not only
56+
# can afford it, but we also need it for the smallest models to not be
57+
# bottle-necked by the input pipeline. Play around with it for -L models tho.
58+
config.prefetch_to_host = 8
59+
config.prefetch_to_device = 4
60+
61+
config.log_training_steps = 50
62+
config.checkpoint_steps = 1000
63+
64+
# Model section
65+
config.model_name = 'vit'
66+
config.model = dict(
67+
variant=arg.variant,
68+
rep_size=False,
69+
pool_type='gap',
70+
)
71+
config.init_head_bias = -10.0
72+
73+
# Optimizer section
74+
config.grad_clip_norm = 1.0
75+
config.optax_name = 'scale_by_adam'
76+
config.optax = dict(mu_dtype='float32')
77+
# The modified AdaFactor we introduced in https://arxiv.org/abs/2106.04560
78+
# almost always behaves exactly like adam, but at a fraction of the memory
79+
# cost (specifically, adam_bf16 = +1.5M, adafactor = +0.5M), hence it is a
80+
# good idea to try it when you are memory-bound!
81+
# config.optax_name = 'big_vision.scale_by_adafactor'
82+
# A good flag to play with when hitting instabilities, is the following:
83+
# config.optax = dict(beta2_cap=0.95)
84+
85+
config.lr = 0.003
86+
config.wd = 0.001 # default is 0.0001; paper used 0.3, effective wd=0.3*lr
87+
config.schedule = dict(
88+
warmup_steps=10_000,
89+
decay_type='linear',
90+
linear_end=0.01,
91+
)
92+
93+
# GSAM settings.
94+
# Note: when rho_max=rho_min and alpha=0, GSAM reduces to SAM.
95+
config.gsam = dict(
96+
rho_max=0.6,
97+
rho_min=0.1,
98+
alpha=0.6,
99+
lr_max=config.get_ref('lr'),
100+
lr_min=config.schedule.get_ref('linear_end') * config.get_ref('lr'),
101+
)
102+
103+
# Eval section
104+
eval_common = dict(
105+
type='classification',
106+
dataset='imagenet2012',
107+
pp_fn=pp.format(lbl='label'),
108+
loss_name=config.loss,
109+
log_steps=2500, # Very fast O(seconds) so it's fine to run it often.
110+
)
111+
config.evals = {}
112+
config.evals.train = {**eval_common, 'split': 'train[:2%]'}
113+
config.evals.minival = {**eval_common, 'split': 'train[99%:]'}
114+
config.evals.val = {**eval_common, 'split': 'validation'}
115+
config.evals.v2 = {**eval_common, 'dataset': 'imagenet_v2', 'split': 'test'}
116+
117+
config.evals.real = {**eval_common}
118+
config.evals.real.dataset = 'imagenet2012_real'
119+
config.evals.real.split = 'validation'
120+
config.evals.real.pp_fn = pp.format(lbl='real_label')
121+
122+
config.fewshot = get_fewshot_lsr(runlocal=arg.runlocal)
123+
config.fewshot.log_steps = 10_000
124+
125+
# Make a few things much smaller for quick local debugging testruns.
126+
if arg.runlocal:
127+
config.shuffle_buffer_size = 10
128+
config.batch_size = 8
129+
config.minival.split = 'train[:16]'
130+
config.val.split = 'validation[:16]'
131+
config.real.split = 'validation[:16]'
132+
config.v2.split = 'test[:16]'
133+
134+
return config

big_vision/trainers/proj/gsam/gsam.py

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
'''This file provides jax implementation of GSAM.'''
2+
3+
import jax
4+
import jax.numpy as jnp
5+
6+
def dual_vector(y):
7+
"""Returns the solution of max_x y^T x s.t. ||x||_2 <= 1.
8+
Args:
9+
y: A pytree of numpy ndarray, vector y in the equation above.
10+
"""
11+
gradient_norm = jnp.sqrt(sum(
12+
jnp.sum(jnp.square(e)) for e in jax.tree_util.tree_leaves(y)))
13+
normalized_gradient = jax.tree_map(lambda x: x / gradient_norm, y)
14+
return normalized_gradient, gradient_norm
15+
16+
def gsam_gradient(loss_fn, params, inputs, targets,
17+
rho_max, rho_min, alpha, lr, lr_max, lr_min, eps=1e-12,
18+
adaptive_perturbation=False, minimize_fp=True):
19+
"""
20+
Get the GSAM gradient (https://openreview.net/pdf?id=edONMAnhLu-).
21+
Args:
22+
loss_fn: the loss function.
23+
params: the model weights.
24+
inputs: the inputs to the loss function.
25+
targets: the targets to the loss function.
26+
rho_max: the maximum rho value for perturbation of weights.
27+
rho_min: the minimum rho value for perturbation of weights.
28+
alpha: the alpha value for the rho schedule, see Algorithm 1 in the paper.
29+
lr: current learning rate.
30+
lr_max: the maximum learning rate.
31+
lr_min: the minimum learning rate.
32+
eps: the epsilon value for numerical stability.
33+
adaptive_perturbation: if False, same perturbation as SAM,
34+
treat all parameters as a single vector,
35+
perturbation norm is calculated as the norm of the whole vector;
36+
If True, perturbation norm is proportional to parameter norm,
37+
this stabilizes training when different layers have weights
38+
of different scales.
39+
Emprically, setting it to True can handle 10x larger rho than
40+
setting it to False.
41+
minimize_fp: if True, min(f_p, h), original GSAM;
42+
if False, min(f, h), where f is the clean loss.
43+
f_p is the perturbed loss, h is the surrogate gap.
44+
If True, training dynamics is closer to SAM than conventional training,
45+
you might observe several loss spikes during training.
46+
If False, the training dynamics is closer to conventional training,
47+
and is often more stable (fewer loss spikes) during training.
48+
Returns:
49+
l_clean: the loss function value.
50+
g_gsam: the GSAM gradient. g_gsam is not averaged across workers,
51+
need to call "jax.lax.pmean" to average.
52+
53+
Note:
54+
Setting `rho_max=rho_min` and `alpha=0` reduces GSAM to SAM.
55+
"""
56+
l_clean, g_clean = jax.value_and_grad(loss_fn)(params, inputs, targets)
57+
g_clean_normalized, g_clean_length = dual_vector(g_clean)
58+
59+
if lr_max == lr_min:
60+
sam_rho = rho_max
61+
else:
62+
sam_rho = rho_min + (rho_max - rho_min) * (lr - lr_min) / (lr_max - lr_min)
63+
64+
# Per-worker perturbation.
65+
if adaptive_perturbation:
66+
param_sam = jax.tree_map(lambda a, b: a + \
67+
jnp.abs(a) * sam_rho * b / (g_clean_length + eps), params, g_clean)
68+
else:
69+
param_sam = jax.tree_map(lambda a, b: a + \
70+
sam_rho * b / (g_clean_length + eps), params, g_clean)
71+
72+
# Get gradients at perturbed weights.
73+
_, g_robust = jax.value_and_grad(loss_fn)(param_sam, inputs, targets)
74+
75+
# Decompose gradients.
76+
g_clean_flatten, _ = jax.tree_util.tree_flatten(g_clean)
77+
g_robust_flatten, _ = jax.tree_util.tree_flatten(g_robust)
78+
79+
if minimize_fp:
80+
# Decompose g_clean onto parallel and vertical to g_robust.
81+
g_robust_normalized, _ = dual_vector(g_robust)
82+
g_robust_normalized_flatten, _ = jax.tree_util.tree_flatten(
83+
g_robust_normalized)
84+
85+
g_clean_projection_norm = sum(jnp.vdot(p, q) for (p,q) in
86+
zip(g_robust_normalized_flatten, g_clean_flatten))
87+
g_clean_residual = jax.tree_map(lambda a, b:
88+
a - g_clean_projection_norm * b, g_clean, g_robust_normalized)
89+
90+
# Get GSAM gradient.
91+
g_gsam = jax.tree_map(lambda a, b: a - b * alpha,
92+
g_robust, g_clean_residual)
93+
else:
94+
# Decompose g_robust onto parallel and vertical to g_clean.
95+
g_clean_normalized, g_clean_length = dual_vector(g_clean)
96+
g_clean_normalized_flatten, _ = jax.tree_util.tree_flatten(
97+
g_clean_normalized)
98+
99+
g_robust_projection_norm = sum(jnp.vdot(p, q) for (p,q) in
100+
zip(g_clean_normalized_flatten, g_robust_flatten))
101+
g_robust_residual = jax.tree_map(lambda a, b:
102+
a - g_robust_projection_norm * b, g_robust, g_clean_normalized)
103+
104+
# Get GSAM gradient.
105+
g_gsam = jax.tree_map(lambda a, b: a + b * alpha,
106+
g_clean, g_robust_residual)
107+
108+
# Always return the clean loss (rather than the perturbed loss).
109+
return l_clean, g_gsam

0 commit comments

Comments
 (0)