Skip to content

Commit 6131c95

Browse files
committed
fix links, add script to benchmark
1 parent 189bcf0 commit 6131c95

File tree

2 files changed

+51
-2
lines changed

2 files changed

+51
-2
lines changed

examples/benchmark_tfmodel_ort.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
"""
2+
The following code compares the speed of tensorflow against onnxruntime
3+
with a model downloaded from Tensorflow Hub.
4+
"""
5+
import time
6+
import numpy
7+
from tqdm import tqdm
8+
import tensorflow_hub as hub
9+
import tf2onnx
10+
import onnxruntime as ort
11+
12+
13+
def generate_random_images(shape=(100, 100), n=10):
14+
imgs = []
15+
for i in range(n):
16+
sh = (1,) + shape + (3,)
17+
img = numpy.clip(numpy.abs(numpy.random.randn(*sh)), 0, 1) * 255
18+
img = img.astype(numpy.float32)
19+
imgs.append(img)
20+
return imgs
21+
22+
23+
def measure_time(fct, imgs):
24+
results = []
25+
times = []
26+
for img in tqdm(imgs):
27+
begin = time.perf_counter()
28+
result = fct(img)
29+
end = time.perf_counter()
30+
results.append(result)
31+
times.append(end - begin)
32+
return results, times
33+
34+
35+
imgs = generate_random_images()
36+
37+
# Download model from https://tfhub.dev/captain-pool/esrgan-tf2/1
38+
# python -m tf2onnx.convert --saved-model esrgan --output "esrgan-tf2.onnx" --opset 12
39+
ort = ort.InferenceSession('esrgan-tf2.onnx')
40+
fct_ort = lambda img: ort.run(None, {'input_0:0': img})
41+
results_ort, duration_ort = measure_time(fct_ort, imgs)
42+
print(len(imgs), duration_ort)
43+
44+
model = hub.load("https://tfhub.dev/captain-pool/esrgan-tf2/1")
45+
results_tf, duration_tf = measure_time(model, imgs)
46+
print(len(imgs), duration_tf)
47+
48+
print("ratio ORT / TF", sum(duration_ort) / sum(duration_tf))

tests/run_pretrained_models.yaml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -117,11 +117,12 @@ benchtf-gru:
117117
##
118118

119119
esrgan-tf2:
120-
url: https://tfhub.dev/captain-pool/esrgan-tf2/1/esrgan-tf2_1.tar.gz
120+
# url: https://tfhub.dev/captain-pool/esrgan-tf2/1/esrgan-tf2_1.tar.gz
121+
url: https://github.com/captain-pool/GSOC/releases/download/1.0.0/esrgan.tar.gz
121122
model: fixme
122123
input_get: get_beach
123124
inputs:
124-
"input_0:0": [1, 416, 416, 3]
125+
"input_0:0": [1, 50, 50, 3]
125126
outputs:
126127
- Identity:0
127128

0 commit comments

Comments
 (0)