Skip to content

Commit 8434c67

Browse files
authored
Change tflite tests from sharkimporter -> sharkdownloader (huggingface#182)
* Change tflite test from sharkimporter -> sharkdownloader * xfail all uint/int tflite sharkdownloader tests
1 parent 79caf72 commit 8434c67

File tree

39 files changed

+592
-645
lines changed

39 files changed

+592
-645
lines changed

shark/shark_downloader.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@
2323
"bool": np.bool_,
2424
"int32": np.int32,
2525
"int64": np.int64,
26+
"uint8": np.uint8,
27+
"int8": np.int8,
2628
}
2729

2830

@@ -32,7 +34,7 @@ def __init__(
3234
model_name: str,
3335
tank_url: str = "https://storage.googleapis.com/shark_tank",
3436
local_tank_dir: str = "./../gen_shark_tank/tflite",
35-
model_type: str = "tflite-tosa",
37+
model_type: str = "tflite",
3638
input_json: str = "input.json",
3739
input_type: str = "int32",
3840
):
@@ -84,7 +86,7 @@ def get_inputs(self):
8486

8587
def load_json_input(self):
8688
print("load json inputs")
87-
if self.model_type in ["tflite-tosa"]:
89+
if self.model_type in ["tflite"]:
8890
input_url = (
8991
self.tank_url + "/" + str(self.model_name) + "/" + "input.json"
9092
)
@@ -109,7 +111,7 @@ def load_json_input(self):
109111
return self.inputs
110112

111113
def load_mlir_model(self):
112-
if self.model_type in ["tflite-tosa"]:
114+
if self.model_type in ["tflite"]:
113115
self.mlir_url = (
114116
self.tank_url
115117
+ "/"

tank/person_detect/person_detect_tflite_test.py renamed to tank/albert_lite_base/albert_lite_base_tflite_importer_test.py

Lines changed: 50 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -4,30 +4,36 @@
44
import pytest
55
import unittest
66
from shark.parser import shark_args
7-
import os
8-
import sys
9-
import urllib.request
10-
from PIL import Image
117
from shark.tflite_utils import TFLitePreprocessor
128

139

14-
# model_path = "https://github.com/tensorflow/tflite-micro/raw/aeac6f39e5c7475cea20c54e86d41e3a38312546/tensorflow/lite/micro/models/person_detect.tflite"
15-
10+
# model_path = "https://tfhub.dev/tensorflow/lite-model/albert_lite_base/squadv1/1?lite-format=tflite"
11+
# model_path = model_path
1612

13+
# Inputs modified to be useful albert inputs.
1714
def generate_inputs(input_details):
18-
exe_basename = os.path.basename(sys.argv[0])
19-
workdir = os.path.join(os.path.dirname(__file__), "../tmp", exe_basename)
20-
os.makedirs(workdir, exist_ok=True)
21-
22-
img_path = "https://github.com/tensorflow/tflite-micro/raw/aeac6f39e5c7475cea20c54e86d41e3a38312546/tensorflow/lite/micro/examples/person_detection/testdata/person.bmp"
23-
local_path = "/".join([workdir, "person.bmp"])
24-
urllib.request.urlretrieve(img_path, local_path)
25-
26-
shape = input_details[0]["shape"]
27-
im = np.array(Image.open(local_path).resize((shape[1], shape[2]))).astype(
28-
input_details[0]["dtype"]
15+
for input in input_details:
16+
print(str(input["shape"]), input["dtype"].__name__)
17+
18+
args = []
19+
args.append(
20+
np.random.randint(
21+
low=0,
22+
high=256,
23+
size=input_details[0]["shape"],
24+
dtype=input_details[0]["dtype"],
25+
)
26+
)
27+
args.append(
28+
np.ones(
29+
shape=input_details[1]["shape"], dtype=input_details[1]["dtype"]
30+
)
31+
)
32+
args.append(
33+
np.zeros(
34+
shape=input_details[2]["shape"], dtype=input_details[2]["dtype"]
35+
)
2936
)
30-
args = [im.reshape(shape)]
3137
return args
3238

3339

@@ -41,12 +47,14 @@ def compare_results(mlir_results, tflite_results, details):
4147
tflite_result = tflite_results[i]
4248
mlir_result = mlir_result.astype(np.single)
4349
tflite_result = tflite_result.astype(np.single)
50+
print("mlir_result.shape", mlir_result.shape)
51+
print("tflite_result.shape", tflite_result.shape)
4452
assert mlir_result.shape == tflite_result.shape, "shape doesnot match"
4553
max_error = np.max(np.abs(mlir_result - tflite_result))
4654
print("Max error (%d): %f", i, max_error)
4755

4856

49-
class PersonDetectionTfliteModuleTester:
57+
class AlbertTfliteModuleTester:
5058
def __init__(
5159
self,
5260
dynamic=False,
@@ -64,25 +72,7 @@ def create_and_check_module(self):
6472
shark_args.save_vmfb = self.save_vmfb
6573

6674
# Preprocess to get SharkImporter input args
67-
# The input has known expected values. We hardcode this value.
68-
input_details = [
69-
{
70-
"shape": [1, 96, 96, 1],
71-
"dtype": np.int8,
72-
"index": 0,
73-
}
74-
]
75-
output_details = [
76-
{
77-
"shape": [1, 2],
78-
"dtype": np.int8,
79-
}
80-
]
81-
tflite_preprocessor = TFLitePreprocessor(
82-
model_name="person_detect",
83-
input_details=input_details,
84-
output_details=output_details,
85-
)
75+
tflite_preprocessor = TFLitePreprocessor(model_name="albert_lite_base")
8676
raw_model_file_path = tflite_preprocessor.get_raw_model_file()
8777
inputs = tflite_preprocessor.get_inputs()
8878
tflite_interpreter = tflite_preprocessor.get_interpreter()
@@ -104,8 +94,20 @@ def create_and_check_module(self):
10494
mlir_dialect="tflite",
10595
)
10696

107-
# Case2: Use manually set inputs
97+
# Case1: Use shark_importer default generate inputs
98+
shark_module.compile()
99+
mlir_results = shark_module.forward(inputs)
100+
## post process results for compare
101+
input_details, output_details = tflite_preprocessor.get_model_details()
102+
mlir_results = list(mlir_results)
103+
for i in range(len(output_details)):
104+
dtype = output_details[i]["dtype"]
105+
mlir_results[i] = mlir_results[i].astype(dtype)
106+
tflite_results = tflite_preprocessor.get_raw_model_output()
107+
compare_results(mlir_results, tflite_results, output_details)
108108

109+
# Case2: Use manually set inputs
110+
input_details, output_details = tflite_preprocessor.get_model_details()
109111
inputs = generate_inputs(input_details) # new inputs
110112

111113
shark_module = SharkInference(
@@ -117,31 +119,34 @@ def create_and_check_module(self):
117119
shark_module.compile()
118120
mlir_results = shark_module.forward(inputs)
119121
## post process results for compare
120-
# The input has known expected values. We hardcode this value.
121-
tflite_results = [np.array([[-113, 113]], dtype=np.int8)]
122+
tflite_results = tflite_preprocessor.get_raw_model_output()
122123
compare_results(mlir_results, tflite_results, output_details)
123124
# print(mlir_results)
124125

125126

126-
class PersonDetectionTfliteModuleTest(unittest.TestCase):
127+
class AlbertTfliteModuleTest(unittest.TestCase):
127128
@pytest.fixture(autouse=True)
128129
def configure(self, pytestconfig):
129130
self.save_mlir = pytestconfig.getoption("save_mlir")
130131
self.save_vmfb = pytestconfig.getoption("save_vmfb")
131132

132133
def setUp(self):
133-
self.module_tester = PersonDetectionTfliteModuleTester(self)
134+
self.module_tester = AlbertTfliteModuleTester(self)
134135
self.module_tester.save_mlir = self.save_mlir
135136

136-
@pytest.mark.skip(reason="TFLite is broken with this model")
137+
import sys
138+
139+
@pytest.mark.xfail(
140+
sys.platform == "darwin", reason="known macos tflite install issue"
141+
)
137142
def test_module_static_cpu(self):
138143
self.module_tester.dynamic = False
139144
self.module_tester.device = "cpu"
140145
self.module_tester.create_and_check_module()
141146

142147

143148
if __name__ == "__main__":
144-
# module_tester = PersonDetectionTfliteModuleTester()
149+
# module_tester = AlbertTfliteModuleTester()
145150
# module_tester.save_mlir = True
146151
# module_tester.save_vmfb = True
147152
# module_tester.create_and_check_module()

tank/albert_lite_base/albert_lite_base_tflite_mlir_test.py

Lines changed: 0 additions & 90 deletions
This file was deleted.

tank/albert_lite_base/albert_lite_base_tflite_test.py

Lines changed: 32 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import numpy as np
2-
from shark.shark_importer import SharkImporter
2+
from shark.shark_downloader import SharkDownloader
33
from shark.shark_inference import SharkInference
44
import pytest
55
import unittest
@@ -70,58 +70,26 @@ def __init__(
7070
def create_and_check_module(self):
7171
shark_args.save_mlir = self.save_mlir
7272
shark_args.save_vmfb = self.save_vmfb
73-
74-
# Preprocess to get SharkImporter input args
75-
tflite_preprocessor = TFLitePreprocessor(model_name="albert_lite_base")
76-
raw_model_file_path = tflite_preprocessor.get_raw_model_file()
77-
inputs = tflite_preprocessor.get_inputs()
78-
tflite_interpreter = tflite_preprocessor.get_interpreter()
79-
80-
# Use SharkImporter to get SharkInference input args
81-
my_shark_importer = SharkImporter(
82-
module=tflite_interpreter,
83-
inputs=inputs,
84-
frontend="tflite",
85-
raw_model_file=raw_model_file_path,
86-
)
87-
mlir_model, func_name = my_shark_importer.import_mlir()
88-
89-
# Use SharkInference to get inference result
90-
shark_module = SharkInference(
91-
mlir_module=mlir_model,
92-
function_name=func_name,
93-
device=self.device,
94-
mlir_dialect="tflite",
73+
shark_downloader = SharkDownloader(
74+
model_name="albert_lite_base",
75+
tank_url="https://storage.googleapis.com/shark_tank",
76+
local_tank_dir="./../gen_shark_tank",
77+
model_type="tflite",
78+
input_json="input.json",
79+
input_type="int32",
9580
)
96-
97-
# Case1: Use shark_importer default generate inputs
98-
shark_module.compile()
99-
mlir_results = shark_module.forward(inputs)
100-
## post process results for compare
101-
input_details, output_details = tflite_preprocessor.get_model_details()
102-
mlir_results = list(mlir_results)
103-
for i in range(len(output_details)):
104-
dtype = output_details[i]["dtype"]
105-
mlir_results[i] = mlir_results[i].astype(dtype)
106-
tflite_results = tflite_preprocessor.get_raw_model_output()
107-
compare_results(mlir_results, tflite_results, output_details)
108-
109-
# Case2: Use manually set inputs
110-
input_details, output_details = tflite_preprocessor.get_model_details()
111-
inputs = generate_inputs(input_details) # new inputs
81+
tflite_tosa_model = shark_downloader.get_mlir_file()
82+
inputs = shark_downloader.get_inputs()
11283

11384
shark_module = SharkInference(
114-
mlir_module=mlir_model,
115-
function_name=func_name,
85+
mlir_module=tflite_tosa_model,
86+
function_name="main",
11687
device=self.device,
11788
mlir_dialect="tflite",
11889
)
11990
shark_module.compile()
120-
mlir_results = shark_module.forward(inputs)
121-
## post process results for compare
122-
tflite_results = tflite_preprocessor.get_raw_model_output()
123-
compare_results(mlir_results, tflite_results, output_details)
124-
# print(mlir_results)
91+
shark_module.forward(inputs)
92+
# print(shark_results)
12593

12694

12795
class AlbertTfliteModuleTest(unittest.TestCase):
@@ -146,9 +114,24 @@ def test_module_static_cpu(self):
146114

147115

148116
if __name__ == "__main__":
117+
unittest.main()
149118
# module_tester = AlbertTfliteModuleTester()
150-
# module_tester.save_mlir = True
151-
# module_tester.save_vmfb = True
152119
# module_tester.create_and_check_module()
153120

154-
unittest.main()
121+
# TEST RESULT:
122+
# (shark.venv) nod% python albert_lite_base_tflite_mlir_test.py
123+
# load json inputs
124+
# TMP_MODEL_DIR = shark/SHARK/shark/./../gen_shark_tank/tflite
125+
# Model has not been download.shark_downloader will automatically download by tank_url if provided. You can also manually to download the model from shark_tank by yourself.
126+
# TMP_MODELNAME_DIR = shark/SHARK/shark/./../gen_shark_tank/tflite/albert_lite_base
127+
# Download mlir model https://storage.googleapis.com/shark_tank/tflite/albert_lite_base/albert_lite_base_tosa.mlir
128+
# Get tosa.mlir model return
129+
# Target triple found:x86_64-linux-gnu
130+
# (shark.venv) nod% python albert_lite_base_tflite_mlir_test.py
131+
# load json inputs
132+
# TMP_MODEL_DIR = shark/SHARK/shark/./../gen_shark_tank/tflite
133+
# TMP_MODELNAME_DIR = shark/SHARK/shark/./../gen_shark_tank/tflite/albert_lite_base
134+
# Model has been downloaded before. shark/SHARK/shark/./../gen_shark_tank/tflite/albert_lite_base/albert_lite_base_tosa.mlir
135+
# Get tosa.mlir model return
136+
# Target triple found:x86_64-linux-gnu
137+
#

0 commit comments

Comments
 (0)