diff --git a/clients/src/main/java/org/apache/kafka/common/security/authenticator/SaslClientAuthenticator.java b/clients/src/main/java/org/apache/kafka/common/security/authenticator/SaslClientAuthenticator.java index addacd92722c8..25653636b403d 100644 --- a/clients/src/main/java/org/apache/kafka/common/security/authenticator/SaslClientAuthenticator.java +++ b/clients/src/main/java/org/apache/kafka/common/security/authenticator/SaslClientAuthenticator.java @@ -690,7 +690,7 @@ public void setAuthenticationEndAndSessionReauthenticationTimes(long nowNanos) { double pctToUse = pctWindowFactorToTakeNetworkLatencyAndClockDriftIntoAccount + RNG.nextDouble() * pctWindowJitterToAvoidReauthenticationStormAcrossManyChannelsSimultaneously; sessionLifetimeMsToUse = (long) (positiveSessionLifetimeMs * pctToUse); - clientSessionReauthenticationTimeNanos = authenticationEndNanos + 1000 * 1000 * sessionLifetimeMsToUse; + clientSessionReauthenticationTimeNanos = Math.addExact(authenticationEndNanos, Utils.msToNs(sessionLifetimeMsToUse)); log.debug( "Finished {} with session expiration in {} ms and session re-authentication on or after {} ms", authenticationOrReauthenticationText(), positiveSessionLifetimeMs, sessionLifetimeMsToUse); diff --git a/clients/src/main/java/org/apache/kafka/common/security/authenticator/SaslServerAuthenticator.java b/clients/src/main/java/org/apache/kafka/common/security/authenticator/SaslServerAuthenticator.java index e2ebaa31cd260..8f1e16b0b116d 100644 --- a/clients/src/main/java/org/apache/kafka/common/security/authenticator/SaslServerAuthenticator.java +++ b/clients/src/main/java/org/apache/kafka/common/security/authenticator/SaslServerAuthenticator.java @@ -681,7 +681,7 @@ else if (!maxReauthSet) else retvalSessionLifetimeMs = zeroIfNegative(Math.min(credentialExpirationMs - authenticationEndMs, connectionsMaxReauthMs)); - sessionExpirationTimeNanos = authenticationEndNanos + 1000 * 1000 * retvalSessionLifetimeMs; + sessionExpirationTimeNanos = Math.addExact(authenticationEndNanos, Utils.msToNs(retvalSessionLifetimeMs)); } if (credentialExpirationMs != null) { diff --git a/clients/src/main/java/org/apache/kafka/common/utils/Utils.java b/clients/src/main/java/org/apache/kafka/common/utils/Utils.java index 02a62ee4524b8..76ebe2dac4e17 100644 --- a/clients/src/main/java/org/apache/kafka/common/utils/Utils.java +++ b/clients/src/main/java/org/apache/kafka/common/utils/Utils.java @@ -1697,4 +1697,17 @@ public static ConfigDef mergeConfigs(List configDefs) { public interface ThrowingRunnable { void run() throws Exception; } + + /** + * convert millisecond to nanosecond, or throw exception if overflow + * @param timeMs the time in millisecond + * @return the converted nanosecond + */ + public static long msToNs(long timeMs) { + try { + return Math.multiplyExact(1000 * 1000, timeMs); + } catch (ArithmeticException e) { + throw new IllegalArgumentException("Cannot convert " + timeMs + " millisecond to nanosecond due to arithmetic overflow", e); + } + } } diff --git a/clients/src/test/java/org/apache/kafka/common/security/authenticator/SaslAuthenticatorTest.java b/clients/src/test/java/org/apache/kafka/common/security/authenticator/SaslAuthenticatorTest.java index 3b1e54dee2c62..2025cc7cc0cc8 100644 --- a/clients/src/test/java/org/apache/kafka/common/security/authenticator/SaslAuthenticatorTest.java +++ b/clients/src/test/java/org/apache/kafka/common/security/authenticator/SaslAuthenticatorTest.java @@ -155,6 +155,7 @@ public class SaslAuthenticatorTest { private static final long CONNECTIONS_MAX_REAUTH_MS_VALUE = 100L; private static final int BUFFER_SIZE = 4 * 1024; private static Time time = Time.SYSTEM; + private static boolean needLargeExpiration = false; private NioEchoServer server; private Selector selector; @@ -178,6 +179,7 @@ public void setup() throws Exception { @AfterEach public void teardown() throws Exception { + needLargeExpiration = false; if (server != null) this.server.close(); if (selector != null) @@ -1607,6 +1609,42 @@ public void testCannotReauthenticateWithDifferentPrincipal() throws Exception { server.verifyReauthenticationMetrics(0, 1); } + @Test + public void testReauthenticateWithLargeReauthValue() throws Exception { + // enable it, we'll get a large expiration timestamp token + needLargeExpiration = true; + String node = "0"; + SecurityProtocol securityProtocol = SecurityProtocol.SASL_SSL; + + configureMechanisms(OAuthBearerLoginModule.OAUTHBEARER_MECHANISM, + List.of(OAuthBearerLoginModule.OAUTHBEARER_MECHANISM)); + // set a large re-auth timeout in server side + saslServerConfigs.put(BrokerSecurityConfigs.CONNECTIONS_MAX_REAUTH_MS_CONFIG, Long.MAX_VALUE); + server = createEchoServer(securityProtocol); + + // set to default value for sasl login configs for initialization in ExpiringCredentialRefreshConfig + saslClientConfigs.put(SaslConfigs.SASL_LOGIN_REFRESH_WINDOW_FACTOR, SaslConfigs.DEFAULT_LOGIN_REFRESH_WINDOW_FACTOR); + saslClientConfigs.put(SaslConfigs.SASL_LOGIN_REFRESH_WINDOW_JITTER, SaslConfigs.DEFAULT_LOGIN_REFRESH_WINDOW_JITTER); + saslClientConfigs.put(SaslConfigs.SASL_LOGIN_REFRESH_MIN_PERIOD_SECONDS, SaslConfigs.DEFAULT_LOGIN_REFRESH_MIN_PERIOD_SECONDS); + saslClientConfigs.put(SaslConfigs.SASL_LOGIN_REFRESH_BUFFER_SECONDS, SaslConfigs.DEFAULT_LOGIN_REFRESH_BUFFER_SECONDS); + saslClientConfigs.put(SaslConfigs.SASL_LOGIN_CALLBACK_HANDLER_CLASS, AlternateLoginCallbackHandler.class); + + createCustomClientConnection(securityProtocol, OAuthBearerLoginModule.OAUTHBEARER_MECHANISM, node, true); + + // channel should be not null before sasl handshake + assertNotNull(selector.channel(node)); + + TestUtils.waitForCondition(() -> { + selector.poll(1000); + // this channel should be closed due to session timeout calculation overflow + return selector.channel(node) == null; + }, "channel didn't close with large re-authentication value"); + + // ensure metrics are as expected + server.verifyAuthenticationMetrics(0, 0); + server.verifyReauthenticationMetrics(0, 0); + } + @Test public void testCorrelationId() { SaslClientAuthenticator authenticator = new SaslClientAuthenticator( @@ -1936,7 +1974,7 @@ private void createClientConnection(SecurityProtocol securityProtocol, String sa if (enableSaslAuthenticateHeader) createClientConnection(securityProtocol, node); else - createClientConnectionWithoutSaslAuthenticateHeader(securityProtocol, saslMechanism, node); + createCustomClientConnection(securityProtocol, saslMechanism, node, false); } private NioEchoServer startServerApiVersionsUnsupportedByClient(final SecurityProtocol securityProtocol, String saslMechanism) throws Exception { @@ -2024,15 +2062,13 @@ protected void enableKafkaSaslAuthenticateHeaders(boolean flag) { return server; } - private void createClientConnectionWithoutSaslAuthenticateHeader(final SecurityProtocol securityProtocol, - final String saslMechanism, String node) throws Exception { - - final ListenerName listenerName = ListenerName.forSecurityProtocol(securityProtocol); - final Map configs = Collections.emptyMap(); - final JaasContext jaasContext = JaasContext.loadClientContext(configs); - final Map jaasContexts = Collections.singletonMap(saslMechanism, jaasContext); - - SaslChannelBuilder clientChannelBuilder = new SaslChannelBuilder(ConnectionMode.CLIENT, jaasContexts, + private SaslChannelBuilder saslChannelBuilderWithoutHeader( + final SecurityProtocol securityProtocol, + final String saslMechanism, + final Map jaasContexts, + final ListenerName listenerName + ) { + return new SaslChannelBuilder(ConnectionMode.CLIENT, jaasContexts, securityProtocol, listenerName, false, saslMechanism, null, null, null, time, new LogContext(), null) { @@ -2059,6 +2095,42 @@ protected void setSaslAuthenticateAndHandshakeVersions(ApiVersionsResponse apiVe }; } }; + } + + private void createCustomClientConnection( + final SecurityProtocol securityProtocol, + final String saslMechanism, + String node, + boolean withSaslAuthenticateHeader + ) throws Exception { + + final ListenerName listenerName = ListenerName.forSecurityProtocol(securityProtocol); + final Map configs = Collections.emptyMap(); + final JaasContext jaasContext = JaasContext.loadClientContext(configs); + final Map jaasContexts = Collections.singletonMap(saslMechanism, jaasContext); + + SaslChannelBuilder clientChannelBuilder; + if (!withSaslAuthenticateHeader) { + clientChannelBuilder = saslChannelBuilderWithoutHeader(securityProtocol, saslMechanism, jaasContexts, listenerName); + } else { + clientChannelBuilder = new SaslChannelBuilder(ConnectionMode.CLIENT, jaasContexts, + securityProtocol, listenerName, false, saslMechanism, + null, null, null, time, new LogContext(), null) { + + @Override + protected SaslClientAuthenticator buildClientAuthenticator(Map configs, + AuthenticateCallbackHandler callbackHandler, + String id, + String serverHost, + String servicePrincipal, + TransportLayer transportLayer, + Subject subject) { + + return new SaslClientAuthenticator(configs, callbackHandler, id, subject, + servicePrincipal, serverHost, saslMechanism, transportLayer, time, new LogContext()); + } + }; + } clientChannelBuilder.configure(saslClientConfigs); this.selector = NetworkTestUtils.createSelector(clientChannelBuilder, time); InetSocketAddress addr = new InetSocketAddress("localhost", server.port()); @@ -2507,10 +2579,11 @@ public void handle(Callback[] callbacks) throws IOException, UnsupportedCallback + ++numInvocations; String headerJson = "{" + claimOrHeaderJsonText("alg", "none") + "}"; /* - * Use a short lifetime so the background refresh thread replaces it before we + * If we're testing large expiration scenario, use a large lifetime. + * Otherwise, use a short lifetime so the background refresh thread replaces it before we * re-authenticate */ - String lifetimeSecondsValueToUse = "1"; + String lifetimeSecondsValueToUse = needLargeExpiration ? String.valueOf(Long.MAX_VALUE) : "1"; String claimsJson; try { claimsJson = String.format("{%s,%s,%s}", diff --git a/clients/src/test/java/org/apache/kafka/common/security/authenticator/SaslServerAuthenticatorTest.java b/clients/src/test/java/org/apache/kafka/common/security/authenticator/SaslServerAuthenticatorTest.java index 81df34f85f4b9..5213f03187e82 100644 --- a/clients/src/test/java/org/apache/kafka/common/security/authenticator/SaslServerAuthenticatorTest.java +++ b/clients/src/test/java/org/apache/kafka/common/security/authenticator/SaslServerAuthenticatorTest.java @@ -269,6 +269,35 @@ public void testSessionExpiresAtTokenExpiry() throws IOException { } } + @Test + public void testSessionWontExpireWithLargeExpirationTime() throws IOException { + String mechanism = OAuthBearerLoginModule.OAUTHBEARER_MECHANISM; + SaslServer saslServer = mock(SaslServer.class); + MockTime time = new MockTime(0, 1, 1000); + // set a Long.MAX_VALUE as the expiration time + Duration largeExpirationTime = Duration.ofMillis(Long.MAX_VALUE); + + try ( + MockedStatic ignored = mockSaslServer(saslServer, mechanism, time, largeExpirationTime); + MockedStatic ignored2 = mockKafkaPrincipal("[principal-type]", "[principal-name"); + TransportLayer transportLayer = mockTransportLayer() + ) { + + SaslServerAuthenticator authenticator = getSaslServerAuthenticatorForOAuth(mechanism, transportLayer, time, largeExpirationTime.toMillis()); + + mockRequest(saslHandshakeRequest(mechanism), transportLayer); + authenticator.authenticate(); + + when(saslServer.isComplete()).thenReturn(false).thenReturn(true); + mockRequest(saslAuthenticateRequest(), transportLayer); + + Throwable t = assertThrows(IllegalArgumentException.class, () -> authenticator.authenticate()); + assertEquals(ArithmeticException.class, t.getCause().getClass()); + assertEquals("Cannot convert " + Long.MAX_VALUE + " millisecond to nanosecond due to arithmetic overflow", + t.getMessage()); + } + } + private SaslServerAuthenticator getSaslServerAuthenticatorForOAuth(String mechanism, TransportLayer transportLayer, Time time, Long maxReauth) { Map configs = Collections.singletonMap(BrokerSecurityConfigs.SASL_ENABLED_MECHANISMS_CONFIG, Collections.singletonList(mechanism)); diff --git a/clients/src/test/java/org/apache/kafka/common/utils/UtilsTest.java b/clients/src/test/java/org/apache/kafka/common/utils/UtilsTest.java index 16fc6af154b20..78f3f1486c2f7 100755 --- a/clients/src/test/java/org/apache/kafka/common/utils/UtilsTest.java +++ b/clients/src/test/java/org/apache/kafka/common/utils/UtilsTest.java @@ -1109,6 +1109,13 @@ public void testTryAll() throws Throwable { assertEquals(expected, recorded); } + @Test + public void testMsToNs() { + assertEquals(1000000, Utils.msToNs(1)); + assertEquals(0, Utils.msToNs(0)); + assertThrows(IllegalArgumentException.class, () -> Utils.msToNs(Long.MAX_VALUE)); + } + private Callable recordingCallable(Map recordingMap, String success, TestException failure) { return () -> { if (success == null)