3030
3131
3232class TestLoRAFinetuneSingleDeviceRecipe :
33- def _get_test_config_overrides (
34- self ,
35- device : str = "cpu" ,
36- enable_ac : bool = False ,
37- dtype_str : str = "fp32" ,
38- epochs : int = 2 ,
39- ):
33+ def _get_test_config_overrides (self , dtype_str : str = "fp32" , epochs : int = 2 ):
4034 return [
4135 "batch_size=8" ,
42- f "device={ device } " ,
36+ "device=cpu " ,
4337 f"dtype={ dtype_str } " ,
44- f "enable_activation_checkpointing={ enable_ac } " ,
38+ "enable_activation_checkpointing=False " ,
4539 "dataset.train_on_input=False" ,
4640 "seed=9" ,
4741 f"epochs={ epochs } " ,
@@ -67,24 +61,13 @@ def _fetch_qlora_expected_loss_values(self, dtype):
6761 @pytest .mark .integration_test
6862 @pytest .mark .parametrize ("compile" , [True , False ])
6963 @pytest .mark .parametrize (
70- "config, model_type, ckpt_type, enable_activation_checkpointing, enable_activation_offloading " ,
64+ "config, model_type, ckpt_type" ,
7165 [
72- ("llama2/7B_lora_single_device" , "llama2" , "meta" , False , False ),
73- ("llama2/7B_lora_single_device" , "llama2" , "meta" , True , True ),
74- ("llama3/8B_lora_single_device" , "llama3" , "tune" , True , False ),
66+ ("llama2/7B_lora_single_device" , "llama2" , "meta" ),
67+ ("llama3/8B_lora_single_device" , "llama3" , "tune" ),
7568 ],
7669 )
77- def test_loss (
78- self ,
79- compile ,
80- config ,
81- model_type ,
82- ckpt_type ,
83- enable_activation_checkpointing ,
84- enable_activation_offloading ,
85- tmpdir ,
86- monkeypatch ,
87- ):
70+ def test_loss (self , compile , config , model_type , ckpt_type , tmpdir , monkeypatch ):
8871 ckpt_component = CKPT_COMPONENT_MAP [ckpt_type ]
8972 ckpt = model_type + "_" + ckpt_type
9073 ckpt_path = Path (CKPT_MODEL_PATHS [ckpt ])
@@ -105,21 +88,11 @@ def test_loss(
10588 tokenizer.prompt_template=null \
10689 metric_logger.filename={ log_file } \
10790 compile={ compile } \
108- enable_activation_checkpointing={ enable_activation_checkpointing } \
109- enable_activation_offloading={ enable_activation_offloading } \
11091 """ .split ()
11192
11293 model_config = MODEL_TEST_CONFIGS [model_type + "_lora" ]
11394
114- cmd = (
115- cmd
116- + self ._get_test_config_overrides (
117- device = "cuda" ,
118- enable_ac = enable_activation_checkpointing ,
119- dtype_str = "fp32" ,
120- )
121- + model_config
122- )
95+ cmd = cmd + self ._get_test_config_overrides (dtype_str = "fp32" ) + model_config
12396 monkeypatch .setattr (sys , "argv" , cmd )
12497 with pytest .raises (SystemExit , match = "" ):
12598 runpy .run_path (TUNE_PATH , run_name = "__main__" )
0 commit comments