diff --git a/mesh_tensorflow/transformer/utils.py b/mesh_tensorflow/transformer/utils.py index f1c696b9..b5e00cd6 100644 --- a/mesh_tensorflow/transformer/utils.py +++ b/mesh_tensorflow/transformer/utils.py @@ -1563,7 +1563,7 @@ def get_estimator(model_type, vocabulary, mesh_shape, def train_model(estimator, vocabulary, sequence_length, batch_size, train_dataset_fn, train_steps, ensemble_inputs, - dataset_split="train"): + dataset_split="train", skip_seen_data=False): """Train a Mesh-TF model. Args: @@ -1585,10 +1585,13 @@ def train_model(estimator, vocabulary, sequence_length, batch_size, configure Unitransformer.ensemble to the right size. If None, then all models are trained on the same inputs. dataset_split: str, which dataset split to train on. + skip_seen_data: a boolean, is `False` by default. Used when a training run + restarts to skip already seen data. """ def input_fn(params): del params + dataset = train_dataset_fn( sequence_length=sequence_length, vocabulary=vocabulary, @@ -1596,6 +1599,12 @@ def input_fn(params): dataset = dataset.repeat().batch( batch_size * (ensemble_inputs or 1), drop_remainder=True) dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE) + + # On the first time data is read in after relaunching, skip data that has + # already been seen. + if skip_seen_data: + recovered_step = estimator.get_variable_value("global_step") + dataset = dataset.skip(recovered_step) return dataset estimator.train(input_fn=input_fn, max_steps=train_steps) @@ -2117,7 +2126,8 @@ def run(tpu_job_name, perplexity_eval_steps=100, init_checkpoint=None, ensemble_inputs=None, - train_model_fn=train_model): + train_model_fn=train_model, + skip_seen_data=False): """Run training, eval, or inference depending on `mode`. Args: @@ -2173,6 +2183,8 @@ def run(tpu_job_name, init_checkpoint: a string, see `get_estimator` docstring for details. ensemble_inputs: an integer, see `train_model` docstring for details. train_model_fn: an optional train function, is `train_model` by default. + skip_seen_data: a boolean, is `False` by default. Used when a training run + restarts to skip already seen data. """ if isinstance(sequence_length, int): sequence_length = {"inputs": sequence_length, @@ -2247,8 +2259,11 @@ def run(tpu_job_name, # train_model if train_dataset_fn is None: raise ValueError("Must provide train_dataset_fn through gin") + train_model_fn(estimator, vocabulary, sequence_length, batch_size, - train_dataset_fn, train_steps, ensemble_inputs) + train_dataset_fn, train_steps, ensemble_inputs, + skip_seen_data=skip_seen_data) + elif mode == "perplexity_eval": if eval_dataset_fn is None: if train_dataset_fn is not None: