Skip to content

Commit fdd4b5a

Browse files
authored
Fix tracker key bug (#378)
1 parent 7f7e1ca commit fdd4b5a

File tree

3 files changed

+24
-15
lines changed

3 files changed

+24
-15
lines changed

docs/deployments/apis.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ Serve models at scale.
1111
model_format: <string> # model format, must be "tensorflow" or "onnx" (default: "onnx" if model path ends with .onnx, "tensorflow" if model path ends with .zip)
1212
request_handler: <string> # path to the request handler implementation file, relative to the cortex root
1313
tracker:
14-
key: <string> # json key to track in the response payload
14+
key: <string> # json key to track if the response payload is a dictionary
1515
model_type: <string> # model type, must be "classification" or "regression"
1616
compute:
1717
min_replicas: <int> # minimum number of replicas (default: 1)

pkg/operator/api/userconfig/apis.go

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ type API struct {
4242
}
4343

4444
type Tracker struct {
45-
Key string `json:"key" yaml:"key"`
45+
Key *string `json:"key" yaml:"key"`
4646
ModelType ModelType `json:"model_type" yaml:"model_type"`
4747
}
4848

@@ -72,10 +72,8 @@ var apiValidation = &cr.StructValidation{
7272
DefaultNil: true,
7373
StructFieldValidations: []*cr.StructFieldValidation{
7474
{
75-
StructField: "Key",
76-
StringValidation: &cr.StringValidation{
77-
Required: true,
78-
},
75+
StructField: "Key",
76+
StringPtrValidation: &cr.StringPtrValidation{},
7977
},
8078
{
8179
StructField: "ModelType",

pkg/workloads/cortex/lib/api_utils.py

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -73,23 +73,34 @@ def extract_predicted_values(api, predictions):
7373

7474
tracker = api.get("tracker")
7575
for prediction in predictions:
76-
predicted_value = prediction.get(tracker["key"])
77-
if predicted_value is None:
78-
raise ValueError(
79-
"failed to track key '{}': not found in response payload".format(tracker["key"])
80-
)
76+
if tracker.get("key") is not None:
77+
key = tracker["key"]
78+
if type(prediction) != dict:
79+
raise ValueError(
80+
"failed to track key '{}': expected prediction to be of type dict but found '{}'".format(
81+
key, type(prediction)
82+
)
83+
)
84+
if prediction.get(key) is None:
85+
raise ValueError(
86+
"failed to track key '{}': not found in prediction".format(tracker["key"])
87+
)
88+
predicted_value = prediction[key]
89+
else:
90+
predicted_value = prediction
91+
8192
if tracker["model_type"] == "classification":
8293
if type(predicted_value) != str and type(predicted_value) != int:
8394
raise ValueError(
84-
"failed to track key '{}': expected type 'str' or 'int' but encountered '{}'".format(
85-
tracker["key"], type(predicted_value)
95+
"failed to track classification prediction: expected type 'str' or 'int' but encountered '{}'".format(
96+
type(predicted_value)
8697
)
8798
)
8899
else:
89100
if type(predicted_value) != float and type(predicted_value) != int: # allow ints
90101
raise ValueError(
91-
"failed to track key '{}': expected type 'float' or 'int' but encountered '{}'".format(
92-
tracker["key"], type(predicted_value)
102+
"failed to track regression prediction: expected type 'float' or 'int' but encountered '{}'".format(
103+
type(predicted_value)
93104
)
94105
)
95106
predicted_values.append(predicted_value)

0 commit comments

Comments
 (0)