Skip to content

[SPARK-10004] [shuffle] Perform auth checks when clients read shuffle data. #8218

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 9 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ import org.apache.spark.storage.{BlockId, StorageLevel}
* is equivalent to one Spark-level shuffle block.
*/
class NettyBlockRpcServer(
appId: String,
serializer: Serializer,
blockManager: BlockDataManager)
extends RpcHandler with Logging {
Expand All @@ -55,7 +56,7 @@ class NettyBlockRpcServer(
case openBlocks: OpenBlocks =>
val blocks: Seq[ManagedBuffer] =
openBlocks.blockIds.map(BlockId.apply).map(blockManager.getBlockData)
val streamId = streamManager.registerStream(blocks.iterator.asJava)
val streamId = streamManager.registerStream(appId, blocks.iterator.asJava)
logTrace(s"Registered streamId $streamId with ${blocks.size} buffers")
responseContext.onSuccess(new StreamHandle(streamId, blocks.size).toByteArray)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ class NettyBlockTransferService(conf: SparkConf, securityManager: SecurityManage
private[this] var appId: String = _

override def init(blockDataManager: BlockDataManager): Unit = {
val rpcHandler = new NettyBlockRpcServer(serializer, blockDataManager)
val rpcHandler = new NettyBlockRpcServer(conf.getAppId, serializer, blockDataManager)
var serverBootstrap: Option[TransportServerBootstrap] = None
var clientBootstrap: Option[TransportClientBootstrap] = None
if (authEnabled) {
Expand Down
4 changes: 4 additions & 0 deletions network/common/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,10 @@
<artifactId>slf4j-api</artifactId>
<scope>provided</scope>
</dependency>
<dependency>
<groupId>com.google.code.findbugs</groupId>
<artifactId>jsr305</artifactId>
</dependency>
<!--
Promote Guava to "compile" so that maven-shade-plugin picks it up (for packaging the Optional
class exposed in the Java API). The plugin will then remove this dependency from the published
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import java.util.UUID;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeUnit;
import javax.annotation.Nullable;

import com.google.common.base.Objects;
import com.google.common.base.Preconditions;
Expand Down Expand Up @@ -70,6 +71,7 @@ public class TransportClient implements Closeable {

private final Channel channel;
private final TransportResponseHandler handler;
@Nullable private String clientId;

public TransportClient(Channel channel, TransportResponseHandler handler) {
this.channel = Preconditions.checkNotNull(channel);
Expand All @@ -84,6 +86,25 @@ public SocketAddress getSocketAddress() {
return channel.remoteAddress();
}

/**
* Returns the ID used by the client to authenticate itself when authentication is enabled.
*
* @return The client ID, or null if authentication is disabled.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this can also be null if clientId hasn't been set yet right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Technically yes, but I'm pretty sure code cannot get a TransportClient handle before the auth bootstraps have run.

*/
public String getClientId() {
return clientId;
}

/**
* Sets the authenticated client ID. This is meant to be used by the authentication layer.
*
* Trying to set a different client ID after it's been set will result in an exception.
*/
public void setClientId(String id) {
Preconditions.checkState(clientId == null, "Client ID has already been set.");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I understand correctly, this will never get called when you aren't using authentication, right? Maybe drop in a comment here explaining that, and on getClientId that it will return null without authentication?

this.clientId = id;
}

/**
* Requests a single chunk from the remote side, from the pre-negotiated streamId.
*
Expand Down Expand Up @@ -207,6 +228,7 @@ public void close() {
public String toString() {
return Objects.toStringHelper(this)
.add("remoteAdress", channel.remoteAddress())
.add("clientId", clientId)
.add("isActive", isActive())
.toString();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,8 @@ public void doBootstrap(TransportClient client, Channel channel) {
payload = saslClient.response(response);
}

client.setClientId(appId);

if (encrypt) {
if (!SparkSaslServer.QOP_AUTH_CONF.equals(saslClient.getNegotiatedProperty(Sasl.QOP))) {
throw new RuntimeException(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ public void receive(TransportClient client, byte[] message, RpcResponseCallback

if (saslServer == null) {
// First message in the handshake, setup the necessary state.
client.setClientId(saslMessage.appId);
saslServer = new SparkSaslServer(saslMessage.appId, secretKeyHolder,
conf.saslServerAlwaysEncrypt());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,13 @@
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicLong;

import com.google.common.base.Preconditions;
import io.netty.channel.Channel;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import org.apache.spark.network.buffer.ManagedBuffer;

import com.google.common.base.Preconditions;
import org.apache.spark.network.client.TransportClient;

/**
* StreamManager which allows registration of an Iterator&lt;ManagedBuffer&gt;, which are individually
Expand All @@ -44,6 +44,7 @@ public class OneForOneStreamManager extends StreamManager {

/** State of a single stream. */
private static class StreamState {
final String appId;
final Iterator<ManagedBuffer> buffers;

// The channel associated to the stream
Expand All @@ -53,7 +54,8 @@ private static class StreamState {
// that the caller only requests each chunk one at a time, in order.
int curChunk = 0;

StreamState(Iterator<ManagedBuffer> buffers) {
StreamState(String appId, Iterator<ManagedBuffer> buffers) {
this.appId = appId;
this.buffers = Preconditions.checkNotNull(buffers);
}
}
Expand Down Expand Up @@ -109,15 +111,34 @@ public void connectionTerminated(Channel channel) {
}
}

@Override
public void checkAuthorization(TransportClient client, long streamId) {
if (client.getClientId() != null) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm confused about this: since getClientId returns null if the client did not enable spark.authenticate, does that mean any application that did not enable SASL can read my shuffle files?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That would be true, but I don't believe you can actually set things up like that. Authentication is either enabled on the server or its not, for all clients.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, do you mean if the server enabled authentication, then any client that did not also enable it will fail the handshake in the first place?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Correct.

StreamState state = streams.get(streamId);
Preconditions.checkArgument(state != null, "Unknown stream ID.");
if (!client.getClientId().equals(state.appId)) {
throw new SecurityException(String.format(
"Client %s not authorized to read stream %d (app %s).",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we not disclose the actual appId in the exception message ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

App IDs are not secret.

client.getClientId(),
streamId,
state.appId));
}
}
}

/**
* Registers a stream of ManagedBuffers which are served as individual chunks one at a time to
* callers. Each ManagedBuffer will be release()'d after it is transferred on the wire. If a
* client connection is closed before the iterator is fully drained, then the remaining buffers
* will all be release()'d.
*
* If an app ID is provided, only callers who've authenticated with the given app ID will be
* allowed to fetch from this stream.
*/
public long registerStream(Iterator<ManagedBuffer> buffers) {
public long registerStream(String appId, Iterator<ManagedBuffer> buffers) {
long myStreamId = nextStreamId.getAndIncrement();
streams.put(myStreamId, new StreamState(buffers));
streams.put(myStreamId, new StreamState(appId, buffers));
return myStreamId;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import io.netty.channel.Channel;

import org.apache.spark.network.buffer.ManagedBuffer;
import org.apache.spark.network.client.TransportClient;

/**
* The StreamManager is used to fetch individual chunks from a stream. This is used in
Expand Down Expand Up @@ -60,4 +61,12 @@ public void registerChannel(Channel channel, long streamId) { }
* to read from the associated streams again, so any state can be cleaned up.
*/
public void connectionTerminated(Channel channel) { }

/**
* Verify that the client is authorized to read from the given stream.
*
* @throws SecurityException If client is not authorized.
*/
public void checkAuthorization(TransportClient client, long streamId) { }

}
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ private void processFetchRequest(final ChunkFetchRequest req) {

ManagedBuffer buf;
try {
streamManager.checkAuthorization(reverseClient, req.streamChunkId.streamId);
streamManager.registerChannel(channel, req.streamChunkId.streamId);
buf = streamManager.getChunk(req.streamChunkId.streamId, req.streamChunkId.chunkIndex);
} catch (Exception e) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ public ExternalShuffleBlockHandler(TransportConf conf, File registeredExecutorFi

/** Enables mocking out the StreamManager and BlockManager. */
@VisibleForTesting
ExternalShuffleBlockHandler(
public ExternalShuffleBlockHandler(
OneForOneStreamManager streamManager,
ExternalShuffleBlockResolver blockManager) {
this.streamManager = streamManager;
Expand All @@ -77,17 +77,19 @@ protected void handleMessage(
RpcResponseCallback callback) {
if (msgObj instanceof OpenBlocks) {
OpenBlocks msg = (OpenBlocks) msgObj;
List<ManagedBuffer> blocks = Lists.newArrayList();
checkAuth(client, msg.appId);

List<ManagedBuffer> blocks = Lists.newArrayList();
for (String blockId : msg.blockIds) {
blocks.add(blockManager.getBlockData(msg.appId, msg.execId, blockId));
}
long streamId = streamManager.registerStream(blocks.iterator());
long streamId = streamManager.registerStream(client.getClientId(), blocks.iterator());
logger.trace("Registered streamId {} with {} buffers", streamId, msg.blockIds.length);
callback.onSuccess(new StreamHandle(streamId, msg.blockIds.length).toByteArray());

} else if (msgObj instanceof RegisterExecutor) {
RegisterExecutor msg = (RegisterExecutor) msgObj;
checkAuth(client, msg.appId);
blockManager.registerExecutor(msg.appId, msg.execId, msg.executorInfo);
callback.onSuccess(new byte[0]);

Expand Down Expand Up @@ -126,4 +128,12 @@ public void reregisterExecutor(AppExecId appExecId, ExecutorShuffleInfo executor
public void close() {
blockManager.close();
}

private void checkAuth(TransportClient client, String appId) {
if (client.getClientId() != null && !client.getClientId().equals(appId)) {
throw new SecurityException(String.format(
"Client for %s not authorized for application %s.", client.getClientId(), appId));
}
}

}
Loading