Skip to content

Commit 29d355e

Browse files
committed
refactor examples to accommodate Lightning-AI/pytorch-lightning#18105
1 parent 41ba761 commit 29d355e

File tree

3 files changed

+23
-5
lines changed

3 files changed

+23
-5
lines changed

src/fts_examples/stable/fts_superglue.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -117,15 +117,23 @@ def __init__(
117117
super().__init__()
118118
task_name = task_name if task_name in TASK_NUM_LABELS.keys() else DEFAULT_TASK
119119
self.text_fields = self.TASK_TEXT_FIELD_MAP[task_name]
120+
self.init_hparams = {
121+
"model_name_or_path": model_name_or_path,
122+
"task_name": task_name,
123+
"max_seq_length": max_seq_length,
124+
"train_batch_size": train_batch_size,
125+
"eval_batch_size": eval_batch_size,
126+
"dataloader_kwargs": dataloader_kwargs,
127+
"tokenizers_parallelism": tokenizers_parallelism,
128+
}
129+
self.save_hyperparameters(self.init_hparams)
120130
self.dataloader_kwargs = {
121131
"num_workers": dataloader_kwargs.get("num_workers", 0),
122132
"pin_memory": dataloader_kwargs.get("pin_memory", False),
123133
}
124-
self.save_hyperparameters()
125134
os.environ["TOKENIZERS_PARALLELISM"] = "true" if self.hparams.tokenizers_parallelism else "false"
126-
self.tokenizer = AutoTokenizer.from_pretrained(
127-
self.hparams.model_name_or_path, use_fast=True, local_files_only=False
128-
)
135+
self.tokenizer = AutoTokenizer.from_pretrained(self.hparams.model_name_or_path, use_fast=True,
136+
local_files_only=False)
129137

130138
def prepare_data(self):
131139
"""Load the SuperGLUE dataset."""

src/fts_examples/stable/ipynb_src/fts_superglue_nb.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -250,11 +250,20 @@ def __init__(
250250
super().__init__()
251251
task_name = task_name if task_name in TASK_NUM_LABELS.keys() else DEFAULT_TASK
252252
self.text_fields = self.TASK_TEXT_FIELD_MAP[task_name]
253+
self.init_hparams = {
254+
"model_name_or_path": model_name_or_path,
255+
"task_name": task_name,
256+
"max_seq_length": max_seq_length,
257+
"train_batch_size": train_batch_size,
258+
"eval_batch_size": eval_batch_size,
259+
"dataloader_kwargs": dataloader_kwargs,
260+
"tokenizers_parallelism": tokenizers_parallelism,
261+
}
262+
self.save_hyperparameters(self.init_hparams)
253263
self.dataloader_kwargs = {
254264
"num_workers": dataloader_kwargs.get("num_workers", 0),
255265
"pin_memory": dataloader_kwargs.get("pin_memory", False),
256266
}
257-
self.save_hyperparameters()
258267
os.environ["TOKENIZERS_PARALLELISM"] = "true" if self.hparams.tokenizers_parallelism else "false"
259268
self.tokenizer = AutoTokenizer.from_pretrained(
260269
self.hparams.model_name_or_path, use_fast=True, local_files_only=False

src/fts_examples/stable/test_examples.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
"does not have many workers",
3737
"is smaller than the logging interval",
3838
"sentencepiece tokenizer that you are converting",
39+
"`resume_download` is deprecated", # required because of upstream usage as of 2.2.2
3940
"distutils Version classes are deprecated", # still required as of PyTorch/Lightning 2.2
4041
"Please use torch.utils._pytree.register_pytree_node", # temp allow deprecated behavior of transformers
4142
"We are importing from `pydantic", # temp pydantic import migration warning

0 commit comments

Comments
 (0)