Skip to content

Fix issue with model esrgan-tf2_1 #1098

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 15 commits into from
Sep 16, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 47 additions & 0 deletions examples/benchmark_tfmodel_ort.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
"""
The following code compares the speed of tensorflow against onnxruntime
with a model downloaded from Tensorflow Hub.
"""
import time
import numpy
from tqdm import tqdm
import tensorflow_hub as hub
import onnxruntime as ort


def generate_random_images(shape=(100, 100), n=10):
imgs = []
for i in range(n):
sh = (1,) + shape + (3,)
img = numpy.clip(numpy.abs(numpy.random.randn(*sh)), 0, 1) * 255
img = img.astype(numpy.float32)
imgs.append(img)
return imgs


def measure_time(fct, imgs):
results = []
times = []
for img in tqdm(imgs):
begin = time.perf_counter()
result = fct(img)
end = time.perf_counter()
results.append(result)
times.append(end - begin)
return results, times


imgs = generate_random_images()

# Download model from https://tfhub.dev/captain-pool/esrgan-tf2/1
# python -m tf2onnx.convert --saved-model esrgan --output "esrgan-tf2.onnx" --opset 12
ort = ort.InferenceSession('esrgan-tf2.onnx')
fct_ort = lambda img: ort.run(None, {'input_0:0': img})
results_ort, duration_ort = measure_time(fct_ort, imgs)
print(len(imgs), duration_ort)

model = hub.load("https://tfhub.dev/captain-pool/esrgan-tf2/1")
results_tf, duration_tf = measure_time(model, imgs)
print(len(imgs), duration_tf)

print("ratio ORT / TF", sum(duration_ort) / sum(duration_tf))
8 changes: 7 additions & 1 deletion tests/run_pretrained_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,13 @@ def run_test(self, name, backend="caffe2", onnx_file=None, opset=None, extra_ops
if self.model_type in ["checkpoint"]:
graph_def, input_names, outputs = tf_loader.from_checkpoint(model_path, input_names, outputs)
elif self.model_type in ["saved_model"]:
graph_def, input_names, outputs = tf_loader.from_saved_model(model_path, input_names, outputs, self.tag)
try:
res = tf_loader.from_saved_model(model_path, input_names, outputs, self.tag)
except OSError:
model_path = dir_name
logger.info("Load model(2) from %r", model_path)
res = tf_loader.from_saved_model(model_path, input_names, outputs, self.tag)
graph_def, input_names, outputs = res[:3]
elif self.model_type in ["keras"]:
graph_def, input_names, outputs = tf_loader.from_keras(model_path, input_names, outputs)
else:
Expand Down
18 changes: 18 additions & 0 deletions tests/run_pretrained_models.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,24 @@ benchtf-gru:
##
## standard image nets
##

esrgan-tf2:
# url: https://tfhub.dev/captain-pool/esrgan-tf2/1/esrgan-tf2_1.tar.gz
url: https://github.com/captain-pool/GSOC/releases/download/1.0.0/esrgan.tar.gz
model: ersgan
model_type: saved_model
input_get: get_beach
opset_constraints:
"onnx":
"min": 10
inputs:
"input_0:0": [1, 50, 50, 3]
outputs:
- Identity:0
rtol: 0.02
atol: 0.0005
tf_min_version: 2.1

inception_v3_slim:
url: https://storage.googleapis.com/download.tensorflow.org/models/inception_v3_2016_08_28_frozen.pb.tar.gz
model: inception_v3_2016_08_28_frozen.pb
Expand Down
5 changes: 4 additions & 1 deletion tf2onnx/optimizer/transpose_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -648,7 +648,10 @@ def _reducemean_handler(self, trans, node):
def _slice_handler(self, trans, node):
axes = None
if self._g.opset < 10:
axes = node.get_attr("axes").ints
axes_values = node.get_attr("axes")
if not axes_values:
return False
axes = axes_values.ints
if axes == [0, 1, 2, 3]:
node.set_attr("axes", NCHW_TO_NHWC)
return self._switch_transpose_and_node(node, trans)
Expand Down
11 changes: 10 additions & 1 deletion tf2onnx/tf_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,16 @@ def from_graphdef(model_path, input_names, output_names):
with tf_session() as sess:
graph_def = tf_graphdef()
with tf_gfile.GFile(model_path, 'rb') as f:
graph_def.ParseFromString(f.read())
try:
content = f.read()
except Exception as e:
raise OSError(
"Unable to load file '{}'.".format(model_path)) from e
try:
graph_def.ParseFromString(content)
except Exception as e:
raise RuntimeError(
"Unable to parse file '{}'.".format(model_path)) from e
tf.import_graph_def(graph_def, name='')
input_names = inputs_without_resource(sess, input_names)
frozen_graph = freeze_session(sess, input_names=input_names, output_names=output_names)
Expand Down