Skip to content

Commit 05620fe

Browse files
finally fixing tests
1 parent 4c8e4d4 commit 05620fe

File tree

1 file changed

+6
-3
lines changed

1 file changed

+6
-3
lines changed

tests/recipes/test_lora_dpo_single_device.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ def _get_test_config_overrides(self, dtype_str: str = "fp32", epochs: int = 2):
4141
"log_every_n_steps=1",
4242
"gradient_accumulation_steps=1",
4343
"clip_grad_norm=100",
44+
"tokenizer.max_seq_len=512",
4445
] + dummy_stack_exchange_dataset_config()
4546

4647
@pytest.mark.parametrize("save_adapter_weights_only", [False, True])
@@ -93,6 +94,8 @@ def test_training_state_on_resume(
9394

9495
expected_loss_values = get_loss_values_from_metric_logger(log_file)
9596

97+
resumed_log_dir = (tmpdir / "resumed/").mkdir()
98+
resumed_log_file = gen_log_file_name(resumed_log_dir)
9699
# Resume training
97100
cmd_2 = f"""
98101
tune run lora_dpo_single_device \
@@ -106,7 +109,7 @@ def test_training_state_on_resume(
106109
checkpointer.output_dir={tmpdir} \
107110
checkpointer.model_type=LLAMA2 \
108111
resume_from_checkpoint=True \
109-
metric_logger.filename={log_file} \
112+
metric_logger.filename={resumed_log_file} \
110113
tokenizer.path=/tmp/test-artifacts/tokenizer.model \
111114
tokenizer.prompt_template=null \
112115
""".split()
@@ -116,10 +119,10 @@ def test_training_state_on_resume(
116119
runpy.run_path(TUNE_PATH, run_name="__main__")
117120

118121
# Second epoch only
119-
loss_values = get_loss_values_from_metric_logger(log_file)[:2]
122+
resumed_loss_values = get_loss_values_from_metric_logger(resumed_log_file)
120123

121124
torch.testing.assert_close(
122-
loss_values, expected_loss_values, rtol=1e-5, atol=1e-5
125+
resumed_loss_values[:2], expected_loss_values[2:], rtol=1e-5, atol=1e-5
123126
)
124127

125128
@pytest.mark.integration_test

0 commit comments

Comments
 (0)