@@ -289,14 +289,11 @@ def run_predict(sample):
289
289
ctx = local_cache ["ctx" ]
290
290
request_handler = local_cache .get ("request_handler" )
291
291
292
- logger .info ("sample: " + util .pp_str_flat (sample ))
293
-
294
292
prepared_sample = sample
295
293
if request_handler is not None and util .has_function (request_handler , "pre_inference" ):
296
294
prepared_sample = request_handler .pre_inference (
297
295
sample , local_cache ["metadata" ]["signatureDef" ]
298
296
)
299
- logger .info ("pre_inference: " + util .pp_str_flat (prepared_sample ))
300
297
301
298
validate_sample (prepared_sample )
302
299
@@ -308,24 +305,18 @@ def run_predict(sample):
308
305
)
309
306
310
307
transformed_sample = transform_sample (prepared_sample )
311
- logger .info ("transformed_sample: " + util .pp_str_flat (transformed_sample ))
312
308
313
309
prediction_request = create_prediction_request (transformed_sample )
314
310
response_proto = local_cache ["stub" ].Predict (prediction_request , timeout = 300.0 )
315
311
result = parse_response_proto (response_proto )
316
-
317
312
result ["transformed_sample" ] = transformed_sample
318
- logger .info ("inference: " + util .pp_str_flat (result ))
319
313
else :
320
314
prediction_request = create_raw_prediction_request (prepared_sample )
321
315
response_proto = local_cache ["stub" ].Predict (prediction_request , timeout = 300.0 )
322
316
result = parse_response_proto_raw (response_proto )
323
317
324
- logger .info ("inference: " + util .pp_str_flat (result ))
325
-
326
318
if request_handler is not None and util .has_function (request_handler , "post_inference" ):
327
319
result = request_handler .post_inference (result , local_cache ["metadata" ]["signatureDef" ])
328
- logger .info ("post_inference: " + util .pp_str_flat (result ))
329
320
330
321
return result
331
322
@@ -352,10 +343,8 @@ def validate_sample(sample):
352
343
raise UserException ('missing key "{}"' .format (input_name ))
353
344
354
345
355
- def prediction_failed (sample , reason = None ):
356
- message = "prediction failed for sample: {}" .format (util .pp_str_flat (sample ))
357
- if reason :
358
- message += " ({})" .format (reason )
346
+ def prediction_failed (reason ):
347
+ message = "prediction failed: " + reason
359
348
360
349
logger .error (message )
361
350
return message , status .HTTP_406_NOT_ACCEPTABLE
@@ -380,16 +369,12 @@ def predict(deployment_name, api_name):
380
369
response = {}
381
370
382
371
if not util .is_dict (payload ) or "samples" not in payload :
383
- util .log_pretty_flat (payload , logging_func = logger .error )
384
- return prediction_failed (payload , "top level `samples` key not found in request" )
372
+ return prediction_failed ('top level "samples" key not found in request' )
385
373
386
374
predictions = []
387
375
samples = payload ["samples" ]
388
376
if not util .is_list (samples ):
389
- util .log_pretty_flat (samples , logging_func = logger .error )
390
- return prediction_failed (
391
- payload , "expected the value of key `samples` to be a list of json objects"
392
- )
377
+ return prediction_failed ('expected the value of key "samples" to be a list of json objects' )
393
378
394
379
for i , sample in enumerate (payload ["samples" ]):
395
380
try :
@@ -402,14 +387,14 @@ def predict(deployment_name, api_name):
402
387
api ["name" ]
403
388
)
404
389
)
405
- return prediction_failed (sample , str (e ))
390
+ return prediction_failed (str (e ))
406
391
except Exception as e :
407
392
logger .exception (
408
393
"An error occurred, see `cortex logs -v api {}` for more details." .format (
409
394
api ["name" ]
410
395
)
411
396
)
412
- return prediction_failed (sample , str (e ))
397
+ return prediction_failed (str (e ))
413
398
414
399
predictions .append (result )
415
400
0 commit comments