@@ -152,8 +152,11 @@ def llama2_hf_checkpoints(self, tmp_path, state_dict_1, state_dict_2):
152152 * embed_dim: 64
153153 * max_seq_len: 128
154154 """
155- checkpoint_file_1 = tmp_path / "llama2_hf_checkpoint_01.pt"
156- checkpoint_file_2 = tmp_path / "llama2_hf_checkpoint_02.pt"
155+ checkpoint_dir = Path .joinpath (tmp_path , "checkpoint_dir" )
156+ checkpoint_dir .mkdir (parents = True , exist_ok = True )
157+
158+ checkpoint_file_1 = checkpoint_dir / "llama2_hf_checkpoint_01.pt"
159+ checkpoint_file_2 = checkpoint_dir / "llama2_hf_checkpoint_02.pt"
157160
158161 torch .save (state_dict_1 , checkpoint_file_1 )
159162 torch .save (state_dict_2 , checkpoint_file_2 )
@@ -163,7 +166,7 @@ def llama2_hf_checkpoints(self, tmp_path, state_dict_1, state_dict_2):
163166 "num_attention_heads" : 4 ,
164167 "num_key_value_heads" : 4 ,
165168 }
166- config_file = Path .joinpath (tmp_path , "config.json" )
169+ config_file = Path .joinpath (checkpoint_dir , "config.json" )
167170 with config_file .open ("w" ) as f :
168171 json .dump (config , f )
169172
@@ -174,23 +177,27 @@ def single_file_checkpointer(
174177 self , llama2_hf_checkpoints , tmp_path
175178 ) -> FullModelHFCheckpointer :
176179 checkpoint_file , _ = llama2_hf_checkpoints
180+ checkpoint_dir = str (Path .joinpath (tmp_path , "checkpoint_dir" ))
181+ output_dir = str (Path .joinpath (tmp_path , "output_dir" ))
177182 return FullModelHFCheckpointer (
178- checkpoint_dir = tmp_path ,
183+ checkpoint_dir = checkpoint_dir ,
179184 checkpoint_files = [checkpoint_file ],
180185 model_type = "LLAMA2" ,
181- output_dir = tmp_path ,
186+ output_dir = output_dir ,
182187 )
183188
184189 @pytest .fixture
185190 def multi_file_checkpointer (
186191 self , llama2_hf_checkpoints , tmp_path
187192 ) -> FullModelHFCheckpointer :
188193 checkpoint_file_1 , checkpoint_file_2 = llama2_hf_checkpoints
194+ checkpoint_dir = str (Path .joinpath (tmp_path , "checkpoint_dir" ))
195+ output_dir = str (Path .joinpath (tmp_path , "output_dir" ))
189196 return FullModelHFCheckpointer (
190- checkpoint_dir = tmp_path ,
197+ checkpoint_dir = checkpoint_dir ,
191198 checkpoint_files = [checkpoint_file_1 , checkpoint_file_2 ],
192199 model_type = "LLAMA2" ,
193- output_dir = tmp_path ,
200+ output_dir = output_dir ,
194201 )
195202
196203 def test_load_save_checkpoint_single_file (
@@ -242,7 +249,7 @@ def test_load_save_checkpoint_single_file(
242249 # assumes we know what the name of the file is. This is fine, breaking this logic
243250 # should be something we capture through this test
244251 output_file = Path .joinpath (
245- checkpoint_file .parent ,
252+ checkpoint_file .parent . parent / "output_dir" ,
246253 "epoch_1" ,
247254 SHARD_FNAME .format (cpt_idx = "1" .zfill (5 ), num_shards = "1" .zfill (5 )),
248255 ).with_suffix (".safetensors" )
@@ -306,12 +313,12 @@ def test_save_load_checkpoint_multiple_file(
306313 # assumes we know what the name of the file is. This is fine, breaking this logic
307314 # should be something we capture through this test
308315 output_file_1 = Path .joinpath (
309- checkpoint_file_1 .parent ,
316+ checkpoint_file_1 .parent . parent / "output_dir" ,
310317 "epoch_1" ,
311318 SHARD_FNAME .format (cpt_idx = "1" .zfill (5 ), num_shards = "2" .zfill (5 )),
312319 ).with_suffix (".safetensors" )
313320 output_file_2 = Path .joinpath (
314- checkpoint_file_2 .parent ,
321+ checkpoint_file_2 .parent . parent / "output_dir" ,
315322 "epoch_1" ,
316323 SHARD_FNAME .format (cpt_idx = "2" .zfill (5 ), num_shards = "2" .zfill (5 )),
317324 ).with_suffix (".safetensors" )
@@ -338,12 +345,14 @@ def test_load_save_adapter_only(
338345 single_file_checkpointer .save_checkpoint (state_dict , epoch = 2 , adapter_only = True )
339346
340347 output_file_1 = Path .joinpath (
341- tmp_path ,
348+ tmp_path / "output_dir" ,
342349 "epoch_2" ,
343350 SHARD_FNAME .format (cpt_idx = "1" .zfill (5 ), num_shards = "1" .zfill (5 )),
344351 )
345352 output_file_2 = Path .joinpath (
346- tmp_path , "epoch_2" , f"{ ADAPTER_MODEL_FNAME } .safetensors"
353+ tmp_path / "output_dir" ,
354+ "epoch_2" ,
355+ f"{ ADAPTER_MODEL_FNAME } .safetensors" ,
347356 )
348357
349358 with pytest .raises (ValueError , match = "Unable to load checkpoint from" ):
@@ -437,12 +446,16 @@ def test_save_checkpoint_in_peft_format(
437446
438447 # Load saved adapter weights and config from file for comparison
439448 adapter_weights_file = Path .joinpath (
440- checkpoint_file .parent , "epoch_1" , f"{ ADAPTER_MODEL_FNAME } .safetensors"
449+ checkpoint_file .parent .parent / "output_dir" ,
450+ "epoch_1" ,
451+ f"{ ADAPTER_MODEL_FNAME } .safetensors" ,
441452 )
442453 actual_adapter_state_dict = safe_torch_load (adapter_weights_file )
443454
444455 adapter_config_file = Path .joinpath (
445- checkpoint_file .parent , "epoch_1" , f"{ ADAPTER_CONFIG_FNAME } .json"
456+ checkpoint_file .parent .parent / "output_dir" ,
457+ "epoch_1" ,
458+ f"{ ADAPTER_CONFIG_FNAME } .json" ,
446459 )
447460 with open (adapter_config_file , "r" ) as f :
448461 adapter_config = json .load (f )
@@ -558,7 +571,10 @@ def mistral_reward_model_hf_checkpoint(self, tmp_path, state_dict):
558571 * intermediate_dim: 256
559572
560573 """
561- checkpoint_file = tmp_path / "mistral_reward_model_hf_checkpoint.pt"
574+ checkpoint_dir = Path .joinpath (tmp_path , "checkpoint_dir" )
575+ checkpoint_dir .mkdir (parents = True , exist_ok = True )
576+
577+ checkpoint_file = checkpoint_dir / "mistral_reward_model_hf_checkpoint.pt"
562578
563579 torch .save (state_dict , checkpoint_file )
564580
@@ -568,7 +584,7 @@ def mistral_reward_model_hf_checkpoint(self, tmp_path, state_dict):
568584 "num_key_value_heads" : 4 ,
569585 "num_classes" : 1 ,
570586 }
571- config_file = Path .joinpath (tmp_path , "config.json" )
587+ config_file = Path .joinpath (checkpoint_dir , "config.json" )
572588 with config_file .open ("w" ) as f :
573589 json .dump (config , f )
574590
@@ -579,11 +595,13 @@ def single_file_checkpointer(
579595 self , mistral_reward_model_hf_checkpoint , tmp_path
580596 ) -> FullModelHFCheckpointer :
581597 checkpoint_file = mistral_reward_model_hf_checkpoint
598+ checkpoint_dir = str (Path .joinpath (tmp_path , "checkpoint_dir" ))
599+ output_dir = str (Path .joinpath (tmp_path , "output_dir" ))
582600 return FullModelHFCheckpointer (
583- checkpoint_dir = tmp_path ,
601+ checkpoint_dir = checkpoint_dir ,
584602 checkpoint_files = [checkpoint_file ],
585603 model_type = "REWARD" ,
586- output_dir = tmp_path ,
604+ output_dir = output_dir ,
587605 )
588606
589607 def test_load_save_checkpoint_single_file (
@@ -636,7 +654,7 @@ def test_load_save_checkpoint_single_file(
636654 # assumes we know what the name of the file is. This is fine, breaking this logic
637655 # should be something we capture through this test
638656 output_file = Path .joinpath (
639- checkpoint_file .parent ,
657+ checkpoint_file .parent . parent / "output_dir" ,
640658 "epoch_1" ,
641659 SHARD_FNAME .format (cpt_idx = "1" .zfill (5 ), num_shards = "1" .zfill (5 )),
642660 ).with_suffix (".safetensors" )
@@ -708,7 +726,10 @@ def gemma_hf_checkpoint(self, tmp_path, state_dict):
708726 * head_dim : 16
709727
710728 """
711- checkpoint_file = tmp_path / "gemma_hf_checkpoint.pt"
729+ checkpoint_dir = Path .joinpath (tmp_path , "checkpoint_dir" )
730+ checkpoint_dir .mkdir (parents = True , exist_ok = True )
731+
732+ checkpoint_file = checkpoint_dir / "gemma_hf_checkpoint.pt"
712733
713734 torch .save (state_dict , checkpoint_file )
714735
@@ -719,7 +740,7 @@ def gemma_hf_checkpoint(self, tmp_path, state_dict):
719740 "head_dim" : _HEAD_DIM ,
720741 "intermediate_size" : _HIDDEN_DIM ,
721742 }
722- config_file = Path .joinpath (tmp_path , "config.json" )
743+ config_file = Path .joinpath (checkpoint_dir , "config.json" )
723744 with config_file .open ("w" ) as f :
724745 json .dump (config , f )
725746
@@ -730,11 +751,13 @@ def single_file_checkpointer(
730751 self , gemma_hf_checkpoint , tmp_path
731752 ) -> FullModelHFCheckpointer :
732753 checkpoint_file = gemma_hf_checkpoint
754+ checkpoint_dir = str (Path .joinpath (tmp_path , "checkpoint_dir" ))
755+ output_dir = str (Path .joinpath (tmp_path , "output_dir" ))
733756 return FullModelHFCheckpointer (
734- checkpoint_dir = tmp_path ,
757+ checkpoint_dir = checkpoint_dir ,
735758 checkpoint_files = [checkpoint_file ],
736759 model_type = "GEMMA" ,
737- output_dir = tmp_path ,
760+ output_dir = output_dir ,
738761 )
739762
740763 def test_load_save_checkpoint_single_file (
@@ -788,7 +811,7 @@ def test_load_save_checkpoint_single_file(
788811 # assumes we know what the name of the file is. This is fine, breaking this logic
789812 # should be something we capture through this test
790813 output_file = Path .joinpath (
791- checkpoint_file .parent ,
814+ checkpoint_file .parent . parent / "output_dir" ,
792815 "epoch_1" ,
793816 SHARD_FNAME .format (cpt_idx = "1" .zfill (5 ), num_shards = "1" .zfill (5 )),
794817 ).with_suffix (".safetensors" )
0 commit comments