Skip to content

Commit 2721a50

Browse files
committed
[SPARK-50851][ML][CONNECT][PYTHON] Express ML params with proto.Expression.Literal
### What changes were proposed in this pull request? Express ML params with `proto.Expression.Literal`: 1, introduce `Literal.SpecializedArray` for large primitive literal arrays (e.g. the initial model coefficients which can be large) ``` message SpecializedArray { oneof value_type { Bools bools = 1; Ints ints = 2; Longs longs = 3; Floats floats = 4; Doubles doubles = 5; Strings strings = 6; } message Bools { repeated bool values = 1; } message Ints { repeated int32 values = 1; } message Longs { repeated int64 values = 1; } message Floats { repeated float values = 1; } message Doubles { repeated double values = 1; } message Strings { repeated string values = 1; } } ``` 2, Replace `proto.Param ` with `proto.Expression` to be consistent with SQL side For `Param[Vector]` and `Param[Matrix]`, apply `proto.Expression.Literal.Struct` with the underlying schema of `VectorUDT` and `MatrixUDT`. E.g. for `Param[Vector]` with value `Vectors.sparse(4, [(1, 1.0), (3, 5.5)])`, the message is like: ``` literal { struct { struct_type { struct { ... <- schema of VectorUDT } } elements { byte: 0 } elements { integer: 4 } elements { specialized_array { ints { values: 1 values: 3 } } } elements { specialized_array { doubles { values: 1 values: 5.5 } } } } ``` ### Why are the changes needed? 1, to optimize large literal arrays, for both ML and SQL (we can apply it in SQL side later) 2, be consistent with SQL side, e.g. the parameterized SQL ``` // (Optional) A map of parameter names to expressions. // It cannot coexist with `pos_arguments`. map<string, Expression.Literal> named_arguments = 4; // (Optional) A sequence of expressions for positional parameters in the SQL query text. // It cannot coexist with `named_arguments`. repeated Expression pos_arguments = 5; ``` 3, to minimize the protobuf change ### Does this PR introduce _any_ user-facing change? no, refactor-only ### How was this patch tested? existing tests ### Was this patch authored or co-authored using generative AI tooling? no Closes #49529 from zhengruifeng/ml_proto_expr. Authored-by: Ruifeng Zheng <[email protected]> Signed-off-by: Ruifeng Zheng <[email protected]>
1 parent 10dd350 commit 2721a50

File tree

23 files changed

+906
-871
lines changed

23 files changed

+906
-871
lines changed

mllib/src/main/scala/org/apache/spark/ml/linalg/MatrixUDT.scala

Lines changed: 22 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -27,23 +27,7 @@ import org.apache.spark.sql.types._
2727
*/
2828
private[spark] class MatrixUDT extends UserDefinedType[Matrix] {
2929

30-
override def sqlType: StructType = {
31-
// type: 0 = sparse, 1 = dense
32-
// the dense matrix is built by numRows, numCols, values and isTransposed, all of which are
33-
// set as not nullable, except values since in the future, support for binary matrices might
34-
// be added for which values are not needed.
35-
// the sparse matrix needs colPtrs and rowIndices, which are set as
36-
// null, while building the dense matrix.
37-
StructType(Array(
38-
StructField("type", ByteType, nullable = false),
39-
StructField("numRows", IntegerType, nullable = false),
40-
StructField("numCols", IntegerType, nullable = false),
41-
StructField("colPtrs", ArrayType(IntegerType, containsNull = false), nullable = true),
42-
StructField("rowIndices", ArrayType(IntegerType, containsNull = false), nullable = true),
43-
StructField("values", ArrayType(DoubleType, containsNull = false), nullable = true),
44-
StructField("isTransposed", BooleanType, nullable = false)
45-
))
46-
}
30+
override def sqlType: StructType = MatrixUDT.sqlType
4731

4832
override def serialize(obj: Matrix): InternalRow = {
4933
val row = new GenericInternalRow(7)
@@ -108,3 +92,24 @@ private[spark] class MatrixUDT extends UserDefinedType[Matrix] {
10892

10993
private[spark] override def asNullable: MatrixUDT = this
11094
}
95+
96+
private[spark] object MatrixUDT {
97+
98+
val sqlType: StructType = {
99+
// type: 0 = sparse, 1 = dense
100+
// the dense matrix is built by numRows, numCols, values and isTransposed, all of which are
101+
// set as not nullable, except values since in the future, support for binary matrices might
102+
// be added for which values are not needed.
103+
// the sparse matrix needs colPtrs and rowIndices, which are set as
104+
// null, while building the dense matrix.
105+
StructType(Array(
106+
StructField("type", ByteType, nullable = false),
107+
StructField("numRows", IntegerType, nullable = false),
108+
StructField("numCols", IntegerType, nullable = false),
109+
StructField("colPtrs", ArrayType(IntegerType, containsNull = false), nullable = true),
110+
StructField("rowIndices", ArrayType(IntegerType, containsNull = false), nullable = true),
111+
StructField("values", ArrayType(DoubleType, containsNull = false), nullable = true),
112+
StructField("isTransposed", BooleanType, nullable = false)
113+
))
114+
}
115+
}

mllib/src/main/scala/org/apache/spark/ml/linalg/VectorUDT.scala

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ import org.apache.spark.sql.types._
2727
*/
2828
private[spark] class VectorUDT extends UserDefinedType[Vector] {
2929

30-
override final def sqlType: StructType = _sqlType
30+
override final def sqlType: StructType = VectorUDT.sqlType
3131

3232
override def serialize(obj: Vector): InternalRow = {
3333
obj match {
@@ -86,8 +86,11 @@ private[spark] class VectorUDT extends UserDefinedType[Vector] {
8686
override def typeName: String = "vector"
8787

8888
private[spark] override def asNullable: VectorUDT = this
89+
}
90+
91+
private[spark] object VectorUDT {
8992

90-
private[this] val _sqlType = {
93+
val sqlType = {
9194
// type: 0 = sparse, 1 = dense
9295
// We only use "values" for dense vectors, and "size", "indices", and "values" for sparse
9396
// vectors. The "values" field is nullable because we might want to add binary vectors later,

python/pyspark/ml/connect/serialize.py

Lines changed: 133 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -18,54 +18,107 @@
1818

1919
import pyspark.sql.connect.proto as pb2
2020
from pyspark.ml.linalg import (
21-
Vectors,
22-
Matrices,
21+
VectorUDT,
22+
MatrixUDT,
2323
DenseVector,
2424
SparseVector,
2525
DenseMatrix,
2626
SparseMatrix,
2727
)
28-
from pyspark.sql.connect.expressions import LiteralExpression
2928

3029
if TYPE_CHECKING:
3130
from pyspark.sql.connect.client import SparkConnectClient
3231
from pyspark.ml.param import Params
3332

3433

35-
def serialize_param(value: Any, client: "SparkConnectClient") -> pb2.Param:
36-
if isinstance(value, DenseVector):
37-
return pb2.Param(vector=pb2.Vector(dense=pb2.Vector.Dense(value=value.values.tolist())))
38-
elif isinstance(value, SparseVector):
39-
return pb2.Param(
40-
vector=pb2.Vector(
41-
sparse=pb2.Vector.Sparse(
42-
size=value.size, index=value.indices.tolist(), value=value.values.tolist()
43-
)
44-
)
45-
)
46-
elif isinstance(value, DenseMatrix):
47-
return pb2.Param(
48-
matrix=pb2.Matrix(
49-
dense=pb2.Matrix.Dense(
50-
num_rows=value.numRows, num_cols=value.numCols, value=value.values.tolist()
51-
)
52-
)
53-
)
34+
def literal_null() -> pb2.Expression.Literal:
35+
dt = pb2.DataType()
36+
dt.null.CopyFrom(pb2.DataType.NULL())
37+
return pb2.Expression.Literal(null=dt)
38+
39+
40+
def build_int_list(value: List[int]) -> pb2.Expression.Literal:
41+
p = pb2.Expression.Literal()
42+
p.specialized_array.ints.values.extend(value)
43+
return p
44+
45+
46+
def build_float_list(value: List[float]) -> pb2.Expression.Literal:
47+
p = pb2.Expression.Literal()
48+
p.specialized_array.doubles.values.extend(value)
49+
return p
50+
51+
52+
def serialize_param(value: Any, client: "SparkConnectClient") -> pb2.Expression.Literal:
53+
from pyspark.sql.connect.types import pyspark_types_to_proto_types
54+
from pyspark.sql.connect.expressions import LiteralExpression
55+
56+
if isinstance(value, SparseVector):
57+
p = pb2.Expression.Literal()
58+
p.struct.struct_type.CopyFrom(pyspark_types_to_proto_types(VectorUDT.sqlType()))
59+
# type = 0
60+
p.struct.elements.append(pb2.Expression.Literal(byte=0))
61+
# size
62+
p.struct.elements.append(pb2.Expression.Literal(integer=value.size))
63+
# indices
64+
p.struct.elements.append(build_int_list(value.indices.tolist()))
65+
# values
66+
p.struct.elements.append(build_float_list(value.values.tolist()))
67+
return p
68+
69+
elif isinstance(value, DenseVector):
70+
p = pb2.Expression.Literal()
71+
p.struct.struct_type.CopyFrom(pyspark_types_to_proto_types(VectorUDT.sqlType()))
72+
# type = 1
73+
p.struct.elements.append(pb2.Expression.Literal(byte=1))
74+
# size = null
75+
p.struct.elements.append(literal_null())
76+
# indices = null
77+
p.struct.elements.append(literal_null())
78+
# values
79+
p.struct.elements.append(build_float_list(value.values.tolist()))
80+
return p
81+
5482
elif isinstance(value, SparseMatrix):
55-
return pb2.Param(
56-
matrix=pb2.Matrix(
57-
sparse=pb2.Matrix.Sparse(
58-
num_rows=value.numRows,
59-
num_cols=value.numCols,
60-
colptr=value.colPtrs.tolist(),
61-
row_index=value.rowIndices.tolist(),
62-
value=value.values.tolist(),
63-
)
64-
)
65-
)
83+
p = pb2.Expression.Literal()
84+
p.struct.struct_type.CopyFrom(pyspark_types_to_proto_types(MatrixUDT.sqlType()))
85+
# type = 0
86+
p.struct.elements.append(pb2.Expression.Literal(byte=0))
87+
# numRows
88+
p.struct.elements.append(pb2.Expression.Literal(integer=value.numRows))
89+
# numCols
90+
p.struct.elements.append(pb2.Expression.Literal(integer=value.numCols))
91+
# colPtrs
92+
p.struct.elements.append(build_int_list(value.colPtrs.tolist()))
93+
# rowIndices
94+
p.struct.elements.append(build_int_list(value.rowIndices.tolist()))
95+
# values
96+
p.struct.elements.append(build_float_list(value.values.tolist()))
97+
# isTransposed
98+
p.struct.elements.append(pb2.Expression.Literal(boolean=value.isTransposed))
99+
return p
100+
101+
elif isinstance(value, DenseMatrix):
102+
p = pb2.Expression.Literal()
103+
p.struct.struct_type.CopyFrom(pyspark_types_to_proto_types(MatrixUDT.sqlType()))
104+
# type = 1
105+
p.struct.elements.append(pb2.Expression.Literal(byte=1))
106+
# numRows
107+
p.struct.elements.append(pb2.Expression.Literal(integer=value.numRows))
108+
# numCols
109+
p.struct.elements.append(pb2.Expression.Literal(integer=value.numCols))
110+
# colPtrs = null
111+
p.struct.elements.append(literal_null())
112+
# rowIndices = null
113+
p.struct.elements.append(literal_null())
114+
# values
115+
p.struct.elements.append(build_float_list(value.values.tolist()))
116+
# isTransposed
117+
p.struct.elements.append(pb2.Expression.Literal(boolean=value.isTransposed))
118+
return p
119+
66120
else:
67-
literal = LiteralExpression._from_value(value).to_plan(client).literal
68-
return pb2.Param(literal=literal)
121+
return LiteralExpression._from_value(value).to_plan(client).literal
69122

70123

71124
def serialize(client: "SparkConnectClient", *args: Any) -> List[Any]:
@@ -80,38 +133,51 @@ def serialize(client: "SparkConnectClient", *args: Any) -> List[Any]:
80133
return result
81134

82135

83-
def deserialize_param(param: pb2.Param) -> Any:
84-
if param.HasField("literal"):
85-
return LiteralExpression._to_value(param.literal)
86-
if param.HasField("vector"):
87-
vector = param.vector
88-
if vector.HasField("dense"):
89-
return Vectors.dense(vector.dense.value)
90-
elif vector.HasField("sparse"):
91-
return Vectors.sparse(vector.sparse.size, vector.sparse.index, vector.sparse.value)
92-
else:
93-
raise ValueError("Unsupported vector type")
94-
if param.HasField("matrix"):
95-
matrix = param.matrix
96-
if matrix.HasField("dense"):
97-
return DenseMatrix(
98-
matrix.dense.num_rows,
99-
matrix.dense.num_cols,
100-
matrix.dense.value,
101-
matrix.dense.is_transposed,
102-
)
103-
elif matrix.HasField("sparse"):
104-
return Matrices.sparse(
105-
matrix.sparse.num_rows,
106-
matrix.sparse.num_cols,
107-
matrix.sparse.colptr,
108-
matrix.sparse.row_index,
109-
matrix.sparse.value,
110-
)
136+
def deserialize_param(literal: pb2.Expression.Literal) -> Any:
137+
from pyspark.sql.connect.types import proto_schema_to_pyspark_data_type
138+
from pyspark.sql.connect.expressions import LiteralExpression
139+
140+
if literal.HasField("struct"):
141+
s = literal.struct
142+
schema = proto_schema_to_pyspark_data_type(s.struct_type)
143+
144+
if schema == VectorUDT.sqlType():
145+
assert len(s.elements) == 4
146+
tpe = s.elements[0].byte
147+
if tpe == 0:
148+
size = s.elements[1].integer
149+
indices = s.elements[2].specialized_array.ints.values
150+
values = s.elements[3].specialized_array.doubles.values
151+
return SparseVector(size, indices, values)
152+
elif tpe == 1:
153+
values = s.elements[3].specialized_array.doubles.values
154+
return DenseVector(values)
155+
else:
156+
raise ValueError(f"Unknown Vector type {tpe}")
157+
158+
elif schema == MatrixUDT.sqlType():
159+
assert len(s.elements) == 7
160+
tpe = s.elements[0].byte
161+
if tpe == 0:
162+
numRows = s.elements[1].integer
163+
numCols = s.elements[2].integer
164+
colPtrs = s.elements[3].specialized_array.ints.values
165+
rowIndices = s.elements[4].specialized_array.ints.values
166+
values = s.elements[5].specialized_array.doubles.values
167+
isTransposed = s.elements[6].boolean
168+
return SparseMatrix(numRows, numCols, colPtrs, rowIndices, values, isTransposed)
169+
elif tpe == 1:
170+
numRows = s.elements[1].integer
171+
numCols = s.elements[2].integer
172+
values = s.elements[5].specialized_array.doubles.values
173+
isTransposed = s.elements[6].boolean
174+
return DenseMatrix(numRows, numCols, values, isTransposed)
175+
else:
176+
raise ValueError(f"Unknown Matrix type {tpe}")
111177
else:
112-
raise ValueError("Unsupported matrix type")
113-
114-
raise ValueError("Unsupported param type")
178+
raise ValueError(f"Unsupported parameter struct {schema}")
179+
else:
180+
return LiteralExpression._to_value(literal)
115181

116182

117183
def deserialize(ml_command_result_properties: Dict[str, Any]) -> Any:
@@ -126,7 +192,7 @@ def deserialize(ml_command_result_properties: Dict[str, Any]) -> Any:
126192

127193

128194
def serialize_ml_params(instance: "Params", client: "SparkConnectClient") -> pb2.MlParams:
129-
params: Mapping[str, pb2.Param] = {
195+
params: Mapping[str, pb2.Expression.Literal] = {
130196
k.name: serialize_param(v, client) for k, v in instance._paramMap.items()
131197
}
132198
return pb2.MlParams(params=params)

python/pyspark/sql/connect/proto/common_pb2.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535

3636

3737
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
38-
b'\n\x1aspark/connect/common.proto\x12\rspark.connect"\xb0\x01\n\x0cStorageLevel\x12\x19\n\x08use_disk\x18\x01 \x01(\x08R\x07useDisk\x12\x1d\n\nuse_memory\x18\x02 \x01(\x08R\tuseMemory\x12 \n\x0cuse_off_heap\x18\x03 \x01(\x08R\nuseOffHeap\x12"\n\x0c\x64\x65serialized\x18\x04 \x01(\x08R\x0c\x64\x65serialized\x12 \n\x0breplication\x18\x05 \x01(\x05R\x0breplication"G\n\x13ResourceInformation\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12\x1c\n\taddresses\x18\x02 \x03(\tR\taddresses"\xc3\x01\n\x17\x45xecutorResourceRequest\x12#\n\rresource_name\x18\x01 \x01(\tR\x0cresourceName\x12\x16\n\x06\x61mount\x18\x02 \x01(\x03R\x06\x61mount\x12.\n\x10\x64iscovery_script\x18\x03 \x01(\tH\x00R\x0f\x64iscoveryScript\x88\x01\x01\x12\x1b\n\x06vendor\x18\x04 \x01(\tH\x01R\x06vendor\x88\x01\x01\x42\x13\n\x11_discovery_scriptB\t\n\x07_vendor"R\n\x13TaskResourceRequest\x12#\n\rresource_name\x18\x01 \x01(\tR\x0cresourceName\x12\x16\n\x06\x61mount\x18\x02 \x01(\x01R\x06\x61mount"\xa5\x03\n\x0fResourceProfile\x12\x64\n\x12\x65xecutor_resources\x18\x01 \x03(\x0b\x32\x35.spark.connect.ResourceProfile.ExecutorResourcesEntryR\x11\x65xecutorResources\x12X\n\x0etask_resources\x18\x02 \x03(\x0b\x32\x31.spark.connect.ResourceProfile.TaskResourcesEntryR\rtaskResources\x1al\n\x16\x45xecutorResourcesEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12<\n\x05value\x18\x02 \x01(\x0b\x32&.spark.connect.ExecutorResourceRequestR\x05value:\x02\x38\x01\x1a\x64\n\x12TaskResourcesEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x38\n\x05value\x18\x02 \x01(\x0b\x32".spark.connect.TaskResourceRequestR\x05value:\x02\x38\x01"X\n\x06Origin\x12\x42\n\rpython_origin\x18\x01 \x01(\x0b\x32\x1b.spark.connect.PythonOriginH\x00R\x0cpythonOriginB\n\n\x08\x66unction"G\n\x0cPythonOrigin\x12\x1a\n\x08\x66ragment\x18\x01 \x01(\tR\x08\x66ragment\x12\x1b\n\tcall_site\x18\x02 \x01(\tR\x08\x63\x61llSiteB6\n\x1eorg.apache.spark.connect.protoP\x01Z\x12internal/generatedb\x06proto3'
38+
b'\n\x1aspark/connect/common.proto\x12\rspark.connect"\xb0\x01\n\x0cStorageLevel\x12\x19\n\x08use_disk\x18\x01 \x01(\x08R\x07useDisk\x12\x1d\n\nuse_memory\x18\x02 \x01(\x08R\tuseMemory\x12 \n\x0cuse_off_heap\x18\x03 \x01(\x08R\nuseOffHeap\x12"\n\x0c\x64\x65serialized\x18\x04 \x01(\x08R\x0c\x64\x65serialized\x12 \n\x0breplication\x18\x05 \x01(\x05R\x0breplication"G\n\x13ResourceInformation\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12\x1c\n\taddresses\x18\x02 \x03(\tR\taddresses"\xc3\x01\n\x17\x45xecutorResourceRequest\x12#\n\rresource_name\x18\x01 \x01(\tR\x0cresourceName\x12\x16\n\x06\x61mount\x18\x02 \x01(\x03R\x06\x61mount\x12.\n\x10\x64iscovery_script\x18\x03 \x01(\tH\x00R\x0f\x64iscoveryScript\x88\x01\x01\x12\x1b\n\x06vendor\x18\x04 \x01(\tH\x01R\x06vendor\x88\x01\x01\x42\x13\n\x11_discovery_scriptB\t\n\x07_vendor"R\n\x13TaskResourceRequest\x12#\n\rresource_name\x18\x01 \x01(\tR\x0cresourceName\x12\x16\n\x06\x61mount\x18\x02 \x01(\x01R\x06\x61mount"\xa5\x03\n\x0fResourceProfile\x12\x64\n\x12\x65xecutor_resources\x18\x01 \x03(\x0b\x32\x35.spark.connect.ResourceProfile.ExecutorResourcesEntryR\x11\x65xecutorResources\x12X\n\x0etask_resources\x18\x02 \x03(\x0b\x32\x31.spark.connect.ResourceProfile.TaskResourcesEntryR\rtaskResources\x1al\n\x16\x45xecutorResourcesEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12<\n\x05value\x18\x02 \x01(\x0b\x32&.spark.connect.ExecutorResourceRequestR\x05value:\x02\x38\x01\x1a\x64\n\x12TaskResourcesEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x38\n\x05value\x18\x02 \x01(\x0b\x32".spark.connect.TaskResourceRequestR\x05value:\x02\x38\x01"X\n\x06Origin\x12\x42\n\rpython_origin\x18\x01 \x01(\x0b\x32\x1b.spark.connect.PythonOriginH\x00R\x0cpythonOriginB\n\n\x08\x66unction"G\n\x0cPythonOrigin\x12\x1a\n\x08\x66ragment\x18\x01 \x01(\tR\x08\x66ragment\x12\x1b\n\tcall_site\x18\x02 \x01(\tR\x08\x63\x61llSite"\x1f\n\x05\x42ools\x12\x16\n\x06values\x18\x01 \x03(\x08R\x06values"\x1e\n\x04Ints\x12\x16\n\x06values\x18\x01 \x03(\x05R\x06values"\x1f\n\x05Longs\x12\x16\n\x06values\x18\x01 \x03(\x03R\x06values" \n\x06\x46loats\x12\x16\n\x06values\x18\x01 \x03(\x02R\x06values"!\n\x07\x44oubles\x12\x16\n\x06values\x18\x01 \x03(\x01R\x06values"!\n\x07Strings\x12\x16\n\x06values\x18\x01 \x03(\tR\x06valuesB6\n\x1eorg.apache.spark.connect.protoP\x01Z\x12internal/generatedb\x06proto3'
3939
)
4040

4141
_globals = globals()
@@ -70,4 +70,16 @@
7070
_globals["_ORIGIN"]._serialized_end = 1091
7171
_globals["_PYTHONORIGIN"]._serialized_start = 1093
7272
_globals["_PYTHONORIGIN"]._serialized_end = 1164
73+
_globals["_BOOLS"]._serialized_start = 1166
74+
_globals["_BOOLS"]._serialized_end = 1197
75+
_globals["_INTS"]._serialized_start = 1199
76+
_globals["_INTS"]._serialized_end = 1229
77+
_globals["_LONGS"]._serialized_start = 1231
78+
_globals["_LONGS"]._serialized_end = 1262
79+
_globals["_FLOATS"]._serialized_start = 1264
80+
_globals["_FLOATS"]._serialized_end = 1296
81+
_globals["_DOUBLES"]._serialized_start = 1298
82+
_globals["_DOUBLES"]._serialized_end = 1331
83+
_globals["_STRINGS"]._serialized_start = 1333
84+
_globals["_STRINGS"]._serialized_end = 1366
7385
# @@protoc_insertion_point(module_scope)

0 commit comments

Comments
 (0)