|
| 1 | +from __future__ import annotations |
| 2 | +import collections |
| 3 | +from typing import Tuple, Callable, Any, Dict, Iterator, List, Optional |
| 4 | + |
| 5 | +from absl import logging |
| 6 | +import torch |
| 7 | +from torch.optim.optimizer import Optimizer |
| 8 | + |
| 9 | +import torch.distributed.nn as dist_nn |
| 10 | +from torch.optim.lr_scheduler import CosineAnnealingLR |
| 11 | +from torch.optim.lr_scheduler import LinearLR |
| 12 | +from torch.optim.lr_scheduler import SequentialLR |
| 13 | + |
| 14 | +from algoperf import spec |
| 15 | +from algoperf.pytorch_utils import pytorch_setup |
| 16 | + |
| 17 | +USE_PYTORCH_DDP = pytorch_setup()[0] |
| 18 | + |
| 19 | +# default Lion parameters |
| 20 | +HPARAMS = { |
| 21 | + "dropout_rate": 0.1, |
| 22 | + "learning_rate": 2e-4, |
| 23 | + "one_minus_beta1": 0.05, |
| 24 | + "beta2": 0.98, |
| 25 | + "weight_decay": 0.5, |
| 26 | + "warmup_factor": 0.02 |
| 27 | +} |
| 28 | +HPARAMS = collections.namedtuple('Hyperparameters', HPARAMS.keys())(**HPARAMS) |
| 29 | + |
| 30 | +# Modified from https://github.com/google/automl/blob/master/lion/lion_pytorch.py. |
| 31 | +class Lion(Optimizer): |
| 32 | + def __init__( |
| 33 | + self, |
| 34 | + params, |
| 35 | + lr: float = 1e-4, |
| 36 | + betas: Tuple[float, float] = (0.9, 0.99), |
| 37 | + weight_decay: float = 0.0, |
| 38 | + ): |
| 39 | + if not 0.0 <= lr: |
| 40 | + raise ValueError('Invalid learning rate: {}'.format(lr)) |
| 41 | + if not 0.0 <= betas[0] < 1.0: |
| 42 | + raise ValueError('Invalid beta parameter at index 0: {}'.format(betas[0])) |
| 43 | + if not 0.0 <= betas[1] < 1.0: |
| 44 | + raise ValueError('Invalid beta parameter at index 1: {}'.format(betas[1])) |
| 45 | + defaults = dict(lr=lr, betas=betas, weight_decay=weight_decay) |
| 46 | + super().__init__(params, defaults) |
| 47 | + |
| 48 | + @torch.no_grad() |
| 49 | + def step(self, closure=None): |
| 50 | + """Performs a single optimization step. |
| 51 | +
|
| 52 | + Args: |
| 53 | + closure (callable, optional): A closure that reevaluates the model |
| 54 | + and returns the loss. |
| 55 | +
|
| 56 | + Returns: |
| 57 | + the loss. |
| 58 | + """ |
| 59 | + loss = None |
| 60 | + if closure is not None: |
| 61 | + with torch.enable_grad(): |
| 62 | + loss = closure() |
| 63 | + |
| 64 | + for group in self.param_groups: |
| 65 | + for p in group['params']: |
| 66 | + if p.grad is None: |
| 67 | + continue |
| 68 | + |
| 69 | + # Perform stepweight decay |
| 70 | + p.data.mul_(1 - group['lr'] * group['weight_decay']) |
| 71 | + |
| 72 | + grad = p.grad |
| 73 | + state = self.state[p] |
| 74 | + # State initialization |
| 75 | + if len(state) == 0: |
| 76 | + # Exponential moving average of gradient values |
| 77 | + state['exp_avg'] = torch.zeros_like(p) |
| 78 | + |
| 79 | + exp_avg = state['exp_avg'] |
| 80 | + beta1, beta2 = group['betas'] |
| 81 | + |
| 82 | + # Weight update |
| 83 | + update = exp_avg * beta1 + grad * (1 - beta1) |
| 84 | + |
| 85 | + p.add_(update.sign_(), alpha=-group['lr']) |
| 86 | + |
| 87 | + # Decay the momentum running average coefficient |
| 88 | + exp_avg.mul_(beta2).add_(grad, alpha=1 - beta2) |
| 89 | + |
| 90 | + return loss |
| 91 | + |
| 92 | + |
| 93 | +def init_optimizer_state(workload: spec.Workload, |
| 94 | + model_params: spec.ParameterContainer, |
| 95 | + model_state: spec.ModelAuxiliaryState, |
| 96 | + hyperparameters: spec.Hyperparameters, |
| 97 | + rng: spec.RandomState) -> spec.OptimizerState: |
| 98 | + """Creates a Lion optimizer and a learning rate schedule.""" |
| 99 | + del model_state |
| 100 | + del rng |
| 101 | + del hyperparameters |
| 102 | + |
| 103 | + hyperparameters = HPARAMS |
| 104 | + |
| 105 | + optimizer_state = { |
| 106 | + 'optimizer': |
| 107 | + Lion( |
| 108 | + model_params.parameters(), |
| 109 | + lr=HPARAMS.learning_rate, |
| 110 | + betas=(1.0 - HPARAMS.one_minus_beta1, |
| 111 | + HPARAMS.beta2), |
| 112 | + weight_decay=HPARAMS.weight_decay) |
| 113 | + } |
| 114 | + |
| 115 | + def pytorch_cosine_warmup(step_hint: int, hyperparameters, optimizer): |
| 116 | + warmup_steps = int(hyperparameters.warmup_factor * step_hint) |
| 117 | + warmup = LinearLR( |
| 118 | + optimizer, start_factor=1e-10, end_factor=1., total_iters=warmup_steps) |
| 119 | + cosine_steps = max(step_hint - warmup_steps, 1) |
| 120 | + cosine_decay = CosineAnnealingLR(optimizer, T_max=cosine_steps) |
| 121 | + return SequentialLR( |
| 122 | + optimizer, schedulers=[warmup, cosine_decay], milestones=[warmup_steps]) |
| 123 | + |
| 124 | + optimizer_state['scheduler'] = pytorch_cosine_warmup( |
| 125 | + workload.step_hint, HPARAMS, optimizer_state['optimizer']) |
| 126 | + optimizer_state['hyperparameters'] = hyperparameters |
| 127 | + |
| 128 | + return optimizer_state |
| 129 | + |
| 130 | + |
| 131 | +def update_params( |
| 132 | + workload: spec.Workload, |
| 133 | + current_param_container: spec.ParameterContainer, |
| 134 | + current_params_types: spec.ParameterTypeTree, |
| 135 | + model_state: spec.ModelAuxiliaryState, |
| 136 | + hyperparameters: spec.Hyperparameters, |
| 137 | + batch: Dict[str, spec.Tensor], |
| 138 | + loss_type: spec.LossType, |
| 139 | + optimizer_state: spec.OptimizerState, |
| 140 | + eval_results: List[Tuple[int, float]], |
| 141 | + global_step: int, |
| 142 | + rng: spec.RandomState, |
| 143 | + train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn: |
| 144 | + """Return (updated_optimizer_state, updated_params, updated_model_state).""" |
| 145 | + del current_params_types |
| 146 | + del loss_type |
| 147 | + del train_state |
| 148 | + del eval_results |
| 149 | + del hyperparameters |
| 150 | + |
| 151 | + hyperparameters = HPARAMS |
| 152 | + |
| 153 | + current_model = current_param_container |
| 154 | + current_model.train() |
| 155 | + optimizer_state['optimizer'].zero_grad() |
| 156 | + |
| 157 | + logits_batch, new_model_state = workload.model_fn( |
| 158 | + params=current_model, |
| 159 | + augmented_and_preprocessed_input_batch=batch, |
| 160 | + model_state=model_state, |
| 161 | + mode=spec.ForwardPassMode.TRAIN, |
| 162 | + rng=rng, |
| 163 | + update_batch_norm=True) |
| 164 | + |
| 165 | + label_smoothing = ( |
| 166 | + hyperparameters.label_smoothing if hasattr(HPARAMS, |
| 167 | + 'label_smoothing') else 0.0) |
| 168 | + if hasattr(hyperparameters, 'grad_clip'): |
| 169 | + grad_clip = hyperparameters.grad_clip |
| 170 | + else: |
| 171 | + grad_clip = None |
| 172 | + |
| 173 | + loss_dict = workload.loss_fn( |
| 174 | + label_batch=batch['targets'], |
| 175 | + logits_batch=logits_batch, |
| 176 | + mask_batch=batch.get('weights'), |
| 177 | + label_smoothing=label_smoothing) |
| 178 | + summed_loss = loss_dict['summed'] |
| 179 | + n_valid_examples = loss_dict['n_valid_examples'] |
| 180 | + if USE_PYTORCH_DDP: |
| 181 | + # Use dist_nn.all_reduce to ensure correct loss and gradient scaling. |
| 182 | + summed_loss = dist_nn.all_reduce(summed_loss) |
| 183 | + n_valid_examples = dist_nn.all_reduce(n_valid_examples) |
| 184 | + loss = summed_loss / n_valid_examples |
| 185 | + |
| 186 | + loss.backward() |
| 187 | + |
| 188 | + if grad_clip is not None: |
| 189 | + torch.nn.utils.clip_grad_norm_( |
| 190 | + current_model.parameters(), max_norm=grad_clip) |
| 191 | + optimizer_state['optimizer'].step() |
| 192 | + optimizer_state['scheduler'].step() |
| 193 | + |
| 194 | + # Log training metrics - loss, grad_norm, batch_size. |
| 195 | + if global_step <= 100 or global_step % 500 == 0: |
| 196 | + with torch.no_grad(): |
| 197 | + parameters = [p for p in current_model.parameters() if p.grad is not None] |
| 198 | + grad_norm = torch.norm( |
| 199 | + torch.stack([torch.norm(p.grad.detach(), 2) for p in parameters]), 2) |
| 200 | + if workload.metrics_logger is not None: |
| 201 | + workload.metrics_logger.append_scalar_metrics( |
| 202 | + { |
| 203 | + 'loss': loss.item(), |
| 204 | + 'grad_norm': grad_norm.item(), |
| 205 | + }, global_step) |
| 206 | + logging.info('%d) loss = %0.3f, grad_norm = %0.3f', |
| 207 | + global_step, |
| 208 | + loss.item(), |
| 209 | + grad_norm.item()) |
| 210 | + |
| 211 | + return (optimizer_state, current_param_container, new_model_state) |
| 212 | + |
| 213 | + |
| 214 | +def prepare_for_eval(workload: spec.Workload, |
| 215 | + current_param_container: spec.ParameterContainer, |
| 216 | + current_params_types: spec.ParameterTypeTree, |
| 217 | + model_state: spec.ModelAuxiliaryState, |
| 218 | + hyperparameters: spec.Hyperparameters, |
| 219 | + loss_type: spec.LossType, |
| 220 | + optimizer_state: spec.OptimizerState, |
| 221 | + eval_results: List[Tuple[int, float]], |
| 222 | + global_step: int, |
| 223 | + rng: spec.RandomState) -> spec.UpdateReturn: |
| 224 | + """Return (updated_optimizer_state, updated_params).""" |
| 225 | + del workload |
| 226 | + del hyperparameters |
| 227 | + del current_params_types |
| 228 | + del loss_type |
| 229 | + del eval_results |
| 230 | + del global_step |
| 231 | + del rng |
| 232 | + return (optimizer_state, current_param_container, model_state) |
| 233 | + |
| 234 | + |
| 235 | +def get_batch_size(workload_name): |
| 236 | + # Return the global batch size. |
| 237 | + if hasattr(HPARAMS, "batch_size"): |
| 238 | + return HPARAMS.batch_size |
| 239 | + if workload_name == 'criteo1tb': |
| 240 | + return 262_144 |
| 241 | + elif workload_name == 'fastmri': |
| 242 | + return 32 |
| 243 | + elif workload_name == 'imagenet_resnet': |
| 244 | + return 1024 |
| 245 | + elif workload_name == 'imagenet_resnet_silu': |
| 246 | + return 512 |
| 247 | + elif workload_name == 'imagenet_resnet_gelu': |
| 248 | + return 512 |
| 249 | + elif workload_name == 'imagenet_vit': |
| 250 | + return 1024 |
| 251 | + elif workload_name == 'librispeech_conformer': |
| 252 | + return 256 |
| 253 | + elif workload_name == 'librispeech_deepspeech': |
| 254 | + return 256 |
| 255 | + elif workload_name == 'ogbg': |
| 256 | + return 512 |
| 257 | + elif workload_name == 'wmt': |
| 258 | + return 128 |
| 259 | + elif workload_name == 'mnist': |
| 260 | + return 16 |
| 261 | + else: |
| 262 | + raise ValueError(f'Unsupported workload name: {workload_name}.') |
| 263 | + |
| 264 | + |
| 265 | +def data_selection(workload: spec.Workload, |
| 266 | + input_queue: Iterator[Dict[str, spec.Tensor]], |
| 267 | + optimizer_state: spec.OptimizerState, |
| 268 | + current_param_container: spec.ParameterContainer, |
| 269 | + model_state: spec.ModelAuxiliaryState, |
| 270 | + hyperparameters: spec.Hyperparameters, |
| 271 | + global_step: int, |
| 272 | + rng: spec.RandomState) -> Dict[str, spec.Tensor]: |
| 273 | + """Select data from the infinitely repeating, pre-shuffled input queue. |
| 274 | + Each element of the queue is a batch of training examples and labels. |
| 275 | + """ |
| 276 | + del workload |
| 277 | + del optimizer_state |
| 278 | + del current_param_container |
| 279 | + del model_state |
| 280 | + del hyperparameters |
| 281 | + del global_step |
| 282 | + del rng |
| 283 | + batch = next(input_queue) |
| 284 | + return batch |
0 commit comments