Skip to content

Commit 1c13cb8

Browse files
committed
Runtime loader of ManagedChannelFactory
1 parent 7ba41bb commit 1c13cb8

File tree

5 files changed

+124
-35
lines changed

5 files changed

+124
-35
lines changed

core/src/main/java/tech/ydb/core/grpc/GrpcTransportBuilder.java

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,13 @@
1515
import com.google.common.net.HostAndPort;
1616
import com.google.common.util.concurrent.MoreExecutors;
1717
import io.grpc.ManagedChannel;
18-
import io.grpc.netty.shaded.io.grpc.netty.NettyChannelBuilder;
1918

2019
import tech.ydb.auth.AuthRpcProvider;
2120
import tech.ydb.auth.NopAuthProvider;
2221
import tech.ydb.core.impl.YdbSchedulerFactory;
2322
import tech.ydb.core.impl.YdbTransportImpl;
2423
import tech.ydb.core.impl.auth.GrpcAuthRpc;
25-
import tech.ydb.core.impl.pool.DefaultChannelFactory;
24+
import tech.ydb.core.impl.pool.ChannelFactoryLoader;
2625
import tech.ydb.core.impl.pool.ManagedChannelFactory;
2726
import tech.ydb.core.utils.Version;
2827

@@ -69,7 +68,7 @@ public enum InitMode {
6968

7069
private byte[] cert = null;
7170
private boolean useTLS = false;
72-
private ManagedChannelFactory.Builder channelFactoryBuilder = DefaultChannelFactory::build;
71+
private ManagedChannelFactory.Builder channelFactoryBuilder = null;
7372
private Supplier<ScheduledExecutorService> schedulerFactory = YdbSchedulerFactory::createScheduler;
7473
private String localDc;
7574
private BalancingSettings balancingSettings;
@@ -177,6 +176,10 @@ public boolean useDefaultGrpcResolver() {
177176
}
178177

179178
public ManagedChannelFactory getManagedChannelFactory() {
179+
if (channelFactoryBuilder == null) {
180+
channelFactoryBuilder = ChannelFactoryLoader.load();
181+
}
182+
180183
return channelFactoryBuilder.buildFactory(this);
181184
}
182185

@@ -193,18 +196,20 @@ public GrpcTransportBuilder withChannelFactoryBuilder(ManagedChannelFactory.Buil
193196
}
194197

195198
/**
196-
* Set a custom initialization of {@link NettyChannelBuilder} <br>
199+
* Set a custom initialization of {@link io.grpc.netty.shaded.io.grpc.netty.NettyChannelBuilder} <br>
197200
* This method is deprecated. Use
198201
* {@link GrpcTransportBuilder#withChannelFactoryBuilder(tech.ydb.core.impl.pool.ManagedChannelFactory.Builder)}
199202
* instead
200203
*
201-
* @param channelInitializer custom NettyChannelBuilder initializator
204+
* @param ci custom NettyChannelBuilder initializator
202205
* @return this
203206
* @deprecated
204207
*/
205208
@Deprecated
206-
public GrpcTransportBuilder withChannelInitializer(Consumer<NettyChannelBuilder> channelInitializer) {
207-
this.channelFactoryBuilder = gtb -> DefaultChannelFactory.build(gtb, channelInitializer);
209+
public GrpcTransportBuilder withChannelInitializer(
210+
Consumer<io.grpc.netty.shaded.io.grpc.netty.NettyChannelBuilder> ci
211+
) {
212+
this.channelFactoryBuilder = tech.ydb.core.impl.pool.ShadedNettyChannelFactory.withInterceptor(ci);
208213
return this;
209214
}
210215

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
package tech.ydb.core.impl.pool;
2+
3+
4+
import org.slf4j.Logger;
5+
import org.slf4j.LoggerFactory;
6+
7+
8+
/**
9+
*
10+
* @author Aleksandr Gorshenin
11+
*/
12+
public class ChannelFactoryLoader {
13+
private static final Logger logger = LoggerFactory.getLogger(ChannelFactoryLoader.class);
14+
15+
private ChannelFactoryLoader() { }
16+
17+
public static ManagedChannelFactory.Builder load() {
18+
return FactoryLoader.factory;
19+
}
20+
21+
private static class FactoryLoader {
22+
private static final String SHADED_DEPS = "io.grpc.netty.shaded.io.grpc.netty.NettyChannelBuilder";
23+
private static final String NETTY_DEPS = "io.grpc.netty.NettyChannelBuilder";
24+
25+
private static ManagedChannelFactory.Builder factory;
26+
27+
static {
28+
boolean ok = tryLoad(SHADED_DEPS, ShadedNettyChannelFactory.build())
29+
|| tryLoad(NETTY_DEPS, NettyChannelFactory.build());
30+
if (!ok) {
31+
throw new IllegalStateException("Cannot load any ManagedChannelFactory!! "
32+
+ "Classpath must contain grpc-netty or grpc-netty-shaded");
33+
}
34+
}
35+
36+
private static boolean tryLoad(String name, ManagedChannelFactory.Builder f) {
37+
try {
38+
Class.forName(name);
39+
logger.info("class {} is found, use {}", name, f);
40+
factory = f;
41+
return true;
42+
} catch (ClassNotFoundException ex) {
43+
logger.info("class {} is not found", name);
44+
return false;
45+
}
46+
}
47+
}
48+
}

core/src/main/java/tech/ydb/core/impl/pool/NettyChannelFactory.java

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import java.io.ByteArrayInputStream;
44
import java.util.concurrent.TimeUnit;
5+
import java.util.function.Consumer;
56

67
import javax.net.ssl.SSLException;
78

@@ -39,7 +40,7 @@ public class NettyChannelFactory implements ManagedChannelFactory {
3940
private final boolean useDefaultGrpcResolver;
4041
private final Long grpcKeepAliveTimeMillis;
4142

42-
public NettyChannelFactory(GrpcTransportBuilder builder) {
43+
private NettyChannelFactory(GrpcTransportBuilder builder) {
4344
this.database = builder.getDatabase();
4445
this.version = builder.getVersionString();
4546
this.useTLS = builder.getUseTls();
@@ -120,4 +121,29 @@ private SslContext createSslContext() {
120121
throw new RuntimeException("cannot create ssl context", e);
121122
}
122123
}
124+
125+
public static ManagedChannelFactory.Builder build() {
126+
return new Builder() {
127+
@Override
128+
public ManagedChannelFactory buildFactory(GrpcTransportBuilder builder) {
129+
return new NettyChannelFactory(builder);
130+
}
131+
132+
@Override
133+
public String toString() {
134+
return "NettyChannelFactory";
135+
}
136+
};
137+
}
138+
139+
public static ManagedChannelFactory.Builder withInterceptor(Consumer<NettyChannelBuilder> ci) {
140+
return builder -> new NettyChannelFactory(builder) {
141+
@Override
142+
protected void configure(NettyChannelBuilder channelBuilder) {
143+
if (ci != null) {
144+
ci.accept(channelBuilder);
145+
}
146+
}
147+
};
148+
}
123149
}

core/src/main/java/tech/ydb/core/impl/pool/DefaultChannelFactory.java renamed to core/src/main/java/tech/ydb/core/impl/pool/ShadedNettyChannelFactory.java

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
* @author Nikolay Perfilov
2828
* @author Aleksandr Gorshenin
2929
*/
30-
public class DefaultChannelFactory implements ManagedChannelFactory {
30+
public class ShadedNettyChannelFactory implements ManagedChannelFactory {
3131
static final int INBOUND_MESSAGE_SIZE = 64 << 20; // 64 MiB
3232
static final String DEFAULT_BALANCER_POLICY = "round_robin";
3333

@@ -40,7 +40,7 @@ public class DefaultChannelFactory implements ManagedChannelFactory {
4040
private final boolean useDefaultGrpcResolver;
4141
private final Long grpcKeepAliveTimeMillis;
4242

43-
private DefaultChannelFactory(GrpcTransportBuilder builder) {
43+
public ShadedNettyChannelFactory(GrpcTransportBuilder builder) {
4444
this.database = builder.getDatabase();
4545
this.version = builder.getVersionString();
4646
this.useTLS = builder.getUseTls();
@@ -122,12 +122,22 @@ private SslContext createSslContext() {
122122
}
123123
}
124124

125-
public static ManagedChannelFactory build(GrpcTransportBuilder builder) {
126-
return new DefaultChannelFactory(builder);
125+
public static ManagedChannelFactory.Builder build() {
126+
return new Builder() {
127+
@Override
128+
public ManagedChannelFactory buildFactory(GrpcTransportBuilder builder) {
129+
return new ShadedNettyChannelFactory(builder);
130+
}
131+
132+
@Override
133+
public String toString() {
134+
return "ShadedNettyChannelFactory";
135+
}
136+
};
127137
}
128138

129-
public static ManagedChannelFactory build(GrpcTransportBuilder builder, Consumer<NettyChannelBuilder> ci) {
130-
return new DefaultChannelFactory(builder) {
139+
public static ManagedChannelFactory.Builder withInterceptor(Consumer<NettyChannelBuilder> ci) {
140+
return builder -> new ShadedNettyChannelFactory(builder) {
131141
@Override
132142
protected void configure(NettyChannelBuilder channelBuilder) {
133143
if (ci != null) {

core/src/test/java/tech/ydb/core/impl/pool/DefaultChannelFactoryTest.java

Lines changed: 21 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -20,19 +20,18 @@
2020
import org.junit.Assert;
2121
import org.junit.Before;
2222
import org.junit.Test;
23+
import static org.mockito.ArgumentMatchers.any;
24+
import static org.mockito.ArgumentMatchers.anyInt;
2325
import org.mockito.MockedStatic;
2426
import org.mockito.Mockito;
27+
import static org.mockito.Mockito.times;
28+
import static org.mockito.Mockito.verify;
29+
import static org.mockito.Mockito.when;
2530
import org.mockito.MockitoAnnotations;
2631

2732
import tech.ydb.core.grpc.GrpcTransport;
2833
import tech.ydb.core.grpc.GrpcTransportBuilder;
2934

30-
import static org.mockito.ArgumentMatchers.any;
31-
import static org.mockito.ArgumentMatchers.anyInt;
32-
import static org.mockito.Mockito.times;
33-
import static org.mockito.Mockito.verify;
34-
import static org.mockito.Mockito.when;
35-
3635
/**
3736
*
3837
* @author Aleksandr Gorshenin
@@ -73,7 +72,7 @@ public void tearDown() throws Exception {
7372
@Test
7473
public void defaultParams() {
7574
GrpcTransportBuilder builder = GrpcTransport.forHost(MOCKED_HOST, MOCKED_PORT, "/Root");
76-
ManagedChannelFactory factory = DefaultChannelFactory.build(builder);
75+
ManagedChannelFactory factory = ChannelFactoryLoader.load().buildFactory(builder);
7776
channelStaticMock.verify(FOR_ADDRESS, times(0));
7877

7978
Assert.assertEquals(30_000l, factory.getConnectTimeoutMs());
@@ -83,8 +82,8 @@ public void defaultParams() {
8382

8483
verify(channelBuilderMock, times(0)).negotiationType(NegotiationType.TLS);
8584
verify(channelBuilderMock, times(1)).negotiationType(NegotiationType.PLAINTEXT);
86-
verify(channelBuilderMock, times(1)).maxInboundMessageSize(DefaultChannelFactory.INBOUND_MESSAGE_SIZE);
87-
verify(channelBuilderMock, times(1)).defaultLoadBalancingPolicy(DefaultChannelFactory.DEFAULT_BALANCER_POLICY);
85+
verify(channelBuilderMock, times(1)).maxInboundMessageSize(ShadedNettyChannelFactory.INBOUND_MESSAGE_SIZE);
86+
verify(channelBuilderMock, times(1)).defaultLoadBalancingPolicy(ShadedNettyChannelFactory.DEFAULT_BALANCER_POLICY);
8887
verify(channelBuilderMock, times(1)).withOption(ChannelOption.ALLOCATOR, ByteBufAllocator.DEFAULT);
8988
verify(channelBuilderMock, times(0)).enableRetry();
9089
verify(channelBuilderMock, times(1)).disableRetry();
@@ -97,7 +96,7 @@ public void defaultSslFactory() {
9796
.withGrpcRetry(true)
9897
.withConnectTimeout(Duration.ofMinutes(1));
9998

100-
ManagedChannelFactory factory = DefaultChannelFactory.build(builder);
99+
ManagedChannelFactory factory = ChannelFactoryLoader.load().buildFactory(builder);
101100
channelStaticMock.verify(FOR_ADDRESS, times(0));
102101

103102
Assert.assertEquals(60000l, factory.getConnectTimeoutMs());
@@ -107,8 +106,8 @@ public void defaultSslFactory() {
107106

108107
verify(channelBuilderMock, times(1)).negotiationType(NegotiationType.TLS);
109108
verify(channelBuilderMock, times(0)).negotiationType(NegotiationType.PLAINTEXT);
110-
verify(channelBuilderMock, times(1)).maxInboundMessageSize(DefaultChannelFactory.INBOUND_MESSAGE_SIZE);
111-
verify(channelBuilderMock, times(1)).defaultLoadBalancingPolicy(DefaultChannelFactory.DEFAULT_BALANCER_POLICY);
109+
verify(channelBuilderMock, times(1)).maxInboundMessageSize(ShadedNettyChannelFactory.INBOUND_MESSAGE_SIZE);
110+
verify(channelBuilderMock, times(1)).defaultLoadBalancingPolicy(ShadedNettyChannelFactory.DEFAULT_BALANCER_POLICY);
112111
verify(channelBuilderMock, times(1)).withOption(ChannelOption.ALLOCATOR, ByteBufAllocator.DEFAULT);
113112
verify(channelBuilderMock, times(1)).enableRetry();
114113
verify(channelBuilderMock, times(0)).disableRetry();
@@ -119,18 +118,19 @@ public void customChannelInitializer() {
119118
GrpcTransportBuilder builder = GrpcTransport.forHost(MOCKED_HOST, MOCKED_PORT, "/Root")
120119
.withUseDefaultGrpcResolver(true);
121120

122-
ManagedChannelFactory factory = DefaultChannelFactory.build(
123-
builder, cb -> cb.withOption(ChannelOption.TCP_NODELAY, Boolean.TRUE)
124-
);
121+
ManagedChannelFactory factory = ShadedNettyChannelFactory
122+
.withInterceptor(cb -> cb.withOption(ChannelOption.TCP_NODELAY, Boolean.TRUE))
123+
.buildFactory(builder);
124+
125125
channelStaticMock.verify(FOR_ADDRESS, times(0));
126126

127127
Assert.assertSame(channelMock, factory.newManagedChannel(MOCKED_HOST, MOCKED_PORT));
128128

129129
channelStaticMock.verify(FOR_ADDRESS, times(1));
130130

131131
verify(channelBuilderMock, times(1)).negotiationType(NegotiationType.PLAINTEXT);
132-
verify(channelBuilderMock, times(1)).maxInboundMessageSize(DefaultChannelFactory.INBOUND_MESSAGE_SIZE);
133-
verify(channelBuilderMock, times(0)).defaultLoadBalancingPolicy(DefaultChannelFactory.DEFAULT_BALANCER_POLICY);
132+
verify(channelBuilderMock, times(1)).maxInboundMessageSize(ShadedNettyChannelFactory.INBOUND_MESSAGE_SIZE);
133+
verify(channelBuilderMock, times(0)).defaultLoadBalancingPolicy(ShadedNettyChannelFactory.DEFAULT_BALANCER_POLICY);
134134
verify(channelBuilderMock, times(1)).withOption(ChannelOption.ALLOCATOR, ByteBufAllocator.DEFAULT);
135135
verify(channelBuilderMock, times(1)).withOption(ChannelOption.TCP_NODELAY, Boolean.TRUE);
136136
}
@@ -147,7 +147,7 @@ public void customSslFactory() throws CertificateException, IOException {
147147
.withGrpcRetry(false)
148148
.withConnectTimeout(4, TimeUnit.SECONDS);
149149

150-
ManagedChannelFactory factory = DefaultChannelFactory.build(builder);
150+
ManagedChannelFactory factory = ChannelFactoryLoader.load().buildFactory(builder);
151151

152152
Assert.assertEquals(4000l, factory.getConnectTimeoutMs());
153153
Assert.assertSame(channelMock, factory.newManagedChannel(MOCKED_HOST, MOCKED_PORT));
@@ -160,8 +160,8 @@ public void customSslFactory() throws CertificateException, IOException {
160160

161161
verify(channelBuilderMock, times(1)).negotiationType(NegotiationType.TLS);
162162
verify(channelBuilderMock, times(0)).negotiationType(NegotiationType.PLAINTEXT);
163-
verify(channelBuilderMock, times(1)).maxInboundMessageSize(DefaultChannelFactory.INBOUND_MESSAGE_SIZE);
164-
verify(channelBuilderMock, times(1)).defaultLoadBalancingPolicy(DefaultChannelFactory.DEFAULT_BALANCER_POLICY);
163+
verify(channelBuilderMock, times(1)).maxInboundMessageSize(ShadedNettyChannelFactory.INBOUND_MESSAGE_SIZE);
164+
verify(channelBuilderMock, times(1)).defaultLoadBalancingPolicy(ShadedNettyChannelFactory.DEFAULT_BALANCER_POLICY);
165165
verify(channelBuilderMock, times(1)).withOption(ChannelOption.ALLOCATOR, ByteBufAllocator.DEFAULT);
166166
verify(channelBuilderMock, times(0)).enableRetry();
167167
verify(channelBuilderMock, times(1)).disableRetry();
@@ -173,7 +173,7 @@ public void invalidSslCert() {
173173
GrpcTransportBuilder builder = GrpcTransport.forHost(MOCKED_HOST, MOCKED_PORT, "/Root")
174174
.withSecureConnection(cert);
175175

176-
ManagedChannelFactory factory = DefaultChannelFactory.build(builder);
176+
ManagedChannelFactory factory = ChannelFactoryLoader.load().buildFactory(builder);
177177

178178
RuntimeException ex = Assert.assertThrows(RuntimeException.class,
179179
() -> factory.newManagedChannel(MOCKED_HOST, MOCKED_PORT));

0 commit comments

Comments
 (0)