Skip to content

Commit 7b42adb

Browse files
committed
Add unit tests
1 parent 8191bcb commit 7b42adb

File tree

9 files changed

+362
-56
lines changed

9 files changed

+362
-56
lines changed

core/src/main/scala/org/apache/spark/storage/BlockManager.scala

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,6 @@ private[spark] class BlockManager(
107107
}
108108

109109
var blockManagerId: BlockManagerId = _
110-
// BlockManagerId(executorId, blockTransferService.hostName, blockTransferService.port)
111110

112111
// Address of the server that serves this executor's shuffle files. This is either an external
113112
// service, or just our own Executor's BlockManager.
@@ -149,8 +148,6 @@ private[spark] class BlockManager(
149148
private val peerFetchLock = new Object
150149
private var lastPeerFetchTime = 0L
151150

152-
// initialize()
153-
154151
/* The compression codec to use. Note that the "lazy" val is necessary because we want to delay
155152
* the initialization of the compression codec until it is first used. The reason is that a Spark
156153
* program could be using a user-defined codec in a third party jar, which is loaded in
@@ -179,7 +176,7 @@ private[spark] class BlockManager(
179176
* the appId may not be known at BlockManager instantiation time (in particular for the driver,
180177
* where it is only learned after registration with the TaskScheduler).
181178
*
182-
* This method initialies the BlockTransferService and ShuffleClient registers with the
179+
* This method initializes the BlockTransferService and ShuffleClient registers with the
183180
* BlockManagerMaster, starts theBlockManagerWorker actor, and registers with a local shuffle
184181
* service if configured.
185182
*/

core/src/test/scala/org/apache/spark/network/nio/ConnectionManagerSuite.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ class ConnectionManagerSuite extends FunSuite {
6060
val conf = new SparkConf
6161
conf.set("spark.authenticate", "true")
6262
conf.set("spark.authenticate.secret", "good")
63+
conf.set("spark.app.id", "app-id")
6364
val securityManager = new SecurityManager(conf)
6465
val manager = new ConnectionManager(0, conf, securityManager)
6566
var numReceivedMessages = 0
@@ -95,6 +96,7 @@ class ConnectionManagerSuite extends FunSuite {
9596
test("security mismatch password") {
9697
val conf = new SparkConf
9798
conf.set("spark.authenticate", "true")
99+
conf.set("spark.app.id", "app-id")
98100
conf.set("spark.authenticate.secret", "good")
99101
val securityManager = new SecurityManager(conf)
100102
val manager = new ConnectionManager(0, conf, securityManager)
@@ -105,9 +107,7 @@ class ConnectionManagerSuite extends FunSuite {
105107
None
106108
})
107109

108-
val badconf = new SparkConf
109-
badconf.set("spark.authenticate", "true")
110-
badconf.set("spark.authenticate.secret", "bad")
110+
val badconf = conf.clone.set("spark.authenticate.secret", "bad")
111111
val badsecurityManager = new SecurityManager(badconf)
112112
val managerServer = new ConnectionManager(0, badconf, badsecurityManager)
113113
var numReceivedServerMessages = 0

network/common/src/main/java/org/apache/spark/network/client/TransportClient.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,7 @@ public void close() {
189189

190190
/** Returns a stable key for the given channel. Only valid after the channel is connected. */
191191
public String getChannelKey() {
192-
return channel.toString();
192+
return String.format("[%s, %s, %s]", channel.remoteAddress(), channel.localAddress(),
193+
channel.hashCode());
193194
}
194195
}
Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.network.sasl;
19+
20+
import com.google.common.base.Charsets;
21+
import io.netty.buffer.ByteBuf;
22+
23+
import org.apache.spark.network.protocol.Encodable;
24+
25+
/**
26+
* Encodes a Sasl-related message which is attempting to authenticate using some credentials tagged
27+
* with the given appId. This appId allows a single SaslRpcHandler to multiplex different
28+
* applications who may be using different sets of credentials.
29+
*/
30+
class SaslMessage implements Encodable {
31+
32+
/** Serialization tag used to catch incorrect payloads. */
33+
private static final byte TAG_BYTE = (byte) 0xEA;
34+
35+
public final String appId;
36+
public final byte[] payload;
37+
38+
public SaslMessage(String appId, byte[] payload) {
39+
this.appId = appId;
40+
this.payload = payload;
41+
}
42+
43+
@Override
44+
public int encodedLength() {
45+
// tag + appIdLength + appId + payloadLength + payload
46+
return 1 + 4 + appId.getBytes(Charsets.UTF_8).length + 4 + payload.length;
47+
}
48+
49+
@Override
50+
public void encode(ByteBuf buf) {
51+
buf.writeByte(TAG_BYTE);
52+
byte[] idBytes = appId.getBytes(Charsets.UTF_8);
53+
buf.writeInt(idBytes.length);
54+
buf.writeBytes(appId.getBytes(Charsets.UTF_8));
55+
buf.writeInt(payload.length);
56+
buf.writeBytes(payload);
57+
}
58+
59+
public static SaslMessage decode(ByteBuf buf) {
60+
if (buf.readByte() != TAG_BYTE) {
61+
throw new IllegalStateException("Expected SaslMessage, received something else");
62+
}
63+
64+
int idLength = buf.readInt();
65+
byte[] idBytes = new byte[idLength];
66+
buf.readBytes(idBytes);
67+
68+
int payloadLength = buf.readInt();
69+
byte[] payload = new byte[payloadLength];
70+
buf.readBytes(payload);
71+
72+
return new SaslMessage(new String(idBytes, Charsets.UTF_8), payload);
73+
}
74+
}

network/shuffle/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java

Lines changed: 2 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,8 @@ public class SaslRpcHandler implements RpcHandler {
4242

4343
private final RpcHandler delegate;
4444
private final SecretKeyHolder secretKeyHolder;
45+
46+
// TODO: Invalidate channels that have closed!
4547
private final ConcurrentMap<String, SparkSaslServer> channelAuthenticationMap;
4648

4749
public SaslRpcHandler(RpcHandler delegate, SecretKeyHolder secretKeyHolder) {
@@ -81,51 +83,3 @@ public StreamManager getStreamManager() {
8183
}
8284
}
8385

84-
/**
85-
* Encodes a Sasl-related message which is attempting to authenticate using some credentials tagged
86-
* with the given id. This 'id' allows a single SaslRpcHandler to multiplex different applications
87-
* who may be using different sets of credentials.
88-
*/
89-
class SaslMessage implements Encodable {
90-
91-
private static final byte TAG_BYTE = (byte) 0xEA;
92-
93-
public final String appId;
94-
public final byte[] payload;
95-
96-
public SaslMessage(String appId, byte[] payload) {
97-
this.appId = appId;
98-
this.payload = payload;
99-
}
100-
101-
@Override
102-
public int encodedLength() {
103-
return 1 + 4 + appId.getBytes(Charsets.UTF_8).length + 4 + payload.length;
104-
}
105-
106-
@Override
107-
public void encode(ByteBuf buf) {
108-
buf.writeByte(TAG_BYTE);
109-
byte[] idBytes = appId.getBytes(Charsets.UTF_8);
110-
buf.writeInt(idBytes.length);
111-
buf.writeBytes(appId.getBytes(Charsets.UTF_8));
112-
buf.writeInt(payload.length);
113-
buf.writeBytes(payload);
114-
}
115-
116-
public static SaslMessage decode(ByteBuf buf) {
117-
if (buf.readByte() != TAG_BYTE) {
118-
throw new IllegalStateException("Expected SaslMessage, received something else");
119-
}
120-
121-
int idLength = buf.readInt();
122-
byte[] idBytes = new byte[idLength];
123-
buf.readBytes(idBytes);
124-
125-
int payloadLength = buf.readInt();
126-
byte[] payload = new byte[payloadLength];
127-
buf.readBytes(payload);
128-
129-
return new SaslMessage(new String(idBytes, Charsets.UTF_8), payload);
130-
}
131-
}
Lines changed: 172 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,172 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.network.sasl;
19+
20+
import java.io.IOException;
21+
22+
import com.google.common.collect.Lists;
23+
import org.junit.After;
24+
import org.junit.AfterClass;
25+
import org.junit.BeforeClass;
26+
import org.junit.Test;
27+
28+
import static org.junit.Assert.*;
29+
30+
import org.apache.spark.network.TestUtils;
31+
import org.apache.spark.network.TransportContext;
32+
import org.apache.spark.network.client.RpcResponseCallback;
33+
import org.apache.spark.network.client.TransportClient;
34+
import org.apache.spark.network.client.TransportClientBootstrap;
35+
import org.apache.spark.network.client.TransportClientFactory;
36+
import org.apache.spark.network.server.OneForOneStreamManager;
37+
import org.apache.spark.network.server.RpcHandler;
38+
import org.apache.spark.network.server.StreamManager;
39+
import org.apache.spark.network.server.TransportServer;
40+
import org.apache.spark.network.shuffle.ExternalShuffleBlockHandler;
41+
import org.apache.spark.network.util.SystemPropertyConfigProvider;
42+
import org.apache.spark.network.util.TransportConf;
43+
44+
public class SaslIntegrationSuite {
45+
static ExternalShuffleBlockHandler handler;
46+
static TransportServer server;
47+
static TransportConf conf;
48+
static TransportContext context;
49+
50+
TransportClientFactory clientFactory;
51+
52+
/** Provides a secret key holder which always returns the given secret key. */
53+
static class TestSecretKeyHolder implements SecretKeyHolder {
54+
55+
private final String secretKey;
56+
57+
TestSecretKeyHolder(String secretKey) {
58+
this.secretKey = secretKey;
59+
}
60+
61+
@Override
62+
public String getSaslUser(String appId) {
63+
return "user";
64+
}
65+
@Override
66+
public String getSecretKey(String appId) {
67+
return secretKey;
68+
}
69+
}
70+
71+
72+
@BeforeClass
73+
public static void beforeAll() throws IOException {
74+
SecretKeyHolder secretKeyHolder = new TestSecretKeyHolder("good-key");
75+
SaslRpcHandler handler = new SaslRpcHandler(new TestRpcHandler(), secretKeyHolder);
76+
conf = new TransportConf(new SystemPropertyConfigProvider());
77+
context = new TransportContext(conf, handler);
78+
server = context.createServer();
79+
}
80+
81+
82+
@AfterClass
83+
public static void afterAll() {
84+
server.close();
85+
}
86+
87+
@After
88+
public void afterEach() {
89+
if (clientFactory != null) {
90+
clientFactory.close();
91+
clientFactory = null;
92+
}
93+
}
94+
95+
@Test
96+
public void testGoodClient() {
97+
clientFactory = context.createClientFactory(
98+
Lists.<TransportClientBootstrap>newArrayList(
99+
new SaslBootstrap("app-id", new TestSecretKeyHolder("good-key"))));
100+
101+
TransportClient client = clientFactory.createClient(TestUtils.getLocalHost(), server.getPort());
102+
String msg = "Hello, World!";
103+
byte[] resp = client.sendRpcSync(msg.getBytes(), 1000);
104+
assertEquals(msg, new String(resp)); // our rpc handler should just return the given msg
105+
}
106+
107+
@Test
108+
public void testBadClient() {
109+
clientFactory = context.createClientFactory(
110+
Lists.<TransportClientBootstrap>newArrayList(
111+
new SaslBootstrap("app-id", new TestSecretKeyHolder("bad-key"))));
112+
113+
try {
114+
// Bootstrap should fail on startup.
115+
clientFactory.createClient(TestUtils.getLocalHost(), server.getPort());
116+
} catch (Exception e) {
117+
assertTrue(e.getMessage(), e.getMessage().contains("Mismatched response"));
118+
}
119+
}
120+
121+
@Test
122+
public void testNoSaslClient() {
123+
clientFactory = context.createClientFactory(
124+
Lists.<TransportClientBootstrap>newArrayList());
125+
126+
TransportClient client = clientFactory.createClient(TestUtils.getLocalHost(), server.getPort());
127+
try {
128+
client.sendRpcSync(new byte[13], 1000);
129+
fail("Should have failed");
130+
} catch (Exception e) {
131+
assertTrue(e.getMessage(), e.getMessage().contains("Expected SaslMessage"));
132+
}
133+
134+
try {
135+
// Guessing the right tag byte doesn't magically get you in...
136+
client.sendRpcSync(new byte[] { (byte) 0xEA }, 1000);
137+
fail("Should have failed");
138+
} catch (Exception e) {
139+
assertTrue(e.getMessage(), e.getMessage().contains("java.lang.IndexOutOfBoundsException"));
140+
}
141+
}
142+
143+
@Test
144+
public void testNoSaslServer() {
145+
RpcHandler handler = new TestRpcHandler();
146+
TransportContext context = new TransportContext(conf, handler);
147+
clientFactory = context.createClientFactory(
148+
Lists.<TransportClientBootstrap>newArrayList(
149+
new SaslBootstrap("app-id", new TestSecretKeyHolder("key"))));
150+
TransportServer server = context.createServer();
151+
try {
152+
clientFactory.createClient(TestUtils.getLocalHost(), server.getPort());
153+
} catch (Exception e) {
154+
assertTrue(e.getMessage(), e.getMessage().contains("Digest-challenge format violation"));
155+
} finally {
156+
server.close();
157+
}
158+
}
159+
160+
/** RPC handler which simply responds with the message it received. */
161+
public static class TestRpcHandler implements RpcHandler {
162+
@Override
163+
public void receive(TransportClient client, byte[] message, RpcResponseCallback callback) {
164+
callback.onSuccess(message);
165+
}
166+
167+
@Override
168+
public StreamManager getStreamManager() {
169+
return new OneForOneStreamManager();
170+
}
171+
}
172+
}

0 commit comments

Comments
 (0)