Skip to content

Commit 717da9c

Browse files
authored
Revert "Resolve TLS channel address before opening socket" (#1979)
This reverts commit 13d4aef.
1 parent 58122f6 commit 717da9c

2 files changed

Lines changed: 12 additions & 108 deletions

File tree

driver-core/src/main/com/mongodb/internal/connection/TlsChannelStreamFactoryFactory.java

Lines changed: 12 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,6 @@
3737
import javax.net.ssl.SSLParameters;
3838
import java.io.Closeable;
3939
import java.io.IOException;
40-
import java.net.InetSocketAddress;
4140
import java.net.StandardSocketOptions;
4241
import java.nio.ByteBuffer;
4342
import java.nio.channels.CompletionHandler;
@@ -210,60 +209,35 @@ private static class TlsChannelStream extends AsynchronousChannelStream {
210209
@Override
211210
public void openAsync(final OperationContext operationContext, final AsyncCompletionHandler<Void> handler) {
212211
isTrue("unopened", getChannel() == null);
213-
SocketChannel socketChannel = null;
214-
SelectorMonitor.SocketRegistration socketRegistration = null;
215-
boolean registered = false;
216212
try {
217-
//getConnectTimeoutMs MUST be called before connection attempt, as it might throw MongoOperationTimeout exception.
218-
int connectTimeoutMs = operationContext.getTimeoutContext().getConnectTimeoutMs();
219-
InetSocketAddress socketAddress = getSocketAddresses(getServerAddress(), inetAddressResolver).get(0);
220-
SocketChannel openedSocketChannel = SocketChannel.open();
221-
socketChannel = openedSocketChannel;
222-
openedSocketChannel.configureBlocking(false);
213+
SocketChannel socketChannel = SocketChannel.open();
214+
socketChannel.configureBlocking(false);
223215

224-
openedSocketChannel.setOption(StandardSocketOptions.TCP_NODELAY, true);
225-
openedSocketChannel.setOption(StandardSocketOptions.SO_KEEPALIVE, true);
216+
socketChannel.setOption(StandardSocketOptions.TCP_NODELAY, true);
217+
socketChannel.setOption(StandardSocketOptions.SO_KEEPALIVE, true);
226218
if (getSettings().getReceiveBufferSize() > 0) {
227-
openedSocketChannel.setOption(StandardSocketOptions.SO_RCVBUF, getSettings().getReceiveBufferSize());
219+
socketChannel.setOption(StandardSocketOptions.SO_RCVBUF, getSettings().getReceiveBufferSize());
228220
}
229221
if (getSettings().getSendBufferSize() > 0) {
230-
openedSocketChannel.setOption(StandardSocketOptions.SO_SNDBUF, getSettings().getSendBufferSize());
222+
socketChannel.setOption(StandardSocketOptions.SO_SNDBUF, getSettings().getSendBufferSize());
231223
}
232-
openedSocketChannel.connect(socketAddress);
233-
socketRegistration = new SelectorMonitor.SocketRegistration(
234-
openedSocketChannel, () -> initializeTslChannel(handler, openedSocketChannel));
224+
//getConnectTimeoutMs MUST be called before connection attempt, as it might throw MongoOperationTimeout exception.
225+
int connectTimeoutMs = operationContext.getTimeoutContext().getConnectTimeoutMs();
226+
socketChannel.connect(getSocketAddresses(getServerAddress(), inetAddressResolver).get(0));
227+
SelectorMonitor.SocketRegistration socketRegistration = new SelectorMonitor.SocketRegistration(
228+
socketChannel, () -> initializeTslChannel(handler, socketChannel));
235229

236230
if (connectTimeoutMs > 0) {
237231
scheduleTimeoutInterruption(handler, socketRegistration, connectTimeoutMs);
238232
}
239233
selectorMonitor.register(socketRegistration);
240-
registered = true;
241234
} catch (IOException e) {
242-
closeUnregisteredSocketChannel(socketChannel, socketRegistration, registered, e);
243235
handler.failed(new MongoSocketOpenException("Exception opening socket", getServerAddress(), e));
244236
} catch (Throwable t) {
245-
closeUnregisteredSocketChannel(socketChannel, socketRegistration, registered, t);
246237
handler.failed(t);
247238
}
248239
}
249240

250-
private void closeUnregisteredSocketChannel(@Nullable final SocketChannel socketChannel,
251-
@Nullable final SelectorMonitor.SocketRegistration socketRegistration,
252-
final boolean registered, final Throwable failure) {
253-
if (!registered) {
254-
if (socketRegistration != null) {
255-
socketRegistration.tryCancelPendingConnection();
256-
}
257-
if (socketChannel != null) {
258-
try {
259-
socketChannel.close();
260-
} catch (IOException e) {
261-
failure.addSuppressed(e);
262-
}
263-
}
264-
}
265-
}
266-
267241
private void scheduleTimeoutInterruption(final AsyncCompletionHandler<Void> handler,
268242
final SelectorMonitor.SocketRegistration socketRegistration,
269243
final int connectTimeoutMs) {
@@ -410,3 +384,4 @@ public void close() throws IOException {
410384
}
411385
}
412386
}
387+

driver-core/src/test/functional/com/mongodb/internal/connection/TlsChannelStreamFunctionalTest.java

Lines changed: 0 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -17,22 +17,18 @@
1717
package com.mongodb.internal.connection;
1818

1919
import com.mongodb.ClusterFixture;
20-
import com.mongodb.MongoSocketException;
2120
import com.mongodb.MongoSocketOpenException;
2221
import com.mongodb.ServerAddress;
23-
import com.mongodb.connection.AsyncCompletionHandler;
2422
import com.mongodb.connection.SocketSettings;
2523
import com.mongodb.connection.SslSettings;
2624
import com.mongodb.internal.TimeoutContext;
2725
import com.mongodb.internal.TimeoutSettings;
28-
import com.mongodb.spi.dns.InetAddressResolver;
2926
import org.bson.ByteBuf;
3027
import org.bson.ByteBufNIO;
3128
import org.junit.jupiter.api.DisplayName;
3229
import org.junit.jupiter.api.Test;
3330
import org.junit.jupiter.params.ParameterizedTest;
3431
import org.junit.jupiter.params.provider.ValueSource;
35-
import org.mockito.ArgumentCaptor;
3632
import org.mockito.MockedStatic;
3733
import org.mockito.Mockito;
3834
import org.mockito.invocation.InvocationOnMock;
@@ -41,13 +37,11 @@
4137
import javax.net.ssl.SSLContext;
4238
import javax.net.ssl.SSLEngine;
4339
import java.io.IOException;
44-
import java.net.InetAddress;
4540
import java.net.ServerSocket;
4641
import java.nio.ByteBuffer;
4742
import java.nio.channels.InterruptedByTimeoutException;
4843
import java.nio.channels.SocketChannel;
4944
import java.util.Collections;
50-
import java.util.List;
5145
import java.util.concurrent.TimeUnit;
5246

5347
import static com.mongodb.ClusterFixture.getPrimaryServerDescription;
@@ -58,12 +52,10 @@
5852
import static org.junit.jupiter.api.Assertions.assertFalse;
5953
import static org.junit.jupiter.api.Assertions.assertInstanceOf;
6054
import static org.junit.jupiter.api.Assertions.assertNotNull;
61-
import static org.junit.jupiter.api.Assertions.assertSame;
6255
import static org.junit.jupiter.api.Assertions.assertThrows;
6356
import static org.junit.jupiter.api.Assertions.assertTrue;
6457
import static org.junit.jupiter.api.Assertions.fail;
6558
import static org.junit.jupiter.api.Assumptions.assumeTrue;
66-
import static org.mockito.ArgumentMatchers.any;
6759
import static org.mockito.ArgumentMatchers.anyInt;
6860
import static org.mockito.ArgumentMatchers.anyString;
6961
import static org.mockito.Mockito.atLeast;
@@ -76,69 +68,6 @@ class TlsChannelStreamFunctionalTest {
7668
private static final String UNREACHABLE_PRIVATE_IP_ADDRESS = "10.255.255.1";
7769
private static final int UNREACHABLE_PORT = 65333;
7870

79-
@Test
80-
void shouldFailAsyncCompletionHandlerWithoutOpeningSocketChannelIfNameResolutionFails() {
81-
//given
82-
ServerAddress serverAddress = new ServerAddress();
83-
MongoSocketException exception = new MongoSocketException("Temporary failure in name resolution", serverAddress);
84-
InetAddressResolver inetAddressResolver = new InetAddressResolver() {
85-
@Override
86-
public List<InetAddress> lookupByName(final String host) {
87-
throw exception;
88-
}
89-
};
90-
91-
try (StreamFactoryFactory streamFactoryFactory = new TlsChannelStreamFactoryFactory(inetAddressResolver);
92-
MockedStatic<SocketChannel> socketChannelMockedStatic = Mockito.mockStatic(SocketChannel.class)) {
93-
StreamFactory streamFactory = streamFactoryFactory.create(SocketSettings.builder()
94-
.connectTimeout(100, TimeUnit.MILLISECONDS)
95-
.build(), SSL_SETTINGS);
96-
Stream stream = streamFactory.create(serverAddress);
97-
@SuppressWarnings("unchecked")
98-
AsyncCompletionHandler<Void> handler = Mockito.mock(AsyncCompletionHandler.class);
99-
100-
//when
101-
stream.openAsync(createOperationContext(100), handler);
102-
103-
//then
104-
verify(handler).failed(exception);
105-
verify(handler, times(0)).completed(null);
106-
socketChannelMockedStatic.verify(SocketChannel::open, times(0));
107-
}
108-
}
109-
110-
@Test
111-
void shouldCloseSocketChannelIfConnectFailsBeforeRegistration() throws IOException {
112-
//given
113-
ServerAddress serverAddress = new ServerAddress();
114-
IOException exception = new IOException("connect failed");
115-
InetAddressResolver inetAddressResolver = host -> Collections.singletonList(InetAddress.getLoopbackAddress());
116-
117-
try (SocketChannel socketChannel = Mockito.spy(SocketChannel.open());
118-
StreamFactoryFactory streamFactoryFactory = new TlsChannelStreamFactoryFactory(inetAddressResolver);
119-
MockedStatic<SocketChannel> socketChannelMockedStatic = Mockito.mockStatic(SocketChannel.class)) {
120-
socketChannelMockedStatic.when(SocketChannel::open).thenReturn(socketChannel);
121-
Mockito.doThrow(exception).when(socketChannel).connect(any());
122-
StreamFactory streamFactory = streamFactoryFactory.create(SocketSettings.builder()
123-
.connectTimeout(100, TimeUnit.MILLISECONDS)
124-
.build(), SSL_SETTINGS);
125-
Stream stream = streamFactory.create(serverAddress);
126-
@SuppressWarnings("unchecked")
127-
AsyncCompletionHandler<Void> handler = Mockito.mock(AsyncCompletionHandler.class);
128-
ArgumentCaptor<Throwable> failureCaptor = ArgumentCaptor.forClass(Throwable.class);
129-
130-
//when
131-
stream.openAsync(createOperationContext(100), handler);
132-
133-
//then
134-
verify(handler).failed(failureCaptor.capture());
135-
MongoSocketOpenException actualException = assertInstanceOf(MongoSocketOpenException.class, failureCaptor.getValue());
136-
assertSame(exception, actualException.getCause());
137-
verify(handler, times(0)).completed(null);
138-
verify(socketChannel).close();
139-
}
140-
}
141-
14271
@ParameterizedTest
14372
@ValueSource(ints = {500, 1000, 2000})
14473
void shouldInterruptConnectionEstablishmentWhenConnectionTimeoutExpires(final int connectTimeoutMs) throws IOException {

0 commit comments

Comments
 (0)