2323log = get_logger ("DEBUG" )
2424
2525
26+ def save_config (config : DictConfig ) -> Path :
27+ """
28+ Save the OmegaConf configuration to a YAML file at `{config.output_dir}/torchtune_config.yaml`.
29+
30+ Args:
31+ config (DictConfig): The OmegaConf config object to be saved. It must contain an `output_dir` attribute
32+ specifying where the configuration file should be saved.
33+
34+ Returns:
35+ Path: The path to the saved configuration file.
36+
37+ Note:
38+ If the specified `output_dir` does not exist, it will be created.
39+ """
40+ try :
41+ output_dir = Path (config .output_dir )
42+ output_dir .mkdir (parents = True , exist_ok = True )
43+
44+ output_config_fname = output_dir / "torchtune_config.yaml"
45+ OmegaConf .save (config , output_config_fname )
46+ return output_config_fname
47+ except Exception as e :
48+ log .warning (f"Error saving config.\n Error: \n { e } ." )
49+
50+
2651class MetricLoggerInterface (Protocol ):
2752 """Abstract metric logger."""
2853
@@ -42,7 +67,7 @@ def log(
4267 pass
4368
4469 def log_config (self , config : DictConfig ) -> None :
45- """Logs the config
70+ """Logs the config as file
4671
4772 Args:
4873 config (DictConfig): config to log
@@ -99,6 +124,9 @@ def log(self, name: str, data: Scalar, step: int) -> None:
99124 self ._file .write (f"Step { step } | { name } :{ data } \n " )
100125 self ._file .flush ()
101126
127+ def log_config (self , config : DictConfig ) -> None :
128+ _ = save_config (config )
129+
102130 def log_dict (self , payload : Mapping [str , Scalar ], step : int ) -> None :
103131 self ._file .write (f"Step { step } | " )
104132 for name , data in payload .items ():
@@ -119,6 +147,9 @@ class StdoutLogger(MetricLoggerInterface):
119147 def log (self , name : str , data : Scalar , step : int ) -> None :
120148 print (f"Step { step } | { name } :{ data } " )
121149
150+ def log_config (self , config : DictConfig ) -> None :
151+ _ = save_config (config )
152+
122153 def log_dict (self , payload : Mapping [str , Scalar ], step : int ) -> None :
123154 print (f"Step { step } | " , end = "" )
124155 for name , data in payload .items ():
@@ -183,6 +214,10 @@ def __init__(
183214 # Use dir if specified, otherwise use log_dir.
184215 self .log_dir = kwargs .pop ("dir" , log_dir )
185216
217+ # create log_dir if missing
218+ if not os .path .exists (self .log_dir ):
219+ os .makedirs (self .log_dir )
220+
186221 _ , self .rank = get_world_size_and_rank ()
187222
188223 if self ._wandb .run is None and self .rank == 0 :
@@ -219,23 +254,16 @@ def log_config(self, config: DictConfig) -> None:
219254 self ._wandb .config .update (
220255 resolved , allow_val_change = self .config_allow_val_change
221256 )
222- try :
223- output_config_fname = Path (
224- os .path .join (
225- config .output_dir ,
226- "torchtune_config.yaml" ,
227- )
228- )
229- OmegaConf .save (config , output_config_fname )
230257
231- log .info (f"Logging { output_config_fname } to W&B under Files" )
258+ # Also try to save the config as a file
259+ output_config_fname = save_config (config )
260+ try :
232261 self ._wandb .save (
233262 output_config_fname , base_path = output_config_fname .parent
234263 )
235-
236264 except Exception as e :
237265 log .warning (
238- f"Error saving { output_config_fname } to W&B.\n Error: \n { e } ."
266+ f"Error uploading { output_config_fname } to W&B.\n Error: \n { e } ."
239267 "Don't worry the config will be logged the W&B workspace"
240268 )
241269
@@ -305,6 +333,9 @@ def log(self, name: str, data: Scalar, step: int) -> None:
305333 if self ._writer :
306334 self ._writer .add_scalar (name , data , global_step = step , new_style = True )
307335
336+ def log_config (self , config : DictConfig ) -> None :
337+ _ = save_config (config )
338+
308339 def log_dict (self , payload : Mapping [str , Scalar ], step : int ) -> None :
309340 for name , data in payload .items ():
310341 self .log (name , data , step )
@@ -387,13 +418,16 @@ def __init__(
387418 "Alternatively, use the ``StdoutLogger``, which can be specified by setting metric_logger_type='stdout'."
388419 ) from e
389420
421+ # Remove 'log_dir' from kwargs as it is not a valid argument for comet_ml.ExperimentConfig
422+ if "log_dir" in kwargs :
423+ del kwargs ["log_dir" ]
424+
390425 _ , self .rank = get_world_size_and_rank ()
391426
392427 # Declare it early so further methods don't crash in case of
393428 # Experiment Creation failure due to mis-named configuration for
394429 # example
395430 self .experiment = None
396-
397431 if self .rank == 0 :
398432 self .experiment = comet_ml .start (
399433 api_key = api_key ,
@@ -421,24 +455,13 @@ def log_config(self, config: DictConfig) -> None:
421455 self .experiment .log_parameters (resolved )
422456
423457 # Also try to save the config as a file
458+ output_config_fname = save_config (config )
424459 try :
425- self ._log_config_as_file (config )
460+ self .experiment .log_asset (
461+ output_config_fname , file_name = output_config_fname .name
462+ )
426463 except Exception as e :
427- log .warning (f"Error saving Config to disk.\n Error: \n { e } ." )
428- return
429-
430- def _log_config_as_file (self , config : DictConfig ):
431- output_config_fname = Path (
432- os .path .join (
433- config .checkpointer .checkpoint_dir ,
434- "torchtune_config.yaml" ,
435- )
436- )
437- OmegaConf .save (config , output_config_fname )
438-
439- self .experiment .log_asset (
440- output_config_fname , file_name = "torchtune_config.yaml"
441- )
464+ log .warning (f"Failed to upload config to Comet assets. Error: { e } " )
442465
443466 def close (self ) -> None :
444467 if self .experiment is not None :
0 commit comments