Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,10 @@ import com.azure.compute.batch.models.BatchJobCreateContent
import com.azure.compute.batch.models.BatchJobConstraints
import com.azure.compute.batch.models.BatchJobUpdateContent
import com.azure.compute.batch.models.BatchNodeFillType
import com.azure.compute.batch.models.BatchNodeIdentityReference
import com.azure.compute.batch.models.BatchPool
import com.azure.compute.batch.models.BatchPoolCreateContent
import com.azure.compute.batch.models.BatchPoolIdentity
import com.azure.compute.batch.models.BatchPoolInfo
import com.azure.compute.batch.models.BatchPoolState
import com.azure.compute.batch.models.BatchStartTask
Expand All @@ -60,6 +62,7 @@ import com.azure.compute.batch.models.OutputFileDestination
import com.azure.compute.batch.models.OutputFileUploadCondition
import com.azure.compute.batch.models.OutputFileUploadConfig
import com.azure.compute.batch.models.ResourceFile
import com.azure.compute.batch.models.UserAssignedIdentity
import com.azure.compute.batch.models.UserIdentity
import com.azure.compute.batch.models.VirtualMachineConfiguration
import com.azure.core.credential.AzureNamedKeyCredential
Expand Down Expand Up @@ -113,6 +116,8 @@ class AzBatchService implements Closeable {
static private final long _1GB = 1 << 30

static final private Map<String,AzVmPoolSpec> allPools = new HashMap<>(50)

static final private Map<String,String> poolManagedIdentityIds = new HashMap<>(50)

AzConfig config

Expand Down Expand Up @@ -474,6 +479,7 @@ class AzBatchService implements Closeable {
assert jobId, 'Missing Azure Batch jobId argument'
assert task, 'Missing Azure Batch task argument'

// SAS token is always required for output file uploads via azcopy
final sas = config.storage().sasToken
if( !sas )
throw new IllegalArgumentException("Missing Azure Blob storage SAS token")
Expand Down Expand Up @@ -545,11 +551,20 @@ class AzBatchService implements Closeable {
final constraints = taskConstraints(task)

log.trace "[AZURE BATCH] Submitting task: $taskId, cpus=${task.config.getCpus()}, mem=${task.config.getMemory()?:'-'}, slots: $slots"

// Check if we should use managed identity and identify the resource ID (ARM)
final poolIdentityClientId = config.batch().poolIdentityClientId
final poolManagedIdentityResourceId = poolIdentityClientId ? getPoolManagedIdentityResourceId(poolId, poolIdentityClientId) : null
if( poolIdentityClientId && !poolManagedIdentityResourceId ) {
// Throw a warning if we are trying to use managed identity and can't locate it on the pool
log.warn "[AZURE BATCH] No managed identity found for pool '$poolId' with client ID '${poolIdentityClientId}'. Falling back to SAS token authentication."
}

return new BatchTaskCreateContent(taskId, cmd)
.setUserIdentity(userIdentity(pool.opts.privileged, pool.opts.runAs, AutoUserScope.TASK))
.setContainerSettings(containerOpts)
.setResourceFiles(resourceFileUrls(task, sas))
.setOutputFiles(outputFileUrls(task, sas))
.setResourceFiles(resourceFileUrls(task, poolManagedIdentityResourceId, sas))
.setOutputFiles(outputFileUrls(task, poolManagedIdentityResourceId, sas))
.setRequiredSlots(slots)
.setConstraints(constraints)
}
Expand Down Expand Up @@ -591,7 +606,7 @@ class AzBatchService implements Closeable {
return result
}

protected List<ResourceFile> resourceFileUrls(TaskRun task, String sas) {
protected List<ResourceFile> resourceFileUrls(TaskRun task, String poolManagedIdentityResourceId, String sas) {
final cmdRun = (AzPath) task.workDir.resolve(TaskRun.CMD_RUN)
final cmdScript = (AzPath) task.workDir.resolve(TaskRun.CMD_SCRIPT)

Expand All @@ -604,44 +619,94 @@ class AzBatchService implements Closeable {
.setFilePath('.nextflow-bin/azcopy')
}

resFiles << new ResourceFile()
.setHttpUrl(AzHelper.toHttpUrl(cmdRun, sas))
.setFilePath(TaskRun.CMD_RUN)
// Create resource files with or without managed identity
if( poolManagedIdentityResourceId ) {
// When using managed identity, create BatchNodeIdentityReference
// For pool-level managed identity, we create an empty reference which will use the pool's identity
// The poolIdentityClientId configuration ensures the pool has been configured with a managed identity
// Azure Batch will automatically use that identity when downloading these resource files
// Create identity reference with the resource ID from the pool
final identityRef = new BatchNodeIdentityReference()
.setResourceId(poolManagedIdentityResourceId)
log.debug "[AZURE BATCH] Using managed identity with resource ID: ${poolManagedIdentityResourceId}"

resFiles << new ResourceFile()
.setHttpUrl(AzHelper.toHttpUrl(cmdRun, null))
.setFilePath(TaskRun.CMD_RUN)
.setIdentityReference(identityRef)

resFiles << new ResourceFile()
.setHttpUrl(AzHelper.toHttpUrl(cmdScript, sas))
.setFilePath(TaskRun.CMD_SCRIPT)
resFiles << new ResourceFile()
.setHttpUrl(AzHelper.toHttpUrl(cmdScript, null))
.setFilePath(TaskRun.CMD_SCRIPT)
.setIdentityReference(identityRef)

if( task.stdin ) {
resFiles << new ResourceFile()
.setHttpUrl(AzHelper.toHttpUrl(cmdScript, null))
.setFilePath(TaskRun.CMD_INFILE)
.setIdentityReference(identityRef)
}
}
else {
// Use traditional SAS token approach
resFiles << new ResourceFile()
.setHttpUrl(AzHelper.toHttpUrl(cmdRun, sas))
.setFilePath(TaskRun.CMD_RUN)

if( task.stdin ) {
resFiles << new ResourceFile()
.setHttpUrl(AzHelper.toHttpUrl(cmdScript, sas))
.setFilePath(TaskRun.CMD_INFILE)
.setFilePath(TaskRun.CMD_SCRIPT)

if( task.stdin ) {
resFiles << new ResourceFile()
.setHttpUrl(AzHelper.toHttpUrl(cmdScript, sas))
.setFilePath(TaskRun.CMD_INFILE)
}
}

return resFiles
}

protected List<OutputFile> outputFileUrls(TaskRun task, String sas) {
protected List<OutputFile> outputFileUrls(TaskRun task, String poolManagedIdentityResourceId, String sas) {
List<OutputFile> result = new ArrayList<>(20)
result << destFile(TaskRun.CMD_EXIT, task.workDir, sas)
result << destFile(TaskRun.CMD_LOG, task.workDir, sas)
result << destFile(TaskRun.CMD_OUTFILE, task.workDir, sas)
result << destFile(TaskRun.CMD_ERRFILE, task.workDir, sas)
result << destFile(TaskRun.CMD_SCRIPT, task.workDir, sas)
result << destFile(TaskRun.CMD_RUN, task.workDir, sas)
result << destFile(TaskRun.CMD_STAGE, task.workDir, sas)
result << destFile(TaskRun.CMD_TRACE, task.workDir, sas)
result << destFile(TaskRun.CMD_ENV, task.workDir, sas)
result << destFile(TaskRun.CMD_EXIT, task.workDir, poolManagedIdentityResourceId, sas)
result << destFile(TaskRun.CMD_LOG, task.workDir, poolManagedIdentityResourceId, sas)
result << destFile(TaskRun.CMD_OUTFILE, task.workDir, poolManagedIdentityResourceId, sas)
result << destFile(TaskRun.CMD_ERRFILE, task.workDir, poolManagedIdentityResourceId, sas)
result << destFile(TaskRun.CMD_SCRIPT, task.workDir, poolManagedIdentityResourceId, sas)
result << destFile(TaskRun.CMD_RUN, task.workDir, poolManagedIdentityResourceId, sas)
result << destFile(TaskRun.CMD_STAGE, task.workDir, poolManagedIdentityResourceId, sas)
result << destFile(TaskRun.CMD_TRACE, task.workDir, poolManagedIdentityResourceId, sas)
result << destFile(TaskRun.CMD_ENV, task.workDir, poolManagedIdentityResourceId, sas)
return result
}

protected OutputFile destFile(String localPath, Path targetDir, String sas) {
protected OutputFile destFile(String localPath, Path targetDir, String poolManagedIdentityResourceId, String sas) {
log.debug "Task output path: $localPath -> ${targetDir.toUriString()}"
def target = targetDir.resolve(localPath)
final dest = new OutputFileBlobContainerDestination(AzHelper.toContainerUrl(targetDir,sas))
.setPath(target.subpath(1,target.nameCount).toString())

return new OutputFile(localPath, new OutputFileDestination().setContainer(dest), new OutputFileUploadConfig(OutputFileUploadCondition.TASK_COMPLETION))

// Calculate the target blob path
def targetPath = targetDir.resolve(localPath)
def blobPath = targetPath.subpath(1, targetPath.nameCount).toString()

// Create the destination with appropriate authentication
def containerUrl = AzHelper.toContainerUrl(targetDir, poolManagedIdentityResourceId ? null : sas)
def destination = new OutputFileBlobContainerDestination(containerUrl)
.setPath(blobPath)

// Add identity reference if using managed identity
if( poolManagedIdentityResourceId ) {
log.debug "[AZURE BATCH] Setting identity reference for $localPath with resource ID: $poolManagedIdentityResourceId"
destination.setIdentityReference(
new BatchNodeIdentityReference().setResourceId(poolManagedIdentityResourceId)
)
}

// Create and return the output file configuration
return new OutputFile(
localPath,
new OutputFileDestination().setContainer(destination),
new OutputFileUploadConfig(OutputFileUploadCondition.TASK_COMPLETION)
)
}

protected BatchSupportedImage getImage(AzPoolOpts opts) {
Expand Down Expand Up @@ -744,6 +809,63 @@ class AzBatchService implements Closeable {

}

/**
* Get the managed identity resource ID from the pool if available
* @param poolId The pool ID to check
* @param poolIdentityClientId Can be 'auto', a specific client ID, null or false
* @return The resource ID of the managed identity, or null if not found/configured
*/
protected String getPoolManagedIdentityResourceId(String poolId, poolIdentityClientId) {
// If poolIdentityClientId is null or false, return null
if( !poolIdentityClientId ) {
return null
}


// TODO: Should we throw an error if we can't find the identity attached to the pool?

try {
def pool = getPool(poolId)
if( !pool ) {
return null
}

def poolIdentity = pool.getIdentity()
if( !poolIdentity ) {
return null
}
List<UserAssignedIdentity> identities = poolIdentity?.getUserAssignedIdentities()

if( !identities || identities.isEmpty() ) {
return null
}

// Handle 'auto' - use the first available identity
if( poolIdentityClientId == 'auto' || poolIdentityClientId == true ) {
def firstIdentity = identities.first()
log.debug "[AZURE BATCH] Using managed identity for pool '$poolId'"
return firstIdentity.getResourceId()
}

// Handle specific client ID
if( poolIdentityClientId instanceof String ) {
def matchingIdentity = identities.find { it.getClientId() == poolIdentityClientId }
if( matchingIdentity ) {
log.debug "[AZURE BATCH] Found managed identity for pool '$poolId'"
return matchingIdentity.getResourceId()
}
return null
}

// Unsupported type
return null
}
catch( Exception e ) {
log.debug "[AZURE BATCH] Error getting managed identity for pool '$poolId': ${e.message}"
return null
}
}

synchronized String getOrCreatePool(TaskRun task) {

final spec = specForTask(task)
Expand Down Expand Up @@ -1108,3 +1230,4 @@ class AzBatchService implements Closeable {
return Failsafe.with(policy).get(action)
}
}

Loading