diff --git a/azure-client-authentication/pom.xml b/azure-client-authentication/pom.xml
index a136a8f8da..9db2e9db57 100644
--- a/azure-client-authentication/pom.xml
+++ b/azure-client-authentication/pom.xml
@@ -69,6 +69,12 @@
junit
test
+
+ org.mockito
+ mockito-core
+ 3.4.4
+ test
+
diff --git a/azure-client-authentication/src/main/java/com/microsoft/azure/credentials/ApplicationTokenCredentials.java b/azure-client-authentication/src/main/java/com/microsoft/azure/credentials/ApplicationTokenCredentials.java
index ae77578750..a21137e671 100644
--- a/azure-client-authentication/src/main/java/com/microsoft/azure/credentials/ApplicationTokenCredentials.java
+++ b/azure-client-authentication/src/main/java/com/microsoft/azure/credentials/ApplicationTokenCredentials.java
@@ -27,10 +27,12 @@
import java.security.spec.InvalidKeySpecException;
import java.security.spec.PKCS8EncodedKeySpec;
import java.util.Date;
-import java.util.HashMap;
-import java.util.Map;
+import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
+import java.util.concurrent.Future;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.locks.ReentrantLock;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
@@ -39,15 +41,19 @@
*/
public class ApplicationTokenCredentials extends AzureTokenCredentials {
/** A mapping from resource endpoint to its cached access token. */
- private Map tokens;
+ private final ConcurrentHashMap tokens;
+ /** A mapping from resource endpoint to its current authentication locks. */
+ private final ConcurrentHashMap authenticationLocks;
/** The active directory application client id. */
- private String clientId;
+ private final String clientId;
/** The authentication secret for the application. */
private String clientSecret;
/** The PKCS12 certificate byte array. */
private byte[] clientCertificate;
/** The certificate password. */
private String clientCertificatePassword;
+ /** The timeout for authentication calls. */
+ private long timeoutInSeconds = 60;
/**
* Initializes a new instance of the ApplicationTokenCredentials.
@@ -62,7 +68,8 @@ public ApplicationTokenCredentials(String clientId, String domain, String secret
super(environment, domain); // defer token acquisition
this.clientId = clientId;
this.clientSecret = secret;
- this.tokens = new HashMap<>();
+ this.tokens = new ConcurrentHashMap<>();
+ this.authenticationLocks = new ConcurrentHashMap<>();
}
/**
@@ -80,7 +87,8 @@ public ApplicationTokenCredentials(String clientId, String domain, byte[] certif
this.clientId = clientId;
this.clientCertificate = certificate;
this.clientCertificatePassword = password;
- this.tokens = new HashMap<>();
+ this.tokens = new ConcurrentHashMap<>();
+ this.authenticationLocks = new ConcurrentHashMap<>();
}
/**
@@ -132,19 +140,61 @@ String clientCertificatePassword() {
return clientCertificatePassword;
}
+ /**
+ * Gets the timeout for AAD authentication calls. Default is 60 seconds.
+ *
+ * @return the timeout in seconds.
+ */
+ public long timeoutInSeconds() {
+ return timeoutInSeconds;
+ }
+
+ /**
+ * Sets the timeout for AAD authentication calls. Default is 60 seconds.
+ *
+ * @param timeoutInSeconds the timeout in seconds.
+ * @return the modified ApplicationTokenCredentials instance
+ */
+ public ApplicationTokenCredentials withTimeoutInSeconds(long timeoutInSeconds) {
+ this.timeoutInSeconds = timeoutInSeconds;
+ return this;
+ }
+
@Override
- public synchronized String getToken(String resource) throws IOException {
+ public String getToken(String resource) throws IOException {
AuthenticationResult authenticationResult = tokens.get(resource);
if (authenticationResult == null || authenticationResult.getExpiresOnDate().before(new Date())) {
- authenticationResult = acquireAccessToken(resource);
+ ReentrantLock lock;
+ synchronized (authenticationLocks) {
+ lock = authenticationLocks.get(resource);
+ if (lock == null) {
+ lock = new ReentrantLock();
+ authenticationLocks.put(resource, lock);
+ }
+ }
+ lock.lock();
+ try {
+ authenticationResult = tokens.get(resource);
+ if (authenticationResult == null || authenticationResult.getExpiresOnDate().before(new Date())) {
+ ExecutorService executor = Executors.newSingleThreadExecutor();
+ try {
+ authenticationResult = acquireAccessToken(resource, executor).get(timeoutInSeconds(), TimeUnit.SECONDS);
+ tokens.put(resource, authenticationResult);
+ } finally {
+ executor.shutdown();
+ }
+ }
+ } catch (Exception e) {
+ throw new IOException(e.getMessage(), e);
+ } finally {
+ lock.unlock();
+ }
}
- tokens.put(resource, authenticationResult);
return authenticationResult.getAccessToken();
}
- private AuthenticationResult acquireAccessToken(String resource) throws IOException {
+ Future acquireAccessToken(String resource, ExecutorService executor) throws IOException {
String authorityUrl = this.environment().activeDirectoryEndpoint() + this.domain();
- ExecutorService executor = Executors.newSingleThreadExecutor();
AuthenticationContext context = new AuthenticationContext(authorityUrl, false, executor);
if (proxy() != null) {
context.setProxy(proxy());
@@ -157,23 +207,21 @@ private AuthenticationResult acquireAccessToken(String resource) throws IOExcept
return context.acquireToken(
resource,
new ClientCredential(this.clientId(), clientSecret),
- null).get();
+ null);
} else if (clientCertificate != null && clientCertificatePassword != null) {
return context.acquireToken(
resource,
AsymmetricKeyCredential.create(clientId, new ByteArrayInputStream(clientCertificate), clientCertificatePassword),
- null).get();
+ null);
} else if (clientCertificate != null) {
return context.acquireToken(
resource,
AsymmetricKeyCredential.create(clientId(), privateKeyFromPem(new String(clientCertificate)), publicKeyFromPem(new String(clientCertificate))),
- null).get();
+ null);
}
throw new AuthenticationException("Please provide either a non-null secret or a non-null certificate.");
} catch (Exception e) {
throw new IOException(e.getMessage(), e);
- } finally {
- executor.shutdown();
}
}