Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions DOCUMENTATION.md
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,7 @@ def update_params(
batch: Dict[str, Tensor],
loss_type: LossType,
optimizer_state: OptimizerState,
train_state: Dict[str, Any],
eval_results: List[Tuple[int, float]],
global_step: int,
rng: RandomState
Expand All @@ -212,6 +213,7 @@ def update_params(
- The `loss_fn` produces a loss per example and a summed loss (both only for one device), which both can be used.
- Allowed to update state for the optimizer.
- Uses the `model_fn` of the `workload` in order to decouple the loss from the model so that model outputs (forward passes) can be reused (by storing them in the optimizer state).
- The submission can access the elapsed training time and get further information about the evaluation through `train_state`.
- The submission can access the target evaluation metric via the `workload` variable.
- **A call to this function will be considered a step**
- The time between a call to this function and the next call to this function will be considered the per-step time.
Expand Down
6 changes: 4 additions & 2 deletions algorithmic_efficiency/spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,7 +403,8 @@ def init_optimizer_state(workload: Workload,
OptimizerState,
List[Tuple[int, float]],
int,
RandomState
RandomState,
Optional[Dict[str, Any]]
],
UpdateReturn]

Expand All @@ -424,7 +425,8 @@ def update_params(workload: Workload,
optimizer_state: OptimizerState,
eval_results: List[Tuple[int, float]],
global_step: int,
rng: RandomState) -> UpdateReturn:
rng: RandomState,
train_state: Optional[Dict[str, Any]] = None) -> UpdateReturn:
"""Return (updated_optimizer_state, updated_params, updated_model_state)."""
pass

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -252,20 +252,23 @@ def _loss_fn(params):
return new_optimizer_state, updated_params, new_model_state, loss, grad_norm


def update_params(workload: spec.Workload,
current_param_container: spec.ParameterContainer,
current_params_types: spec.ParameterTypeTree,
model_state: spec.ModelAuxiliaryState,
hyperparameters: spec.Hyperparameters,
batch: Dict[str, spec.Tensor],
loss_type: spec.LossType,
optimizer_state: spec.OptimizerState,
eval_results: List[Tuple[int, float]],
global_step: int,
rng: spec.RandomState) -> spec.UpdateReturn:
def update_params(
workload: spec.Workload,
current_param_container: spec.ParameterContainer,
current_params_types: spec.ParameterTypeTree,
model_state: spec.ModelAuxiliaryState,
hyperparameters: spec.Hyperparameters,
batch: Dict[str, spec.Tensor],
loss_type: spec.LossType,
optimizer_state: spec.OptimizerState,
eval_results: List[Tuple[int, float]],
global_step: int,
rng: spec.RandomState,
train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn:
"""Return (updated_optimizer_state, updated_params, updated_model_state)."""
del current_params_types
del loss_type
del train_state
del eval_results

optimizer_state, opt_update_fn = optimizer_state
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -252,20 +252,23 @@ def _loss_fn(params):
return new_optimizer_state, updated_params, new_model_state, loss, grad_norm


def update_params(workload: spec.Workload,
current_param_container: spec.ParameterContainer,
current_params_types: spec.ParameterTypeTree,
model_state: spec.ModelAuxiliaryState,
hyperparameters: spec.Hyperparameters,
batch: Dict[str, spec.Tensor],
loss_type: spec.LossType,
optimizer_state: spec.OptimizerState,
eval_results: List[Tuple[int, float]],
global_step: int,
rng: spec.RandomState) -> spec.UpdateReturn:
def update_params(
workload: spec.Workload,
current_param_container: spec.ParameterContainer,
current_params_types: spec.ParameterTypeTree,
model_state: spec.ModelAuxiliaryState,
hyperparameters: spec.Hyperparameters,
batch: Dict[str, spec.Tensor],
loss_type: spec.LossType,
optimizer_state: spec.OptimizerState,
eval_results: List[Tuple[int, float]],
global_step: int,
rng: spec.RandomState,
train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn:
"""Return (updated_optimizer_state, updated_params, updated_model_state)."""
del current_params_types
del loss_type
del train_state
del eval_results

optimizer_state, opt_update_fn = optimizer_state
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Submission file for an NAdamW optimizer with warmup+cosine LR in PyTorch."""

import math
from typing import Dict, Iterator, List, Tuple
from typing import Any, Dict, Iterator, List, Optional, Tuple

from absl import logging
import torch
Expand Down Expand Up @@ -224,20 +224,23 @@ def pytorch_cosine_warmup(step_hint: int, hyperparameters, optimizer):
return optimizer_state


def update_params(workload: spec.Workload,
current_param_container: spec.ParameterContainer,
current_params_types: spec.ParameterTypeTree,
model_state: spec.ModelAuxiliaryState,
hyperparameters: spec.Hyperparameters,
batch: Dict[str, spec.Tensor],
loss_type: spec.LossType,
optimizer_state: spec.OptimizerState,
eval_results: List[Tuple[int, float]],
global_step: int,
rng: spec.RandomState) -> spec.UpdateReturn:
def update_params(
workload: spec.Workload,
current_param_container: spec.ParameterContainer,
current_params_types: spec.ParameterTypeTree,
model_state: spec.ModelAuxiliaryState,
hyperparameters: spec.Hyperparameters,
batch: Dict[str, spec.Tensor],
loss_type: spec.LossType,
optimizer_state: spec.OptimizerState,
eval_results: List[Tuple[int, float]],
global_step: int,
rng: spec.RandomState,
train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn:
"""Return (updated_optimizer_state, updated_params, updated_model_state)."""
del current_params_types
del loss_type
del train_state
del eval_results

current_model = current_param_container
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Submission file for an NAdamW optimizer with warmup+cosine LR in PyTorch."""

import math
from typing import Dict, Iterator, List, Tuple
from typing import Any, Dict, Iterator, List, Optional, Tuple

from absl import logging
import torch
Expand Down Expand Up @@ -224,20 +224,23 @@ def pytorch_cosine_warmup(step_hint: int, hyperparameters, optimizer):
return optimizer_state


def update_params(workload: spec.Workload,
current_param_container: spec.ParameterContainer,
current_params_types: spec.ParameterTypeTree,
model_state: spec.ModelAuxiliaryState,
hyperparameters: spec.Hyperparameters,
batch: Dict[str, spec.Tensor],
loss_type: spec.LossType,
optimizer_state: spec.OptimizerState,
eval_results: List[Tuple[int, float]],
global_step: int,
rng: spec.RandomState) -> spec.UpdateReturn:
def update_params(
workload: spec.Workload,
current_param_container: spec.ParameterContainer,
current_params_types: spec.ParameterTypeTree,
model_state: spec.ModelAuxiliaryState,
hyperparameters: spec.Hyperparameters,
batch: Dict[str, spec.Tensor],
loss_type: spec.LossType,
optimizer_state: spec.OptimizerState,
eval_results: List[Tuple[int, float]],
global_step: int,
rng: spec.RandomState,
train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn:
"""Return (updated_optimizer_state, updated_params, updated_model_state)."""
del current_params_types
del loss_type
del train_state
del eval_results

current_model = current_param_container
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -264,20 +264,23 @@ def _loss_fn(params):
return new_optimizer_state, updated_params, new_model_state, loss, grad_norm


def update_params(workload: spec.Workload,
current_param_container: spec.ParameterContainer,
current_params_types: spec.ParameterTypeTree,
model_state: spec.ModelAuxiliaryState,
hyperparameters: spec.Hyperparameters,
batch: Dict[str, spec.Tensor],
loss_type: spec.LossType,
optimizer_state: spec.OptimizerState,
eval_results: List[Tuple[int, float]],
global_step: int,
rng: spec.RandomState) -> spec.UpdateReturn:
def update_params(
workload: spec.Workload,
current_param_container: spec.ParameterContainer,
current_params_types: spec.ParameterTypeTree,
model_state: spec.ModelAuxiliaryState,
hyperparameters: spec.Hyperparameters,
batch: Dict[str, spec.Tensor],
loss_type: spec.LossType,
optimizer_state: spec.OptimizerState,
eval_results: List[Tuple[int, float]],
global_step: int,
rng: spec.RandomState,
train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn:
"""Return (updated_optimizer_state, updated_params, updated_model_state)."""
del current_params_types
del loss_type
del train_state
del eval_results
del hyperparameters

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -264,20 +264,23 @@ def _loss_fn(params):
return new_optimizer_state, updated_params, new_model_state, loss, grad_norm


def update_params(workload: spec.Workload,
current_param_container: spec.ParameterContainer,
current_params_types: spec.ParameterTypeTree,
model_state: spec.ModelAuxiliaryState,
hyperparameters: spec.Hyperparameters,
batch: Dict[str, spec.Tensor],
loss_type: spec.LossType,
optimizer_state: spec.OptimizerState,
eval_results: List[Tuple[int, float]],
global_step: int,
rng: spec.RandomState) -> spec.UpdateReturn:
def update_params(
workload: spec.Workload,
current_param_container: spec.ParameterContainer,
current_params_types: spec.ParameterTypeTree,
model_state: spec.ModelAuxiliaryState,
hyperparameters: spec.Hyperparameters,
batch: Dict[str, spec.Tensor],
loss_type: spec.LossType,
optimizer_state: spec.OptimizerState,
eval_results: List[Tuple[int, float]],
global_step: int,
rng: spec.RandomState,
train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn:
"""Return (updated_optimizer_state, updated_params, updated_model_state)."""
del current_params_types
del loss_type
del train_state
del eval_results
del hyperparameters

Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Submission file for an NAdamW optimizer with warmup+cosine LR in PyTorch."""

import math
from typing import Dict, Iterator, List, Tuple
from typing import Any, Dict, Iterator, List, Optional, Tuple

from absl import logging
import torch
Expand Down Expand Up @@ -236,20 +236,23 @@ def pytorch_cosine_warmup(step_hint: int, hyperparameters, optimizer):
return optimizer_state


def update_params(workload: spec.Workload,
current_param_container: spec.ParameterContainer,
current_params_types: spec.ParameterTypeTree,
model_state: spec.ModelAuxiliaryState,
hyperparameters: spec.Hyperparameters,
batch: Dict[str, spec.Tensor],
loss_type: spec.LossType,
optimizer_state: spec.OptimizerState,
eval_results: List[Tuple[int, float]],
global_step: int,
rng: spec.RandomState) -> spec.UpdateReturn:
def update_params(
workload: spec.Workload,
current_param_container: spec.ParameterContainer,
current_params_types: spec.ParameterTypeTree,
model_state: spec.ModelAuxiliaryState,
hyperparameters: spec.Hyperparameters,
batch: Dict[str, spec.Tensor],
loss_type: spec.LossType,
optimizer_state: spec.OptimizerState,
eval_results: List[Tuple[int, float]],
global_step: int,
rng: spec.RandomState,
train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn:
"""Return (updated_optimizer_state, updated_params, updated_model_state)."""
del current_params_types
del loss_type
del train_state
del eval_results
del hyperparameters

Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Submission file for an NAdamW optimizer with warmup+cosine LR in PyTorch."""

import math
from typing import Dict, Iterator, List, Tuple
from typing import Any, Dict, Iterator, List, Optional, Tuple

from absl import logging
import torch
Expand Down Expand Up @@ -236,20 +236,23 @@ def pytorch_cosine_warmup(step_hint: int, hyperparameters, optimizer):
return optimizer_state


def update_params(workload: spec.Workload,
current_param_container: spec.ParameterContainer,
current_params_types: spec.ParameterTypeTree,
model_state: spec.ModelAuxiliaryState,
hyperparameters: spec.Hyperparameters,
batch: Dict[str, spec.Tensor],
loss_type: spec.LossType,
optimizer_state: spec.OptimizerState,
eval_results: List[Tuple[int, float]],
global_step: int,
rng: spec.RandomState) -> spec.UpdateReturn:
def update_params(
workload: spec.Workload,
current_param_container: spec.ParameterContainer,
current_params_types: spec.ParameterTypeTree,
model_state: spec.ModelAuxiliaryState,
hyperparameters: spec.Hyperparameters,
batch: Dict[str, spec.Tensor],
loss_type: spec.LossType,
optimizer_state: spec.OptimizerState,
eval_results: List[Tuple[int, float]],
global_step: int,
rng: spec.RandomState,
train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn:
"""Return (updated_optimizer_state, updated_params, updated_model_state)."""
del current_params_types
del loss_type
del train_state
del eval_results
del hyperparameters

Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Training algorithm track submission functions for CIFAR10."""

import functools
from typing import Dict, Iterator, List, Tuple
from typing import Any, Dict, Iterator, List, Optional, Tuple

from flax import jax_utils
import jax
Expand Down Expand Up @@ -110,21 +110,24 @@ def _loss_fn(params):

# Not allowed to update the model parameters, hyperparameters, global step, or
# optimzier state.
def update_params(workload: spec.Workload,
current_param_container: spec.ParameterContainer,
current_params_types: spec.ParameterTypeTree,
model_state: spec.ModelAuxiliaryState,
hyperparameters: spec.Hyperparameters,
batch: Dict[str, spec.Tensor],
loss_type: spec.LossType,
optimizer_state: spec.OptimizerState,
eval_results: List[Tuple[int, float]],
global_step: int,
rng: spec.RandomState) -> spec.UpdateReturn:
def update_params(
workload: spec.Workload,
current_param_container: spec.ParameterContainer,
current_params_types: spec.ParameterTypeTree,
model_state: spec.ModelAuxiliaryState,
hyperparameters: spec.Hyperparameters,
batch: Dict[str, spec.Tensor],
loss_type: spec.LossType,
optimizer_state: spec.OptimizerState,
eval_results: List[Tuple[int, float]],
global_step: int,
rng: spec.RandomState,
train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn:
"""Return (updated_optimizer_state, updated_params, updated_model_state)."""
del current_params_types
del loss_type
del global_step
del train_state
del eval_results
optimizer_state, opt_update_fn = optimizer_state
per_device_rngs = jax.random.split(rng, jax.local_device_count())
Expand Down
Loading
Loading