Skip to content

Commit 27e9986

Browse files
author
Chris Elion
authored
enforce onnx conversion (expect tf2 CI to fail) (#3600)
1 parent 470fcc0 commit 27e9986

File tree

6 files changed

+54
-19
lines changed

6 files changed

+54
-19
lines changed

.circleci/config.yml

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,17 @@ jobs:
1818
pip_constraints:
1919
type: string
2020
description: Constraints file that is passed to "pip install". We constraint older versions of libraries for older python runtime, in order to help ensure compatibility.
21+
enforce_onnx_conversion:
22+
type: integer
23+
default: 0
24+
description: Whether to raise an exception if ONNX models couldn't be saved.
2125
executor: << parameters.executor >>
2226
working_directory: ~/repo
2327

2428
# Run additional numpy checks on unit tests
2529
environment:
2630
TEST_ENFORCE_NUMPY_FLOAT32: 1
31+
TEST_ENFORCE_ONNX_CONVERSION: << parameters.enforce_onnx_conversion >>
2732

2833
steps:
2934
- checkout
@@ -217,6 +222,8 @@ workflows:
217222
pyversion: 3.7.3
218223
# Test python 3.7 with the newest supported versions
219224
pip_constraints: test_constraints_max_tf1_version.txt
225+
# Make sure ONNX conversion passes here (recent version of tensorflow 1.x)
226+
enforce_onnx_conversion: 1
220227
- build_python:
221228
name: python_3.7.3+tf2
222229
executor: python373

docs/Unity-Inference-Engine.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ There are currently two supported model formats:
3333
* ONNX (`.onnx`) files use an [industry-standard open format](https://onnx.ai/about.html) produced by the [tf2onnx package](https://github.com/onnx/tensorflow-onnx).
3434

3535
Export to ONNX is currently considered beta. To enable it, make sure `tf2onnx>=1.5.5` is installed in pip.
36-
tf2onnx does not currently support tensorflow 2.0.0 or later.
36+
tf2onnx does not currently support tensorflow 2.0.0 or later, or earlier than 1.12.0.
3737

3838
## Using the Unity Inference Engine
3939

ml-agents/mlagents/model_serialization.py

Lines changed: 43 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
1+
from distutils.util import strtobool
2+
import os
13
import logging
24
from typing import Any, List, Set, NamedTuple
5+
from distutils.version import LooseVersion
36

47
try:
58
import onnx
@@ -18,6 +21,11 @@
1821
from tensorflow.python.framework import graph_util
1922
from mlagents.trainers import tensorflow_to_barracuda as tf2bc
2023

24+
if LooseVersion(tf.__version__) < LooseVersion("1.12.0"):
25+
# ONNX is only tested on 1.12.0 and later
26+
ONNX_EXPORT_ENABLED = False
27+
28+
2129
logger = logging.getLogger("mlagents.trainers")
2230

2331
POSSIBLE_INPUT_NODES = frozenset(
@@ -67,18 +75,28 @@ def export_policy_model(
6775
logger.info(f"Exported {settings.model_path}.nn file")
6876

6977
# Save to onnx too (if we were able to import it)
70-
if ONNX_EXPORT_ENABLED and settings.convert_to_onnx:
71-
try:
72-
onnx_graph = convert_frozen_to_onnx(settings, frozen_graph_def)
73-
onnx_output_path = settings.model_path + ".onnx"
74-
with open(onnx_output_path, "wb") as f:
75-
f.write(onnx_graph.SerializeToString())
76-
logger.info(f"Converting to {onnx_output_path}")
77-
except Exception:
78-
logger.exception(
79-
"Exception trying to save ONNX graph. Please report this error on "
80-
"https://github.com/Unity-Technologies/ml-agents/issues and "
81-
"attach a copy of frozen_graph_def.pb"
78+
if ONNX_EXPORT_ENABLED:
79+
if settings.convert_to_onnx:
80+
try:
81+
onnx_graph = convert_frozen_to_onnx(settings, frozen_graph_def)
82+
onnx_output_path = settings.model_path + ".onnx"
83+
with open(onnx_output_path, "wb") as f:
84+
f.write(onnx_graph.SerializeToString())
85+
logger.info(f"Converting to {onnx_output_path}")
86+
except Exception:
87+
# Make conversion errors fatal depending on environment variables (only done during CI)
88+
if _enforce_onnx_conversion():
89+
raise
90+
logger.exception(
91+
"Exception trying to save ONNX graph. Please report this error on "
92+
"https://github.com/Unity-Technologies/ml-agents/issues and "
93+
"attach a copy of frozen_graph_def.pb"
94+
)
95+
96+
else:
97+
if _enforce_onnx_conversion():
98+
raise RuntimeError(
99+
"ONNX conversion enforced, but couldn't import dependencies."
82100
)
83101

84102

@@ -203,3 +221,16 @@ def _process_graph(settings: SerializationSettings, graph: tf.Graph) -> List[str
203221
for n in nodes:
204222
logger.info("\t" + n)
205223
return nodes
224+
225+
226+
def _enforce_onnx_conversion() -> bool:
227+
env_var_name = "TEST_ENFORCE_ONNX_CONVERSION"
228+
if env_var_name not in os.environ:
229+
return False
230+
231+
val = os.environ[env_var_name]
232+
try:
233+
# This handles e.g. "false" converting reasonably to False
234+
return strtobool(val)
235+
except Exception:
236+
return False

test_constraints_max_tf1_version.txt

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,5 @@
33
# For projects with upper bounds, we should periodically update this list to the latest release version
44
grpcio>=1.23.0
55
numpy>=1.17.2
6-
# Temporary workaround for https://github.com/tensorflow/tensorflow/issues/36179 and https://github.com/tensorflow/tensorflow/issues/36188
7-
tensorflow>=1.14.0,<1.15.1
6+
tensorflow>=1.15.2,<2.0.0
87
h5py>=2.10.0

test_constraints_min_version.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,5 +3,5 @@ grpcio==1.11.0
33
numpy==1.14.1
44
Pillow==4.2.1
55
protobuf==3.6
6-
tensorflow==1.7
6+
tensorflow==1.7.0
77
h5py==2.9.0

test_requirements.txt

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,4 @@ pytest>4.0.0,<6.0.0
33
pytest-cov==2.6.1
44
pytest-xdist
55

6-
# Tests install onnx and tf2onnx, but this doesn't support tensorflow>=2.0.0
7-
# Since we test tensorflow2.0 with python3.7, exclude it based on the python version
8-
tf2onnx>=1.5.5; python_version < '3.7'
6+
tf2onnx>=1.5.5

0 commit comments

Comments
 (0)