Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -725,11 +725,12 @@ public DatabricksConfig newWithWorkspaceHost(String host) {

/**
* Gets the default OAuth redirect URL. If one is not provided explicitly, uses
* http://localhost:8080/callback
* http://localhost:8020, which is the default redirect URL for the default client ID
* (databricks-cli).
*
* @return The OAuth redirect URL to use
*/
public String getEffectiveOAuthRedirectUrl() {
return redirectUrl != null ? redirectUrl : "http://localhost:8080/callback";
return redirectUrl != null ? redirectUrl : "http://localhost:8020";
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
import com.databricks.sdk.core.DatabricksException;
import java.io.IOException;
import java.nio.file.Path;
import java.util.Arrays;
import java.util.List;
import java.util.Objects;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
Expand Down Expand Up @@ -51,12 +53,15 @@ public OAuthHeaderFactory configure(DatabricksConfig config) {
// Use the utility class to resolve client ID and client secret
String clientId = OAuthClientUtils.resolveClientId(config);
String clientSecret = OAuthClientUtils.resolveClientSecret(config);
List<String> scopes = config.getScopes();
if (scopes == null) {
scopes = Arrays.asList("all-apis", "offline_access");
}

try {
if (tokenCache == null) {
// Create a default FileTokenCache based on config
Path cachePath =
TokenCacheUtils.getCacheFilePath(config.getHost(), clientId, config.getScopes());
Path cachePath = TokenCacheUtils.getCacheFilePath(config.getHost(), clientId, scopes);
tokenCache = new FileTokenCache(cachePath);
}

Expand Down Expand Up @@ -89,7 +94,7 @@ public OAuthHeaderFactory configure(DatabricksConfig config) {

// If no cached token or refresh failed, perform browser auth
SessionCredentials credentials =
performBrowserAuth(config, clientId, clientSecret, tokenCache);
performBrowserAuth(config, clientId, clientSecret, scopes, tokenCache);
tokenCache.save(credentials.getToken());
return credentials.configure(config);
} catch (IOException | DatabricksException e) {
Expand All @@ -99,7 +104,11 @@ public OAuthHeaderFactory configure(DatabricksConfig config) {
}

SessionCredentials performBrowserAuth(
DatabricksConfig config, String clientId, String clientSecret, TokenCache tokenCache)
DatabricksConfig config,
String clientId,
String clientSecret,
List<String> scopes,
TokenCache tokenCache)
throws IOException {
LOGGER.debug("Performing browser authentication");
OAuthClient client =
Expand All @@ -109,7 +118,7 @@ SessionCredentials performBrowserAuth(
.withClientSecret(clientSecret)
.withHost(config.getHost())
.withRedirectUrl(config.getEffectiveOAuthRedirectUrl())
.withScopes(config.getScopes())
.withScopes(scopes)
.build();
Consent consent = client.initiateConsent();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import com.databricks.sdk.core.http.HttpClient;
import java.io.IOException;
import java.net.MalformedURLException;
import java.net.URLEncoder;
import java.nio.charset.StandardCharsets;
import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;
Expand Down Expand Up @@ -105,12 +106,7 @@ private OAuthClient(Builder b) throws IOException {

List<String> scopes = b.scopes;
if (scopes == null) {
scopes = Arrays.asList("offline_access", "clusters", "sql");
}
if (config.isAzure()) {
scopes =
Arrays.asList(
config.getEffectiveAzureLoginAppId() + "/user_impersonation", "offline_access");
scopes = Arrays.asList("all-apis", "offline_access");
}
this.scopes = scopes;
}
Expand Down Expand Up @@ -169,9 +165,18 @@ private static byte[] sha256(byte[] input) {
private static String urlEncode(String urlBase, Map<String, String> params) {
String queryParams =
params.entrySet().stream()
.map(entry -> entry.getKey() + "=" + entry.getValue())
.map(
entry -> {
try {
return URLEncoder.encode(entry.getKey(), StandardCharsets.UTF_8.toString())
+ "="
+ URLEncoder.encode(entry.getValue(), StandardCharsets.UTF_8.toString());
} catch (Exception e) {
throw new DatabricksException("Failed to URL encode parameters", e);
}
})
.collect(Collectors.joining("&"));
return urlBase + "?" + queryParams.replaceAll(" ", "%20");
return urlBase + "?" + queryParams;
}

public Consent initiateConsent() throws MalformedURLException {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import java.time.LocalDateTime;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.junit.jupiter.api.Test;
import org.mockito.ArgumentCaptor;
Expand Down Expand Up @@ -60,8 +61,8 @@ void clientAndConsentTest() throws IOException {
assertNotNull(authUrl);
assertTrue(authUrl.contains("response_type=code"));
assertTrue(authUrl.contains("client_id=test-client-id"));
assertTrue(authUrl.contains("redirect_uri=http://localhost:8080/callback"));
assertTrue(authUrl.contains("scope=offline_access%20clusters%20sql"));
assertTrue(authUrl.contains("redirect_uri=http%3A%2F%2Flocalhost%3A8020"));
assertTrue(authUrl.contains("scope=all-apis+offline_access"));
}
}

Expand Down Expand Up @@ -105,7 +106,7 @@ void clientAndConsentTestWithCustomRedirectUrl() throws IOException {
assertNotNull(authUrl);
assertTrue(authUrl.contains("response_type=code"));
assertTrue(authUrl.contains("client_id=test-client-id"));
assertTrue(authUrl.contains("redirect_uri=http://localhost:8010"));
assertTrue(authUrl.contains("redirect_uri=http%3A%2F%2Flocalhost%3A8010"));
assertTrue(authUrl.contains("scope=sql"));
}
}
Expand Down Expand Up @@ -281,6 +282,7 @@ void cacheWithValidTokenTest() throws IOException {
any(DatabricksConfig.class),
any(String.class),
any(String.class),
any(List.class),
any(TokenCache.class));

// Verify token was saved back to cache
Expand Down Expand Up @@ -362,6 +364,7 @@ void cacheWithInvalidAccessTokenValidRefreshTest() throws IOException {
any(DatabricksConfig.class),
any(String.class),
any(String.class),
any(List.class),
any(TokenCache.class));

// Verify token was saved back to cache
Expand Down Expand Up @@ -433,7 +436,8 @@ void cacheWithInvalidAccessTokenRefreshFailingTest() throws IOException {
Mockito.spy(new ExternalBrowserCredentialsProvider(mockTokenCache));
Mockito.doReturn(browserAuthCreds)
.when(provider)
.performBrowserAuth(any(DatabricksConfig.class), any(), any(), any(TokenCache.class));
.performBrowserAuth(
any(DatabricksConfig.class), any(), any(), any(List.class), any(TokenCache.class));

// Spy on the config to inject the endpoints
DatabricksConfig spyConfig = Mockito.spy(config);
Expand All @@ -451,7 +455,8 @@ void cacheWithInvalidAccessTokenRefreshFailingTest() throws IOException {

// Verify performBrowserAuth was called since refresh failed
Mockito.verify(provider, Mockito.times(1))
.performBrowserAuth(any(DatabricksConfig.class), any(), any(), any(TokenCache.class));
.performBrowserAuth(
any(DatabricksConfig.class), any(), any(), any(List.class), any(TokenCache.class));

// Verify token was saved after browser auth (for the new token)
Mockito.verify(mockTokenCache, Mockito.times(1)).save(any(Token.class));
Expand Down Expand Up @@ -494,7 +499,8 @@ void cacheWithInvalidTokensTest() throws IOException {
Mockito.spy(new ExternalBrowserCredentialsProvider(mockTokenCache));
Mockito.doReturn(browserAuthCreds)
.when(provider)
.performBrowserAuth(any(DatabricksConfig.class), any(), any(), any(TokenCache.class));
.performBrowserAuth(
any(DatabricksConfig.class), any(), any(), any(List.class), any(TokenCache.class));

// Configure provider
HeaderFactory headerFactory = provider.configure(config);
Expand All @@ -507,7 +513,8 @@ void cacheWithInvalidTokensTest() throws IOException {

// Verify performBrowserAuth was called since we had an invalid token
Mockito.verify(provider, Mockito.times(1))
.performBrowserAuth(any(DatabricksConfig.class), any(), any(), any(TokenCache.class));
.performBrowserAuth(
any(DatabricksConfig.class), any(), any(), any(List.class), any(TokenCache.class));

// Verify token was saved after browser auth (for the new token)
Mockito.verify(mockTokenCache, Mockito.times(1)).save(any(Token.class));
Expand Down
Loading