Closed
Description
Describe the bug
Converting a Bert2Bert model from TensorFlow model official, I get the exact same error either using the conversion capacity from pb or from function:
Traceback (most recent call last):
File "C:/dev/ml/TextGenerator/text_generator/models/bert2bert/save_model.py", line 167, in <module>
test()
File "C:/dev/ml/TextGenerator/text_generator/models/bert2bert/save_model.py", line 160, in test
"segment_ids": tf.TensorSpec(shape=(None, 200,), dtype=tf.int32)
File "C:\dev\ml\TextGenerator\venv\lib\site-packages\tf2onnx\convert.py", line 533, in from_function
frozen_graph = tf_loader.from_function(concrete_func, input_names, output_names, large_model=large_model)
File "C:\dev\ml\TextGenerator\venv\lib\site-packages\tf2onnx\tf_loader.py", line 247, in from_function
graph_def = tf_optimize(input_names, output_names, graph_def)
File "C:\dev\ml\TextGenerator\venv\lib\site-packages\tf2onnx\tf_loader.py", line 666, in tf_optimize
graph_def = tf_optimize_grappler(input_names, output_names, graph_def, fold_constant)
File "C:\dev\ml\TextGenerator\venv\lib\site-packages\tf2onnx\tf_loader.py", line 650, in tf_optimize_grappler
graph_def = tf_opt.OptimizeGraph(config, meta_graph)
File "C:\dev\ml\TextGenerator\venv\lib\site-packages\tensorflow\python\grappler\tf_optimizer.py", line 58, in OptimizeGraph
graph_id, strip_default_attributes)
tensorflow.python.framework.errors_impl.InvalidArgumentError: input resource[0] expected type resource != float, the type of bert2_bert_while_decoder_gather_resource_0[0]
In {{node bert2_bert/while/decoder/Gather}}
System information
- OS Platform and Distribution (e.g., Linux Ubuntu 16.04): Windows 10.0.19042
- Tensorflow Version: 2.5.0
- Python version: 3.7.6
To Reproduce
Here is the minimal code to reproduce my issue, it uses the Bert2Bert model from Tensorflow model official.
from official.nlp.nhnet.models import Bert2Bert, get_bert2bert_layers
from official.nlp.nhnet.configs import UNITTEST_CONFIG, BERT2BERTConfig
import tf2onnx
bert2bert_config = BERT2BERTConfig.from_args(**UNITTEST_CONFIG, len_title=32)
bert_layer, decoder_layer = get_bert2bert_layers(params=bert2bert_config)
bert2bert = Bert2Bert(bert2bert_config, bert_layer, decoder_layer)
@tf.function()
def serve(inputs):
return bert2bert(inputs=inputs, mode="predict")
model_proto, _ = tf2onnx.convert.from_function(
function=serve,
opset=14,
input_signature=[{
"input_ids": tf.TensorSpec(shape=(None, 200,), dtype=tf.int32),
"input_mask": tf.TensorSpec(shape=(None, 200,), dtype=tf.int32),
"segment_ids": tf.TensorSpec(shape=(None, 200,), dtype=tf.int32)
}],
)
If it is simpler, I also attach a pb of the Bert2Bert model saved_model.zip, to reproduce the exact same bug, simply run
python -m tf2onnx.convert --saved_model path/to/pb --output path/to/onnx --tag serve --signature_def serve --opset 14
Additional context
However, when using my custom graph freezing method and then the tf2onnx conversion from the resulting frozen graph, it works perfectly fine. Here is how I froze my graph:
import tensorflow as tf
import pathlib
from tensorflow.lite.python.util import run_graph_optimizations, get_grappler_config
from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2
saved_model = tf.saved_model.load(saved_model_dir)
concrete_fn = saved_model.signatures['serve']
concrete_fn.inputs[0].set_shape([batch_size, max_seq_length])
concrete_fn.inputs[1].set_shape([batch_size, max_seq_length])
concrete_fn.inputs[2].set_shape([batch_size, max_seq_length])
frozen_concrete_fn = convert_variables_to_constants_v2(concrete_fn)
frozen_concrete_graph_def = frozen_concrete_fn.graph.as_graph_def()
input_tensors = [
tensor for tensor in frozen_concrete_fn.inputs
if tensor.dtype != tf.resource
]
output_tensors = frozen_concrete_fn.outputs
frozen_concrete_graph_def = run_graph_optimizations(
frozen_concrete_graph_def,
input_tensors,
output_tensors,
config=get_grappler_config(list(grappler_config)),
graph=frozen_concrete_fn.graph
)
output_dir = pathlib.Path(saved_model_dir).parent
frozen_graph_name = f"frozen_graph.bs{batch_size}.sl{max_seq_length}.pb"
tf.io.write_graph(graph_or_graph_def=frozen_concrete_graph_def,
name=frozen_graph_name,
logdir=str(output_dir.absolute()),
as_text=False)
Metadata
Metadata
Assignees
Labels
No labels