Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(amazonq): Enable SageMaker SSO Users to use AmazonQ Chat & Code completion using their Credentials & Q Pro-tier profile #6338

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ src.gen/*
**/src/shared/telemetry/clienttelemetry.d.ts
**/src/codewhisperer/client/codewhispererclient.d.ts
**/src/codewhisperer/client/codewhispereruserclient.d.ts
**/src/shared/sagemaker/client/sagemakerclient.d.ts
**/src/amazonqFeatureDev/client/featuredevproxyclient.d.ts
**/src/auth/sso/oidcclientpkce.d.ts

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
{
"type": "Feature",
"description": "Enable SageMaker SSO user access token & Profile ARN to be used for accessing AmazonQ & CodeWhisperer features"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"CodeWhisperer" is no more

Suggested change
"description": "Enable SageMaker SSO user access token & Profile ARN to be used for accessing AmazonQ & CodeWhisperer features"
"description": "Auth: Enable SSO access for SageMaker users."

}
4 changes: 4 additions & 0 deletions packages/core/scripts/build/generateServiceClient.ts
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,10 @@ void (async () => {
serviceJsonPath: 'src/amazonqFeatureDev/client/codewhispererruntime-2022-11-11.json',
serviceName: 'FeatureDevProxyClient',
},
{
serviceJsonPath: 'src/shared/sagemaker/client/service-2.json',
serviceName: 'SageMakerClient',
},
]
await generateServiceClients(serviceClientDefinitions)
})()
18 changes: 4 additions & 14 deletions packages/core/src/auth/activation.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,31 +3,21 @@
* SPDX-License-Identifier: Apache-2.0
*/

import * as vscode from 'vscode'
import { Auth } from './auth'
import { LoginManager } from './deprecated/loginManager'
import { fromString } from './providers/credentials'
import { getLogger } from '../shared/logger'
import { ExtensionUse, initializeCredentialsProviderManager } from './utils'
import { isAmazonQ, isCloud9, isSageMaker } from '../shared/extensionUtilities'
import { isCloud9, isSageMaker } from '../shared/extensionUtilities'
import { isInDevEnv } from '../shared/vscode/env'
import { isWeb } from '../shared/extensionGlobals'

interface SagemakerCookie {
authMode?: 'Sso' | 'Iam'
}

export async function initialize(loginManager: LoginManager): Promise<void> {
if (isAmazonQ() && isSageMaker()) {
// The command `sagemaker.parseCookies` is registered in VS Code Sagemaker environment.
const result = (await vscode.commands.executeCommand('sagemaker.parseCookies')) as SagemakerCookie
if (result.authMode !== 'Sso') {
initializeCredentialsProviderManager()
}
}
await initializeCredentialsProviderManager()

Auth.instance.onDidChangeActiveConnection(async (conn) => {
// This logic needs to be moved to `Auth.useConnection` to correctly record `passive`
if (conn?.type === 'iam' && conn.state === 'valid') {
if (conn?.state === 'valid' && (isSageMaker() || conn?.type === 'iam')) {
await loginManager.login({ passive: true, providerId: fromString(conn.id) })
} else {
await loginManager.logout()
Expand Down
64 changes: 59 additions & 5 deletions packages/core/src/auth/auth.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@ const localize = nls.loadMessageBundle()
import * as vscode from 'vscode'
import * as localizedText from '../shared/localizedText'
import { Credentials } from '@aws-sdk/types'
import { SsoAccessTokenProvider } from './sso/ssoAccessTokenProvider'
import { SsoAccessTokenProvider, SsoTokenProvider } from './sso/ssoAccessTokenProvider'
import { Timeout } from '../shared/utilities/timeoutUtils'
import { errorCode, isAwsError, isNetworkError, ToolkitError, UnknownError } from '../shared/errors'
import { getCache, getCacheFileWatcher } from './sso/cache'
import { getCache, getCacheFileWatcher, SsoCache } from './sso/cache'
import { isNonNullable, Mutable } from '../shared/utilities/tsUtils'
import { SsoToken, truncateStartUrl } from './sso/model'
import { SsoClient } from './sso/clients'
Expand Down Expand Up @@ -69,6 +69,7 @@ import { withTelemetryContext } from '../shared/telemetry/util'
import { DiskCacheError } from '../shared/utilities/cacheUtils'
import { setContext } from '../shared/vscode/setContext'
import { builderIdStartUrl, internalStartUrl } from './sso/constants'
import { SageMakerSsoTokenProvider } from './sso/sageMakerAccessTokenProvider'

interface AuthService {
/**
Expand Down Expand Up @@ -121,6 +122,10 @@ function keyedDebounce<T, U extends any[], K extends string = string>(
}
}

export function useSageMakerSsoProfile() {
return isSageMaker() && isAmazonQ()
}

export interface ConnectionStateChangeEvent {
readonly id: Connection['id']
readonly state: ProfileMetadata['connectionState']
Expand All @@ -141,17 +146,29 @@ export class Auth implements AuthService, ConnectionManager {
readonly #onDidChangeConnectionState = new vscode.EventEmitter<ConnectionStateChangeEvent>()
readonly #onDidUpdateConnection = new vscode.EventEmitter<StatefulConnection>()
readonly #onDidDeleteConnection = new vscode.EventEmitter<DeletedConnection>()
readonly #onDidPrecreateActiveConnection = new vscode.EventEmitter<StatefulConnection>()
public readonly onDidChangeActiveConnection = this.#onDidChangeActiveConnection.event
public readonly onDidChangeConnectionState = this.#onDidChangeConnectionState.event
public readonly onDidUpdateConnection = this.#onDidUpdateConnection.event
/** Fired when a connection and its metadata has been completely deleted */
public readonly onDidDeleteConnection = this.#onDidDeleteConnection.event
public readonly onDidPrecreateActiveConnection = this.#onDidPrecreateActiveConnection.event

public constructor(
private readonly store: ProfileStore,
private readonly iamProfileProvider = CredentialsProviderManager.getInstance(),
private readonly createSsoClient = SsoClient.create.bind(SsoClient),
private readonly createSsoTokenProvider = SsoAccessTokenProvider.create.bind(SsoAccessTokenProvider)
private readonly createSsoTokenProvider: (
profile: {
readonly startUrl: string
readonly region: string
readonly identifier?: string
readonly scopes: string[]
},
cache?: SsoCache
) => SsoTokenProvider = useSageMakerSsoProfile()
? SageMakerSsoTokenProvider.create.bind(SageMakerSsoTokenProvider)
: SsoAccessTokenProvider.create.bind(SsoAccessTokenProvider)
) {}

#activeConnection: Mutable<StatefulConnection> | undefined
Expand Down Expand Up @@ -324,6 +341,29 @@ export class Auth implements AuthService, ConnectionManager {
return toCollection(load.bind(this))
}

private async createSageMakerSsoConnection(): Promise<StatefulConnection | undefined> {
if (!useSageMakerSsoProfile) {
return undefined
}
const id = SageMakerSsoTokenProvider.sagemakerConectionId
const { startUrl, region, scopes } = SageMakerSsoTokenProvider.getSagemakerProfile()
const profile = createSsoProfile(startUrl, region, scopes)
const tokenProvider = this.getSsoTokenProvider(id, {
...profile,
metadata: { connectionState: 'unauthenticated' },
})

const token = await tokenProvider.getToken()
if (!token) {
return undefined
}

const storedProfile = await this.store.addProfile(id, profile)
await this.updateConnectionState(id, 'valid')
const connection = this.getSsoConnection(id, storedProfile)
return connection
}

public async createConnection(profile: SsoProfile): Promise<SsoConnection>
@withTelemetryContext({ name: 'createConnection', class: authClassName })
public async createConnection(profile: Profile): Promise<Connection> {
Expand Down Expand Up @@ -786,7 +826,7 @@ export class Auth implements AuthService, ConnectionManager {
{
identifier: tokenIdentifier,
startUrl: profile.startUrl,
scopes: profile.scopes,
scopes: profile.scopes ?? [],
region: profile.ssoRegion,
},
this.#ssoCache
Expand Down Expand Up @@ -859,7 +899,7 @@ export class Auth implements AuthService, ConnectionManager {

private readonly getToken = keyedDebounce(this._getToken.bind(this))
@withTelemetryContext({ name: '_getToken', class: authClassName })
private async _getToken(id: Connection['id'], provider: SsoAccessTokenProvider): Promise<SsoToken> {
private async _getToken(id: Connection['id'], provider: SsoTokenProvider): Promise<SsoToken> {
const token = await provider.getToken().catch((err) => {
this.throwOnRecoverableError(err)

Expand Down Expand Up @@ -963,6 +1003,20 @@ export class Auth implements AuthService, ConnectionManager {
return this.authenticate(id, refresh)
}

public async tryAutoConnectSageMaker(): Promise<StatefulConnection | undefined> {
try {
const sagemakerConnection = await this.createSageMakerSsoConnection()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could the connection already exist in a saved state from a prior session?

if (!sagemakerConnection) {
return undefined
}

await this.useConnection({ id: SageMakerSsoTokenProvider.sagemakerConectionId })
return sagemakerConnection
} catch (err) {
getLogger().warn(`auth: failed to connect using SageMaker auth token: %s`, err)
}
}

public readonly tryAutoConnect = once(async () => this._tryAutoConnect())
@withTelemetryContext({ name: 'tryAutoConnect', class: authClassName })
private async _tryAutoConnect() {
Expand Down
4 changes: 2 additions & 2 deletions packages/core/src/auth/providers/ssoCredentialsProvider.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,13 @@ import { CredentialType } from '../../shared/telemetry/telemetry.gen'
import { getStringHash } from '../../shared/utilities/textUtilities'
import { CredentialsId, CredentialsProvider, CredentialsProviderType } from './credentials'
import { SsoClient } from '../sso/clients'
import { SsoAccessTokenProvider } from '../sso/ssoAccessTokenProvider'
import { SsoTokenProvider } from '../sso/ssoAccessTokenProvider'

export class SsoCredentialsProvider implements CredentialsProvider {
public constructor(
private readonly id: CredentialsId,
private readonly client: SsoClient,
private readonly tokenProvider: SsoAccessTokenProvider,
private readonly tokenProvider: SsoTokenProvider,
private readonly accountId: string,
private readonly roleName: string
) {}
Expand Down
25 changes: 23 additions & 2 deletions packages/core/src/auth/secondaryAuth.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import * as vscode from 'vscode'
import { getLogger } from '../shared/logger'
import { cast, Optional } from '../shared/utilities/typeConstructors'
import { Auth } from './auth'
import { Auth, useSageMakerSsoProfile } from './auth'
import { onceChanged } from '../shared/utilities/functionUtils'
import { isNonNullable } from '../shared/utilities/tsUtils'
import { ToolIdStateKey } from '../shared/globalState'
Expand All @@ -24,7 +24,7 @@ let currentConn: Auth['activeConnection']
const auths = new Map<string, SecondaryAuth>()
const multiConnectionListeners = new WeakMap<Auth, vscode.Disposable>()
const registerAuthListener = (auth: Auth) => {
return auth.onDidChangeActiveConnection(async (newConn) => {
const activeConnectionChangeListener = auth.onDidChangeActiveConnection(async (newConn) => {
// When we change the active connection, there may be
// secondary auths that were dependent on the previous active connection.
// To ensure secondary auths still work, when we change to a new active connection,
Expand All @@ -38,6 +38,21 @@ const registerAuthListener = (auth: Auth) => {
}
currentConn = newConn
})

const precreatedConnectionCreatedListener = auth.onDidPrecreateActiveConnection(async (newConn) => {
await Promise.all(
Array.from(auths.values())
.filter((a) => !a.hasSavedConnection && a.isUsable(newConn))
.map((a) => a.saveConnection(newConn))
)
})

return {
dispose: () => {
activeConnectionChangeListener.dispose()
precreatedConnectionCreatedListener.dispose()
},
}
}

export function getSecondaryAuth<T extends Connection>(
Expand Down Expand Up @@ -306,6 +321,12 @@ export class SecondaryAuth<T extends Connection = Connection> {
id: 'undefined',
connectionState: 'undefined',
})
if (useSageMakerSsoProfile()) {
const connection = await this.auth.tryAutoConnectSageMaker()
if (connection) {
this.saveConnection(connection as unknown as T)
}
}
await this.auth.tryAutoConnect()
this.#savedConnection = await this._loadSavedConnection(span)
this.#onDidChangeActiveConnection.fire(this.activeConnection)
Expand Down
6 changes: 3 additions & 3 deletions packages/core/src/auth/sso/clients.ts
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ import { AsyncCollection } from '../../shared/utilities/asyncCollection'
import { pageableToCollection, partialClone } from '../../shared/utilities/collectionUtils'
import { assertHasProps, isNonNullable, RequiredProps, selectFrom } from '../../shared/utilities/tsUtils'
import { getLogger } from '../../shared/logger'
import { SsoAccessTokenProvider } from './ssoAccessTokenProvider'
import { SsoTokenProvider } from './ssoAccessTokenProvider'
import { AwsClientResponseError, isClientFault } from '../../shared/errors'
import { DevSettings } from '../../shared/settings'
import { SdkError } from '@aws-sdk/types'
Expand Down Expand Up @@ -158,7 +158,7 @@ export class SsoClient {

public constructor(
private readonly client: PromisifyClient<SSO>,
private readonly provider: SsoAccessTokenProvider
private readonly provider: SsoTokenProvider
) {}

@withTelemetryContext({ name: 'listAccounts', class: ssoClientClassName })
Expand Down Expand Up @@ -236,7 +236,7 @@ export class SsoClient {
throw error
}

public static create(region: string, provider: SsoAccessTokenProvider) {
public static create(region: string, provider: SsoTokenProvider) {
return new this(
new SSO({
region,
Expand Down
Loading
Loading