Skip to content

Commit 10c6834

Browse files
committed
Show expected shape on TF API errors
1 parent 2b5863d commit 10c6834

File tree

1 file changed

+20
-6
lines changed
  • pkg/workloads/cortex/tf_api

1 file changed

+20
-6
lines changed

pkg/workloads/cortex/tf_api/api.py

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -136,10 +136,16 @@ def create_prediction_request(transformed_sample):
136136
shape = []
137137
for dim in signature_def[signature_key]["tensorShape"]["dim"]:
138138
shape.append(int(dim["size"]))
139-
tensor_proto = tf.make_tensor_proto(
140-
np.array(value).reshape(shape), dtype=data_type, shape=shape
141-
)
142-
prediction_request.inputs[column_name].CopyFrom(tensor_proto)
139+
140+
try:
141+
tensor_proto = tf.make_tensor_proto(
142+
np.array(value).reshape(shape), dtype=data_type, shape=shape
143+
)
144+
prediction_request.inputs[column_name].CopyFrom(tensor_proto)
145+
except Exception as e:
146+
raise UserException(
147+
'key "{}"'.format(column_name), "expected shape {}".format(shape)
148+
) from e
143149

144150
return prediction_request
145151

@@ -163,8 +169,16 @@ def create_raw_prediction_request(sample):
163169
shape = [1]
164170
value = [value]
165171
sig_type = signature_def[signature_key]["inputs"][column_name]["dtype"]
166-
tensor_proto = tf.make_tensor_proto(value, dtype=DTYPE_TO_TF_TYPE[sig_type], shape=shape)
167-
prediction_request.inputs[column_name].CopyFrom(tensor_proto)
172+
173+
try:
174+
tensor_proto = tf.make_tensor_proto(
175+
value, dtype=DTYPE_TO_TF_TYPE[sig_type], shape=shape
176+
)
177+
prediction_request.inputs[column_name].CopyFrom(tensor_proto)
178+
except Exception as e:
179+
raise UserException(
180+
'key "{}"'.format(column_name), "expected shape {}".format(shape)
181+
) from e
168182

169183
return prediction_request
170184

0 commit comments

Comments
 (0)