diff --git a/tests/conftest.py b/tests/conftest.py index 0309781e7b..7bab05dfb3 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -253,7 +253,9 @@ def mxnet_eia_latest_py_version(): @pytest.fixture(scope="module", params=["py2", "py3"]) def pytorch_training_py_version(pytorch_training_version, request): - if Version(pytorch_training_version) >= Version("2.0"): + if Version(pytorch_training_version) >= Version("2.3"): + return "py311" + elif Version(pytorch_training_version) >= Version("2.0"): return "py310" elif Version(pytorch_training_version) >= Version("1.13"): return "py39"