Skip to content

Commit 3c68452

Browse files
authored
use strict=false when loading checkpoints with mfsdp (#1390)
Should fix failing TOT recipe tests Signed-off-by: Peter St. John <[email protected]>
1 parent 8935b13 commit 3c68452

File tree

3 files changed

+3
-3
lines changed

3 files changed

+3
-3
lines changed

bionemo-recipes/recipes/esm2_native_te/checkpoint.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -240,7 +240,7 @@ def load_checkpoint_mfsdp(
240240
}
241241
torch.distributed.checkpoint.load(state_dict=ckpt_state_dict, checkpoint_id=checkpoint_path)
242242

243-
model.load_state_dict(ckpt_state_dict["model"])
243+
model.load_state_dict(ckpt_state_dict["model"], strict=False)
244244
optimizer.load_state_dict(ckpt_state_dict["optimizer"])
245245
scheduler.load_state_dict(ckpt_state_dict["scheduler"])
246246
dataloader = load_dataloader(dataloader, checkpoint_path, dist_config)

bionemo-recipes/recipes/vit/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ def load_dcp_checkpoint(checkpoint_path, model=None, optimizer=None):
8383
state_dict["optimizer"] = optimizer.state_dict()
8484
torch.distributed.checkpoint.load(state_dict, checkpoint_id=checkpoint_path)
8585
if model is not None:
86-
model.load_state_dict(state_dict["model"])
86+
model.load_state_dict(state_dict["model"], strict=False)
8787
if optimizer is not None:
8888
optimizer.load_state_dict(state_dict["optimizer"])
8989
```

bionemo-recipes/recipes/vit/checkpoint.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ def load_dcp_checkpoint(checkpoint_path, model=None, optimizer=None):
6767
state_dict["optimizer"] = optimizer.state_dict()
6868
torch.distributed.checkpoint.load(state_dict, checkpoint_id=checkpoint_path)
6969
if model is not None:
70-
model.load_state_dict(state_dict["model"])
70+
model.load_state_dict(state_dict["model"], strict=False)
7171
if optimizer is not None:
7272
optimizer.load_state_dict(state_dict["optimizer"])
7373

0 commit comments

Comments
 (0)