Skip to content

Commit 68d5421

Browse files
vbabaninrozza
andauthored
Add CSOT to OIDC. (#1741)
JAVA-5357 --------- Co-authored-by: Ross Lawley <[email protected]>
1 parent d04391f commit 68d5421

File tree

8 files changed

+146
-17
lines changed

8 files changed

+146
-17
lines changed

driver-core/src/main/com/mongodb/internal/TimeoutContext.java

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,9 @@ public static MongoOperationTimeoutException createMongoTimeoutException(final S
7070
public static <T> T throwMongoTimeoutException(final String message) {
7171
throw new MongoOperationTimeoutException(message);
7272
}
73+
public static <T> T throwMongoTimeoutException() {
74+
throw new MongoOperationTimeoutException("The operation exceeded the timeout limit.");
75+
}
7376

7477
public static MongoOperationTimeoutException createMongoTimeoutException(final Throwable cause) {
7578
return createMongoTimeoutException("Operation exceeded the timeout limit: " + cause.getMessage(), cause);

driver-core/src/main/com/mongodb/internal/connection/AwsAuthenticator.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ public String getMechanismName() {
6868
}
6969

7070
@Override
71-
protected SaslClient createSaslClient(final ServerAddress serverAddress) {
71+
protected SaslClient createSaslClient(final ServerAddress serverAddress, final OperationContext operationContext) {
7272
return new AwsSaslClient(getMongoCredential());
7373
}
7474

driver-core/src/main/com/mongodb/internal/connection/GSSAPIAuthenticator.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ public String getMechanismName() {
6767
}
6868

6969
@Override
70-
protected SaslClient createSaslClient(final ServerAddress serverAddress) {
70+
protected SaslClient createSaslClient(final ServerAddress serverAddress, final OperationContext operationContext) {
7171
MongoCredential credential = getMongoCredential();
7272
try {
7373
Map<String, Object> saslClientProperties = credential.getMechanismProperty(JAVA_SASL_CLIENT_PROPERTIES_KEY, null);

driver-core/src/main/com/mongodb/internal/connection/OidcAuthenticator.java

Lines changed: 33 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
import com.mongodb.connection.ClusterConnectionMode;
3030
import com.mongodb.connection.ConnectionDescription;
3131
import com.mongodb.internal.Locks;
32+
import com.mongodb.internal.TimeoutContext;
3233
import com.mongodb.internal.VisibleForTesting;
3334
import com.mongodb.internal.async.SingleResultCallback;
3435
import com.mongodb.internal.authentication.AzureCredentialHelper;
@@ -45,10 +46,12 @@
4546
import java.nio.file.Files;
4647
import java.nio.file.Paths;
4748
import java.time.Duration;
49+
import java.time.temporal.ChronoUnit;
4850
import java.util.Arrays;
4951
import java.util.Collections;
5052
import java.util.List;
5153
import java.util.Map;
54+
import java.util.concurrent.TimeUnit;
5255
import java.util.stream.Collectors;
5356

5457
import static com.mongodb.AuthenticationMechanism.MONGODB_OIDC;
@@ -64,11 +67,14 @@
6467
import static com.mongodb.assertions.Assertions.assertFalse;
6568
import static com.mongodb.assertions.Assertions.assertNotNull;
6669
import static com.mongodb.assertions.Assertions.assertTrue;
70+
import static com.mongodb.internal.TimeoutContext.throwMongoTimeoutException;
6771
import static com.mongodb.internal.async.AsyncRunnable.beginAsync;
6872
import static com.mongodb.internal.connection.OidcAuthenticator.OidcValidator.validateBeforeUse;
6973
import static java.lang.String.format;
7074

7175
/**
76+
* Created per connection, and exists until connection is closed.
77+
*
7278
* <p>This class is not part of the public API and may be removed or changed at any time</p>
7379
*/
7480
public final class OidcAuthenticator extends SaslAuthenticator {
@@ -118,8 +124,21 @@ public OidcAuthenticator(final MongoCredentialWithCache credential,
118124
}
119125
}
120126

121-
private Duration getCallbackTimeout() {
122-
return isHumanCallback() ? HUMAN_CALLBACK_TIMEOUT : CALLBACK_TIMEOUT;
127+
private Duration getCallbackTimeout(final TimeoutContext timeoutContext) {
128+
if (isHumanCallback()) {
129+
return HUMAN_CALLBACK_TIMEOUT;
130+
}
131+
132+
if (timeoutContext.hasTimeoutMS()) {
133+
return assertNotNull(timeoutContext.getTimeout()).call(TimeUnit.MILLISECONDS,
134+
() ->
135+
// we can get here if server selection timeout was set to infinite.
136+
ChronoUnit.FOREVER.getDuration(),
137+
(renamingMs) -> Duration.ofMillis(renamingMs),
138+
() -> throwMongoTimeoutException());
139+
140+
}
141+
return CALLBACK_TIMEOUT;
123142
}
124143

125144
@Override
@@ -128,10 +147,10 @@ public String getMechanismName() {
128147
}
129148

130149
@Override
131-
protected SaslClient createSaslClient(final ServerAddress serverAddress) {
150+
protected SaslClient createSaslClient(final ServerAddress serverAddress, final OperationContext operationContext) {
132151
this.serverAddress = assertNotNull(serverAddress);
133152
MongoCredentialWithCache mongoCredentialWithCache = getMongoCredentialWithCache();
134-
return new OidcSaslClient(mongoCredentialWithCache);
153+
return new OidcSaslClient(mongoCredentialWithCache, operationContext.getTimeoutContext());
135154
}
136155

137156
@Override
@@ -322,7 +341,7 @@ private void authenticationLoopAsync(final InternalConnection connection, final
322341
).finish(callback);
323342
}
324343

325-
private byte[] evaluate(final byte[] challenge) {
344+
private byte[] evaluate(final byte[] challenge, final TimeoutContext timeoutContext) {
326345
byte[][] jwt = new byte[1][];
327346
Locks.withInterruptibleLock(getMongoCredentialWithCache().getOidcLock(), () -> {
328347
OidcCacheEntry oidcCacheEntry = getMongoCredentialWithCache().getOidcCacheEntry();
@@ -343,7 +362,7 @@ private byte[] evaluate(final byte[] challenge) {
343362
// Invoke Callback using cached Refresh Token
344363
fallbackState = FallbackState.PHASE_2_REFRESH_CALLBACK_TOKEN;
345364
OidcCallbackResult result = requestCallback.onRequest(new OidcCallbackContextImpl(
346-
getCallbackTimeout(), cachedIdpInfo, cachedRefreshToken, userName));
365+
getCallbackTimeout(timeoutContext), cachedIdpInfo, cachedRefreshToken, userName));
347366
jwt[0] = populateCacheWithCallbackResultAndPrepareJwt(cachedIdpInfo, result);
348367
} else {
349368
// cache is empty
@@ -352,7 +371,7 @@ private byte[] evaluate(final byte[] challenge) {
352371
// no principal request
353372
fallbackState = FallbackState.PHASE_3B_CALLBACK_TOKEN;
354373
OidcCallbackResult result = requestCallback.onRequest(new OidcCallbackContextImpl(
355-
getCallbackTimeout(), userName));
374+
getCallbackTimeout(timeoutContext), userName));
356375
jwt[0] = populateCacheWithCallbackResultAndPrepareJwt(null, result);
357376
if (result.getRefreshToken() != null) {
358377
throw new MongoConfigurationException(
@@ -382,7 +401,7 @@ private byte[] evaluate(final byte[] challenge) {
382401
// there is no cached refresh token
383402
fallbackState = FallbackState.PHASE_3B_CALLBACK_TOKEN;
384403
OidcCallbackResult result = requestCallback.onRequest(new OidcCallbackContextImpl(
385-
getCallbackTimeout(), idpInfo, null, userName));
404+
getCallbackTimeout(timeoutContext), idpInfo, null, userName));
386405
jwt[0] = populateCacheWithCallbackResultAndPrepareJwt(idpInfo, result);
387406
}
388407
}
@@ -501,14 +520,18 @@ OidcCacheEntry clearRefreshToken() {
501520
}
502521

503522
private final class OidcSaslClient extends SaslClientImpl {
523+
private final TimeoutContext timeoutContext;
504524

505-
private OidcSaslClient(final MongoCredentialWithCache mongoCredentialWithCache) {
525+
private OidcSaslClient(final MongoCredentialWithCache mongoCredentialWithCache,
526+
final TimeoutContext timeoutContext) {
506527
super(mongoCredentialWithCache.getCredential());
528+
529+
this.timeoutContext = timeoutContext;
507530
}
508531

509532
@Override
510533
public byte[] evaluateChallenge(final byte[] challenge) {
511-
return evaluate(challenge);
534+
return evaluate(challenge, timeoutContext);
512535
}
513536

514537
@Override

driver-core/src/main/com/mongodb/internal/connection/PlainAuthenticator.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ public String getMechanismName() {
4747
}
4848

4949
@Override
50-
protected SaslClient createSaslClient(final ServerAddress serverAddress) {
50+
protected SaslClient createSaslClient(final ServerAddress serverAddress, final OperationContext operationContext) {
5151
MongoCredential credential = getMongoCredential();
5252
isTrue("mechanism is PLAIN", credential.getAuthenticationMechanism() == PLAIN);
5353
try {

driver-core/src/main/com/mongodb/internal/connection/SaslAuthenticator.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ abstract class SaslAuthenticator extends Authenticator implements SpeculativeAut
6565
public void authenticate(final InternalConnection connection, final ConnectionDescription connectionDescription,
6666
final OperationContext operationContext) {
6767
doAsSubject(() -> {
68-
SaslClient saslClient = createSaslClient(connection.getDescription().getServerAddress());
68+
SaslClient saslClient = createSaslClient(connection.getDescription().getServerAddress(), operationContext);
6969
throwIfSaslClientIsNull(saslClient);
7070
try {
7171
BsonDocument responseDocument = getNextSaslResponse(saslClient, connection, operationContext);
@@ -105,7 +105,7 @@ void authenticateAsync(final InternalConnection connection, final ConnectionDesc
105105
final OperationContext operationContext, final SingleResultCallback<Void> callback) {
106106
try {
107107
doAsSubject(() -> {
108-
SaslClient saslClient = createSaslClient(connection.getDescription().getServerAddress());
108+
SaslClient saslClient = createSaslClient(connection.getDescription().getServerAddress(), operationContext);
109109
throwIfSaslClientIsNull(saslClient);
110110
getNextSaslResponseAsync(saslClient, connection, operationContext, callback);
111111
return null;
@@ -117,7 +117,7 @@ void authenticateAsync(final InternalConnection connection, final ConnectionDesc
117117

118118
public abstract String getMechanismName();
119119

120-
protected abstract SaslClient createSaslClient(ServerAddress serverAddress);
120+
protected abstract SaslClient createSaslClient(ServerAddress serverAddress, OperationContext operationContext);
121121

122122
protected void appendSaslStartOptions(final BsonDocument saslStartCommand) {
123123
}

driver-core/src/main/com/mongodb/internal/connection/ScramShaAuthenticator.java

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,13 +90,17 @@ protected void appendSaslStartOptions(final BsonDocument saslStartCommand) {
9090

9191

9292
@Override
93-
protected SaslClient createSaslClient(final ServerAddress serverAddress) {
93+
protected SaslClient createSaslClient(final ServerAddress serverAddress, @Nullable final OperationContext operationContext) {
9494
if (speculativeSaslClient != null) {
9595
return speculativeSaslClient;
9696
}
9797
return new ScramShaSaslClient(getMongoCredentialWithCache().getCredential(), randomStringGenerator, authenticationHashGenerator);
9898
}
9999

100+
protected SaslClient createSaslClient(final ServerAddress serverAddress) {
101+
return createSaslClient(serverAddress, null);
102+
}
103+
100104
@Override
101105
public BsonDocument createSpeculativeAuthenticateCommand(final InternalConnection connection) {
102106
try {

driver-sync/src/test/functional/com/mongodb/internal/connection/OidcAuthenticationProseTests.java

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,10 @@
4242
import org.junit.jupiter.api.AfterEach;
4343
import org.junit.jupiter.api.BeforeEach;
4444
import org.junit.jupiter.api.Test;
45+
import org.junit.jupiter.params.ParameterizedTest;
46+
import org.junit.jupiter.params.provider.Arguments;
47+
import org.junit.jupiter.params.provider.MethodSource;
48+
import org.junit.jupiter.params.provider.ValueSource;
4549

4650
import java.io.IOException;
4751
import java.lang.reflect.Field;
@@ -50,6 +54,7 @@
5054
import java.nio.file.Path;
5155
import java.nio.file.Paths;
5256
import java.time.Duration;
57+
import java.time.temporal.ChronoUnit;
5358
import java.util.ArrayList;
5459
import java.util.Arrays;
5560
import java.util.Collections;
@@ -58,9 +63,11 @@
5863
import java.util.concurrent.CompletableFuture;
5964
import java.util.concurrent.ConcurrentLinkedQueue;
6065
import java.util.concurrent.ExecutionException;
66+
import java.util.concurrent.TimeUnit;
6167
import java.util.concurrent.atomic.AtomicInteger;
6268
import java.util.function.Supplier;
6369
import java.util.stream.Collectors;
70+
import java.util.stream.Stream;
6471

6572
import static com.mongodb.MongoCredential.ALLOWED_HOSTS_KEY;
6673
import static com.mongodb.MongoCredential.ENVIRONMENT_KEY;
@@ -72,9 +79,12 @@
7279
import static com.mongodb.MongoCredential.TOKEN_RESOURCE_KEY;
7380
import static com.mongodb.assertions.Assertions.assertNotNull;
7481
import static com.mongodb.testing.MongoAssertions.assertCause;
82+
import static java.lang.Math.min;
83+
import static java.lang.String.format;
7584
import static java.lang.System.getenv;
7685
import static java.util.Arrays.asList;
7786
import static org.junit.jupiter.api.Assertions.assertEquals;
87+
import static org.junit.jupiter.api.Assertions.assertFalse;
7888
import static org.junit.jupiter.api.Assertions.assertNull;
7989
import static org.junit.jupiter.api.Assertions.assertThrows;
8090
import static org.junit.jupiter.api.Assertions.assertTrue;
@@ -198,6 +208,91 @@ public void test2p1ValidCallbackInputs() {
198208
}
199209
}
200210

211+
// Not a prose test
212+
@ParameterizedTest(name = "{0}. "
213+
+ "Parameters: timeoutMs={1}, "
214+
+ "serverSelectionTimeoutMS={2},"
215+
+ " expectedTimeoutThreshold={3}")
216+
@MethodSource
217+
void testValidCallbackInputsTimeoutWhenTimeoutMsIsSet(final String testName,
218+
final int timeoutMs,
219+
final int serverSelectionTimeoutMS,
220+
final int expectedTimeoutThreshold) {
221+
TestCallback callback1 = createCallback();
222+
223+
OidcCallback callback2 = (context) -> {
224+
assertTrue(context.getTimeout().toMillis() < expectedTimeoutThreshold,
225+
format("Expected timeout to be less than %d, but was %d",
226+
expectedTimeoutThreshold,
227+
context.getTimeout().toMillis()));
228+
return callback1.onRequest(context);
229+
};
230+
231+
MongoClientSettings clientSettings = MongoClientSettings.builder(createSettings(callback2))
232+
.applyToClusterSettings(builder ->
233+
builder.serverSelectionTimeout(
234+
serverSelectionTimeoutMS,
235+
TimeUnit.MILLISECONDS))
236+
.timeout(timeoutMs, TimeUnit.MILLISECONDS)
237+
.build();
238+
239+
try (MongoClient mongoClient = createMongoClient(clientSettings)) {
240+
long start = System.nanoTime();
241+
performFind(mongoClient);
242+
assertEquals(1, callback1.getInvocations());
243+
long elapsed = msElapsedSince(start);
244+
245+
assertFalse(elapsed > (timeoutMs == 0 ? serverSelectionTimeoutMS : min(serverSelectionTimeoutMS, timeoutMs)),
246+
format("Elapsed time %d is greater then minimum of serverSelectionTimeoutMS and timeoutMs, which is %d. "
247+
+ "This indicates that the callback was not called with the expected timeout.",
248+
min(serverSelectionTimeoutMS, timeoutMs),
249+
elapsed));
250+
}
251+
}
252+
253+
private static Stream<Arguments> testValidCallbackInputsTimeoutWhenTimeoutMsIsSet() {
254+
return Stream.of(
255+
Arguments.of("serverSelectionTimeoutMS honored for oidc callback if it's lower than timeoutMS",
256+
1000, // timeoutMS
257+
500, // serverSelectionTimeoutMS
258+
499), // expectedTimeoutThreshold
259+
Arguments.of("timeoutMS honored for oidc callback if it's lower than serverSelectionTimeoutMS",
260+
500, // timeoutMS
261+
1000, // serverSelectionTimeoutMS
262+
499), // expectedTimeoutThreshold
263+
Arguments.of("serverSelectionTimeoutMS honored for oidc callback if timeoutMS=0",
264+
0, // infinite timeoutMS
265+
500, // serverSelectionTimeoutMS
266+
499) // expectedTimeoutThreshold
267+
);
268+
}
269+
270+
// Not a prose test
271+
@ParameterizedTest(name = "test callback timeout when server selection timeout is "
272+
+ "infinite and timeoutMs is set to {0}")
273+
@ValueSource(ints = {0, 100})
274+
void testCallbackTimeoutWhenServerSelectionTimeoutIsInfiniteTimeoutMsIsSet(final int timeoutMs) {
275+
TestCallback callback1 = createCallback();
276+
277+
OidcCallback callback2 = (context) -> {
278+
assertEquals(context.getTimeout(), ChronoUnit.FOREVER.getDuration());
279+
return callback1.onRequest(context);
280+
};
281+
282+
MongoClientSettings clientSettings = MongoClientSettings.builder(createSettings(callback2))
283+
.applyToClusterSettings(builder ->
284+
builder.serverSelectionTimeout(
285+
-1, // -1 means infinite
286+
TimeUnit.MILLISECONDS))
287+
.timeout(timeoutMs, TimeUnit.MILLISECONDS)
288+
.build();
289+
290+
try (MongoClient mongoClient = createMongoClient(clientSettings)) {
291+
performFind(mongoClient);
292+
assertEquals(1, callback1.getInvocations());
293+
}
294+
}
295+
201296
@Test
202297
public void test2p2RequestCallbackReturnsNull() {
203298
//noinspection ConstantConditions
@@ -1143,4 +1238,8 @@ public TestCallback createHumanCallback() {
11431238
.setPathSupplier(() -> oidcTokenDirectory() + "test_user1")
11441239
.setRefreshToken("refreshToken");
11451240
}
1241+
1242+
private long msElapsedSince(final long timeOfStart) {
1243+
return TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - timeOfStart);
1244+
}
11461245
}

0 commit comments

Comments
 (0)