Skip to content

Converting self attention models from TensorFlow to Onnx #1455

@Gerstenberger

Description

@Gerstenberger

Converting a Transfromer LM from TensorFlow to Onnx and then running the onnx model results in the error

onnxruntime.capi.onnxruntime_pybind11_state.InvalidGraph: [ONNXRuntimeError] : 10 : INVALID_GRAPH : Load model from /home/agerstenberger/projects/asr/quantization/onnx/lm/sis-setups/work/apptek_asr/onnx/convert/ConvertTFCheckpointToOnnxJob.41TMCrq7LXoD/output/model.onnx failed:This is an invalid model. 
Type Error: Type 'tensor(bool)' of input parameter (Cast__183:0) of operator (Mul) in node (output/rec/dec_0_self_att_att/MatrixBandPart) is invalid.

using the latest tensorflow and onnx* packages.

One fix could be replacing in the SelfAttentionLayer

from returnn.tf.util.basic import matrix_triangular
# (1,1,num_queries,num_keys)
energy_mask = matrix_triangular((1, 1, num_queries, num_queries), dtype=tf.bool, lower=True)
if num_queries is not num_keys:
energy_mask_left = tf.ones((1, 1, num_queries, num_keys - num_queries), dtype=tf.bool)
energy_mask = tf.concat([energy_mask_left, energy_mask], axis=-1)
using is_onnx_export_global with

...
mask_dtype = tf.bool if not util.is_onnx_export_global() else tf.float32
energy_mask = matrix_triangular((1, 1, num_queries, num_queries), dtype=mask_dtype, lower=True)
if num_queries is not num_keys:
    energy_mask_left = tf.ones((1, 1, num_queries, num_keys - num_queries), dtype=mask_dtype)
    energy_mask = tf.concat([energy_mask_left, energy_mask], axis=-1)
if util.is_onnx_export_global():                                                                                                                                                        
    energy_mask = tf.cast(energy_mask, dtype=tf.bool)
...

This used to work with an older version of tf2onnx. However, the most recent tf2onnx further yields the error during conversion

Tensorflow op [output/rec/dec_0_self_att_att/ones_like: OnesLike] is not supported

and then during runtime

onnxruntime.capi.onnxruntime_pybind11_state.InvalidGraph: [ONNXRuntimeError] : 10 : INVALID_GRAPH : Load model from /home/agerstenberger/projects/asr/quantization/onnx/lm/sis-setups/work/apptek_asr/onnx/convert/ConvertTFCheckpointToOnnxJob.41TMCrq7LXoD/output/model.onnx failed:This is an invalid model. 
In Node, ("output/rec/dec_0_self_att_att/ones_like", OnesLike, "", -1) : ("output/rec/dec_0_self_att_att/energy:0": tensor(float),) -> ("output/rec/dec_0_self_att_att/ones_like:0",) , Error No Op registered for OnesLike with domain_version of 15

which is not converted but just copied from the original graph, the default behaviour of tf2onnx for unknown ops (with input tensor output/rec/dec_0_self_att_att/energy:0).
Currently i'm not sure when that changed, but i can review that.

Kind of sounds like a bug to me (have to try out older versions again).
If thats the case, should we wait for a fix? Or provide a workaround?

For a workaround i see two approaches:

  1. replace occurences of tf.ones_like in the SelfAttentionLayer with tf.fill ops. TensorFlow actually calls tf.fill in tf.ones, which is called by tf.ones_like. I'm not certain that this is a good idea.
  2. provide a custom conversion handler via tf_op decorator (see here) and plug it inside the conversion job. This would also take care of other appearences of tf.ones_like. No need to adapt returnn further.

What are your thoughts?

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions