Skip to content

Commit a635b6f

Browse files
authored
Add tf/torch/mhlo/tosa support for SharkDownloader (huggingface#151)
1 parent e8aa105 commit a635b6f

File tree

2 files changed

+63
-33
lines changed

2 files changed

+63
-33
lines changed

shark/shark_downloader.py

Lines changed: 62 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -40,9 +40,12 @@ def __init__(
4040
self.local_tank_dir = local_tank_dir
4141
self.tank_url = tank_url
4242
self.model_type = model_type
43-
self.input_json = input_json
44-
self.input_type = input_type_to_np_dtype[input_type]
43+
self.input_json = input_json # optional if you don't have input
44+
self.input_type = input_type_to_np_dtype[
45+
input_type
46+
] # optional if you don't have input
4547
self.mlir_file = None # .mlir file local address.
48+
self.mlir_url = None
4649
self.inputs = None # Input has to be (list of np.array) for sharkInference.forward use
4750
self.mlir_model = []
4851

@@ -73,51 +76,78 @@ def load_json_input(self):
7376
]
7477
else:
7578
print(
76-
"No json input required for current model. You could call setup_inputs(you_inputs)."
79+
"No json input required for current model type. You could call setup_inputs(YOU_INPUTS)."
7780
)
7881
return self.inputs
7982

8083
def load_mlir_model(self):
81-
if self.model_type in ["tflite-tosa"]:
82-
workdir = os.path.join(
83-
os.path.dirname(__file__), self.local_tank_dir
84+
workdir = os.path.join(os.path.dirname(__file__), self.local_tank_dir)
85+
os.makedirs(workdir, exist_ok=True)
86+
print(f"TMP_MODEL_DIR = {workdir}")
87+
# use model name get dir.
88+
model_name_dir = os.path.join(workdir, str(self.model_name))
89+
if not os.path.exists(model_name_dir):
90+
print(
91+
"Model has not been download."
92+
"shark_downloader will automatically download by tank_url if provided."
93+
" You can also manually to download the model from shark_tank by yourself."
8494
)
85-
os.makedirs(workdir, exist_ok=True)
86-
print(f"TMP_MODEL_DIR = {workdir}")
95+
os.makedirs(model_name_dir, exist_ok=True)
96+
print(f"TMP_MODELNAME_DIR = {model_name_dir}")
8797

88-
# use model name get dir.
89-
model_name_dir = os.path.join(workdir, str(self.model_name))
90-
if not os.path.exists(model_name_dir):
91-
print(
92-
"Model has not been download."
93-
"shark_downloader will automatically download by tank_url if provided."
94-
" You can also manually to download the model from shark_tank by yourself."
95-
)
96-
os.makedirs(model_name_dir, exist_ok=True)
97-
print(f"TMP_MODELNAME_DIR = {model_name_dir}")
98-
99-
mlir_url = (
98+
if self.model_type in ["tflite-tosa"]:
99+
self.mlir_url = (
100100
self.tank_url
101-
+ "/tflite/"
101+
+ "/"
102102
+ str(self.model_name)
103103
+ "/"
104104
+ str(self.model_name)
105-
+ "_tosa.mlir"
105+
+ "_tflite.mlir"
106106
)
107107
self.mlir_file = "/".join(
108-
[model_name_dir, str(self.model_name) + "_tosa.mlir"]
108+
[model_name_dir, str(self.model_name) + "_tfite.mlir"]
109+
)
110+
elif self.model_type in ["tensorflow"]:
111+
self.mlir_url = (
112+
self.tank_url
113+
+ "/"
114+
+ str(self.model_name)
115+
+ "/"
116+
+ str(self.model_name)
117+
+ "_tf.mlir"
118+
)
119+
self.mlir_file = "/".join(
120+
[model_name_dir, str(self.model_name) + "_tf.mlir"]
121+
)
122+
elif self.model_type in ["torch", "jax", "mhlo", "tosa"]:
123+
self.mlir_url = (
124+
self.tank_url
125+
+ "/"
126+
+ str(self.model_name)
127+
+ "/"
128+
+ str(self.model_name)
129+
+ "_"
130+
+ str(self.model_type)
131+
+ ".mlir"
132+
)
133+
self.mlir_file = "/".join(
134+
[
135+
model_name_dir,
136+
str(self.model_name) + "_" + str(self.model_type) + ".mlir",
137+
]
109138
)
110-
if os.path.exists(self.mlir_file):
111-
print("Model has been downloaded before.", self.mlir_file)
112-
else:
113-
print("Download mlir model", mlir_url)
114-
urllib.request.urlretrieve(mlir_url, self.mlir_file)
115-
116-
print("Get tosa.mlir model return")
117-
with open(self.mlir_file) as f:
118-
self.mlir_model = f.read()
119139
else:
120140
print("Unsupported mlir model")
141+
142+
if os.path.exists(self.mlir_file):
143+
print("Model has been downloaded before.", self.mlir_file)
144+
else:
145+
print("Download mlir model", self.mlir_url)
146+
urllib.request.urlretrieve(self.mlir_url, self.mlir_file)
147+
148+
print("Get .mlir model return")
149+
with open(self.mlir_file) as f:
150+
self.mlir_model = f.read()
121151
return self.mlir_model
122152

123153
def setup_inputs(self, inputs):

tank/tflite/albert_lite_base/albert_lite_base_tflite_mlir_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ def create_and_check_module(self):
2424
self.shark_downloader = SharkDownloader(
2525
model_name="albert_lite_base",
2626
tank_url="https://storage.googleapis.com/shark_tank",
27-
local_tank_dir="./../gen_shark_tank/tflite",
27+
local_tank_dir="./../gen_shark_tank",
2828
model_type="tflite-tosa",
2929
input_json="input.json",
3030
input_type="int32",

0 commit comments

Comments
 (0)