diff --git a/README.md b/README.md index 9c16fb9f..252fd618 100644 --- a/README.md +++ b/README.md @@ -353,7 +353,9 @@ samples_1 = [np.random.randn(2) + center_1 for _ in range(n_sample//2)] labels_1 = [1 for _ in range(n_sample//2)] rows = map(to_row, zip(map(lambda x: x.tolist(), samples_0 + samples_1), labels_0 + labels_1)) -sdf = spark.createDataFrame(rows) +schema = StructType([StructField("inputCol", ArrayType(FloatType())), + StructField("label", LongType())]) +sdf = spark.createDataFrame(rows, schema) ```