Skip to content

Commit 8d3d043

Browse files
authored
Update evo2 ModelCheckpoint args (#935)
### Description Adds new arguments to control `ModelCheckpoint`. ### Type of changes <!-- Mark the relevant option with an [x] --> - [ ] Bug fix (non-breaking change which fixes an issue) - [x] New feature (non-breaking change which adds functionality) - [ ] Refactor - [x] Documentation update - [ ] Other (please describe): ### Pre-submit Checklist <!--- Ensure all items are completed before submitting --> - [x] I have tested these changes locally - [x] I have updated the documentation accordingly - [x] I have added/updated tests as needed - [x] All existing tests pass successfully --------- Signed-off-by: Jared Wilber <[email protected]>
1 parent 9e34dd9 commit 8d3d043

File tree

2 files changed

+41
-5
lines changed

2 files changed

+41
-5
lines changed

sub-packages/bionemo-evo2/README.md

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ usage: train_evo2 [-h] (-d DATASET_CONFIG | --mock-data) [--dataset-dir DATASET_
4949
[--debug-ddp-parity-freq DEBUG_DDP_PARITY_FREQ] [--hybrid-override-pattern HYBRID_OVERRIDE_PATTERN] [--num-layers NUM_LAYERS] [--create-tflops-callback] [--log-parameters-and-shapes] [--lr LR] [--min-lr MIN_LR]
5050
[--warmup-steps WARMUP_STEPS] [--nsys-profiling] [--nsys-start-step NSYS_START_STEP] [--nsys-end-step NSYS_END_STEP] [--no-renormalize-loss] [--nsys-ranks NSYS_RANKS [NSYS_RANKS ...]]
5151
[--activation-checkpoint-recompute-num-layers ACTIVATION_CHECKPOINT_RECOMPUTE_NUM_LAYERS] [--disable-checkpointing] [--clip-grad CLIP_GRAD] [--seq-len-interpolation-factor SEQ_LEN_INTERPOLATION_FACTOR]
52-
[--overlap-param-gather] [--overlap-grad-reduce] [--hidden-dropout HIDDEN_DROPOUT] [--attention-dropout ATTENTION_DROPOUT] [--no-activation-checkpointing | --selective-activation-checkpointing]
52+
[--overlap-param-gather] [--overlap-grad-reduce] [--hidden-dropout HIDDEN_DROPOUT] [--attention-dropout ATTENTION_DROPOUT] [--save-top-k SAVE_TOP_K] [--metric-to-monitor-for-checkpoints METRIC_TO_MONITOR_FOR_CHECKPOINTS] [--save-last-checkpoint] [--no-save-last-checkpoint] [--no-activation-checkpointing | --selective-activation-checkpointing]
5353

5454
Train a Hyena model using NeMo 2.0.
5555

@@ -179,6 +179,14 @@ options:
179179
Dropout probability for the hyena layers (default: 0.0)
180180
--attention-dropout ATTENTION_DROPOUT
181181
Dropout probability for the attention layers. (default: 0.0)
182+
--save-top-k SAVE_TOP_K
183+
Number of best checkpoints to keep. Set to -1 to save all checkpoints. (default: 5)
184+
--metric-to-monitor-for-checkpoints METRIC_TO_MONITOR_FOR_CHECKPOINTS
185+
Metric to monitor for checkpoints. (default: val_loss)
186+
--save-last-checkpoint
187+
Save the last checkpoint. (default: True)
188+
--no-save-last-checkpoint
189+
Disable saving the last checkpoint. (default: True)
182190
--no-activation-checkpointing
183191
--selective-activation-checkpointing
184192
```

sub-packages/bionemo-evo2/src/bionemo/evo2/run/train.py

Lines changed: 32 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,6 @@
3838
)
3939
from nemo.collections.nlp.modules.common.tokenizer_utils import get_nmt_tokenizer
4040
from nemo.lightning.pytorch import callbacks as nl_callbacks
41-
from nemo.lightning.pytorch.callbacks import ModelCheckpoint
4241
from nemo.lightning.pytorch.callbacks.flops_callback import FLOPsMeasurementCallback
4342
from nemo.lightning.pytorch.callbacks.megatron_comm_overlap import MegatronCommOverlapCallback
4443
from nemo.lightning.pytorch.optim import CosineAnnealingScheduler
@@ -389,6 +388,31 @@ def parse_args(args: Optional[List[str]] = None) -> argparse.Namespace:
389388
default=0.0,
390389
help="Dropout probability for the attention layers.",
391390
)
391+
parser.add_argument(
392+
"--save-top-k",
393+
type=int,
394+
default=5,
395+
help="Number of best checkpoints to keep. Set to -1 to save all checkpoints.",
396+
)
397+
parser.add_argument(
398+
"--metric-to-monitor-for-checkpoints",
399+
type=str,
400+
default="val_loss",
401+
help="Metric to monitor for checkpoints.",
402+
)
403+
parser.add_argument(
404+
"--save-last-checkpoint",
405+
action="store_true",
406+
default=True,
407+
help="Save the last checkpoint.",
408+
)
409+
parser.add_argument(
410+
"--no-save-last-checkpoint",
411+
action="store_false",
412+
dest="save_last_checkpoint",
413+
default=True,
414+
help="Disable saving the last checkpoint.",
415+
)
392416
recompute_group = parser.add_mutually_exclusive_group(required=False)
393417
recompute_group.add_argument("--no-activation-checkpointing", action="store_true", default=False)
394418
recompute_group.add_argument("--selective-activation-checkpointing", action="store_true", default=False)
@@ -601,11 +625,15 @@ def train(args: argparse.Namespace) -> nl.Trainer:
601625

602626
if args.create_checkpoint_callback:
603627
checkpoint_path = str(Path(nemo_logger.save_dir) / "checkpoints")
604-
checkpoint_callback = ModelCheckpoint(
605-
every_n_train_steps=args.val_check_interval,
628+
checkpoint_callback = nl_callbacks.ModelCheckpoint(
606629
dirpath=checkpoint_path,
607-
save_top_k=5,
630+
save_last=args.save_last_checkpoint,
631+
monitor=args.metric_to_monitor_for_checkpoints,
632+
save_top_k=args.save_top_k,
633+
every_n_train_steps=args.val_check_interval,
608634
always_save_context=True,
635+
filename="{epoch}-{step}-{consumed_samples}",
636+
save_weights_only=False,
609637
save_optim_on_train_end=True,
610638
save_context_on_train_end=True,
611639
)

0 commit comments

Comments
 (0)