Skip to content

Commit 19cbd46

Browse files
support udf instance ser/de after initialization
1 parent e895e0c commit 19cbd46

File tree

5 files changed

+116
-45
lines changed

5 files changed

+116
-45
lines changed

sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1128,7 +1128,7 @@ private[hive] object HiveQl {
11281128
Explode(attributes, nodeToExpr(child))
11291129

11301130
case Token("TOK_FUNCTION", Token(functionName, Nil) :: children) =>
1131-
HiveGenericUdtf(functionName, attributes, children.map(nodeToExpr))
1131+
HiveGenericUdtf(new HiveFunctionCache(functionName), attributes, children.map(nodeToExpr))
11321132

11331133
case a: ASTNode =>
11341134
throw new NotImplementedError(

sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala

Lines changed: 87 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -54,47 +54,80 @@ private[hive] abstract class HiveFunctionRegistry
5454
val functionClassName = functionInfo.getFunctionClass.getName
5555

5656
if (classOf[UDF].isAssignableFrom(functionInfo.getFunctionClass)) {
57-
HiveSimpleUdf(functionClassName, children)
57+
HiveSimpleUdf(new HiveFunctionCache(functionClassName), children)
5858
} else if (classOf[GenericUDF].isAssignableFrom(functionInfo.getFunctionClass)) {
59-
HiveGenericUdf(functionClassName, children)
59+
HiveGenericUdf(new HiveFunctionCache(functionClassName), children)
6060
} else if (
6161
classOf[AbstractGenericUDAFResolver].isAssignableFrom(functionInfo.getFunctionClass)) {
62-
HiveGenericUdaf(functionClassName, children)
62+
HiveGenericUdaf(new HiveFunctionCache(functionClassName), children)
6363
} else if (classOf[UDAF].isAssignableFrom(functionInfo.getFunctionClass)) {
64-
HiveUdaf(functionClassName, children)
64+
HiveUdaf(new HiveFunctionCache(functionClassName), children)
6565
} else if (classOf[GenericUDTF].isAssignableFrom(functionInfo.getFunctionClass)) {
66-
HiveGenericUdtf(functionClassName, Nil, children)
66+
HiveGenericUdtf(new HiveFunctionCache(functionClassName), Nil, children)
6767
} else {
6868
sys.error(s"No handler for udf ${functionInfo.getFunctionClass}")
6969
}
7070
}
7171
}
7272

73-
private[hive] trait HiveFunctionFactory {
74-
val functionClassName: String
73+
/**
74+
* This class provides the UDF creation and also the UDF instance serialization and
75+
* de-serialization cross process boundary.
76+
* @param functionClassName UDF class name
77+
*/
78+
class HiveFunctionCache(var functionClassName: String) extends java.io.Externalizable {
79+
// for Seriliazation
80+
def this() = this(null)
7581

76-
def createFunction[UDFType]() =
77-
getContextOrSparkClassLoader.loadClass(functionClassName).newInstance.asInstanceOf[UDFType]
78-
}
82+
private var instance: Any = null
7983

80-
private[hive] abstract class HiveUdf extends Expression with Logging with HiveFunctionFactory {
81-
self: Product =>
84+
def writeExternal(out: java.io.ObjectOutput) {
85+
// Some of the UDF are serializable, but some not
86+
// Hive Utilities can handle both case
87+
val baos = new java.io.ByteArrayOutputStream()
88+
HiveShim.serializePlan(instance, baos)
89+
val functionInBytes = baos.toByteArray
8290

83-
type UDFType
84-
type EvaluatedType = Any
91+
// output the function name
92+
out.writeUTF(functionClassName)
8593

86-
def nullable = true
94+
// output the function bytes
95+
out.writeInt(functionInBytes.length)
96+
out.write(functionInBytes, 0, functionInBytes.length)
97+
}
8798

88-
lazy val function = createFunction[UDFType]()
99+
def readExternal(in: java.io.ObjectInput) {
100+
// read the function name
101+
functionClassName = in.readUTF()
89102

90-
override def toString = s"$nodeName#$functionClassName(${children.mkString(",")})"
91-
}
103+
// read the function in bytes
104+
val functionInBytesLength = in.readInt()
105+
val functionInBytes = new Array[Byte](functionInBytesLength)
106+
in.read(functionInBytes, 0, functionInBytesLength)
92107

93-
private[hive] case class HiveSimpleUdf(functionClassName: String, children: Seq[Expression])
94-
extends HiveUdf with HiveInspectors {
108+
// deserialize the function object via Hive Utilities
109+
instance = HiveShim.deserializePlan(new java.io.ByteArrayInputStream(functionInBytes),
110+
getContextOrSparkClassLoader.loadClass(functionClassName))
111+
}
112+
113+
def createFunction[UDFType]() = {
114+
if (instance == null) {
115+
instance = getContextOrSparkClassLoader.loadClass(functionClassName).newInstance
116+
}
117+
instance.asInstanceOf[UDFType]
118+
}
119+
}
95120

121+
private[hive] case class HiveSimpleUdf(cache: HiveFunctionCache, children: Seq[Expression])
122+
extends Expression with HiveInspectors with Logging {
123+
type EvaluatedType = Any
96124
type UDFType = UDF
97125

126+
def nullable = true
127+
128+
@transient
129+
lazy val function = cache.createFunction[UDFType]()
130+
98131
@transient
99132
protected lazy val method =
100133
function.getResolver.getEvalMethod(children.map(_.dataType.toTypeInfo))
@@ -131,6 +164,8 @@ private[hive] case class HiveSimpleUdf(functionClassName: String, children: Seq[
131164
.convertIfNecessary(wrap(children.map(c => c.eval(input)), arguments, cached): _*): _*),
132165
returnInspector)
133166
}
167+
168+
override def toString = s"$nodeName#${cache.functionClassName}(${children.mkString(",")})"
134169
}
135170

136171
// Adapter from Catalyst ExpressionResult to Hive DeferredObject
@@ -144,16 +179,23 @@ private[hive] class DeferredObjectAdapter(oi: ObjectInspector)
144179
override def get(): AnyRef = wrap(func(), oi)
145180
}
146181

147-
private[hive] case class HiveGenericUdf(functionClassName: String, children: Seq[Expression])
148-
extends HiveUdf with HiveInspectors {
182+
private[hive] case class HiveGenericUdf(cache: HiveFunctionCache, children: Seq[Expression])
183+
extends Expression with HiveInspectors with Logging {
149184
type UDFType = GenericUDF
185+
type EvaluatedType = Any
186+
187+
def nullable = true
188+
189+
@transient
190+
lazy val function = cache.createFunction[UDFType]()
150191

151192
@transient
152193
protected lazy val argumentInspectors = children.map(toInspector)
153194

154195
@transient
155-
protected lazy val returnInspector =
196+
protected lazy val returnInspector = {
156197
function.initializeAndFoldConstants(argumentInspectors.toArray)
198+
}
157199

158200
@transient
159201
protected lazy val isUDFDeterministic = {
@@ -183,18 +225,19 @@ private[hive] case class HiveGenericUdf(functionClassName: String, children: Seq
183225
}
184226
unwrap(function.evaluate(deferedObjects), returnInspector)
185227
}
228+
229+
override def toString = s"$nodeName#${cache.functionClassName}(${children.mkString(",")})"
186230
}
187231

188232
private[hive] case class HiveGenericUdaf(
189-
functionClassName: String,
233+
cache: HiveFunctionCache,
190234
children: Seq[Expression]) extends AggregateExpression
191-
with HiveInspectors
192-
with HiveFunctionFactory {
235+
with HiveInspectors {
193236

194237
type UDFType = AbstractGenericUDAFResolver
195238

196239
@transient
197-
protected lazy val resolver: AbstractGenericUDAFResolver = createFunction()
240+
protected lazy val resolver: AbstractGenericUDAFResolver = cache.createFunction()
198241

199242
@transient
200243
protected lazy val objectInspector = {
@@ -209,22 +252,22 @@ private[hive] case class HiveGenericUdaf(
209252

210253
def nullable: Boolean = true
211254

212-
override def toString = s"$nodeName#$functionClassName(${children.mkString(",")})"
255+
override def toString = s"$nodeName#${cache.functionClassName}(${children.mkString(",")})"
213256

214-
def newInstance() = new HiveUdafFunction(functionClassName, children, this)
257+
def newInstance() = new HiveUdafFunction(cache, children, this)
215258
}
216259

217260
/** It is used as a wrapper for the hive functions which uses UDAF interface */
218261
private[hive] case class HiveUdaf(
219-
functionClassName: String,
262+
cache: HiveFunctionCache,
220263
children: Seq[Expression]) extends AggregateExpression
221-
with HiveInspectors
222-
with HiveFunctionFactory {
264+
with HiveInspectors {
223265

224266
type UDFType = UDAF
225267

226268
@transient
227-
protected lazy val resolver: AbstractGenericUDAFResolver = new GenericUDAFBridge(createFunction())
269+
protected lazy val resolver: AbstractGenericUDAFResolver =
270+
new GenericUDAFBridge(cache.createFunction())
228271

229272
@transient
230273
protected lazy val objectInspector = {
@@ -239,10 +282,10 @@ private[hive] case class HiveUdaf(
239282

240283
def nullable: Boolean = true
241284

242-
override def toString = s"$nodeName#$functionClassName(${children.mkString(",")})"
285+
override def toString = s"$nodeName#${cache.functionClassName}(${children.mkString(",")})"
243286

244287
def newInstance() =
245-
new HiveUdafFunction(functionClassName, children, this, true)
288+
new HiveUdafFunction(cache, children, this, true)
246289
}
247290

248291
/**
@@ -257,13 +300,13 @@ private[hive] case class HiveUdaf(
257300
* user defined aggregations, which have clean semantics even in a partitioned execution.
258301
*/
259302
private[hive] case class HiveGenericUdtf(
260-
functionClassName: String,
303+
cache: HiveFunctionCache,
261304
aliasNames: Seq[String],
262305
children: Seq[Expression])
263-
extends Generator with HiveInspectors with HiveFunctionFactory {
306+
extends Generator with HiveInspectors {
264307

265308
@transient
266-
protected lazy val function: GenericUDTF = createFunction()
309+
protected lazy val function: GenericUDTF = cache.createFunction()
267310

268311
@transient
269312
protected lazy val inputInspectors = children.map(_.dataType).map(toInspector)
@@ -320,25 +363,24 @@ private[hive] case class HiveGenericUdtf(
320363
}
321364
}
322365

323-
override def toString = s"$nodeName#$functionClassName(${children.mkString(",")})"
366+
override def toString = s"$nodeName#${cache.functionClassName}(${children.mkString(",")})"
324367
}
325368

326369
private[hive] case class HiveUdafFunction(
327-
functionClassName: String,
370+
cache: HiveFunctionCache,
328371
exprs: Seq[Expression],
329372
base: AggregateExpression,
330373
isUDAFBridgeRequired: Boolean = false)
331374
extends AggregateFunction
332-
with HiveInspectors
333-
with HiveFunctionFactory {
375+
with HiveInspectors {
334376

335377
def this() = this(null, null, null)
336378

337379
private val resolver =
338380
if (isUDAFBridgeRequired) {
339-
new GenericUDAFBridge(createFunction[UDAF]())
381+
new GenericUDAFBridge(cache.createFunction[UDAF]())
340382
} else {
341-
createFunction[AbstractGenericUDAFResolver]()
383+
cache.createFunction[AbstractGenericUDAFResolver]()
342384
}
343385

344386
private val inspectors = exprs.map(_.dataType).map(toInspector).toArray
@@ -361,3 +403,4 @@ private[hive] case class HiveUdafFunction(
361403
function.iterate(buffer, inputs)
362404
}
363405
}
406+

sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,13 @@ class HiveUdfSuite extends QueryTest {
6060
| getStruct(1).f5 FROM src LIMIT 1
6161
""".stripMargin).first() === Row(1, 2, 3, 4, 5))
6262
}
63+
64+
test("SPARK-4785 When called with arguments referring column fields, PMOD throws NPE") {
65+
checkAnswer(
66+
sql("SELECT PMOD(CAST(key as INT), 10) FROM src LIMIT 1"),
67+
8
68+
)
69+
}
6370

6471
test("hive struct udf") {
6572
sql(

sql/hive/v0.12.0/src/main/scala/org/apache/spark/sql/hive/Shim12.scala

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,16 @@ import org.apache.spark.sql.catalyst.types.DecimalType
4949
private[hive] object HiveShim {
5050
val version = "0.12.0"
5151

52+
import org.apache.hadoop.hive.ql.exec.Utilities
53+
54+
def deserializePlan[UDFType](is: java.io.InputStream, clazz: Class[UDFType]): UDFType = {
55+
Utilities.deserializePlan(is).asInstanceOf[UDFType]
56+
}
57+
58+
def serializePlan(function: Any, out: java.io.OutputStream): Unit = {
59+
Utilities.serializePlan(function, out)
60+
}
61+
5262
def getTableDesc(
5363
serdeClass: Class[_ <: Deserializer],
5464
inputFormatClass: Class[_ <: InputFormat[_, _]],

sql/hive/v0.13.1/src/main/scala/org/apache/spark/sql/hive/Shim13.scala

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,17 @@ import scala.language.implicitConversions
4848
private[hive] object HiveShim {
4949
val version = "0.13.1"
5050

51+
import org.apache.hadoop.hive.ql.exec.Utilities
52+
import org.apache.hadoop.hive.conf.HiveConf
53+
54+
def deserializePlan[UDFType](is: java.io.InputStream, clazz: Class[UDFType]): UDFType = {
55+
Utilities.deserializePlan(is, clazz, new HiveConf())
56+
}
57+
58+
def serializePlan(function: Any, out: java.io.OutputStream): Unit = {
59+
Utilities.serializePlan(function, out, new HiveConf())
60+
}
61+
5162
def getTableDesc(
5263
serdeClass: Class[_ <: Deserializer],
5364
inputFormatClass: Class[_ <: InputFormat[_, _]],

0 commit comments

Comments
 (0)