Skip to content

Commit 52d6609

Browse files
ekmbpre-commit-ci[bot]
authored andcommitted
remove pos emb from state dict for old models (#7068)
* remove pos emb from state dict Signed-off-by: Evelina <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * move to nlp_model Signed-off-by: Evelina <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * update comment Signed-off-by: Evelina <[email protected]> * fix nmt test Signed-off-by: Evelina <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix nmt test Signed-off-by: Evelina <[email protected]> --------- Signed-off-by: Evelina <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Signed-off-by: jubick1337 <[email protected]>
1 parent 47e4ac9 commit 52d6609

File tree

1 file changed

+11
-1
lines changed

1 file changed

+11
-1
lines changed

nemo/collections/nlp/models/nlp_model.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
import hashlib
1717
import json
1818
import os
19-
from typing import Any, Optional
19+
from typing import Any, Mapping, Optional
2020

2121
from omegaconf import DictConfig, OmegaConf
2222
from pytorch_lightning import Trainer
@@ -385,3 +385,13 @@ def load_from_checkpoint(
385385
finally:
386386
cls._set_model_restore_state(is_being_restored=False)
387387
return checkpoint
388+
389+
def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True):
390+
# starting with trasformers v4.31.0, buffer for position_ids is persistent=False
391+
if (
392+
self.bert_model is not None
393+
and "position_ids" not in self.bert_model.embeddings._modules
394+
and "bert_model.embeddings.position_ids" in state_dict
395+
):
396+
del state_dict["bert_model.embeddings.position_ids"]
397+
super(NLPModel, self).load_state_dict(state_dict, strict=strict)

0 commit comments

Comments
 (0)