@@ -24,14 +24,14 @@ import java.nio.channels._
24
24
import java .nio .channels .spi ._
25
25
import java .util .concurrent .atomic .AtomicInteger
26
26
import java .util .concurrent .{LinkedBlockingDeque , ThreadPoolExecutor , TimeUnit }
27
- import java .util .{Timer , TimerTask }
28
27
29
28
import scala .collection .mutable .{ArrayBuffer , HashMap , HashSet , SynchronizedMap , SynchronizedQueue }
30
29
import scala .concurrent .duration ._
31
30
import scala .concurrent .{Await , ExecutionContext , Future , Promise }
32
31
import scala .language .postfixOps
33
32
34
33
import com .google .common .base .Charsets .UTF_8
34
+ import io .netty .util .{Timeout , TimerTask , HashedWheelTimer }
35
35
36
36
import org .apache .spark ._
37
37
import org .apache .spark .network .sasl .{SparkSaslClient , SparkSaslServer }
@@ -77,7 +77,8 @@ private[nio] class ConnectionManager(
77
77
}
78
78
79
79
private val selector = SelectorProvider .provider.openSelector()
80
- private val ackTimeoutMonitor = new Timer (" AckTimeoutMonitor" , true )
80
+ private val ackTimeoutMonitor =
81
+ new HashedWheelTimer (Utils .namedThreadFactory(" AckTimeoutMonitor" ))
81
82
82
83
private val ackTimeout = conf.getInt(" spark.core.connection.ack.wait.timeout" , 60 )
83
84
@@ -903,8 +904,8 @@ private[nio] class ConnectionManager(
903
904
// memory leaks since cancelled TimerTasks won't necessarily be garbage collected until they are
904
905
// scheduled to run. Therefore, extract the message id from outside of the task:
905
906
val messageId = message.id
906
- val timeoutTask = new TimerTask {
907
- override def run (): Unit = {
907
+ val timeoutTask : TimerTask = new TimerTask {
908
+ override def run (timeout : Timeout ): Unit = {
908
909
messageStatuses.synchronized {
909
910
messageStatuses.remove(messageId).foreach ( s => {
910
911
val e = new IOException (" sendMessageReliably failed because ack " +
@@ -917,8 +918,10 @@ private[nio] class ConnectionManager(
917
918
}
918
919
}
919
920
921
+ val timoutTaskHandle = ackTimeoutMonitor.newTimeout(timeoutTask, ackTimeout, TimeUnit .SECONDS )
922
+
920
923
val status = new MessageStatus (message, connectionManagerId, s => {
921
- timeoutTask .cancel()
924
+ timoutTaskHandle .cancel()
922
925
s match {
923
926
case scala.util.Failure (e) =>
924
927
// Indicates a failure where we either never sent or never got ACK'd
@@ -947,7 +950,6 @@ private[nio] class ConnectionManager(
947
950
messageStatuses += ((message.id, status))
948
951
}
949
952
950
- ackTimeoutMonitor.schedule(timeoutTask, ackTimeout * 1000 )
951
953
sendMessage(connectionManagerId, message)
952
954
promise.future
953
955
}
@@ -957,7 +959,7 @@ private[nio] class ConnectionManager(
957
959
}
958
960
959
961
def stop () {
960
- ackTimeoutMonitor.cancel ()
962
+ ackTimeoutMonitor.stop ()
961
963
selectorThread.interrupt()
962
964
selectorThread.join()
963
965
selector.close()
0 commit comments