diff --git a/pkg/consts/consts.go b/pkg/consts/consts.go index 1d08176e31..4d5f661da7 100644 --- a/pkg/consts/consts.go +++ b/pkg/consts/consts.go @@ -39,6 +39,4 @@ var ( TelemetryURL = "https://telemetry.cortexlabs.dev" MaxClassesPerRequest = 75 // cloudwatch.GeMetricData can get up to 100 metrics per request, avoid multiple requests and have room for other stats - - DefaultTFServingSignatureKey = "predict" ) diff --git a/pkg/operator/api/userconfig/apis.go b/pkg/operator/api/userconfig/apis.go index 5cd1f921b6..1f5ed41ce0 100644 --- a/pkg/operator/api/userconfig/apis.go +++ b/pkg/operator/api/userconfig/apis.go @@ -23,7 +23,6 @@ import ( "strings" "github.com/aws/aws-sdk-go/service/s3" - "github.com/cortexlabs/cortex/pkg/consts" "github.com/cortexlabs/cortex/pkg/lib/aws" cr "github.com/cortexlabs/cortex/pkg/lib/configreader" "github.com/cortexlabs/cortex/pkg/lib/errors" @@ -115,7 +114,7 @@ var apiValidation = &cr.StructValidation{ { StructField: "SignatureKey", StringValidation: &cr.StringValidation{ - Default: consts.DefaultTFServingSignatureKey, + Required: true, }, }, }, @@ -277,12 +276,6 @@ func (api *API) Validate(projectFileMap map[string][]byte) error { } } - if api.ModelFormat == TensorFlowModelFormat && api.TFServing == nil { - api.TFServing = &TFServingOptions{ - SignatureKey: consts.DefaultTFServingSignatureKey, - } - } - if api.ModelFormat != TensorFlowModelFormat && api.TFServing != nil { return errors.Wrap(ErrorTFServingOptionsForTFOnly(api.ModelFormat), Identify(api)) } diff --git a/pkg/workloads/cortex/tf_api/api.py b/pkg/workloads/cortex/tf_api/api.py index c901ec6ad6..f03c0e5875 100644 --- a/pkg/workloads/cortex/tf_api/api.py +++ b/pkg/workloads/cortex/tf_api/api.py @@ -44,7 +44,9 @@ "ctx": None, "stub": None, "api": None, - "metadata": None, + "signature_key": None, + "parsed_signature": None, + "model_metadata": None, "request_handler": None, "class_set": set(), } @@ -115,8 +117,8 @@ def after_request(response): def create_prediction_request(sample): - signature_def = local_cache["metadata"]["signatureDef"] - signature_key = local_cache["api"]["tf_serving"]["signature_key"] + signature_def = local_cache["model_metadata"]["signatureDef"] + signature_key = local_cache["signature_key"] prediction_request = predict_pb2.PredictRequest() prediction_request.model_spec.name = "model" prediction_request.model_spec.signature_name = signature_key @@ -179,7 +181,7 @@ def run_predict(sample, debug=False): if request_handler is not None and util.has_function(request_handler, "pre_inference"): try: prepared_sample = request_handler.pre_inference( - sample, local_cache["metadata"]["signatureDef"] + sample, local_cache["model_metadata"]["signatureDef"] ) debug_obj("pre_inference", prepared_sample, debug) except Exception as e: @@ -196,7 +198,9 @@ def run_predict(sample, debug=False): if request_handler is not None and util.has_function(request_handler, "post_inference"): try: - result = request_handler.post_inference(result, local_cache["metadata"]["signatureDef"]) + result = request_handler.post_inference( + result, local_cache["model_metadata"]["signatureDef"] + ) debug_obj("post_inference", result, debug) except Exception as e: raise UserRuntimeException( @@ -207,9 +211,7 @@ def run_predict(sample, debug=False): def validate_sample(sample): - signature = extract_signature( - local_cache["metadata"]["signatureDef"], local_cache["api"]["tf_serving"]["signature_key"] - ) + signature = local_cache["parsed_signature"] for input_name, _ in signature.items(): if input_name not in sample: raise UserException('missing key "{}"'.format(input_name)) @@ -252,38 +254,56 @@ def predict(deployment_name, api_name): def extract_signature(signature_def, signature_key): - if ( - signature_def.get(signature_key) is None - or signature_def[signature_key].get("inputs") is None - ): - raise UserException( - 'unable to find "' + signature_key + "\" in model's signature definition" - ) - - metadata = {} - for input_name, input_metadata in signature_def[signature_key]["inputs"].items(): - metadata[input_name] = { + logger.info("signature defs found in model: {}".format(signature_def)) + + available_keys = list(signature_def.keys()) + if len(available_keys) == 0: + raise UserException("unable to find signature defs in model") + + if signature_key is None: + if len(available_keys) == 1: + logger.info( + "signature_key was not configured by user, using signature key '{}' found in signature def map".format( + available_keys[0] + ) + ) + signature_key = available_keys[0] + else: + raise UserException( + "signature_key was not configured by user, please specify one the following keys '{}' found in signature def map".format( + "', '".join(available_keys) + ) + ) + else: + if signature_def.get(signature_key) is None: + possibilities_str = "key: '{}'".format(available_keys[0]) + if len(available_keys) > 1: + possibilities_str = "keys: '{}'".format("', '".join(available_keys)) + + raise UserException( + "signature_key '{}' was not found in signature def map, but found the following {}".format( + signature_key, possibilities_str + ) + ) + + signature_def_val = signature_def.get(signature_key) + + if signature_def_val.get("inputs") is None: + raise UserException("unable to find 'inputs' in signature def '{}'".format(signature_key)) + + parsed_signature = {} + for input_name, input_metadata in signature_def_val["inputs"].items(): + parsed_signature[input_name] = { "shape": [int(dim["size"]) for dim in input_metadata["tensorShape"]["dim"]], "type": DTYPE_TO_TF_TYPE[input_metadata["dtype"]].name, } - return metadata + return signature_key, parsed_signature @app.route("///signature", methods=["GET"]) def get_signature(app_name, api_name): - ctx = local_cache["ctx"] - api = local_cache["api"] - - try: - metadata = extract_signature( - local_cache["metadata"]["signatureDef"], - local_cache["api"]["tf_serving"]["signature_key"], - ) - except Exception as e: - logger.exception("failed to get signature") - return jsonify(error=str(e)), 404 - - response = {"signature": metadata} + signature = local_cache["parsed_signature"] + response = {"signature": signature} return jsonify(response) @@ -375,7 +395,7 @@ def start(args): limit = 60 for i in range(limit): try: - local_cache["metadata"] = run_get_model_metadata() + local_cache["model_metadata"] = run_get_model_metadata() break except Exception as e: if i > 6: @@ -385,14 +405,18 @@ def start(args): sys.exit(1) time.sleep(5) - logger.info( - "model_signature: {}".format( - extract_signature( - local_cache["metadata"]["signatureDef"], - local_cache["api"]["tf_serving"]["signature_key"], - ) - ) + + signature_key = None + if api.get("tf_serving") is not None and api["tf_serving"].get("signature_key") is not None: + signature_key = api["tf_serving"]["signature_key"] + + key, parsed_signature = extract_signature( + local_cache["model_metadata"]["signatureDef"], signature_key ) + + local_cache["signature_key"] = key + local_cache["parsed_signature"] = parsed_signature + logger.info("model_signature: {}".format(local_cache["parsed_signature"])) serve(app, listen="*:{}".format(args.port))