Skip to content

Commit 1893c78

Browse files
committed
Apply client-side encryption in transactions on sharded clusters
This fixes a bug in both sync and async drivers where client-side encryption is not applied when in a transaction. JAVA-3752
1 parent 086de66 commit 1893c78

File tree

10 files changed

+437
-81
lines changed

10 files changed

+437
-81
lines changed

driver-core/src/main/com/mongodb/internal/async/client/AsyncCryptBinding.java

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,9 @@
1717
package com.mongodb.internal.async.client;
1818

1919
import com.mongodb.ReadPreference;
20-
import com.mongodb.internal.async.SingleResultCallback;
20+
import com.mongodb.ServerAddress;
2121
import com.mongodb.connection.ServerDescription;
22+
import com.mongodb.internal.async.SingleResultCallback;
2223
import com.mongodb.internal.binding.AsyncClusterAwareReadWriteBinding;
2324
import com.mongodb.internal.binding.AsyncConnectionSource;
2425
import com.mongodb.internal.binding.AsyncReadWriteBinding;
@@ -74,6 +75,20 @@ public void onResult(final AsyncConnectionSource result, final Throwable t) {
7475
});
7576
}
7677

78+
@Override
79+
public void getConnectionSource(final ServerAddress serverAddress, final SingleResultCallback<AsyncConnectionSource> callback) {
80+
wrapped.getConnectionSource(serverAddress, new SingleResultCallback<AsyncConnectionSource>() {
81+
@Override
82+
public void onResult(final AsyncConnectionSource result, final Throwable t) {
83+
if (t != null) {
84+
callback.onResult(null, t);
85+
} else {
86+
callback.onResult(new AsyncCryptConnectionSource(result), null);
87+
}
88+
}
89+
});
90+
}
91+
7792
@Override
7893
public int getCount() {
7994
return wrapped.getCount();

driver-core/src/main/com/mongodb/internal/async/client/ClientSessionBinding.java

Lines changed: 41 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424
import com.mongodb.internal.binding.AsyncClusterAwareReadWriteBinding;
2525
import com.mongodb.internal.binding.AsyncConnectionSource;
2626
import com.mongodb.internal.binding.AsyncReadWriteBinding;
27-
import com.mongodb.internal.binding.AsyncSingleServerBinding;
2827
import com.mongodb.internal.connection.AsyncConnection;
2928
import com.mongodb.internal.connection.Server;
3029
import com.mongodb.internal.selector.ReadPreferenceServerSelector;
@@ -53,77 +52,45 @@ public ReadPreference getReadPreference() {
5352

5453
@Override
5554
public void getReadConnectionSource(final SingleResultCallback<AsyncConnectionSource> callback) {
56-
wrapped.getReadConnectionSource(new SingleResultCallback<AsyncConnectionSource>() {
57-
@Override
58-
public void onResult(final AsyncConnectionSource result, final Throwable t) {
59-
if (t != null) {
60-
callback.onResult(null, t);
61-
} else {
62-
wrapConnectionSource(result, callback);
63-
}
64-
}
65-
});
55+
if (isActiveShardedTxn()) {
56+
getPinnedConnectionSource(callback);
57+
} else {
58+
wrapped.getReadConnectionSource(new WrappingCallback(callback));
59+
}
6660
}
6761

6862
public void getWriteConnectionSource(final SingleResultCallback<AsyncConnectionSource> callback) {
69-
wrapped.getWriteConnectionSource(new SingleResultCallback<AsyncConnectionSource>() {
70-
@Override
71-
public void onResult(final AsyncConnectionSource result, final Throwable t) {
72-
if (t != null) {
73-
callback.onResult(null, t);
74-
} else {
75-
wrapConnectionSource(result, callback);
76-
}
77-
}
78-
});
63+
if (isActiveShardedTxn()) {
64+
getPinnedConnectionSource(callback);
65+
} else {
66+
wrapped.getWriteConnectionSource(new WrappingCallback(callback));
67+
}
7968
}
8069

8170
@Override
8271
public SessionContext getSessionContext() {
8372
return sessionContext;
8473
}
8574

86-
private void wrapConnectionSource(final AsyncConnectionSource connectionSource,
87-
final SingleResultCallback<AsyncConnectionSource> callback) {
88-
if (isActiveShardedTxn()) {
89-
if (session.getPinnedServerAddress() == null) {
90-
wrapped.getCluster().selectServerAsync(
91-
new ReadPreferenceServerSelector(wrapped.getReadPreference()),
92-
new SingleResultCallback<Server>() {
93-
@Override
94-
public void onResult(final Server server, final Throwable t) {
95-
if (t != null) {
96-
callback.onResult(null, t);
97-
} else {
98-
session.setPinnedServerAddress(server.getDescription().getAddress());
99-
setSingleServerBindingConnectionSource(callback);
100-
}
75+
private void getPinnedConnectionSource(final SingleResultCallback<AsyncConnectionSource> callback) {
76+
if (session.getPinnedServerAddress() == null) {
77+
wrapped.getCluster().selectServerAsync(
78+
new ReadPreferenceServerSelector(wrapped.getReadPreference()), new SingleResultCallback<Server>() {
79+
@Override
80+
public void onResult(final Server server, final Throwable t) {
81+
if (t != null) {
82+
callback.onResult(null, t);
83+
} else {
84+
session.setPinnedServerAddress(server.getDescription().getAddress());
85+
wrapped.getConnectionSource(session.getPinnedServerAddress(), new WrappingCallback(callback));
10186
}
102-
});
103-
} else {
104-
setSingleServerBindingConnectionSource(callback);
105-
}
87+
}
88+
});
10689
} else {
107-
callback.onResult(new SessionBindingAsyncConnectionSource(connectionSource), null);
90+
wrapped.getConnectionSource(session.getPinnedServerAddress(), new WrappingCallback(callback));
10891
}
10992
}
11093

111-
private void setSingleServerBindingConnectionSource(final SingleResultCallback<AsyncConnectionSource> callback) {
112-
final AsyncSingleServerBinding binding =
113-
new AsyncSingleServerBinding(wrapped.getCluster(), session.getPinnedServerAddress(), wrapped.getReadPreference());
114-
binding.getWriteConnectionSource(new SingleResultCallback<AsyncConnectionSource>() {
115-
@Override
116-
public void onResult(final AsyncConnectionSource result, final Throwable t) {
117-
binding.release();
118-
if (t != null) {
119-
callback.onResult(null, t);
120-
} else {
121-
callback.onResult(new SessionBindingAsyncConnectionSource(result), null);
122-
}
123-
}
124-
});
125-
}
126-
12794
@Override
12895
public int getCount() {
12996
return wrapped.getCount();
@@ -225,4 +192,21 @@ public ReadConcern getReadConcern() {
225192
}
226193
}
227194
}
195+
196+
private class WrappingCallback implements SingleResultCallback<AsyncConnectionSource> {
197+
private final SingleResultCallback<AsyncConnectionSource> callback;
198+
199+
WrappingCallback(final SingleResultCallback<AsyncConnectionSource> callback) {
200+
this.callback = callback;
201+
}
202+
203+
@Override
204+
public void onResult(final AsyncConnectionSource result, final Throwable t) {
205+
if (t != null) {
206+
callback.onResult(null, t);
207+
} else {
208+
callback.onResult(new SessionBindingAsyncConnectionSource(result), null);
209+
}
210+
}
211+
}
228212
}

driver-core/src/main/com/mongodb/internal/binding/AsyncClusterAwareReadWriteBinding.java

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,21 @@
1616

1717
package com.mongodb.internal.binding;
1818

19+
import com.mongodb.ServerAddress;
20+
import com.mongodb.internal.async.SingleResultCallback;
1921
import com.mongodb.internal.connection.Cluster;
2022

2123
/**
2224
* This interface is not part of the public API and may be removed or changed at any time.
2325
*/
2426
public interface AsyncClusterAwareReadWriteBinding extends AsyncReadWriteBinding {
2527
Cluster getCluster();
28+
29+
/**
30+
* Returns a connection source to the specified server
31+
*
32+
* @param serverAddress the server address
33+
* @param callback the to be passed the connection source
34+
*/
35+
void getConnectionSource(ServerAddress serverAddress, SingleResultCallback<AsyncConnectionSource> callback);
2636
}

driver-core/src/main/com/mongodb/internal/binding/AsyncClusterBinding.java

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,16 +18,18 @@
1818

1919
import com.mongodb.ReadConcern;
2020
import com.mongodb.ReadPreference;
21-
import com.mongodb.internal.async.SingleResultCallback;
21+
import com.mongodb.ServerAddress;
2222
import com.mongodb.connection.ServerDescription;
23+
import com.mongodb.internal.async.SingleResultCallback;
2324
import com.mongodb.internal.connection.AsyncConnection;
2425
import com.mongodb.internal.connection.Cluster;
2526
import com.mongodb.internal.connection.ReadConcernAwareNoOpSessionContext;
2627
import com.mongodb.internal.connection.Server;
2728
import com.mongodb.internal.selector.ReadPreferenceServerSelector;
29+
import com.mongodb.internal.selector.ServerAddressSelector;
2830
import com.mongodb.internal.selector.WritableServerSelector;
29-
import com.mongodb.selector.ServerSelector;
3031
import com.mongodb.internal.session.SessionContext;
32+
import com.mongodb.selector.ServerSelector;
3133

3234
import static com.mongodb.assertions.Assertions.notNull;
3335

@@ -87,6 +89,11 @@ public void getWriteConnectionSource(final SingleResultCallback<AsyncConnectionS
8789
getAsyncClusterBindingConnectionSource(new WritableServerSelector(), callback);
8890
}
8991

92+
@Override
93+
public void getConnectionSource(final ServerAddress serverAddress, final SingleResultCallback<AsyncConnectionSource> callback) {
94+
getAsyncClusterBindingConnectionSource(new ServerAddressSelector(serverAddress), callback);
95+
}
96+
9097
private void getAsyncClusterBindingConnectionSource(final ServerSelector serverSelector,
9198
final SingleResultCallback<AsyncConnectionSource> callback) {
9299
cluster.selectServerAsync(serverSelector, new SingleResultCallback<Server>() {

driver-core/src/main/com/mongodb/internal/binding/ClusterAwareReadWriteBinding.java

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,18 @@
1616

1717
package com.mongodb.internal.binding;
1818

19+
import com.mongodb.ServerAddress;
1920
import com.mongodb.internal.connection.Cluster;
2021

2122
/**
2223
* This interface is not part of the public API and may be removed or changed at any time.
2324
*/
2425
public interface ClusterAwareReadWriteBinding extends ReadWriteBinding {
2526
Cluster getCluster();
27+
28+
/**
29+
* Returns a connection source to the specified server address.
30+
* @return the connection source
31+
*/
32+
ConnectionSource getConnectionSource(ServerAddress serverAddress);
2633
}

driver-core/src/main/com/mongodb/internal/binding/ClusterBinding.java

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,15 +18,17 @@
1818

1919
import com.mongodb.ReadConcern;
2020
import com.mongodb.ReadPreference;
21+
import com.mongodb.ServerAddress;
2122
import com.mongodb.connection.ServerDescription;
2223
import com.mongodb.internal.connection.Cluster;
2324
import com.mongodb.internal.connection.Connection;
2425
import com.mongodb.internal.connection.ReadConcernAwareNoOpSessionContext;
2526
import com.mongodb.internal.connection.Server;
2627
import com.mongodb.internal.selector.ReadPreferenceServerSelector;
28+
import com.mongodb.internal.selector.ServerAddressSelector;
2729
import com.mongodb.internal.selector.WritableServerSelector;
28-
import com.mongodb.selector.ServerSelector;
2930
import com.mongodb.internal.session.SessionContext;
31+
import com.mongodb.selector.ServerSelector;
3032

3133
import static com.mongodb.assertions.Assertions.notNull;
3234

@@ -75,20 +77,25 @@ public ReadPreference getReadPreference() {
7577
}
7678

7779
@Override
78-
public ConnectionSource getReadConnectionSource() {
79-
return new ClusterBindingConnectionSource(new ReadPreferenceServerSelector(readPreference));
80+
public SessionContext getSessionContext() {
81+
return new ReadConcernAwareNoOpSessionContext(readConcern);
8082
}
8183

8284
@Override
83-
public SessionContext getSessionContext() {
84-
return new ReadConcernAwareNoOpSessionContext(readConcern);
85+
public ConnectionSource getReadConnectionSource() {
86+
return new ClusterBindingConnectionSource(new ReadPreferenceServerSelector(readPreference));
8587
}
8688

8789
@Override
8890
public ConnectionSource getWriteConnectionSource() {
8991
return new ClusterBindingConnectionSource(new WritableServerSelector());
9092
}
9193

94+
@Override
95+
public ConnectionSource getConnectionSource(final ServerAddress serverAddress) {
96+
return new ClusterBindingConnectionSource(new ServerAddressSelector(serverAddress));
97+
}
98+
9299
private final class ClusterBindingConnectionSource extends AbstractReferenceCounted implements ConnectionSource {
93100
private final Server server;
94101

0 commit comments

Comments
 (0)