Skip to content

Commit 2bab3ec

Browse files
committed
fix test
1 parent 78b7d0d commit 2bab3ec

File tree

1 file changed

+9
-3
lines changed

1 file changed

+9
-3
lines changed

tests/torchtune/config/test_config_utils.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@
2828
},
2929
"d": 4,
3030
"f": 8,
31+
"g": "foo",
32+
"h": "${g}/bar",
3133
}
3234

3335

@@ -50,7 +52,9 @@ def test_get_component_from_path(self):
5052
):
5153
_ = _get_component_from_path("torchtune.models.dummy")
5254

53-
@mock.patch("torchtune.config._parse.OmegaConf.load", return_value=_CONFIG)
55+
@mock.patch(
56+
"torchtune.config._parse.OmegaConf.load", return_value=OmegaConf.create(_CONFIG)
57+
)
5458
def test_merge_yaml_and_cli_args(self, mock_load):
5559
parser = TuneRecipeArgumentParser("test parser")
5660
yaml_args, cli_args = parser.parse_known_args(
@@ -63,6 +67,7 @@ def test_merge_yaml_and_cli_args(self, mock_load):
6367
"d=6", # Test overriding a flat param
6468
"e=7", # Test adding a new param
6569
"~f", # Test removing a param
70+
"g=bazz", # Test interpolation happens after override
6671
]
6772
)
6873
conf = _merge_yaml_and_cli_args(yaml_args, cli_args)
@@ -75,6 +80,7 @@ def test_merge_yaml_and_cli_args(self, mock_load):
7580
assert conf.d == 6, f"d == {conf.d}, not 6 as set in overrides."
7681
assert conf.e == 7, f"e == {conf.e}, not 7 as set in overrides."
7782
assert "f" not in conf, f"f == {conf.f}, not removed as set in overrides."
83+
assert conf.h == "bazz/bar", f"h == {conf.h}, not bazz/bar as set in overrides."
7884
mock_load.assert_called_once()
7985

8086
yaml_args, cli_args = parser.parse_known_args(
@@ -185,5 +191,5 @@ def test_remove_key_by_dotpath(self):
185191

186192
# Test removing non-existent param fails
187193
cfg = copy.deepcopy(_CONFIG)
188-
with pytest.raises(KeyError, match="'g'"):
189-
_remove_key_by_dotpath(cfg, "g")
194+
with pytest.raises(KeyError, match="'i'"):
195+
_remove_key_by_dotpath(cfg, "i")

0 commit comments

Comments
 (0)