Skip to content

Commit 3bd3129

Browse files
techaddictmateiz
authored andcommitted
SPARK-1428: MLlib should convert non-float64 NumPy arrays to float64 instead of complaining
Author: Sandeep <[email protected]> Closes apache#356 from techaddict/1428 and squashes the following commits: 3bdf5f6 [Sandeep] SPARK-1428: MLlib should convert non-float64 NumPy arrays to float64 instead of complaining
1 parent 79820fe commit 3bd3129

File tree

1 file changed

+14
-4
lines changed

1 file changed

+14
-4
lines changed

python/pyspark/mllib/_common.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,9 @@
1515
# limitations under the License.
1616
#
1717

18-
from numpy import ndarray, copyto, float64, int64, int32, ones, array_equal, array, dot, shape
18+
from numpy import ndarray, copyto, float64, int64, int32, ones, array_equal, array, dot, shape, complex, issubdtype
1919
from pyspark import SparkContext, RDD
20+
import numpy as np
2021

2122
from pyspark.serializers import Serializer
2223
import struct
@@ -47,13 +48,22 @@ def _deserialize_byte_array(shape, ba, offset):
4748
return ar.copy()
4849

4950
def _serialize_double_vector(v):
50-
"""Serialize a double vector into a mutually understood format."""
51+
"""Serialize a double vector into a mutually understood format.
52+
53+
>>> x = array([1,2,3])
54+
>>> y = _deserialize_double_vector(_serialize_double_vector(x))
55+
>>> array_equal(y, array([1.0, 2.0, 3.0]))
56+
True
57+
"""
5158
if type(v) != ndarray:
5259
raise TypeError("_serialize_double_vector called on a %s; "
5360
"wanted ndarray" % type(v))
61+
"""complex is only datatype that can't be converted to float64"""
62+
if issubdtype(v.dtype, complex):
63+
raise TypeError("_serialize_double_vector called on a %s; "
64+
"wanted ndarray" % type(v))
5465
if v.dtype != float64:
55-
raise TypeError("_serialize_double_vector called on an ndarray of %s; "
56-
"wanted ndarray of float64" % v.dtype)
66+
v = v.astype(float64)
5767
if v.ndim != 1:
5868
raise TypeError("_serialize_double_vector called on a %ddarray; "
5969
"wanted a 1darray" % v.ndim)

0 commit comments

Comments
 (0)