Skip to content

Commit a1d1529

Browse files
committed
[SPARK-6475][SQL] recognize array types when infer data types from JavaBeans
Right now if there is a array field in a JavaBean, the user wold see an exception in `createDataFrame`. liancheng Author: Xiangrui Meng <[email protected]> Closes #5146 from mengxr/SPARK-6475 and squashes the following commits: 51e87e5 [Xiangrui Meng] validate schemas 4f2df5e [Xiangrui Meng] recognize array types when infer data types from JavaBeans
1 parent 08d4528 commit a1d1529

File tree

2 files changed

+89
-32
lines changed

2 files changed

+89
-32
lines changed

sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala

Lines changed: 49 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1210,38 +1210,56 @@ class SQLContext(@transient val sparkContext: SparkContext)
12101210
* Returns a Catalyst Schema for the given java bean class.
12111211
*/
12121212
protected def getSchema(beanClass: Class[_]): Seq[AttributeReference] = {
1213+
val (dataType, _) = inferDataType(beanClass)
1214+
dataType.asInstanceOf[StructType].fields.map { f =>
1215+
AttributeReference(f.name, f.dataType, f.nullable)()
1216+
}
1217+
}
1218+
1219+
/**
1220+
* Infers the corresponding SQL data type of a Java class.
1221+
* @param clazz Java class
1222+
* @return (SQL data type, nullable)
1223+
*/
1224+
private def inferDataType(clazz: Class[_]): (DataType, Boolean) = {
12131225
// TODO: All of this could probably be moved to Catalyst as it is mostly not Spark specific.
1214-
val beanInfo = Introspector.getBeanInfo(beanClass)
1215-
1216-
// Note: The ordering of elements may differ from when the schema is inferred in Scala.
1217-
// This is because beanInfo.getPropertyDescriptors gives no guarantees about
1218-
// element ordering.
1219-
val fields = beanInfo.getPropertyDescriptors.filterNot(_.getName == "class")
1220-
fields.map { property =>
1221-
val (dataType, nullable) = property.getPropertyType match {
1222-
case c: Class[_] if c.isAnnotationPresent(classOf[SQLUserDefinedType]) =>
1223-
(c.getAnnotation(classOf[SQLUserDefinedType]).udt().newInstance(), true)
1224-
case c: Class[_] if c == classOf[java.lang.String] => (StringType, true)
1225-
case c: Class[_] if c == java.lang.Short.TYPE => (ShortType, false)
1226-
case c: Class[_] if c == java.lang.Integer.TYPE => (IntegerType, false)
1227-
case c: Class[_] if c == java.lang.Long.TYPE => (LongType, false)
1228-
case c: Class[_] if c == java.lang.Double.TYPE => (DoubleType, false)
1229-
case c: Class[_] if c == java.lang.Byte.TYPE => (ByteType, false)
1230-
case c: Class[_] if c == java.lang.Float.TYPE => (FloatType, false)
1231-
case c: Class[_] if c == java.lang.Boolean.TYPE => (BooleanType, false)
1232-
1233-
case c: Class[_] if c == classOf[java.lang.Short] => (ShortType, true)
1234-
case c: Class[_] if c == classOf[java.lang.Integer] => (IntegerType, true)
1235-
case c: Class[_] if c == classOf[java.lang.Long] => (LongType, true)
1236-
case c: Class[_] if c == classOf[java.lang.Double] => (DoubleType, true)
1237-
case c: Class[_] if c == classOf[java.lang.Byte] => (ByteType, true)
1238-
case c: Class[_] if c == classOf[java.lang.Float] => (FloatType, true)
1239-
case c: Class[_] if c == classOf[java.lang.Boolean] => (BooleanType, true)
1240-
case c: Class[_] if c == classOf[java.math.BigDecimal] => (DecimalType(), true)
1241-
case c: Class[_] if c == classOf[java.sql.Date] => (DateType, true)
1242-
case c: Class[_] if c == classOf[java.sql.Timestamp] => (TimestampType, true)
1243-
}
1244-
AttributeReference(property.getName, dataType, nullable)()
1226+
clazz match {
1227+
case c: Class[_] if c.isAnnotationPresent(classOf[SQLUserDefinedType]) =>
1228+
(c.getAnnotation(classOf[SQLUserDefinedType]).udt().newInstance(), true)
1229+
1230+
case c: Class[_] if c == classOf[java.lang.String] => (StringType, true)
1231+
case c: Class[_] if c == java.lang.Short.TYPE => (ShortType, false)
1232+
case c: Class[_] if c == java.lang.Integer.TYPE => (IntegerType, false)
1233+
case c: Class[_] if c == java.lang.Long.TYPE => (LongType, false)
1234+
case c: Class[_] if c == java.lang.Double.TYPE => (DoubleType, false)
1235+
case c: Class[_] if c == java.lang.Byte.TYPE => (ByteType, false)
1236+
case c: Class[_] if c == java.lang.Float.TYPE => (FloatType, false)
1237+
case c: Class[_] if c == java.lang.Boolean.TYPE => (BooleanType, false)
1238+
1239+
case c: Class[_] if c == classOf[java.lang.Short] => (ShortType, true)
1240+
case c: Class[_] if c == classOf[java.lang.Integer] => (IntegerType, true)
1241+
case c: Class[_] if c == classOf[java.lang.Long] => (LongType, true)
1242+
case c: Class[_] if c == classOf[java.lang.Double] => (DoubleType, true)
1243+
case c: Class[_] if c == classOf[java.lang.Byte] => (ByteType, true)
1244+
case c: Class[_] if c == classOf[java.lang.Float] => (FloatType, true)
1245+
case c: Class[_] if c == classOf[java.lang.Boolean] => (BooleanType, true)
1246+
1247+
case c: Class[_] if c == classOf[java.math.BigDecimal] => (DecimalType(), true)
1248+
case c: Class[_] if c == classOf[java.sql.Date] => (DateType, true)
1249+
case c: Class[_] if c == classOf[java.sql.Timestamp] => (TimestampType, true)
1250+
1251+
case c: Class[_] if c.isArray =>
1252+
val (dataType, nullable) = inferDataType(c.getComponentType)
1253+
(ArrayType(dataType, nullable), true)
1254+
1255+
case _ =>
1256+
val beanInfo = Introspector.getBeanInfo(clazz)
1257+
val properties = beanInfo.getPropertyDescriptors.filterNot(_.getName == "class")
1258+
val fields = properties.map { property =>
1259+
val (dataType, nullable) = inferDataType(property.getPropertyType)
1260+
new StructField(property.getName, dataType, nullable)
1261+
}
1262+
(new StructType(fields), true)
12451263
}
12461264
}
12471265
}

sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,29 +17,39 @@
1717

1818
package test.org.apache.spark.sql;
1919

20+
import java.io.Serializable;
21+
import java.util.Arrays;
22+
2023
import org.junit.After;
2124
import org.junit.Assert;
2225
import org.junit.Before;
2326
import org.junit.Ignore;
2427
import org.junit.Test;
2528

29+
import org.apache.spark.api.java.JavaRDD;
30+
import org.apache.spark.api.java.JavaSparkContext;
2631
import org.apache.spark.sql.*;
32+
import org.apache.spark.sql.test.TestSQLContext;
2733
import org.apache.spark.sql.test.TestSQLContext$;
28-
import static org.apache.spark.sql.functions.*;
34+
import org.apache.spark.sql.types.*;
2935

36+
import static org.apache.spark.sql.functions.*;
3037

3138
public class JavaDataFrameSuite {
39+
private transient JavaSparkContext jsc;
3240
private transient SQLContext context;
3341

3442
@Before
3543
public void setUp() {
3644
// Trigger static initializer of TestData
3745
TestData$.MODULE$.testData();
46+
jsc = new JavaSparkContext(TestSQLContext.sparkContext());
3847
context = TestSQLContext$.MODULE$;
3948
}
4049

4150
@After
4251
public void tearDown() {
52+
jsc = null;
4353
context = null;
4454
}
4555

@@ -90,4 +100,33 @@ public void testShow() {
90100
df.show();
91101
df.show(1000);
92102
}
103+
104+
public static class Bean implements Serializable {
105+
private double a = 0.0;
106+
private Integer[] b = new Integer[]{0, 1};
107+
108+
public double getA() {
109+
return a;
110+
}
111+
112+
public Integer[] getB() {
113+
return b;
114+
}
115+
}
116+
117+
@Test
118+
public void testCreateDataFrameFromJavaBeans() {
119+
Bean bean = new Bean();
120+
JavaRDD<Bean> rdd = jsc.parallelize(Arrays.asList(bean));
121+
DataFrame df = context.createDataFrame(rdd, Bean.class);
122+
StructType schema = df.schema();
123+
Assert.assertEquals(new StructField("a", DoubleType$.MODULE$, false, Metadata.empty()),
124+
schema.apply("a"));
125+
Assert.assertEquals(
126+
new StructField("b", new ArrayType(IntegerType$.MODULE$, true), true, Metadata.empty()),
127+
schema.apply("b"));
128+
Row first = df.select("a", "b").first();
129+
Assert.assertEquals(bean.getA(), first.getDouble(0), 0.0);
130+
Assert.assertArrayEquals(bean.getB(), first.<Integer[]>getAs(1));
131+
}
93132
}

0 commit comments

Comments
 (0)