diff --git a/azure-client-authentication/pom.xml b/azure-client-authentication/pom.xml index a136a8f8da..9db2e9db57 100644 --- a/azure-client-authentication/pom.xml +++ b/azure-client-authentication/pom.xml @@ -69,6 +69,12 @@ junit test + + org.mockito + mockito-core + 3.4.4 + test + diff --git a/azure-client-authentication/src/main/java/com/microsoft/azure/credentials/ApplicationTokenCredentials.java b/azure-client-authentication/src/main/java/com/microsoft/azure/credentials/ApplicationTokenCredentials.java index ae77578750..a21137e671 100644 --- a/azure-client-authentication/src/main/java/com/microsoft/azure/credentials/ApplicationTokenCredentials.java +++ b/azure-client-authentication/src/main/java/com/microsoft/azure/credentials/ApplicationTokenCredentials.java @@ -27,10 +27,12 @@ import java.security.spec.InvalidKeySpecException; import java.security.spec.PKCS8EncodedKeySpec; import java.util.Date; -import java.util.HashMap; -import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; +import java.util.concurrent.Future; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.locks.ReentrantLock; import java.util.regex.Matcher; import java.util.regex.Pattern; @@ -39,15 +41,19 @@ */ public class ApplicationTokenCredentials extends AzureTokenCredentials { /** A mapping from resource endpoint to its cached access token. */ - private Map tokens; + private final ConcurrentHashMap tokens; + /** A mapping from resource endpoint to its current authentication locks. */ + private final ConcurrentHashMap authenticationLocks; /** The active directory application client id. */ - private String clientId; + private final String clientId; /** The authentication secret for the application. */ private String clientSecret; /** The PKCS12 certificate byte array. */ private byte[] clientCertificate; /** The certificate password. */ private String clientCertificatePassword; + /** The timeout for authentication calls. */ + private long timeoutInSeconds = 60; /** * Initializes a new instance of the ApplicationTokenCredentials. @@ -62,7 +68,8 @@ public ApplicationTokenCredentials(String clientId, String domain, String secret super(environment, domain); // defer token acquisition this.clientId = clientId; this.clientSecret = secret; - this.tokens = new HashMap<>(); + this.tokens = new ConcurrentHashMap<>(); + this.authenticationLocks = new ConcurrentHashMap<>(); } /** @@ -80,7 +87,8 @@ public ApplicationTokenCredentials(String clientId, String domain, byte[] certif this.clientId = clientId; this.clientCertificate = certificate; this.clientCertificatePassword = password; - this.tokens = new HashMap<>(); + this.tokens = new ConcurrentHashMap<>(); + this.authenticationLocks = new ConcurrentHashMap<>(); } /** @@ -132,19 +140,61 @@ String clientCertificatePassword() { return clientCertificatePassword; } + /** + * Gets the timeout for AAD authentication calls. Default is 60 seconds. + * + * @return the timeout in seconds. + */ + public long timeoutInSeconds() { + return timeoutInSeconds; + } + + /** + * Sets the timeout for AAD authentication calls. Default is 60 seconds. + * + * @param timeoutInSeconds the timeout in seconds. + * @return the modified ApplicationTokenCredentials instance + */ + public ApplicationTokenCredentials withTimeoutInSeconds(long timeoutInSeconds) { + this.timeoutInSeconds = timeoutInSeconds; + return this; + } + @Override - public synchronized String getToken(String resource) throws IOException { + public String getToken(String resource) throws IOException { AuthenticationResult authenticationResult = tokens.get(resource); if (authenticationResult == null || authenticationResult.getExpiresOnDate().before(new Date())) { - authenticationResult = acquireAccessToken(resource); + ReentrantLock lock; + synchronized (authenticationLocks) { + lock = authenticationLocks.get(resource); + if (lock == null) { + lock = new ReentrantLock(); + authenticationLocks.put(resource, lock); + } + } + lock.lock(); + try { + authenticationResult = tokens.get(resource); + if (authenticationResult == null || authenticationResult.getExpiresOnDate().before(new Date())) { + ExecutorService executor = Executors.newSingleThreadExecutor(); + try { + authenticationResult = acquireAccessToken(resource, executor).get(timeoutInSeconds(), TimeUnit.SECONDS); + tokens.put(resource, authenticationResult); + } finally { + executor.shutdown(); + } + } + } catch (Exception e) { + throw new IOException(e.getMessage(), e); + } finally { + lock.unlock(); + } } - tokens.put(resource, authenticationResult); return authenticationResult.getAccessToken(); } - private AuthenticationResult acquireAccessToken(String resource) throws IOException { + Future acquireAccessToken(String resource, ExecutorService executor) throws IOException { String authorityUrl = this.environment().activeDirectoryEndpoint() + this.domain(); - ExecutorService executor = Executors.newSingleThreadExecutor(); AuthenticationContext context = new AuthenticationContext(authorityUrl, false, executor); if (proxy() != null) { context.setProxy(proxy()); @@ -157,23 +207,21 @@ private AuthenticationResult acquireAccessToken(String resource) throws IOExcept return context.acquireToken( resource, new ClientCredential(this.clientId(), clientSecret), - null).get(); + null); } else if (clientCertificate != null && clientCertificatePassword != null) { return context.acquireToken( resource, AsymmetricKeyCredential.create(clientId, new ByteArrayInputStream(clientCertificate), clientCertificatePassword), - null).get(); + null); } else if (clientCertificate != null) { return context.acquireToken( resource, AsymmetricKeyCredential.create(clientId(), privateKeyFromPem(new String(clientCertificate)), publicKeyFromPem(new String(clientCertificate))), - null).get(); + null); } throw new AuthenticationException("Please provide either a non-null secret or a non-null certificate."); } catch (Exception e) { throw new IOException(e.getMessage(), e); - } finally { - executor.shutdown(); } }