@@ -1210,38 +1210,56 @@ class SQLContext(@transient val sparkContext: SparkContext)
1210
1210
* Returns a Catalyst Schema for the given java bean class.
1211
1211
*/
1212
1212
protected def getSchema (beanClass : Class [_]): Seq [AttributeReference ] = {
1213
+ val (dataType, _) = inferDataType(beanClass)
1214
+ dataType.asInstanceOf [StructType ].fields.map { f =>
1215
+ AttributeReference (f.name, f.dataType, f.nullable)()
1216
+ }
1217
+ }
1218
+
1219
+ /**
1220
+ * Infers the corresponding SQL data type of a Java class.
1221
+ * @param clazz Java class
1222
+ * @return (SQL data type, nullable)
1223
+ */
1224
+ private def inferDataType (clazz : Class [_]): (DataType , Boolean ) = {
1213
1225
// TODO: All of this could probably be moved to Catalyst as it is mostly not Spark specific.
1214
- val beanInfo = Introspector .getBeanInfo(beanClass)
1215
-
1216
- // Note: The ordering of elements may differ from when the schema is inferred in Scala.
1217
- // This is because beanInfo.getPropertyDescriptors gives no guarantees about
1218
- // element ordering.
1219
- val fields = beanInfo.getPropertyDescriptors.filterNot(_.getName == " class" )
1220
- fields.map { property =>
1221
- val (dataType, nullable) = property.getPropertyType match {
1222
- case c : Class [_] if c.isAnnotationPresent(classOf [SQLUserDefinedType ]) =>
1223
- (c.getAnnotation(classOf [SQLUserDefinedType ]).udt().newInstance(), true )
1224
- case c : Class [_] if c == classOf [java.lang.String ] => (StringType , true )
1225
- case c : Class [_] if c == java.lang.Short .TYPE => (ShortType , false )
1226
- case c : Class [_] if c == java.lang.Integer .TYPE => (IntegerType , false )
1227
- case c : Class [_] if c == java.lang.Long .TYPE => (LongType , false )
1228
- case c : Class [_] if c == java.lang.Double .TYPE => (DoubleType , false )
1229
- case c : Class [_] if c == java.lang.Byte .TYPE => (ByteType , false )
1230
- case c : Class [_] if c == java.lang.Float .TYPE => (FloatType , false )
1231
- case c : Class [_] if c == java.lang.Boolean .TYPE => (BooleanType , false )
1232
-
1233
- case c : Class [_] if c == classOf [java.lang.Short ] => (ShortType , true )
1234
- case c : Class [_] if c == classOf [java.lang.Integer ] => (IntegerType , true )
1235
- case c : Class [_] if c == classOf [java.lang.Long ] => (LongType , true )
1236
- case c : Class [_] if c == classOf [java.lang.Double ] => (DoubleType , true )
1237
- case c : Class [_] if c == classOf [java.lang.Byte ] => (ByteType , true )
1238
- case c : Class [_] if c == classOf [java.lang.Float ] => (FloatType , true )
1239
- case c : Class [_] if c == classOf [java.lang.Boolean ] => (BooleanType , true )
1240
- case c : Class [_] if c == classOf [java.math.BigDecimal ] => (DecimalType (), true )
1241
- case c : Class [_] if c == classOf [java.sql.Date ] => (DateType , true )
1242
- case c : Class [_] if c == classOf [java.sql.Timestamp ] => (TimestampType , true )
1243
- }
1244
- AttributeReference (property.getName, dataType, nullable)()
1226
+ clazz match {
1227
+ case c : Class [_] if c.isAnnotationPresent(classOf [SQLUserDefinedType ]) =>
1228
+ (c.getAnnotation(classOf [SQLUserDefinedType ]).udt().newInstance(), true )
1229
+
1230
+ case c : Class [_] if c == classOf [java.lang.String ] => (StringType , true )
1231
+ case c : Class [_] if c == java.lang.Short .TYPE => (ShortType , false )
1232
+ case c : Class [_] if c == java.lang.Integer .TYPE => (IntegerType , false )
1233
+ case c : Class [_] if c == java.lang.Long .TYPE => (LongType , false )
1234
+ case c : Class [_] if c == java.lang.Double .TYPE => (DoubleType , false )
1235
+ case c : Class [_] if c == java.lang.Byte .TYPE => (ByteType , false )
1236
+ case c : Class [_] if c == java.lang.Float .TYPE => (FloatType , false )
1237
+ case c : Class [_] if c == java.lang.Boolean .TYPE => (BooleanType , false )
1238
+
1239
+ case c : Class [_] if c == classOf [java.lang.Short ] => (ShortType , true )
1240
+ case c : Class [_] if c == classOf [java.lang.Integer ] => (IntegerType , true )
1241
+ case c : Class [_] if c == classOf [java.lang.Long ] => (LongType , true )
1242
+ case c : Class [_] if c == classOf [java.lang.Double ] => (DoubleType , true )
1243
+ case c : Class [_] if c == classOf [java.lang.Byte ] => (ByteType , true )
1244
+ case c : Class [_] if c == classOf [java.lang.Float ] => (FloatType , true )
1245
+ case c : Class [_] if c == classOf [java.lang.Boolean ] => (BooleanType , true )
1246
+
1247
+ case c : Class [_] if c == classOf [java.math.BigDecimal ] => (DecimalType (), true )
1248
+ case c : Class [_] if c == classOf [java.sql.Date ] => (DateType , true )
1249
+ case c : Class [_] if c == classOf [java.sql.Timestamp ] => (TimestampType , true )
1250
+
1251
+ case c : Class [_] if c.isArray =>
1252
+ val (dataType, nullable) = inferDataType(c.getComponentType)
1253
+ (ArrayType (dataType, nullable), true )
1254
+
1255
+ case _ =>
1256
+ val beanInfo = Introspector .getBeanInfo(clazz)
1257
+ val properties = beanInfo.getPropertyDescriptors.filterNot(_.getName == " class" )
1258
+ val fields = properties.map { property =>
1259
+ val (dataType, nullable) = inferDataType(property.getPropertyType)
1260
+ new StructField (property.getName, dataType, nullable)
1261
+ }
1262
+ (new StructType (fields), true )
1245
1263
}
1246
1264
}
1247
1265
}
0 commit comments