-
Notifications
You must be signed in to change notification settings - Fork 112
Open
Description
Right now the only way to construct a train step is using a loss function and an optimizer:
def train_step(model, loss, optimizer, opts \\ []) do
This is suitable for most cases, but some instances it may be easier to allow a user to pass an objective function to differentiate through rather than just the loss function. In a default train step the constructed objective function is:
objective_fn = fn trainable_parameters, model_state, loss_scale_state, inp, tar ->
# hack to use trainable parameters as grad
model_state =
update_in(model_state, [Access.key!(:data)], fn data ->
tree_merge(data, trainable_parameters, fn _, _, v -> v end)
end)
model_out = forward_model_fn.(model_state, inp)
unscaled_loss = loss_fn.(tar, model_out.prediction)
scaled_loss = scale_loss.(unscaled_loss, loss_scale_state)
{model_out, scaled_loss, unscaled_loss}
end
If we can clean this form up a bit, and get rid of the hack, this could be a useful API for constructing more complex training objectives without needing to re-implement the entire train step
Metadata
Metadata
Assignees
Labels
No labels