|
15 | 15 | # limitations under the License.
|
16 | 16 | #
|
17 | 17 |
|
18 |
| -from numpy import ndarray, copyto, float64, int64, int32, ones, array_equal, array, dot, shape |
| 18 | +from numpy import ndarray, copyto, float64, int64, int32, ones, array_equal, array, dot, shape, complex, issubdtype |
19 | 19 | from pyspark import SparkContext, RDD
|
| 20 | +import numpy as np |
20 | 21 |
|
21 | 22 | from pyspark.serializers import Serializer
|
22 | 23 | import struct
|
@@ -47,13 +48,22 @@ def _deserialize_byte_array(shape, ba, offset):
|
47 | 48 | return ar.copy()
|
48 | 49 |
|
49 | 50 | def _serialize_double_vector(v):
|
50 |
| - """Serialize a double vector into a mutually understood format.""" |
| 51 | + """Serialize a double vector into a mutually understood format. |
| 52 | +
|
| 53 | + >>> x = array([1,2,3]) |
| 54 | + >>> y = _deserialize_double_vector(_serialize_double_vector(x)) |
| 55 | + >>> array_equal(y, array([1.0, 2.0, 3.0])) |
| 56 | + True |
| 57 | + """ |
51 | 58 | if type(v) != ndarray:
|
52 | 59 | raise TypeError("_serialize_double_vector called on a %s; "
|
53 | 60 | "wanted ndarray" % type(v))
|
| 61 | + """complex is only datatype that can't be converted to float64""" |
| 62 | + if issubdtype(v.dtype, complex): |
| 63 | + raise TypeError("_serialize_double_vector called on a %s; " |
| 64 | + "wanted ndarray" % type(v)) |
54 | 65 | if v.dtype != float64:
|
55 |
| - raise TypeError("_serialize_double_vector called on an ndarray of %s; " |
56 |
| - "wanted ndarray of float64" % v.dtype) |
| 66 | + v = v.astype(float64) |
57 | 67 | if v.ndim != 1:
|
58 | 68 | raise TypeError("_serialize_double_vector called on a %ddarray; "
|
59 | 69 | "wanted a 1darray" % v.ndim)
|
|
0 commit comments