Skip to content

Commit 5d434b3

Browse files
authored
Improve signature def detection (#460)
1 parent 5b45a6c commit 5d434b3

File tree

3 files changed

+66
-51
lines changed

3 files changed

+66
-51
lines changed

pkg/consts/consts.go

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,4 @@ var (
3939
TelemetryURL = "https://telemetry.cortexlabs.dev"
4040

4141
MaxClassesPerRequest = 75 // cloudwatch.GeMetricData can get up to 100 metrics per request, avoid multiple requests and have room for other stats
42-
43-
DefaultTFServingSignatureKey = "predict"
4442
)

pkg/operator/api/userconfig/apis.go

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@ import (
2323
"strings"
2424

2525
"github.com/aws/aws-sdk-go/service/s3"
26-
"github.com/cortexlabs/cortex/pkg/consts"
2726
"github.com/cortexlabs/cortex/pkg/lib/aws"
2827
cr "github.com/cortexlabs/cortex/pkg/lib/configreader"
2928
"github.com/cortexlabs/cortex/pkg/lib/errors"
@@ -115,7 +114,7 @@ var apiValidation = &cr.StructValidation{
115114
{
116115
StructField: "SignatureKey",
117116
StringValidation: &cr.StringValidation{
118-
Default: consts.DefaultTFServingSignatureKey,
117+
Required: true,
119118
},
120119
},
121120
},
@@ -277,12 +276,6 @@ func (api *API) Validate(projectFileMap map[string][]byte) error {
277276
}
278277
}
279278

280-
if api.ModelFormat == TensorFlowModelFormat && api.TFServing == nil {
281-
api.TFServing = &TFServingOptions{
282-
SignatureKey: consts.DefaultTFServingSignatureKey,
283-
}
284-
}
285-
286279
if api.ModelFormat != TensorFlowModelFormat && api.TFServing != nil {
287280
return errors.Wrap(ErrorTFServingOptionsForTFOnly(api.ModelFormat), Identify(api))
288281
}

pkg/workloads/cortex/tf_api/api.py

Lines changed: 65 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,9 @@
4444
"ctx": None,
4545
"stub": None,
4646
"api": None,
47-
"metadata": None,
47+
"signature_key": None,
48+
"parsed_signature": None,
49+
"model_metadata": None,
4850
"request_handler": None,
4951
"class_set": set(),
5052
}
@@ -115,8 +117,8 @@ def after_request(response):
115117

116118

117119
def create_prediction_request(sample):
118-
signature_def = local_cache["metadata"]["signatureDef"]
119-
signature_key = local_cache["api"]["tf_serving"]["signature_key"]
120+
signature_def = local_cache["model_metadata"]["signatureDef"]
121+
signature_key = local_cache["signature_key"]
120122
prediction_request = predict_pb2.PredictRequest()
121123
prediction_request.model_spec.name = "model"
122124
prediction_request.model_spec.signature_name = signature_key
@@ -179,7 +181,7 @@ def run_predict(sample, debug=False):
179181
if request_handler is not None and util.has_function(request_handler, "pre_inference"):
180182
try:
181183
prepared_sample = request_handler.pre_inference(
182-
sample, local_cache["metadata"]["signatureDef"]
184+
sample, local_cache["model_metadata"]["signatureDef"]
183185
)
184186
debug_obj("pre_inference", prepared_sample, debug)
185187
except Exception as e:
@@ -196,7 +198,9 @@ def run_predict(sample, debug=False):
196198

197199
if request_handler is not None and util.has_function(request_handler, "post_inference"):
198200
try:
199-
result = request_handler.post_inference(result, local_cache["metadata"]["signatureDef"])
201+
result = request_handler.post_inference(
202+
result, local_cache["model_metadata"]["signatureDef"]
203+
)
200204
debug_obj("post_inference", result, debug)
201205
except Exception as e:
202206
raise UserRuntimeException(
@@ -207,9 +211,7 @@ def run_predict(sample, debug=False):
207211

208212

209213
def validate_sample(sample):
210-
signature = extract_signature(
211-
local_cache["metadata"]["signatureDef"], local_cache["api"]["tf_serving"]["signature_key"]
212-
)
214+
signature = local_cache["parsed_signature"]
213215
for input_name, _ in signature.items():
214216
if input_name not in sample:
215217
raise UserException('missing key "{}"'.format(input_name))
@@ -252,38 +254,56 @@ def predict(deployment_name, api_name):
252254

253255

254256
def extract_signature(signature_def, signature_key):
255-
if (
256-
signature_def.get(signature_key) is None
257-
or signature_def[signature_key].get("inputs") is None
258-
):
259-
raise UserException(
260-
'unable to find "' + signature_key + "\" in model's signature definition"
261-
)
262-
263-
metadata = {}
264-
for input_name, input_metadata in signature_def[signature_key]["inputs"].items():
265-
metadata[input_name] = {
257+
logger.info("signature defs found in model: {}".format(signature_def))
258+
259+
available_keys = list(signature_def.keys())
260+
if len(available_keys) == 0:
261+
raise UserException("unable to find signature defs in model")
262+
263+
if signature_key is None:
264+
if len(available_keys) == 1:
265+
logger.info(
266+
"signature_key was not configured by user, using signature key '{}' found in signature def map".format(
267+
available_keys[0]
268+
)
269+
)
270+
signature_key = available_keys[0]
271+
else:
272+
raise UserException(
273+
"signature_key was not configured by user, please specify one the following keys '{}' found in signature def map".format(
274+
"', '".join(available_keys)
275+
)
276+
)
277+
else:
278+
if signature_def.get(signature_key) is None:
279+
possibilities_str = "key: '{}'".format(available_keys[0])
280+
if len(available_keys) > 1:
281+
possibilities_str = "keys: '{}'".format("', '".join(available_keys))
282+
283+
raise UserException(
284+
"signature_key '{}' was not found in signature def map, but found the following {}".format(
285+
signature_key, possibilities_str
286+
)
287+
)
288+
289+
signature_def_val = signature_def.get(signature_key)
290+
291+
if signature_def_val.get("inputs") is None:
292+
raise UserException("unable to find 'inputs' in signature def '{}'".format(signature_key))
293+
294+
parsed_signature = {}
295+
for input_name, input_metadata in signature_def_val["inputs"].items():
296+
parsed_signature[input_name] = {
266297
"shape": [int(dim["size"]) for dim in input_metadata["tensorShape"]["dim"]],
267298
"type": DTYPE_TO_TF_TYPE[input_metadata["dtype"]].name,
268299
}
269-
return metadata
300+
return signature_key, parsed_signature
270301

271302

272303
@app.route("/<app_name>/<api_name>/signature", methods=["GET"])
273304
def get_signature(app_name, api_name):
274-
ctx = local_cache["ctx"]
275-
api = local_cache["api"]
276-
277-
try:
278-
metadata = extract_signature(
279-
local_cache["metadata"]["signatureDef"],
280-
local_cache["api"]["tf_serving"]["signature_key"],
281-
)
282-
except Exception as e:
283-
logger.exception("failed to get signature")
284-
return jsonify(error=str(e)), 404
285-
286-
response = {"signature": metadata}
305+
signature = local_cache["parsed_signature"]
306+
response = {"signature": signature}
287307
return jsonify(response)
288308

289309

@@ -375,7 +395,7 @@ def start(args):
375395
limit = 60
376396
for i in range(limit):
377397
try:
378-
local_cache["metadata"] = run_get_model_metadata()
398+
local_cache["model_metadata"] = run_get_model_metadata()
379399
break
380400
except Exception as e:
381401
if i > 6:
@@ -385,14 +405,18 @@ def start(args):
385405
sys.exit(1)
386406

387407
time.sleep(5)
388-
logger.info(
389-
"model_signature: {}".format(
390-
extract_signature(
391-
local_cache["metadata"]["signatureDef"],
392-
local_cache["api"]["tf_serving"]["signature_key"],
393-
)
394-
)
408+
409+
signature_key = None
410+
if api.get("tf_serving") is not None and api["tf_serving"].get("signature_key") is not None:
411+
signature_key = api["tf_serving"]["signature_key"]
412+
413+
key, parsed_signature = extract_signature(
414+
local_cache["model_metadata"]["signatureDef"], signature_key
395415
)
416+
417+
local_cache["signature_key"] = key
418+
local_cache["parsed_signature"] = parsed_signature
419+
logger.info("model_signature: {}".format(local_cache["parsed_signature"]))
396420
serve(app, listen="*:{}".format(args.port))
397421

398422

0 commit comments

Comments
 (0)