2828 },
2929 "d" : 4 ,
3030 "f" : 8 ,
31+ "g" : "foo" ,
32+ "h" : "${g}/bar" ,
3133}
3234
3335
@@ -50,7 +52,9 @@ def test_get_component_from_path(self):
5052 ):
5153 _ = _get_component_from_path ("torchtune.models.dummy" )
5254
53- @mock .patch ("torchtune.config._parse.OmegaConf.load" , return_value = _CONFIG )
55+ @mock .patch (
56+ "torchtune.config._parse.OmegaConf.load" , return_value = OmegaConf .create (_CONFIG )
57+ )
5458 def test_merge_yaml_and_cli_args (self , mock_load ):
5559 parser = TuneRecipeArgumentParser ("test parser" )
5660 yaml_args , cli_args = parser .parse_known_args (
@@ -63,6 +67,7 @@ def test_merge_yaml_and_cli_args(self, mock_load):
6367 "d=6" , # Test overriding a flat param
6468 "e=7" , # Test adding a new param
6569 "~f" , # Test removing a param
70+ "g=bazz" , # Test interpolation happens after override
6671 ]
6772 )
6873 conf = _merge_yaml_and_cli_args (yaml_args , cli_args )
@@ -75,6 +80,7 @@ def test_merge_yaml_and_cli_args(self, mock_load):
7580 assert conf .d == 6 , f"d == { conf .d } , not 6 as set in overrides."
7681 assert conf .e == 7 , f"e == { conf .e } , not 7 as set in overrides."
7782 assert "f" not in conf , f"f == { conf .f } , not removed as set in overrides."
83+ assert conf .h == "bazz/bar" , f"h == { conf .h } , not bazz/bar as set in overrides."
7884 mock_load .assert_called_once ()
7985
8086 yaml_args , cli_args = parser .parse_known_args (
@@ -185,5 +191,5 @@ def test_remove_key_by_dotpath(self):
185191
186192 # Test removing non-existent param fails
187193 cfg = copy .deepcopy (_CONFIG )
188- with pytest .raises (KeyError , match = "'g '" ):
189- _remove_key_by_dotpath (cfg , "g " )
194+ with pytest .raises (KeyError , match = "'i '" ):
195+ _remove_key_by_dotpath (cfg , "i " )
0 commit comments