-
Notifications
You must be signed in to change notification settings - Fork 132
Description
Converting a Transfromer LM from TensorFlow to Onnx and then running the onnx model results in the error
onnxruntime.capi.onnxruntime_pybind11_state.InvalidGraph: [ONNXRuntimeError] : 10 : INVALID_GRAPH : Load model from /home/agerstenberger/projects/asr/quantization/onnx/lm/sis-setups/work/apptek_asr/onnx/convert/ConvertTFCheckpointToOnnxJob.41TMCrq7LXoD/output/model.onnx failed:This is an invalid model.
Type Error: Type 'tensor(bool)' of input parameter (Cast__183:0) of operator (Mul) in node (output/rec/dec_0_self_att_att/MatrixBandPart) is invalid.
using the latest tensorflow
and onnx*
packages.
One fix could be replacing in the SelfAttentionLayer
returnn/returnn/tf/layers/rec.py
Lines 7997 to 8003 in dbef0ca
from returnn.tf.util.basic import matrix_triangular | |
# (1,1,num_queries,num_keys) | |
energy_mask = matrix_triangular((1, 1, num_queries, num_queries), dtype=tf.bool, lower=True) | |
if num_queries is not num_keys: | |
energy_mask_left = tf.ones((1, 1, num_queries, num_keys - num_queries), dtype=tf.bool) | |
energy_mask = tf.concat([energy_mask_left, energy_mask], axis=-1) |
is_onnx_export_global
with
...
mask_dtype = tf.bool if not util.is_onnx_export_global() else tf.float32
energy_mask = matrix_triangular((1, 1, num_queries, num_queries), dtype=mask_dtype, lower=True)
if num_queries is not num_keys:
energy_mask_left = tf.ones((1, 1, num_queries, num_keys - num_queries), dtype=mask_dtype)
energy_mask = tf.concat([energy_mask_left, energy_mask], axis=-1)
if util.is_onnx_export_global():
energy_mask = tf.cast(energy_mask, dtype=tf.bool)
...
This used to work with an older version of tf2onnx
. However, the most recent tf2onnx
further yields the error during conversion
Tensorflow op [output/rec/dec_0_self_att_att/ones_like: OnesLike] is not supported
and then during runtime
onnxruntime.capi.onnxruntime_pybind11_state.InvalidGraph: [ONNXRuntimeError] : 10 : INVALID_GRAPH : Load model from /home/agerstenberger/projects/asr/quantization/onnx/lm/sis-setups/work/apptek_asr/onnx/convert/ConvertTFCheckpointToOnnxJob.41TMCrq7LXoD/output/model.onnx failed:This is an invalid model.
In Node, ("output/rec/dec_0_self_att_att/ones_like", OnesLike, "", -1) : ("output/rec/dec_0_self_att_att/energy:0": tensor(float),) -> ("output/rec/dec_0_self_att_att/ones_like:0",) , Error No Op registered for OnesLike with domain_version of 15
which is not converted but just copied from the original graph, the default behaviour of tf2onnx
for unknown ops (with input tensor output/rec/dec_0_self_att_att/energy:0
).
Currently i'm not sure when that changed, but i can review that.
Kind of sounds like a bug to me (have to try out older versions again).
If thats the case, should we wait for a fix? Or provide a workaround?
For a workaround i see two approaches:
- replace occurences of
tf.ones_like
in theSelfAttentionLayer
withtf.fill
ops. TensorFlow actually callstf.fill
intf.ones
, which is called bytf.ones_like
. I'm not certain that this is a good idea. - provide a custom conversion handler via
tf_op
decorator (see here) and plug it inside the conversion job. This would also take care of other appearences oftf.ones_like
. No need to adaptreturnn
further.
What are your thoughts?