Skip to content

[Spark 3922] Refactor spark-core to use Utils.UTF_8 #2781

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions core/src/main/scala/org/apache/spark/SparkSaslClient.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@

package org.apache.spark

import java.io.IOException
import javax.security.auth.callback.Callback
import javax.security.auth.callback.CallbackHandler
import javax.security.auth.callback.NameCallback
Expand All @@ -31,6 +30,8 @@ import javax.security.sasl.SaslException

import scala.collection.JavaConversions.mapAsJavaMap

import com.google.common.base.Charsets.UTF_8

/**
* Implements SASL Client logic for Spark
*/
Expand Down Expand Up @@ -111,10 +112,10 @@ private[spark] class SparkSaslClient(securityMgr: SecurityManager) extends Logg
CallbackHandler {

private val userName: String =
SparkSaslServer.encodeIdentifier(securityMgr.getSaslUser().getBytes("utf-8"))
SparkSaslServer.encodeIdentifier(securityMgr.getSaslUser().getBytes(UTF_8))
private val secretKey = securityMgr.getSecretKey()
private val userPassword: Array[Char] = SparkSaslServer.encodePassword(
if (secretKey != null) secretKey.getBytes("utf-8") else "".getBytes("utf-8"))
if (secretKey != null) secretKey.getBytes(UTF_8) else "".getBytes(UTF_8))

/**
* Implementation used to respond to SASL request from the server.
Expand Down
10 changes: 6 additions & 4 deletions core/src/main/scala/org/apache/spark/SparkSaslServer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ import javax.security.sasl.Sasl
import javax.security.sasl.SaslException
import javax.security.sasl.SaslServer
import scala.collection.JavaConversions.mapAsJavaMap

import com.google.common.base.Charsets.UTF_8
import org.apache.commons.net.util.Base64

/**
Expand Down Expand Up @@ -89,7 +91,7 @@ private[spark] class SparkSaslServer(securityMgr: SecurityManager) extends Loggi
extends CallbackHandler {

private val userName: String =
SparkSaslServer.encodeIdentifier(securityMgr.getSaslUser().getBytes("utf-8"))
SparkSaslServer.encodeIdentifier(securityMgr.getSaslUser().getBytes(UTF_8))

override def handle(callbacks: Array[Callback]) {
logDebug("In the sasl server callback handler")
Expand All @@ -101,7 +103,7 @@ private[spark] class SparkSaslServer(securityMgr: SecurityManager) extends Loggi
case pc: PasswordCallback => {
logDebug("handle: SASL server callback: setting userPassword")
val password: Array[Char] =
SparkSaslServer.encodePassword(securityMgr.getSecretKey().getBytes("utf-8"))
SparkSaslServer.encodePassword(securityMgr.getSecretKey().getBytes(UTF_8))
pc.setPassword(password)
}
case rc: RealmCallback => {
Expand Down Expand Up @@ -159,7 +161,7 @@ private[spark] object SparkSaslServer {
* @return Base64-encoded string
*/
def encodeIdentifier(identifier: Array[Byte]): String = {
new String(Base64.encodeBase64(identifier), "utf-8")
new String(Base64.encodeBase64(identifier), UTF_8)
}

/**
Expand All @@ -168,7 +170,7 @@ private[spark] object SparkSaslServer {
* @return password as a char array.
*/
def encodePassword(password: Array[Byte]): Array[Char] = {
new String(Base64.encodeBase64(password), "utf-8").toCharArray()
new String(Base64.encodeBase64(password), UTF_8).toCharArray()
}
}

Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,14 @@ package org.apache.spark.api.python

import java.io._
import java.net._
import java.nio.charset.Charset
import java.util.{List => JList, ArrayList => JArrayList, Map => JMap, Collections}

import scala.collection.JavaConversions._
import scala.collection.JavaConverters._
import scala.collection.mutable
import scala.language.existentials

import com.google.common.base.Charsets.UTF_8
import net.razorvine.pickle.{Pickler, Unpickler}

import org.apache.hadoop.conf.Configuration
Expand Down Expand Up @@ -134,7 +134,7 @@ private[spark] class PythonRDD(
val exLength = stream.readInt()
val obj = new Array[Byte](exLength)
stream.readFully(obj)
throw new PythonException(new String(obj, "utf-8"),
throw new PythonException(new String(obj, UTF_8),
writerThread.exception.getOrElse(null))
case SpecialLengths.END_OF_DATA_SECTION =>
// We've finished the data section of the output, but we can still
Expand Down Expand Up @@ -318,7 +318,6 @@ private object SpecialLengths {
}

private[spark] object PythonRDD extends Logging {
val UTF8 = Charset.forName("UTF-8")

// remember the broadcasts sent to each worker
private val workerBroadcasts = new mutable.WeakHashMap[Socket, mutable.Set[Long]]()
Expand Down Expand Up @@ -586,7 +585,7 @@ private[spark] object PythonRDD extends Logging {
}

def writeUTF(str: String, dataOut: DataOutputStream) {
val bytes = str.getBytes(UTF8)
val bytes = str.getBytes(UTF_8)
dataOut.writeInt(bytes.length)
dataOut.write(bytes)
}
Expand Down Expand Up @@ -849,7 +848,7 @@ private[spark] object PythonRDD extends Logging {

private
class BytesToString extends org.apache.spark.api.java.function.Function[Array[Byte], String] {
override def call(arr: Array[Byte]) : String = new String(arr, PythonRDD.UTF8)
override def call(arr: Array[Byte]) : String = new String(arr, UTF_8)
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@
package org.apache.spark.api.python

import java.io.{DataOutput, DataInput}
import java.nio.charset.Charset

import com.google.common.base.Charsets.UTF_8

import org.apache.hadoop.io._
import org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat
Expand Down Expand Up @@ -136,7 +137,7 @@ object WriteInputFormatTestDataGenerator {
sc.parallelize(intKeys).saveAsSequenceFile(intPath)
sc.parallelize(intKeys.map{ case (k, v) => (k.toDouble, v) }).saveAsSequenceFile(doublePath)
sc.parallelize(intKeys.map{ case (k, v) => (k.toString, v) }).saveAsSequenceFile(textPath)
sc.parallelize(intKeys.map{ case (k, v) => (k, v.getBytes(Charset.forName("UTF-8"))) }
sc.parallelize(intKeys.map{ case (k, v) => (k, v.getBytes(UTF_8)) }
).saveAsSequenceFile(bytesPath)
val bools = Seq((1, true), (2, true), (2, false), (3, true), (2, false), (1, false))
sc.parallelize(bools).saveAsSequenceFile(boolPath)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import scala.collection.JavaConversions._
import scala.collection.Map

import akka.actor.ActorRef
import com.google.common.base.Charsets
import com.google.common.base.Charsets.UTF_8
import com.google.common.io.Files
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.{FileUtil, Path}
Expand Down Expand Up @@ -178,7 +178,7 @@ private[spark] class DriverRunner(
val stderr = new File(baseDir, "stderr")
val header = "Launch Command: %s\n%s\n\n".format(
command.mkString("\"", "\" \"", "\""), "=" * 40)
Files.append(header, stderr, Charsets.UTF_8)
Files.append(header, stderr, UTF_8)
CommandUtils.redirectStream(process.getErrorStream, stderr)
}
runCommandWithRetry(ProcessBuilderLike(builder), initialize, supervise)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ package org.apache.spark.deploy.worker
import java.io._

import akka.actor.ActorRef
import com.google.common.base.Charsets
import com.google.common.base.Charsets.UTF_8
import com.google.common.io.Files

import org.apache.spark.{SparkConf, Logging}
Expand Down Expand Up @@ -151,7 +151,7 @@ private[spark] class ExecutorRunner(
stdoutAppender = FileAppender(process.getInputStream, stdout, conf)

val stderr = new File(executorDir, "stderr")
Files.write(header, stderr, Charsets.UTF_8)
Files.write(header, stderr, UTF_8)
stderrAppender = FileAppender(process.getErrorStream, stderr, conf)

state = ExecutorState.RUNNING
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,13 @@ package org.apache.spark.network.netty.client

import java.util.concurrent.TimeoutException

import com.google.common.base.Charsets.UTF_8
import io.netty.bootstrap.Bootstrap
import io.netty.buffer.PooledByteBufAllocator
import io.netty.channel.socket.SocketChannel
import io.netty.channel.{ChannelFutureListener, ChannelFuture, ChannelInitializer, ChannelOption}
import io.netty.handler.codec.LengthFieldBasedFrameDecoder
import io.netty.handler.codec.string.StringEncoder
import io.netty.util.CharsetUtil

import org.apache.spark.Logging

Expand Down Expand Up @@ -61,7 +61,7 @@ class BlockFetchingClient(factory: BlockFetchingClientFactory, hostname: String,
b.handler(new ChannelInitializer[SocketChannel] {
override def initChannel(ch: SocketChannel): Unit = {
ch.pipeline
.addLast("encoder", new StringEncoder(CharsetUtil.UTF_8))
.addLast("encoder", new StringEncoder(UTF_8))
// maxFrameLength = 2G, lengthFieldOffset = 0, lengthFieldLength = 4
.addLast("framedLengthDecoder", new LengthFieldBasedFrameDecoder(Int.MaxValue, 0, 4))
.addLast("handler", handler)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

package org.apache.spark.network.netty.client

import com.google.common.base.Charsets.UTF_8
import io.netty.buffer.ByteBuf
import io.netty.channel.{ChannelHandlerContext, SimpleChannelInboundHandler}

Expand Down Expand Up @@ -67,7 +68,7 @@ class BlockFetchingClientHandler extends SimpleChannelInboundHandler[ByteBuf] wi
val blockIdLen = in.readInt()
val blockIdBytes = new Array[Byte](math.abs(blockIdLen))
in.readBytes(blockIdBytes)
val blockId = new String(blockIdBytes)
val blockId = new String(blockIdBytes, UTF_8)
val blockSize = totalLen - math.abs(blockIdLen) - 4

def server = ctx.channel.remoteAddress.toString
Expand All @@ -76,7 +77,7 @@ class BlockFetchingClientHandler extends SimpleChannelInboundHandler[ByteBuf] wi
if (blockIdLen < 0) {
val errorMessageBytes = new Array[Byte](blockSize)
in.readBytes(errorMessageBytes)
val errorMsg = new String(errorMessageBytes)
val errorMsg = new String(errorMessageBytes, UTF_8)
logTrace(s"Received block $blockId ($blockSize B) with error $errorMsg from $server")

val listener = outstandingRequests.get(blockId)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package org.apache.spark.network.netty.server

import java.net.InetSocketAddress

import com.google.common.base.Charsets.UTF_8
import io.netty.bootstrap.ServerBootstrap
import io.netty.buffer.PooledByteBufAllocator
import io.netty.channel.{ChannelFuture, ChannelInitializer, ChannelOption}
Expand All @@ -30,7 +31,6 @@ import io.netty.channel.socket.nio.NioServerSocketChannel
import io.netty.channel.socket.oio.OioServerSocketChannel
import io.netty.handler.codec.LineBasedFrameDecoder
import io.netty.handler.codec.string.StringDecoder
import io.netty.util.CharsetUtil

import org.apache.spark.{Logging, SparkConf}
import org.apache.spark.network.netty.NettyConfig
Expand Down Expand Up @@ -131,7 +131,7 @@ class BlockServer(conf: NettyConfig, dataProvider: BlockDataProvider) extends Lo
override def initChannel(ch: SocketChannel): Unit = {
ch.pipeline
.addLast("frameDecoder", new LineBasedFrameDecoder(1024)) // max block id length 1024
.addLast("stringDecoder", new StringDecoder(CharsetUtil.UTF_8))
.addLast("stringDecoder", new StringDecoder(UTF_8))
.addLast("blockHeaderEncoder", new BlockHeaderEncoder)
.addLast("handler", new BlockServerHandler(dataProvider))
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,13 @@

package org.apache.spark.network.netty.server

import com.google.common.base.Charsets.UTF_8
import io.netty.channel.ChannelInitializer
import io.netty.channel.socket.SocketChannel
import io.netty.handler.codec.LineBasedFrameDecoder
import io.netty.handler.codec.string.StringDecoder
import io.netty.util.CharsetUtil
import org.apache.spark.storage.BlockDataProvider

import org.apache.spark.storage.BlockDataProvider

/** Channel initializer that sets up the pipeline for the BlockServer. */
private[netty]
Expand All @@ -33,7 +33,7 @@ class BlockServerChannelInitializer(dataProvider: BlockDataProvider)
override def initChannel(ch: SocketChannel): Unit = {
ch.pipeline
.addLast("frameDecoder", new LineBasedFrameDecoder(1024)) // max block id length 1024
.addLast("stringDecoder", new StringDecoder(CharsetUtil.UTF_8))
.addLast("stringDecoder", new StringDecoder(UTF_8))
.addLast("blockHeaderEncoder", new BlockHeaderEncoder)
.addLast("handler", new BlockServerHandler(dataProvider))
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ import scala.concurrent.duration._
import scala.concurrent.{Await, ExecutionContext, Future, Promise}
import scala.language.postfixOps

import com.google.common.base.Charsets.UTF_8

import org.apache.spark._
import org.apache.spark.util.Utils

Expand Down Expand Up @@ -923,7 +925,7 @@ private[nio] class ConnectionManager(
val errorMsgByteBuf = ackMessage.asInstanceOf[BufferMessage].buffers.head
val errorMsgBytes = new Array[Byte](errorMsgByteBuf.limit())
errorMsgByteBuf.get(errorMsgBytes)
val errorMsg = new String(errorMsgBytes, "utf-8")
val errorMsg = new String(errorMsgBytes, UTF_8)
val e = new IOException(
s"sendMessageReliably failed with ACK that signalled a remote error: $errorMsg")
if (!promise.tryFailure(e)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ import java.nio.ByteBuffer

import scala.collection.mutable.ArrayBuffer

import com.google.common.base.Charsets.UTF_8

import org.apache.spark.util.Utils

private[nio] abstract class Message(val typ: Long, val id: Int) {
Expand Down Expand Up @@ -92,7 +94,7 @@ private[nio] object Message {
*/
def createErrorMessage(exception: Exception, ackId: Int): BufferMessage = {
val exceptionString = Utils.exceptionString(exception)
val serializedExceptionString = ByteBuffer.wrap(exceptionString.getBytes("utf-8"))
val serializedExceptionString = ByteBuffer.wrap(exceptionString.getBytes(UTF_8))
val errorMessage = createBufferMessage(serializedExceptionString, ackId)
errorMessage.hasError = true
errorMessage
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package org.apache.spark.network.netty.client

import java.nio.ByteBuffer

import com.google.common.base.Charsets.UTF_8
import io.netty.buffer.Unpooled
import io.netty.channel.embedded.EmbeddedChannel

Expand All @@ -42,7 +43,7 @@ class BlockFetchingClientHandlerSuite extends FunSuite with PrivateMethodTester
parsedBlockId = bid
val bytes = new Array[Byte](refCntBuf.byteBuffer().remaining)
refCntBuf.byteBuffer().get(bytes)
parsedBlockData = new String(bytes)
parsedBlockData = new String(bytes, UTF_8)
}
}
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,12 @@

package org.apache.spark.network.netty.server

import com.google.common.base.Charsets.UTF_8
import io.netty.buffer.ByteBuf
import io.netty.channel.embedded.EmbeddedChannel

import org.scalatest.FunSuite


class BlockHeaderEncoderSuite extends FunSuite {

test("encode normal block data") {
Expand All @@ -35,7 +35,7 @@ class BlockHeaderEncoderSuite extends FunSuite {

val blockIdBytes = new Array[Byte](blockId.length)
out.readBytes(blockIdBytes)
assert(new String(blockIdBytes) === blockId)
assert(new String(blockIdBytes, UTF_8) === blockId)
assert(out.readableBytes() === 0)

channel.close()
Expand All @@ -52,11 +52,11 @@ class BlockHeaderEncoderSuite extends FunSuite {

val blockIdBytes = new Array[Byte](blockId.length)
out.readBytes(blockIdBytes)
assert(new String(blockIdBytes) === blockId)
assert(new String(blockIdBytes, UTF_8) === blockId)

val errorMsgBytes = new Array[Byte](errorMsg.length)
out.readBytes(errorMsgBytes)
assert(new String(errorMsgBytes) === errorMsg)
assert(new String(errorMsgBytes, UTF_8) === errorMsg)
assert(out.readableBytes() === 0)

channel.close()
Expand Down
Loading