@@ -41,6 +41,7 @@ def _get_test_config_overrides(self, dtype_str: str = "fp32", epochs: int = 2):
4141 "log_every_n_steps=1" ,
4242 "gradient_accumulation_steps=1" ,
4343 "clip_grad_norm=100" ,
44+ "tokenizer.max_seq_len=512" ,
4445 ] + dummy_stack_exchange_dataset_config ()
4546
4647 @pytest .mark .parametrize ("save_adapter_weights_only" , [False , True ])
@@ -93,6 +94,8 @@ def test_training_state_on_resume(
9394
9495 expected_loss_values = get_loss_values_from_metric_logger (log_file )
9596
97+ resumed_log_dir = (tmpdir / "resumed/" ).mkdir ()
98+ resumed_log_file = gen_log_file_name (resumed_log_dir )
9699 # Resume training
97100 cmd_2 = f"""
98101 tune run lora_dpo_single_device \
@@ -106,7 +109,7 @@ def test_training_state_on_resume(
106109 checkpointer.output_dir={ tmpdir } \
107110 checkpointer.model_type=LLAMA2 \
108111 resume_from_checkpoint=True \
109- metric_logger.filename={ log_file } \
112+ metric_logger.filename={ resumed_log_file } \
110113 tokenizer.path=/tmp/test-artifacts/tokenizer.model \
111114 tokenizer.prompt_template=null \
112115 """ .split ()
@@ -116,10 +119,10 @@ def test_training_state_on_resume(
116119 runpy .run_path (TUNE_PATH , run_name = "__main__" )
117120
118121 # Second epoch only
119- loss_values = get_loss_values_from_metric_logger (log_file )[: 2 ]
122+ resumed_loss_values = get_loss_values_from_metric_logger (resumed_log_file )
120123
121124 torch .testing .assert_close (
122- loss_values , expected_loss_values , rtol = 1e-5 , atol = 1e-5
125+ resumed_loss_values [: 2 ] , expected_loss_values [ 2 :] , rtol = 1e-5 , atol = 1e-5
123126 )
124127
125128 @pytest .mark .integration_test
0 commit comments