Skip to content

Construct train step from an objective function and optimizer #595

@seanmor5

Description

@seanmor5

Right now the only way to construct a train step is using a loss function and an optimizer:

def train_step(model, loss, optimizer, opts \\ []) do

This is suitable for most cases, but some instances it may be easier to allow a user to pass an objective function to differentiate through rather than just the loss function. In a default train step the constructed objective function is:

  objective_fn = fn trainable_parameters, model_state, loss_scale_state, inp, tar ->
    # hack to use trainable parameters as grad
    model_state =
      update_in(model_state, [Access.key!(:data)], fn data ->
        tree_merge(data, trainable_parameters, fn _, _, v -> v end)
      end)

    model_out = forward_model_fn.(model_state, inp)
    unscaled_loss = loss_fn.(tar, model_out.prediction)
    scaled_loss = scale_loss.(unscaled_loss, loss_scale_state)

    {model_out, scaled_loss, unscaled_loss}
  end

If we can clean this form up a bit, and get rid of the hack, this could be a useful API for constructing more complex training objectives without needing to re-implement the entire train step

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions