@@ -40,9 +40,12 @@ def __init__(
40
40
self .local_tank_dir = local_tank_dir
41
41
self .tank_url = tank_url
42
42
self .model_type = model_type
43
- self .input_json = input_json
44
- self .input_type = input_type_to_np_dtype [input_type ]
43
+ self .input_json = input_json # optional if you don't have input
44
+ self .input_type = input_type_to_np_dtype [
45
+ input_type
46
+ ] # optional if you don't have input
45
47
self .mlir_file = None # .mlir file local address.
48
+ self .mlir_url = None
46
49
self .inputs = None # Input has to be (list of np.array) for sharkInference.forward use
47
50
self .mlir_model = []
48
51
@@ -73,51 +76,78 @@ def load_json_input(self):
73
76
]
74
77
else :
75
78
print (
76
- "No json input required for current model. You could call setup_inputs(you_inputs )."
79
+ "No json input required for current model type . You could call setup_inputs(YOU_INPUTS )."
77
80
)
78
81
return self .inputs
79
82
80
83
def load_mlir_model (self ):
81
- if self .model_type in ["tflite-tosa" ]:
82
- workdir = os .path .join (
83
- os .path .dirname (__file__ ), self .local_tank_dir
84
+ workdir = os .path .join (os .path .dirname (__file__ ), self .local_tank_dir )
85
+ os .makedirs (workdir , exist_ok = True )
86
+ print (f"TMP_MODEL_DIR = { workdir } " )
87
+ # use model name get dir.
88
+ model_name_dir = os .path .join (workdir , str (self .model_name ))
89
+ if not os .path .exists (model_name_dir ):
90
+ print (
91
+ "Model has not been download."
92
+ "shark_downloader will automatically download by tank_url if provided."
93
+ " You can also manually to download the model from shark_tank by yourself."
84
94
)
85
- os .makedirs (workdir , exist_ok = True )
86
- print (f"TMP_MODEL_DIR = { workdir } " )
95
+ os .makedirs (model_name_dir , exist_ok = True )
96
+ print (f"TMP_MODELNAME_DIR = { model_name_dir } " )
87
97
88
- # use model name get dir.
89
- model_name_dir = os .path .join (workdir , str (self .model_name ))
90
- if not os .path .exists (model_name_dir ):
91
- print (
92
- "Model has not been download."
93
- "shark_downloader will automatically download by tank_url if provided."
94
- " You can also manually to download the model from shark_tank by yourself."
95
- )
96
- os .makedirs (model_name_dir , exist_ok = True )
97
- print (f"TMP_MODELNAME_DIR = { model_name_dir } " )
98
-
99
- mlir_url = (
98
+ if self .model_type in ["tflite-tosa" ]:
99
+ self .mlir_url = (
100
100
self .tank_url
101
- + "/tflite/ "
101
+ + "/"
102
102
+ str (self .model_name )
103
103
+ "/"
104
104
+ str (self .model_name )
105
- + "_tosa .mlir"
105
+ + "_tflite .mlir"
106
106
)
107
107
self .mlir_file = "/" .join (
108
- [model_name_dir , str (self .model_name ) + "_tosa.mlir" ]
108
+ [model_name_dir , str (self .model_name ) + "_tfite.mlir" ]
109
+ )
110
+ elif self .model_type in ["tensorflow" ]:
111
+ self .mlir_url = (
112
+ self .tank_url
113
+ + "/"
114
+ + str (self .model_name )
115
+ + "/"
116
+ + str (self .model_name )
117
+ + "_tf.mlir"
118
+ )
119
+ self .mlir_file = "/" .join (
120
+ [model_name_dir , str (self .model_name ) + "_tf.mlir" ]
121
+ )
122
+ elif self .model_type in ["torch" , "jax" , "mhlo" , "tosa" ]:
123
+ self .mlir_url = (
124
+ self .tank_url
125
+ + "/"
126
+ + str (self .model_name )
127
+ + "/"
128
+ + str (self .model_name )
129
+ + "_"
130
+ + str (self .model_type )
131
+ + ".mlir"
132
+ )
133
+ self .mlir_file = "/" .join (
134
+ [
135
+ model_name_dir ,
136
+ str (self .model_name ) + "_" + str (self .model_type ) + ".mlir" ,
137
+ ]
109
138
)
110
- if os .path .exists (self .mlir_file ):
111
- print ("Model has been downloaded before." , self .mlir_file )
112
- else :
113
- print ("Download mlir model" , mlir_url )
114
- urllib .request .urlretrieve (mlir_url , self .mlir_file )
115
-
116
- print ("Get tosa.mlir model return" )
117
- with open (self .mlir_file ) as f :
118
- self .mlir_model = f .read ()
119
139
else :
120
140
print ("Unsupported mlir model" )
141
+
142
+ if os .path .exists (self .mlir_file ):
143
+ print ("Model has been downloaded before." , self .mlir_file )
144
+ else :
145
+ print ("Download mlir model" , self .mlir_url )
146
+ urllib .request .urlretrieve (self .mlir_url , self .mlir_file )
147
+
148
+ print ("Get .mlir model return" )
149
+ with open (self .mlir_file ) as f :
150
+ self .mlir_model = f .read ()
121
151
return self .mlir_model
122
152
123
153
def setup_inputs (self , inputs ):
0 commit comments