Skip to content

Commit a5626aa

Browse files
author
Chris Elion
authored
Support for ONNX export (#3101)
1 parent 2c2f930 commit a5626aa

File tree

8 files changed

+227
-48
lines changed

8 files changed

+227
-48
lines changed

com.unity.ml-agents/CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.
77
## [Unreleased]
88
### Major Changes
99
- Agent.CollectObservations now takes a VectorSensor argument. It was also overloaded to optionally take an ActionMasker argument. (#3352, #3389)
10+
- Beta support for ONNX export was added. If the `tf2onnx` python package is installed, models will be saved to `.onnx` as well as `.nn` format.
11+
Note that Barracuda 0.6.0 or later is required to import the `.onnx` files properly
1012

1113
### Minor Changes
1214
- Monitor.cs was moved to Examples. (#3372)

docs/Unity-Inference-Engine.md

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,10 +27,17 @@ but we only tested for the following platforms :
2727
* iOS
2828
* Android
2929

30+
## Supported formats
31+
There are currently two supported model formats:
32+
* Barracuda (`.nn`) files use a proprietary format produced by the [`tensorflow_to_barracuda.py`]() script.
33+
* ONNX (`.onnx`) files use an [industry-standard open format](https://onnx.ai/about.html) produced by the [tf2onnx package](https://github.com/onnx/tensorflow-onnx).
34+
35+
Export to ONNX is currently considered beta. To enable it, make sure `tf2onnx>=1.5.5` is installed in pip.
36+
tf2onnx does not currently support tensorflow 2.0.0 or later.
37+
3038
## Using the Unity Inference Engine
3139

32-
When using a model, drag the `.nn` file into the **Model** field
33-
in the Inspector of the Agent.
40+
When using a model, drag the model file into the **Model** field in the Inspector of the Agent.
3441
Select the **Inference Device** : CPU or GPU you want to use for Inference.
3542

3643
**Note:** For most of the models generated with the ML-Agents toolkit, CPU will be faster than GPU.

ml-agents-envs/setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def run(self):
4646
install_requires=[
4747
"cloudpickle",
4848
"grpcio>=1.11.0",
49-
"numpy>=1.13.3,<2.0",
49+
"numpy>=1.14.1,<2.0",
5050
"Pillow>=4.2.1",
5151
"protobuf>=3.6",
5252
],
Lines changed: 205 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,205 @@
1+
import logging
2+
from typing import Any, List, Set, NamedTuple
3+
4+
try:
5+
import onnx
6+
from tf2onnx.tfonnx import process_tf_graph, tf_optimize
7+
from tf2onnx import optimizer
8+
9+
ONNX_EXPORT_ENABLED = True
10+
except ImportError:
11+
# Either onnx and tf2onnx not installed, or they're not compatible with the version of tensorflow
12+
ONNX_EXPORT_ENABLED = False
13+
pass
14+
15+
from mlagents.tf_utils import tf
16+
17+
from tensorflow.python.platform import gfile
18+
from tensorflow.python.framework import graph_util
19+
from mlagents.trainers import tensorflow_to_barracuda as tf2bc
20+
21+
logger = logging.getLogger("mlagents.trainers")
22+
23+
POSSIBLE_INPUT_NODES = frozenset(
24+
[
25+
"action_masks",
26+
"epsilon",
27+
"prev_action",
28+
"recurrent_in",
29+
"sequence_length",
30+
"vector_observation",
31+
]
32+
)
33+
34+
POSSIBLE_OUTPUT_NODES = frozenset(
35+
["action", "action_probs", "recurrent_out", "value_estimate"]
36+
)
37+
38+
MODEL_CONSTANTS = frozenset(
39+
["action_output_shape", "is_continuous_control", "memory_size", "version_number"]
40+
)
41+
VISUAL_OBSERVATION_PREFIX = "visual_observation_"
42+
43+
44+
class SerializationSettings(NamedTuple):
45+
model_path: str
46+
brain_name: str
47+
convert_to_barracuda: bool = True
48+
convert_to_onnx: bool = True
49+
onnx_opset: int = 9
50+
51+
52+
def export_policy_model(
53+
settings: SerializationSettings, graph: tf.Graph, sess: tf.Session
54+
) -> None:
55+
"""
56+
Exports latest saved model to .nn format for Unity embedding.
57+
"""
58+
frozen_graph_def = _make_frozen_graph(settings, graph, sess)
59+
# Save frozen graph
60+
frozen_graph_def_path = settings.model_path + "/frozen_graph_def.pb"
61+
with gfile.GFile(frozen_graph_def_path, "wb") as f:
62+
f.write(frozen_graph_def.SerializeToString())
63+
64+
# Convert to barracuda
65+
if settings.convert_to_barracuda:
66+
tf2bc.convert(frozen_graph_def_path, settings.model_path + ".nn")
67+
logger.info(f"Exported {settings.model_path}.nn file")
68+
69+
# Save to onnx too (if we were able to import it)
70+
if ONNX_EXPORT_ENABLED and settings.convert_to_onnx:
71+
try:
72+
onnx_graph = convert_frozen_to_onnx(settings, frozen_graph_def)
73+
onnx_output_path = settings.model_path + ".onnx"
74+
with open(onnx_output_path, "wb") as f:
75+
f.write(onnx_graph.SerializeToString())
76+
logger.info(f"Converting to {onnx_output_path}")
77+
except Exception:
78+
logger.exception(
79+
"Exception trying to save ONNX graph. Please report this error on "
80+
"https://github.com/Unity-Technologies/ml-agents/issues and "
81+
"attach a copy of frozen_graph_def.pb"
82+
)
83+
84+
85+
def _make_frozen_graph(
86+
settings: SerializationSettings, graph: tf.Graph, sess: tf.Session
87+
) -> tf.GraphDef:
88+
with graph.as_default():
89+
target_nodes = ",".join(_process_graph(settings, graph))
90+
graph_def = graph.as_graph_def()
91+
output_graph_def = graph_util.convert_variables_to_constants(
92+
sess, graph_def, target_nodes.replace(" ", "").split(",")
93+
)
94+
return output_graph_def
95+
96+
97+
def convert_frozen_to_onnx(
98+
settings: SerializationSettings, frozen_graph_def: tf.GraphDef
99+
) -> Any:
100+
# This is basically https://github.com/onnx/tensorflow-onnx/blob/master/tf2onnx/convert.py
101+
102+
# Some constants in the graph need to be read by the inference system.
103+
# These aren't used by the model anywhere, so trying to make sure they propagate
104+
# through conversion and import is a losing battle. Instead, save them now,
105+
# so that we can add them back later.
106+
constant_values = {}
107+
for n in frozen_graph_def.node:
108+
if n.name in MODEL_CONSTANTS:
109+
val = n.attr["value"].tensor.int_val[0]
110+
constant_values[n.name] = val
111+
112+
inputs = _get_input_node_names(frozen_graph_def)
113+
outputs = _get_output_node_names(frozen_graph_def)
114+
logger.info(f"onnx export - inputs:{inputs} outputs:{outputs}")
115+
116+
frozen_graph_def = tf_optimize(
117+
inputs, outputs, frozen_graph_def, fold_constant=True
118+
)
119+
120+
with tf.Graph().as_default() as tf_graph:
121+
tf.import_graph_def(frozen_graph_def, name="")
122+
with tf.Session(graph=tf_graph):
123+
g = process_tf_graph(
124+
tf_graph,
125+
input_names=inputs,
126+
output_names=outputs,
127+
opset=settings.onnx_opset,
128+
)
129+
130+
onnx_graph = optimizer.optimize_graph(g)
131+
model_proto = onnx_graph.make_model(settings.brain_name)
132+
133+
# Save the constant values back the graph initializer.
134+
# This will ensure the importer gets them as global constants.
135+
constant_nodes = []
136+
for k, v in constant_values.items():
137+
constant_node = _make_onnx_node_for_constant(k, v)
138+
constant_nodes.append(constant_node)
139+
model_proto.graph.initializer.extend(constant_nodes)
140+
return model_proto
141+
142+
143+
def _make_onnx_node_for_constant(name: str, value: int) -> Any:
144+
tensor_value = onnx.TensorProto(
145+
data_type=onnx.TensorProto.INT32,
146+
name=name,
147+
int32_data=[value],
148+
dims=[1, 1, 1, 1],
149+
)
150+
return tensor_value
151+
152+
153+
def _get_input_node_names(frozen_graph_def: Any) -> List[str]:
154+
"""
155+
Get the list of input node names from the graph.
156+
Names are suffixed with ":0"
157+
"""
158+
node_names = _get_frozen_graph_node_names(frozen_graph_def)
159+
input_names = node_names & POSSIBLE_INPUT_NODES
160+
161+
# Check visual inputs sequentially, and exit as soon as we don't find one
162+
vis_index = 0
163+
while True:
164+
vis_node_name = f"{VISUAL_OBSERVATION_PREFIX}{vis_index}"
165+
if vis_node_name in node_names:
166+
input_names.add(vis_node_name)
167+
else:
168+
break
169+
vis_index += 1
170+
# Append the port
171+
return [f"{n}:0" for n in input_names]
172+
173+
174+
def _get_output_node_names(frozen_graph_def: Any) -> List[str]:
175+
"""
176+
Get the list of output node names from the graph.
177+
Names are suffixed with ":0"
178+
"""
179+
node_names = _get_frozen_graph_node_names(frozen_graph_def)
180+
output_names = node_names & POSSIBLE_OUTPUT_NODES
181+
# Append the port
182+
return [f"{n}:0" for n in output_names]
183+
184+
185+
def _get_frozen_graph_node_names(frozen_graph_def: Any) -> Set[str]:
186+
"""
187+
Get all the node names from the graph.
188+
"""
189+
names = set()
190+
for node in frozen_graph_def.node:
191+
names.add(node.name)
192+
return names
193+
194+
195+
def _process_graph(settings: SerializationSettings, graph: tf.Graph) -> List[str]:
196+
"""
197+
Gets the list of the output nodes present in the graph for inference
198+
:return: list of node names
199+
"""
200+
all_nodes = [x.name for x in graph.as_graph_def().node]
201+
nodes = [x for x in all_nodes if x in POSSIBLE_OUTPUT_NODES | MODEL_CONSTANTS]
202+
logger.info("List of nodes to export for brain :" + settings.brain_name)
203+
for n in nodes:
204+
logger.info("\t" + n)
205+
return nodes

ml-agents/mlagents/trainers/tf_policy.py

Lines changed: 1 addition & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,13 @@
22
from typing import Any, Dict, List, Optional
33

44
import numpy as np
5+
56
from mlagents.tf_utils import tf
67
from mlagents import tf_utils
78

89
from mlagents_envs.exception import UnityException
910
from mlagents.trainers.policy import Policy
1011
from mlagents.trainers.action_info import ActionInfo
11-
from tensorflow.python.platform import gfile
12-
from tensorflow.python.framework import graph_util
13-
from mlagents.trainers import tensorflow_to_barracuda as tf2bc
1412
from mlagents.trainers.trajectory import SplitObservations
1513
from mlagents.trainers.buffer import AgentBuffer
1614
from mlagents.trainers.brain_conversion_utils import get_global_agent_id
@@ -34,17 +32,6 @@ class TFPolicy(Policy):
3432
functions to interact with it to perform evaluate and updating.
3533
"""
3634

37-
possible_output_nodes = [
38-
"action",
39-
"value_estimate",
40-
"action_probs",
41-
"recurrent_out",
42-
"memory_size",
43-
"version_number",
44-
"is_continuous_control",
45-
"action_output_shape",
46-
]
47-
4835
def __init__(self, seed, brain, trainer_parameters):
4936
"""
5037
Initialized the policy.
@@ -328,35 +315,6 @@ def save_model(self, steps):
328315
self.graph, self.model_path, "raw_graph_def.pb", as_text=False
329316
)
330317

331-
def export_model(self):
332-
"""
333-
Exports latest saved model to .nn format for Unity embedding.
334-
"""
335-
336-
with self.graph.as_default():
337-
target_nodes = ",".join(self._process_graph())
338-
graph_def = self.graph.as_graph_def()
339-
output_graph_def = graph_util.convert_variables_to_constants(
340-
self.sess, graph_def, target_nodes.replace(" ", "").split(",")
341-
)
342-
frozen_graph_def_path = self.model_path + "/frozen_graph_def.pb"
343-
with gfile.GFile(frozen_graph_def_path, "wb") as f:
344-
f.write(output_graph_def.SerializeToString())
345-
tf2bc.convert(frozen_graph_def_path, self.model_path + ".nn")
346-
logger.info("Exported " + self.model_path + ".nn file")
347-
348-
def _process_graph(self):
349-
"""
350-
Gets the list of the output nodes present in the graph for inference
351-
:return: list of node names
352-
"""
353-
all_nodes = [x.name for x in self.graph.as_graph_def().node]
354-
nodes = [x for x in all_nodes if x in self.possible_output_nodes]
355-
logger.info("List of nodes to export for brain :" + self.brain.brain_name)
356-
for n in nodes:
357-
logger.info("\t" + n)
358-
return nodes
359-
360318
def update_normalization(self, vector_obs: np.ndarray) -> None:
361319
"""
362320
If this policy normalizes vector observations, this will update the norm values in the graph.

ml-agents/mlagents/trainers/trainer.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
from mlagents_envs.exception import UnityException
1313
from mlagents_envs.timers import set_gauge
14+
from mlagents.model_serialization import export_policy_model, SerializationSettings
1415
from mlagents.trainers.tf_policy import TFPolicy
1516
from mlagents.trainers.stats import StatsReporter
1617
from mlagents.trainers.trajectory import Trajectory
@@ -192,7 +193,9 @@ def export_model(self, name_behavior_id: str) -> None:
192193
"""
193194
Exports the model
194195
"""
195-
self.get_policy(name_behavior_id).export_model()
196+
policy = self.get_policy(name_behavior_id)
197+
settings = SerializationSettings(policy.model_path, policy.brain.brain_name)
198+
export_policy_model(settings, policy.graph, policy.sess)
196199

197200
def _write_summary(self, step: int) -> None:
198201
"""

test_constraints_min_version.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# pip constraints to use the *lowest* versions allowed in ml-agents/setup.py
22
grpcio==1.11.0
3-
numpy==1.13.3
3+
numpy==1.14.1
44
Pillow==4.2.1
55
protobuf==3.6
66
tensorflow==1.7

test_requirements.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,7 @@
22
pytest>4.0.0,<6.0.0
33
pytest-cov==2.6.1
44
pytest-xdist
5+
6+
# Tests install onnx and tf2onnx, but this doesn't support tensorflow>=2.0.0
7+
# Since we test tensorflow2.0 with python3.7, exclude it based on the python version
8+
tf2onnx>=1.5.5; python_version < '3.7'

0 commit comments

Comments
 (0)