Skip to content

Commit

Permalink
Add timeout to ApplicationTokenCredentials (#676)
Browse files Browse the repository at this point in the history
* Add timeout to ApplicationTokenCredentials

* Fix executor and add unit tests
  • Loading branch information
jianghaolu authored Jul 27, 2020
1 parent dbc615c commit c152e03
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 16 deletions.
6 changes: 6 additions & 0 deletions azure-client-authentication/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,12 @@
<artifactId>junit</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.mockito</groupId>
<artifactId>mockito-core</artifactId>
<version>3.4.4</version>
<scope>test</scope>
</dependency>
</dependencies>
<build>
<plugins>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,12 @@
import java.security.spec.InvalidKeySpecException;
import java.security.spec.PKCS8EncodedKeySpec;
import java.util.Date;
import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.locks.ReentrantLock;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

Expand All @@ -39,15 +41,19 @@
*/
public class ApplicationTokenCredentials extends AzureTokenCredentials {
/** A mapping from resource endpoint to its cached access token. */
private Map<String, AuthenticationResult> tokens;
private final ConcurrentHashMap<String, AuthenticationResult> tokens;
/** A mapping from resource endpoint to its current authentication locks. */
private final ConcurrentHashMap<String, ReentrantLock> authenticationLocks;
/** The active directory application client id. */
private String clientId;
private final String clientId;
/** The authentication secret for the application. */
private String clientSecret;
/** The PKCS12 certificate byte array. */
private byte[] clientCertificate;
/** The certificate password. */
private String clientCertificatePassword;
/** The timeout for authentication calls. */
private long timeoutInSeconds = 60;

/**
* Initializes a new instance of the ApplicationTokenCredentials.
Expand All @@ -62,7 +68,8 @@ public ApplicationTokenCredentials(String clientId, String domain, String secret
super(environment, domain); // defer token acquisition
this.clientId = clientId;
this.clientSecret = secret;
this.tokens = new HashMap<>();
this.tokens = new ConcurrentHashMap<>();
this.authenticationLocks = new ConcurrentHashMap<>();
}

/**
Expand All @@ -80,7 +87,8 @@ public ApplicationTokenCredentials(String clientId, String domain, byte[] certif
this.clientId = clientId;
this.clientCertificate = certificate;
this.clientCertificatePassword = password;
this.tokens = new HashMap<>();
this.tokens = new ConcurrentHashMap<>();
this.authenticationLocks = new ConcurrentHashMap<>();
}

/**
Expand Down Expand Up @@ -132,19 +140,61 @@ String clientCertificatePassword() {
return clientCertificatePassword;
}

/**
* Gets the timeout for AAD authentication calls. Default is 60 seconds.
*
* @return the timeout in seconds.
*/
public long timeoutInSeconds() {
return timeoutInSeconds;
}

/**
* Sets the timeout for AAD authentication calls. Default is 60 seconds.
*
* @param timeoutInSeconds the timeout in seconds.
* @return the modified ApplicationTokenCredentials instance
*/
public ApplicationTokenCredentials withTimeoutInSeconds(long timeoutInSeconds) {
this.timeoutInSeconds = timeoutInSeconds;
return this;
}

@Override
public synchronized String getToken(String resource) throws IOException {
public String getToken(String resource) throws IOException {
AuthenticationResult authenticationResult = tokens.get(resource);
if (authenticationResult == null || authenticationResult.getExpiresOnDate().before(new Date())) {
authenticationResult = acquireAccessToken(resource);
ReentrantLock lock;
synchronized (authenticationLocks) {
lock = authenticationLocks.get(resource);
if (lock == null) {
lock = new ReentrantLock();
authenticationLocks.put(resource, lock);
}
}
lock.lock();
try {
authenticationResult = tokens.get(resource);
if (authenticationResult == null || authenticationResult.getExpiresOnDate().before(new Date())) {
ExecutorService executor = Executors.newSingleThreadExecutor();
try {
authenticationResult = acquireAccessToken(resource, executor).get(timeoutInSeconds(), TimeUnit.SECONDS);
tokens.put(resource, authenticationResult);
} finally {
executor.shutdown();
}
}
} catch (Exception e) {
throw new IOException(e.getMessage(), e);
} finally {
lock.unlock();
}
}
tokens.put(resource, authenticationResult);
return authenticationResult.getAccessToken();
}

private AuthenticationResult acquireAccessToken(String resource) throws IOException {
Future<AuthenticationResult> acquireAccessToken(String resource, ExecutorService executor) throws IOException {
String authorityUrl = this.environment().activeDirectoryEndpoint() + this.domain();
ExecutorService executor = Executors.newSingleThreadExecutor();
AuthenticationContext context = new AuthenticationContext(authorityUrl, false, executor);
if (proxy() != null) {
context.setProxy(proxy());
Expand All @@ -157,23 +207,21 @@ private AuthenticationResult acquireAccessToken(String resource) throws IOExcept
return context.acquireToken(
resource,
new ClientCredential(this.clientId(), clientSecret),
null).get();
null);
} else if (clientCertificate != null && clientCertificatePassword != null) {
return context.acquireToken(
resource,
AsymmetricKeyCredential.create(clientId, new ByteArrayInputStream(clientCertificate), clientCertificatePassword),
null).get();
null);
} else if (clientCertificate != null) {
return context.acquireToken(
resource,
AsymmetricKeyCredential.create(clientId(), privateKeyFromPem(new String(clientCertificate)), publicKeyFromPem(new String(clientCertificate))),
null).get();
null);
}
throw new AuthenticationException("Please provide either a non-null secret or a non-null certificate.");
} catch (Exception e) {
throw new IOException(e.getMessage(), e);
} finally {
executor.shutdown();
}
}

Expand Down

0 comments on commit c152e03

Please sign in to comment.