diff --git a/examples/iris-classifier/handlers/sklearn.py b/examples/iris-classifier/handlers/sklearn.py index 4389497f79..8be7f4705d 100644 --- a/examples/iris-classifier/handlers/sklearn.py +++ b/examples/iris-classifier/handlers/sklearn.py @@ -23,7 +23,7 @@ def pre_inference(sample, metadata): sample["petal_width"], ] ) - return ((x - scalars["mean"]) / scalars["stddev"]).astype(np.float32) + return (x - scalars["mean"]) / scalars["stddev"] def post_inference(prediction, metadata): diff --git a/pkg/workloads/cortex/onnx_serve/api.py b/pkg/workloads/cortex/onnx_serve/api.py index ff4a54b9bc..315a88c08b 100644 --- a/pkg/workloads/cortex/onnx_serve/api.py +++ b/pkg/workloads/cortex/onnx_serve/api.py @@ -108,10 +108,18 @@ def transform_to_numpy(input_pyobj, input_metadata): if dim is None: target_shape[idx] = 1 - if type(input_pyobj) is not np.ndarray: - np_arr = np.array(input_pyobj, dtype=target_dtype) - else: + if type(input_pyobj) is np.ndarray: np_arr = input_pyobj + if np.issubdtype(np_arr.dtype, np.number) == np.issubdtype(target_dtype, np.number): + if str(np_arr.dtype) != target_dtype: + np_arr = np_arr.astype(target_dtype) + else: + raise ValueError( + "expected dtype '{}' but found '{}'".format(target_dtype, np_arr.dtype) + ) + else: + np_arr = np.array(input_pyobj, dtype=target_dtype) + np_arr = np_arr.reshape(target_shape) return np_arr except Exception as e: