44
44
"ctx" : None ,
45
45
"stub" : None ,
46
46
"api" : None ,
47
- "metadata" : None ,
47
+ "signature_key" : None ,
48
+ "parsed_signature" : None ,
49
+ "model_metadata" : None ,
48
50
"request_handler" : None ,
49
51
"class_set" : set (),
50
52
}
@@ -115,8 +117,8 @@ def after_request(response):
115
117
116
118
117
119
def create_prediction_request (sample ):
118
- signature_def = local_cache ["metadata " ]["signatureDef" ]
119
- signature_key = local_cache ["api" ][ "tf_serving" ][ " signature_key" ]
120
+ signature_def = local_cache ["model_metadata " ]["signatureDef" ]
121
+ signature_key = local_cache ["signature_key" ]
120
122
prediction_request = predict_pb2 .PredictRequest ()
121
123
prediction_request .model_spec .name = "model"
122
124
prediction_request .model_spec .signature_name = signature_key
@@ -179,7 +181,7 @@ def run_predict(sample, debug=False):
179
181
if request_handler is not None and util .has_function (request_handler , "pre_inference" ):
180
182
try :
181
183
prepared_sample = request_handler .pre_inference (
182
- sample , local_cache ["metadata " ]["signatureDef" ]
184
+ sample , local_cache ["model_metadata " ]["signatureDef" ]
183
185
)
184
186
debug_obj ("pre_inference" , prepared_sample , debug )
185
187
except Exception as e :
@@ -196,7 +198,9 @@ def run_predict(sample, debug=False):
196
198
197
199
if request_handler is not None and util .has_function (request_handler , "post_inference" ):
198
200
try :
199
- result = request_handler .post_inference (result , local_cache ["metadata" ]["signatureDef" ])
201
+ result = request_handler .post_inference (
202
+ result , local_cache ["model_metadata" ]["signatureDef" ]
203
+ )
200
204
debug_obj ("post_inference" , result , debug )
201
205
except Exception as e :
202
206
raise UserRuntimeException (
@@ -207,9 +211,7 @@ def run_predict(sample, debug=False):
207
211
208
212
209
213
def validate_sample (sample ):
210
- signature = extract_signature (
211
- local_cache ["metadata" ]["signatureDef" ], local_cache ["api" ]["tf_serving" ]["signature_key" ]
212
- )
214
+ signature = local_cache ["parsed_signature" ]
213
215
for input_name , _ in signature .items ():
214
216
if input_name not in sample :
215
217
raise UserException ('missing key "{}"' .format (input_name ))
@@ -252,38 +254,56 @@ def predict(deployment_name, api_name):
252
254
253
255
254
256
def extract_signature (signature_def , signature_key ):
255
- if (
256
- signature_def .get (signature_key ) is None
257
- or signature_def [signature_key ].get ("inputs" ) is None
258
- ):
259
- raise UserException (
260
- 'unable to find "' + signature_key + "\" in model's signature definition"
261
- )
262
-
263
- metadata = {}
264
- for input_name , input_metadata in signature_def [signature_key ]["inputs" ].items ():
265
- metadata [input_name ] = {
257
+ logger .info ("signature defs found in model: {}" .format (signature_def ))
258
+
259
+ available_keys = list (signature_def .keys ())
260
+ if len (available_keys ) == 0 :
261
+ raise UserException ("unable to find signature defs in model" )
262
+
263
+ if signature_key is None :
264
+ if len (available_keys ) == 1 :
265
+ logger .info (
266
+ "signature_key was not configured by user, using signature key '{}' found in signature def map" .format (
267
+ available_keys [0 ]
268
+ )
269
+ )
270
+ signature_key = available_keys [0 ]
271
+ else :
272
+ raise UserException (
273
+ "signature_key was not configured by user, please specify one the following keys '{}' found in signature def map" .format (
274
+ "', '" .join (available_keys )
275
+ )
276
+ )
277
+ else :
278
+ if signature_def .get (signature_key ) is None :
279
+ possibilities_str = "key: '{}'" .format (available_keys [0 ])
280
+ if len (available_keys ) > 1 :
281
+ possibilities_str = "keys: '{}'" .format ("', '" .join (available_keys ))
282
+
283
+ raise UserException (
284
+ "signature_key '{}' was not found in signature def map, but found the following {}" .format (
285
+ signature_key , possibilities_str
286
+ )
287
+ )
288
+
289
+ signature_def_val = signature_def .get (signature_key )
290
+
291
+ if signature_def_val .get ("inputs" ) is None :
292
+ raise UserException ("unable to find 'inputs' in signature def '{}'" .format (signature_key ))
293
+
294
+ parsed_signature = {}
295
+ for input_name , input_metadata in signature_def_val ["inputs" ].items ():
296
+ parsed_signature [input_name ] = {
266
297
"shape" : [int (dim ["size" ]) for dim in input_metadata ["tensorShape" ]["dim" ]],
267
298
"type" : DTYPE_TO_TF_TYPE [input_metadata ["dtype" ]].name ,
268
299
}
269
- return metadata
300
+ return signature_key , parsed_signature
270
301
271
302
272
303
@app .route ("/<app_name>/<api_name>/signature" , methods = ["GET" ])
273
304
def get_signature (app_name , api_name ):
274
- ctx = local_cache ["ctx" ]
275
- api = local_cache ["api" ]
276
-
277
- try :
278
- metadata = extract_signature (
279
- local_cache ["metadata" ]["signatureDef" ],
280
- local_cache ["api" ]["tf_serving" ]["signature_key" ],
281
- )
282
- except Exception as e :
283
- logger .exception ("failed to get signature" )
284
- return jsonify (error = str (e )), 404
285
-
286
- response = {"signature" : metadata }
305
+ signature = local_cache ["parsed_signature" ]
306
+ response = {"signature" : signature }
287
307
return jsonify (response )
288
308
289
309
@@ -375,7 +395,7 @@ def start(args):
375
395
limit = 60
376
396
for i in range (limit ):
377
397
try :
378
- local_cache ["metadata " ] = run_get_model_metadata ()
398
+ local_cache ["model_metadata " ] = run_get_model_metadata ()
379
399
break
380
400
except Exception as e :
381
401
if i > 6 :
@@ -385,14 +405,18 @@ def start(args):
385
405
sys .exit (1 )
386
406
387
407
time .sleep (5 )
388
- logger . info (
389
- "model_signature: {}" . format (
390
- extract_signature (
391
- local_cache [ "metadata " ]["signatureDef" ],
392
- local_cache [ "api" ][ "tf_serving" ][ "signature_key" ],
393
- )
394
- )
408
+
409
+ signature_key = None
410
+ if api . get ( "tf_serving" ) is not None and api [ "tf_serving" ]. get ( "signature_key" ) is not None :
411
+ signature_key = api [ "tf_serving " ]["signature_key" ]
412
+
413
+ key , parsed_signature = extract_signature (
414
+ local_cache [ "model_metadata" ][ "signatureDef" ], signature_key
395
415
)
416
+
417
+ local_cache ["signature_key" ] = key
418
+ local_cache ["parsed_signature" ] = parsed_signature
419
+ logger .info ("model_signature: {}" .format (local_cache ["parsed_signature" ]))
396
420
serve (app , listen = "*:{}" .format (args .port ))
397
421
398
422
0 commit comments