Skip to content
12 changes: 11 additions & 1 deletion nemo/collections/nlp/models/nlp_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import hashlib
import json
import os
from typing import Any, Optional
from typing import Any, Mapping, Optional

from omegaconf import DictConfig, OmegaConf
from pytorch_lightning import Trainer
Expand Down Expand Up @@ -385,3 +385,13 @@ def load_from_checkpoint(
finally:
cls._set_model_restore_state(is_being_restored=False)
return checkpoint

def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True):
# starting with trasformers v4.31.0, buffer for position_ids is persistent=False
if (
self.bert_model is not None
and "position_ids" not in self.bert_model.embeddings._modules
and "bert_model.embeddings.position_ids" in state_dict
):
del state_dict["bert_model.embeddings.position_ids"]
super(NLPModel, self).load_state_dict(state_dict, strict=strict)