Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

KAFAK-14604: SASL session expiration time will be overflowed when calculation #18526

Open
wants to merge 1 commit into
base: trunk
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -690,7 +690,7 @@ public void setAuthenticationEndAndSessionReauthenticationTimes(long nowNanos) {
double pctToUse = pctWindowFactorToTakeNetworkLatencyAndClockDriftIntoAccount + RNG.nextDouble()
* pctWindowJitterToAvoidReauthenticationStormAcrossManyChannelsSimultaneously;
sessionLifetimeMsToUse = (long) (positiveSessionLifetimeMs * pctToUse);
clientSessionReauthenticationTimeNanos = authenticationEndNanos + 1000 * 1000 * sessionLifetimeMsToUse;
clientSessionReauthenticationTimeNanos = Math.addExact(authenticationEndNanos, Utils.msToNs(sessionLifetimeMsToUse));
log.debug(
"Finished {} with session expiration in {} ms and session re-authentication on or after {} ms",
authenticationOrReauthenticationText(), positiveSessionLifetimeMs, sessionLifetimeMsToUse);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -681,7 +681,7 @@ else if (!maxReauthSet)
else
retvalSessionLifetimeMs = zeroIfNegative(Math.min(credentialExpirationMs - authenticationEndMs, connectionsMaxReauthMs));

sessionExpirationTimeNanos = authenticationEndNanos + 1000 * 1000 * retvalSessionLifetimeMs;
sessionExpirationTimeNanos = Math.addExact(authenticationEndNanos, Utils.msToNs(retvalSessionLifetimeMs));
}

if (credentialExpirationMs != null) {
Expand Down
13 changes: 13 additions & 0 deletions clients/src/main/java/org/apache/kafka/common/utils/Utils.java
Original file line number Diff line number Diff line change
Expand Up @@ -1697,4 +1697,17 @@ public static ConfigDef mergeConfigs(List<ConfigDef> configDefs) {
public interface ThrowingRunnable {
void run() throws Exception;
}

/**
* convert millisecond to nanosecond, or throw exception if overflow
* @param timeMs the time in millisecond
* @return the converted nanosecond
*/
public static long msToNs(long timeMs) {
try {
return Math.multiplyExact(1000 * 1000, timeMs);
} catch (ArithmeticException e) {
throw new IllegalArgumentException("Cannot convert " + timeMs + " millisecond to nanosecond due to arithmetic overflow", e);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,7 @@ public class SaslAuthenticatorTest {
private static final long CONNECTIONS_MAX_REAUTH_MS_VALUE = 100L;
private static final int BUFFER_SIZE = 4 * 1024;
private static Time time = Time.SYSTEM;
private static boolean needLargeExpiration = false;

private NioEchoServer server;
private Selector selector;
Expand All @@ -178,6 +179,7 @@ public void setup() throws Exception {

@AfterEach
public void teardown() throws Exception {
needLargeExpiration = false;
if (server != null)
this.server.close();
if (selector != null)
Expand Down Expand Up @@ -1607,6 +1609,42 @@ public void testCannotReauthenticateWithDifferentPrincipal() throws Exception {
server.verifyReauthenticationMetrics(0, 1);
}

@Test
public void testReauthenticateWithLargeReauthValue() throws Exception {
// enable it, we'll get a large expiration timestamp token
needLargeExpiration = true;
String node = "0";
SecurityProtocol securityProtocol = SecurityProtocol.SASL_SSL;

configureMechanisms(OAuthBearerLoginModule.OAUTHBEARER_MECHANISM,
List.of(OAuthBearerLoginModule.OAUTHBEARER_MECHANISM));
// set a large re-auth timeout in server side
saslServerConfigs.put(BrokerSecurityConfigs.CONNECTIONS_MAX_REAUTH_MS_CONFIG, Long.MAX_VALUE);
server = createEchoServer(securityProtocol);

// set to default value for sasl login configs for initialization in ExpiringCredentialRefreshConfig
saslClientConfigs.put(SaslConfigs.SASL_LOGIN_REFRESH_WINDOW_FACTOR, SaslConfigs.DEFAULT_LOGIN_REFRESH_WINDOW_FACTOR);
saslClientConfigs.put(SaslConfigs.SASL_LOGIN_REFRESH_WINDOW_JITTER, SaslConfigs.DEFAULT_LOGIN_REFRESH_WINDOW_JITTER);
saslClientConfigs.put(SaslConfigs.SASL_LOGIN_REFRESH_MIN_PERIOD_SECONDS, SaslConfigs.DEFAULT_LOGIN_REFRESH_MIN_PERIOD_SECONDS);
saslClientConfigs.put(SaslConfigs.SASL_LOGIN_REFRESH_BUFFER_SECONDS, SaslConfigs.DEFAULT_LOGIN_REFRESH_BUFFER_SECONDS);
saslClientConfigs.put(SaslConfigs.SASL_LOGIN_CALLBACK_HANDLER_CLASS, AlternateLoginCallbackHandler.class);

createCustomClientConnection(securityProtocol, OAuthBearerLoginModule.OAUTHBEARER_MECHANISM, node, true);

// channel should be not null before sasl handshake
assertNotNull(selector.channel(node));

TestUtils.waitForCondition(() -> {
selector.poll(1000);
// this channel should be closed due to session timeout calculation overflow
return selector.channel(node) == null;
}, "channel didn't close with large re-authentication value");

// ensure metrics are as expected
server.verifyAuthenticationMetrics(0, 0);
server.verifyReauthenticationMetrics(0, 0);
}

@Test
public void testCorrelationId() {
SaslClientAuthenticator authenticator = new SaslClientAuthenticator(
Expand Down Expand Up @@ -1936,7 +1974,7 @@ private void createClientConnection(SecurityProtocol securityProtocol, String sa
if (enableSaslAuthenticateHeader)
createClientConnection(securityProtocol, node);
else
createClientConnectionWithoutSaslAuthenticateHeader(securityProtocol, saslMechanism, node);
createCustomClientConnection(securityProtocol, saslMechanism, node, false);
}

private NioEchoServer startServerApiVersionsUnsupportedByClient(final SecurityProtocol securityProtocol, String saslMechanism) throws Exception {
Expand Down Expand Up @@ -2024,15 +2062,13 @@ protected void enableKafkaSaslAuthenticateHeaders(boolean flag) {
return server;
}

private void createClientConnectionWithoutSaslAuthenticateHeader(final SecurityProtocol securityProtocol,
final String saslMechanism, String node) throws Exception {

final ListenerName listenerName = ListenerName.forSecurityProtocol(securityProtocol);
final Map<String, ?> configs = Collections.emptyMap();
final JaasContext jaasContext = JaasContext.loadClientContext(configs);
final Map<String, JaasContext> jaasContexts = Collections.singletonMap(saslMechanism, jaasContext);

SaslChannelBuilder clientChannelBuilder = new SaslChannelBuilder(ConnectionMode.CLIENT, jaasContexts,
private SaslChannelBuilder saslChannelBuilderWithoutHeader(
final SecurityProtocol securityProtocol,
final String saslMechanism,
final Map<String, JaasContext> jaasContexts,
final ListenerName listenerName
) {
return new SaslChannelBuilder(ConnectionMode.CLIENT, jaasContexts,
securityProtocol, listenerName, false, saslMechanism,
null, null, null, time, new LogContext(), null) {

Expand All @@ -2059,6 +2095,42 @@ protected void setSaslAuthenticateAndHandshakeVersions(ApiVersionsResponse apiVe
};
}
};
}

private void createCustomClientConnection(
final SecurityProtocol securityProtocol,
final String saslMechanism,
String node,
boolean withSaslAuthenticateHeader
) throws Exception {

final ListenerName listenerName = ListenerName.forSecurityProtocol(securityProtocol);
final Map<String, ?> configs = Collections.emptyMap();
final JaasContext jaasContext = JaasContext.loadClientContext(configs);
final Map<String, JaasContext> jaasContexts = Collections.singletonMap(saslMechanism, jaasContext);

SaslChannelBuilder clientChannelBuilder;
if (!withSaslAuthenticateHeader) {
clientChannelBuilder = saslChannelBuilderWithoutHeader(securityProtocol, saslMechanism, jaasContexts, listenerName);
} else {
clientChannelBuilder = new SaslChannelBuilder(ConnectionMode.CLIENT, jaasContexts,
securityProtocol, listenerName, false, saslMechanism,
null, null, null, time, new LogContext(), null) {

@Override
protected SaslClientAuthenticator buildClientAuthenticator(Map<String, ?> configs,
AuthenticateCallbackHandler callbackHandler,
String id,
String serverHost,
String servicePrincipal,
TransportLayer transportLayer,
Subject subject) {

return new SaslClientAuthenticator(configs, callbackHandler, id, subject,
servicePrincipal, serverHost, saslMechanism, transportLayer, time, new LogContext());
}
};
}
clientChannelBuilder.configure(saslClientConfigs);
this.selector = NetworkTestUtils.createSelector(clientChannelBuilder, time);
InetSocketAddress addr = new InetSocketAddress("localhost", server.port());
Expand Down Expand Up @@ -2507,10 +2579,11 @@ public void handle(Callback[] callbacks) throws IOException, UnsupportedCallback
+ ++numInvocations;
String headerJson = "{" + claimOrHeaderJsonText("alg", "none") + "}";
/*
* Use a short lifetime so the background refresh thread replaces it before we
* If we're testing large expiration scenario, use a large lifetime.
* Otherwise, use a short lifetime so the background refresh thread replaces it before we
* re-authenticate
*/
String lifetimeSecondsValueToUse = "1";
String lifetimeSecondsValueToUse = needLargeExpiration ? String.valueOf(Long.MAX_VALUE) : "1";
String claimsJson;
try {
claimsJson = String.format("{%s,%s,%s}",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,35 @@ public void testSessionExpiresAtTokenExpiry() throws IOException {
}
}

@Test
public void testSessionWontExpireWithLargeExpirationTime() throws IOException {
String mechanism = OAuthBearerLoginModule.OAUTHBEARER_MECHANISM;
SaslServer saslServer = mock(SaslServer.class);
MockTime time = new MockTime(0, 1, 1000);
// set a Long.MAX_VALUE as the expiration time
Duration largeExpirationTime = Duration.ofMillis(Long.MAX_VALUE);

try (
MockedStatic<?> ignored = mockSaslServer(saslServer, mechanism, time, largeExpirationTime);
MockedStatic<?> ignored2 = mockKafkaPrincipal("[principal-type]", "[principal-name");
TransportLayer transportLayer = mockTransportLayer()
) {

SaslServerAuthenticator authenticator = getSaslServerAuthenticatorForOAuth(mechanism, transportLayer, time, largeExpirationTime.toMillis());

mockRequest(saslHandshakeRequest(mechanism), transportLayer);
authenticator.authenticate();

when(saslServer.isComplete()).thenReturn(false).thenReturn(true);
mockRequest(saslAuthenticateRequest(), transportLayer);

Throwable t = assertThrows(IllegalArgumentException.class, () -> authenticator.authenticate());
assertEquals(ArithmeticException.class, t.getCause().getClass());
assertEquals("Cannot convert " + Long.MAX_VALUE + " millisecond to nanosecond due to arithmetic overflow",
t.getMessage());
}
}

private SaslServerAuthenticator getSaslServerAuthenticatorForOAuth(String mechanism, TransportLayer transportLayer, Time time, Long maxReauth) {
Map<String, ?> configs = Collections.singletonMap(BrokerSecurityConfigs.SASL_ENABLED_MECHANISMS_CONFIG,
Collections.singletonList(mechanism));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1109,6 +1109,13 @@ public void testTryAll() throws Throwable {
assertEquals(expected, recorded);
}

@Test
public void testMsToNs() {
assertEquals(1000000, Utils.msToNs(1));
assertEquals(0, Utils.msToNs(0));
assertThrows(IllegalArgumentException.class, () -> Utils.msToNs(Long.MAX_VALUE));
}

private Callable<Void> recordingCallable(Map<String, Object> recordingMap, String success, TestException failure) {
return () -> {
if (success == null)
Expand Down
Loading