|
| 1 | +""" |
| 2 | +The following code compares the speed of tensorflow against onnxruntime |
| 3 | +with a model downloaded from Tensorflow Hub. |
| 4 | +""" |
| 5 | +import time |
| 6 | +import numpy |
| 7 | +from tqdm import tqdm |
| 8 | +import tensorflow_hub as hub |
| 9 | +import tf2onnx |
| 10 | +import onnxruntime as ort |
| 11 | + |
| 12 | + |
| 13 | +def generate_random_images(shape=(100, 100), n=10): |
| 14 | + imgs = [] |
| 15 | + for i in range(n): |
| 16 | + sh = (1,) + shape + (3,) |
| 17 | + img = numpy.clip(numpy.abs(numpy.random.randn(*sh)), 0, 1) * 255 |
| 18 | + img = img.astype(numpy.float32) |
| 19 | + imgs.append(img) |
| 20 | + return imgs |
| 21 | + |
| 22 | + |
| 23 | +def measure_time(fct, imgs): |
| 24 | + results = [] |
| 25 | + times = [] |
| 26 | + for img in tqdm(imgs): |
| 27 | + begin = time.perf_counter() |
| 28 | + result = fct(img) |
| 29 | + end = time.perf_counter() |
| 30 | + results.append(result) |
| 31 | + times.append(end - begin) |
| 32 | + return results, times |
| 33 | + |
| 34 | + |
| 35 | +imgs = generate_random_images() |
| 36 | + |
| 37 | +# Download model from https://tfhub.dev/captain-pool/esrgan-tf2/1 |
| 38 | +# python -m tf2onnx.convert --saved-model esrgan --output "esrgan-tf2.onnx" --opset 12 |
| 39 | +ort = ort.InferenceSession('esrgan-tf2.onnx') |
| 40 | +fct_ort = lambda img: ort.run(None, {'input_0:0': img}) |
| 41 | +results_ort, duration_ort = measure_time(fct_ort, imgs) |
| 42 | +print(len(imgs), duration_ort) |
| 43 | + |
| 44 | +model = hub.load("https://tfhub.dev/captain-pool/esrgan-tf2/1") |
| 45 | +results_tf, duration_tf = measure_time(model, imgs) |
| 46 | +print(len(imgs), duration_tf) |
| 47 | + |
| 48 | +print("ratio ORT / TF", sum(duration_ort) / sum(duration_tf)) |
0 commit comments