Skip to content

Commit ad958ca

Browse files
author
Ervin T
authored
Modify Yamato tests (#4584)
1 parent 03f7e79 commit ad958ca

File tree

1 file changed

+3
-9
lines changed

1 file changed

+3
-9
lines changed

ml-agents/tests/yamato/training_int_tests.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -26,11 +26,10 @@ def run_training(python_version: str, csharp_version: str) -> bool:
2626
f"Running training with python={python_version or latest} and c#={csharp_version or latest}"
2727
)
2828
output_dir = "models" if python_version else "results"
29-
nn_file_expected = f"./{output_dir}/{run_id}/3DBall.nn"
3029
onnx_file_expected = f"./{output_dir}/{run_id}/3DBall.onnx"
3130
frozen_graph_file_expected = f"./{output_dir}/{run_id}/3DBall/frozen_graph_def.pb"
3231

33-
if os.path.exists(nn_file_expected):
32+
if os.path.exists(onnx_file_expected):
3433
# Should never happen - make sure nothing leftover from an old test.
3534
print("Artifacts from previous build found!")
3635
return False
@@ -96,21 +95,16 @@ def run_training(python_version: str, csharp_version: str) -> bool:
9695
if csharp_version is None and python_version is None:
9796
model_artifacts_dir = os.path.join(get_base_output_path(), "models")
9897
os.makedirs(model_artifacts_dir, exist_ok=True)
99-
shutil.copy(nn_file_expected, model_artifacts_dir)
10098
shutil.copy(onnx_file_expected, model_artifacts_dir)
10199
shutil.copy(frozen_graph_file_expected, model_artifacts_dir)
102100

103-
if (
104-
res.returncode != 0
105-
or not os.path.exists(nn_file_expected)
106-
or not os.path.exists(onnx_file_expected)
107-
):
101+
if res.returncode != 0 or not os.path.exists(onnx_file_expected):
108102
print("mlagents-learn run FAILED!")
109103
return False
110104

111105
if csharp_version is None and python_version is None:
112106
# Use abs path so that loading doesn't get confused
113-
model_path = os.path.abspath(os.path.dirname(nn_file_expected))
107+
model_path = os.path.abspath(os.path.dirname(onnx_file_expected))
114108
for extension in ["nn", "onnx"]:
115109
inference_ok = run_inference(env_path, model_path, extension)
116110
if not inference_ok:

0 commit comments

Comments
 (0)