Skip to content

Commit 22b07c8

Browse files
committed
Add lion algorithm
1 parent fb2f492 commit 22b07c8

File tree

3 files changed

+284
-0
lines changed

3 files changed

+284
-0
lines changed

reference_algorithms/paper_baselines/lion/__init__.py

Whitespace-only changes.

reference_algorithms/paper_baselines/lion/pytorch/__init__.py

Whitespace-only changes.
Lines changed: 284 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,284 @@
1+
from __future__ import annotations
2+
import collections
3+
from typing import Tuple, Callable, Any, Dict, Iterator, List, Optional
4+
5+
from absl import logging
6+
import torch
7+
from torch.optim.optimizer import Optimizer
8+
9+
import torch.distributed.nn as dist_nn
10+
from torch.optim.lr_scheduler import CosineAnnealingLR
11+
from torch.optim.lr_scheduler import LinearLR
12+
from torch.optim.lr_scheduler import SequentialLR
13+
14+
from algoperf import spec
15+
from algoperf.pytorch_utils import pytorch_setup
16+
17+
USE_PYTORCH_DDP = pytorch_setup()[0]
18+
19+
# default Lion parameters
20+
HPARAMS = {
21+
"dropout_rate": 0.1,
22+
"learning_rate": 2e-4,
23+
"one_minus_beta1": 0.05,
24+
"beta2": 0.98,
25+
"weight_decay": 0.5,
26+
"warmup_factor": 0.02
27+
}
28+
HPARAMS = collections.namedtuple('Hyperparameters', HPARAMS.keys())(**HPARAMS)
29+
30+
# Modified from https://github.com/google/automl/blob/master/lion/lion_pytorch.py.
31+
class Lion(Optimizer):
32+
def __init__(
33+
self,
34+
params,
35+
lr: float = 1e-4,
36+
betas: Tuple[float, float] = (0.9, 0.99),
37+
weight_decay: float = 0.0,
38+
):
39+
if not 0.0 <= lr:
40+
raise ValueError('Invalid learning rate: {}'.format(lr))
41+
if not 0.0 <= betas[0] < 1.0:
42+
raise ValueError('Invalid beta parameter at index 0: {}'.format(betas[0]))
43+
if not 0.0 <= betas[1] < 1.0:
44+
raise ValueError('Invalid beta parameter at index 1: {}'.format(betas[1]))
45+
defaults = dict(lr=lr, betas=betas, weight_decay=weight_decay)
46+
super().__init__(params, defaults)
47+
48+
@torch.no_grad()
49+
def step(self, closure=None):
50+
"""Performs a single optimization step.
51+
52+
Args:
53+
closure (callable, optional): A closure that reevaluates the model
54+
and returns the loss.
55+
56+
Returns:
57+
the loss.
58+
"""
59+
loss = None
60+
if closure is not None:
61+
with torch.enable_grad():
62+
loss = closure()
63+
64+
for group in self.param_groups:
65+
for p in group['params']:
66+
if p.grad is None:
67+
continue
68+
69+
# Perform stepweight decay
70+
p.data.mul_(1 - group['lr'] * group['weight_decay'])
71+
72+
grad = p.grad
73+
state = self.state[p]
74+
# State initialization
75+
if len(state) == 0:
76+
# Exponential moving average of gradient values
77+
state['exp_avg'] = torch.zeros_like(p)
78+
79+
exp_avg = state['exp_avg']
80+
beta1, beta2 = group['betas']
81+
82+
# Weight update
83+
update = exp_avg * beta1 + grad * (1 - beta1)
84+
85+
p.add_(update.sign_(), alpha=-group['lr'])
86+
87+
# Decay the momentum running average coefficient
88+
exp_avg.mul_(beta2).add_(grad, alpha=1 - beta2)
89+
90+
return loss
91+
92+
93+
def init_optimizer_state(workload: spec.Workload,
94+
model_params: spec.ParameterContainer,
95+
model_state: spec.ModelAuxiliaryState,
96+
hyperparameters: spec.Hyperparameters,
97+
rng: spec.RandomState) -> spec.OptimizerState:
98+
"""Creates a Lion optimizer and a learning rate schedule."""
99+
del model_state
100+
del rng
101+
del hyperparameters
102+
103+
hyperparameters = HPARAMS
104+
105+
optimizer_state = {
106+
'optimizer':
107+
Lion(
108+
model_params.parameters(),
109+
lr=HPARAMS.learning_rate,
110+
betas=(1.0 - HPARAMS.one_minus_beta1,
111+
HPARAMS.beta2),
112+
weight_decay=HPARAMS.weight_decay)
113+
}
114+
115+
def pytorch_cosine_warmup(step_hint: int, hyperparameters, optimizer):
116+
warmup_steps = int(hyperparameters.warmup_factor * step_hint)
117+
warmup = LinearLR(
118+
optimizer, start_factor=1e-10, end_factor=1., total_iters=warmup_steps)
119+
cosine_steps = max(step_hint - warmup_steps, 1)
120+
cosine_decay = CosineAnnealingLR(optimizer, T_max=cosine_steps)
121+
return SequentialLR(
122+
optimizer, schedulers=[warmup, cosine_decay], milestones=[warmup_steps])
123+
124+
optimizer_state['scheduler'] = pytorch_cosine_warmup(
125+
workload.step_hint, HPARAMS, optimizer_state['optimizer'])
126+
optimizer_state['hyperparameters'] = hyperparameters
127+
128+
return optimizer_state
129+
130+
131+
def update_params(
132+
workload: spec.Workload,
133+
current_param_container: spec.ParameterContainer,
134+
current_params_types: spec.ParameterTypeTree,
135+
model_state: spec.ModelAuxiliaryState,
136+
hyperparameters: spec.Hyperparameters,
137+
batch: Dict[str, spec.Tensor],
138+
loss_type: spec.LossType,
139+
optimizer_state: spec.OptimizerState,
140+
eval_results: List[Tuple[int, float]],
141+
global_step: int,
142+
rng: spec.RandomState,
143+
train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn:
144+
"""Return (updated_optimizer_state, updated_params, updated_model_state)."""
145+
del current_params_types
146+
del loss_type
147+
del train_state
148+
del eval_results
149+
del hyperparameters
150+
151+
hyperparameters = HPARAMS
152+
153+
current_model = current_param_container
154+
current_model.train()
155+
optimizer_state['optimizer'].zero_grad()
156+
157+
logits_batch, new_model_state = workload.model_fn(
158+
params=current_model,
159+
augmented_and_preprocessed_input_batch=batch,
160+
model_state=model_state,
161+
mode=spec.ForwardPassMode.TRAIN,
162+
rng=rng,
163+
update_batch_norm=True)
164+
165+
label_smoothing = (
166+
hyperparameters.label_smoothing if hasattr(HPARAMS,
167+
'label_smoothing') else 0.0)
168+
if hasattr(hyperparameters, 'grad_clip'):
169+
grad_clip = hyperparameters.grad_clip
170+
else:
171+
grad_clip = None
172+
173+
loss_dict = workload.loss_fn(
174+
label_batch=batch['targets'],
175+
logits_batch=logits_batch,
176+
mask_batch=batch.get('weights'),
177+
label_smoothing=label_smoothing)
178+
summed_loss = loss_dict['summed']
179+
n_valid_examples = loss_dict['n_valid_examples']
180+
if USE_PYTORCH_DDP:
181+
# Use dist_nn.all_reduce to ensure correct loss and gradient scaling.
182+
summed_loss = dist_nn.all_reduce(summed_loss)
183+
n_valid_examples = dist_nn.all_reduce(n_valid_examples)
184+
loss = summed_loss / n_valid_examples
185+
186+
loss.backward()
187+
188+
if grad_clip is not None:
189+
torch.nn.utils.clip_grad_norm_(
190+
current_model.parameters(), max_norm=grad_clip)
191+
optimizer_state['optimizer'].step()
192+
optimizer_state['scheduler'].step()
193+
194+
# Log training metrics - loss, grad_norm, batch_size.
195+
if global_step <= 100 or global_step % 500 == 0:
196+
with torch.no_grad():
197+
parameters = [p for p in current_model.parameters() if p.grad is not None]
198+
grad_norm = torch.norm(
199+
torch.stack([torch.norm(p.grad.detach(), 2) for p in parameters]), 2)
200+
if workload.metrics_logger is not None:
201+
workload.metrics_logger.append_scalar_metrics(
202+
{
203+
'loss': loss.item(),
204+
'grad_norm': grad_norm.item(),
205+
}, global_step)
206+
logging.info('%d) loss = %0.3f, grad_norm = %0.3f',
207+
global_step,
208+
loss.item(),
209+
grad_norm.item())
210+
211+
return (optimizer_state, current_param_container, new_model_state)
212+
213+
214+
def prepare_for_eval(workload: spec.Workload,
215+
current_param_container: spec.ParameterContainer,
216+
current_params_types: spec.ParameterTypeTree,
217+
model_state: spec.ModelAuxiliaryState,
218+
hyperparameters: spec.Hyperparameters,
219+
loss_type: spec.LossType,
220+
optimizer_state: spec.OptimizerState,
221+
eval_results: List[Tuple[int, float]],
222+
global_step: int,
223+
rng: spec.RandomState) -> spec.UpdateReturn:
224+
"""Return (updated_optimizer_state, updated_params)."""
225+
del workload
226+
del hyperparameters
227+
del current_params_types
228+
del loss_type
229+
del eval_results
230+
del global_step
231+
del rng
232+
return (optimizer_state, current_param_container, model_state)
233+
234+
235+
def get_batch_size(workload_name):
236+
# Return the global batch size.
237+
if hasattr(HPARAMS, "batch_size"):
238+
return HPARAMS.batch_size
239+
if workload_name == 'criteo1tb':
240+
return 262_144
241+
elif workload_name == 'fastmri':
242+
return 32
243+
elif workload_name == 'imagenet_resnet':
244+
return 1024
245+
elif workload_name == 'imagenet_resnet_silu':
246+
return 512
247+
elif workload_name == 'imagenet_resnet_gelu':
248+
return 512
249+
elif workload_name == 'imagenet_vit':
250+
return 1024
251+
elif workload_name == 'librispeech_conformer':
252+
return 256
253+
elif workload_name == 'librispeech_deepspeech':
254+
return 256
255+
elif workload_name == 'ogbg':
256+
return 512
257+
elif workload_name == 'wmt':
258+
return 128
259+
elif workload_name == 'mnist':
260+
return 16
261+
else:
262+
raise ValueError(f'Unsupported workload name: {workload_name}.')
263+
264+
265+
def data_selection(workload: spec.Workload,
266+
input_queue: Iterator[Dict[str, spec.Tensor]],
267+
optimizer_state: spec.OptimizerState,
268+
current_param_container: spec.ParameterContainer,
269+
model_state: spec.ModelAuxiliaryState,
270+
hyperparameters: spec.Hyperparameters,
271+
global_step: int,
272+
rng: spec.RandomState) -> Dict[str, spec.Tensor]:
273+
"""Select data from the infinitely repeating, pre-shuffled input queue.
274+
Each element of the queue is a batch of training examples and labels.
275+
"""
276+
del workload
277+
del optimizer_state
278+
del current_param_container
279+
del model_state
280+
del hyperparameters
281+
del global_step
282+
del rng
283+
batch = next(input_queue)
284+
return batch

0 commit comments

Comments
 (0)