Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
184 changes: 184 additions & 0 deletions test/e2e/modules/resources/rd/distributed_batch_job.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,184 @@
// Copyright 2026 NVIDIA CORPORATION
// SPDX-License-Identifier: Apache-2.0

package rd

import (
"context"
"fmt"
"maps"
"time"

batchv1 "k8s.io/api/batch/v1"
v1 "k8s.io/api/core/v1"
"k8s.io/apimachinery/pkg/api/errors"
"k8s.io/apimachinery/pkg/types"
"k8s.io/apimachinery/pkg/util/wait"
"k8s.io/utils/ptr"
runtimeClient "sigs.k8s.io/controller-runtime/pkg/client"

v2 "github.com/kai-scheduler/KAI-scheduler/pkg/apis/scheduling/v2"
"github.com/kai-scheduler/KAI-scheduler/pkg/apis/scheduling/v2alpha2"
pgconstants "github.com/kai-scheduler/KAI-scheduler/pkg/podgrouper/podgrouper/plugins/constants"
)

const (
// JobNameLabel is the label the k8s Job controller sets on every pod it creates.
JobNameLabel = "batch.kubernetes.io/job-name"

podGroupFetchTimeout = 30 * time.Second
podGroupFetchPoll = 250 * time.Millisecond
)

// DistributedBatchJobOptions configures CreateDistributedBatchJob. Every field is optional
// — pass DistributedBatchJobOptions{} to get a single-pod gang Job with no resource requests.
type DistributedBatchJobOptions struct {
// Parallelism is the number of pods the Job spawns. nil means 1.
Parallelism *int32
// MinMember is the PodGroup MinAvailable. nil means Parallelism (gang).
// Gang: MinMember == Parallelism
// Elastic: 1 <= MinMember < Parallelism
MinMember *int32
// Resources applied to each pod. Zero value means no requests/limits.
Resources v1.ResourceRequirements
// NamePrefix is prepended to the generated Job name.
NamePrefix string
// TopologyConstraint is propagated to the auto-created PodGroup via annotations.
TopologyConstraint *v2alpha2.TopologyConstraint
// PriorityClassName is set on the pod template; the podgrouper reads it onto the PodGroup.
PriorityClassName string
// Preemptibility is set as a Job label; the podgrouper reads it onto the PodGroup.
Preemptibility v2alpha2.Preemptibility
// ExtraLabels are merged into pod template labels (e.g. for test filtering).
ExtraLabels map[string]string
// PodSpecMutator is applied to the pod template spec after defaults are set. Scale
// tests use this to inject KWOK tolerations/affinity without importing scale into rd.
PodSpecMutator func(*v1.PodSpec)
}

// CreateDistributedBatchJob submits a batch Job annotated with kai.scheduler/batch-min-member
// so the podgrouper produces a single PodGroup with MinAvailable=opts.MinMember. Returns the
// Job, the PodGroup (once the podgrouper has created it), and the pods the Job spawned.
func CreateDistributedBatchJob(
ctx context.Context,
kubeClient runtimeClient.Client,
jobQueue *v2.Queue,
opts DistributedBatchJobOptions,
) (*batchv1.Job, *v2alpha2.PodGroup, []*v1.Pod, error) {
parallelism := ptr.Deref(opts.Parallelism, 1)
minMember := ptr.Deref(opts.MinMember, parallelism)

job := buildDistributedBatchJob(jobQueue, opts, parallelism, minMember)
if err := kubeClient.Create(ctx, job); err != nil {
return nil, nil, nil, fmt.Errorf("create Job: %w", err)
}

podGroup, err := waitForPodGroup(ctx, kubeClient, job)
if err != nil {
return job, nil, nil, err
}

pods, err := waitForJobPods(ctx, kubeClient, job, parallelism)
if err != nil {
return job, podGroup, nil, err
}

return job, podGroup, pods, nil
}

func buildDistributedBatchJob(
jobQueue *v2.Queue, opts DistributedBatchJobOptions, parallelism, minMember int32,
) *batchv1.Job {
job := CreateBatchJobObject(jobQueue, opts.Resources)
job.Name = opts.NamePrefix + job.Name
job.Spec.Parallelism = ptr.To(parallelism)
job.Spec.Completions = ptr.To(parallelism)

if job.Annotations == nil {
job.Annotations = map[string]string{}
}
job.Annotations[pgconstants.MinMemberOverrideKey] = fmt.Sprintf("%d", minMember)

if tc := opts.TopologyConstraint; tc != nil {
if tc.Topology != "" {
job.Annotations[pgconstants.TopologyKey] = tc.Topology
}
if tc.RequiredTopologyLevel != "" {
job.Annotations[pgconstants.TopologyRequiredPlacementKey] = tc.RequiredTopologyLevel
}
if tc.PreferredTopologyLevel != "" {
job.Annotations[pgconstants.TopologyPreferredPlacementKey] = tc.PreferredTopologyLevel
}
}

if opts.Preemptibility != "" {
job.Labels[pgconstants.PreemptibilityLabelKey] = string(opts.Preemptibility)
}

if opts.PriorityClassName != "" {
job.Spec.Template.Spec.PriorityClassName = opts.PriorityClassName
}

maps.Copy(job.Spec.Template.ObjectMeta.Labels, opts.ExtraLabels)

if opts.PodSpecMutator != nil {
opts.PodSpecMutator(&job.Spec.Template.Spec)
}

return job
}

func waitForPodGroup(
ctx context.Context, kubeClient runtimeClient.Client, job *batchv1.Job,
) (*v2alpha2.PodGroup, error) {
name := PodGroupNameForJob(job)
pg := &v2alpha2.PodGroup{}
key := types.NamespacedName{Namespace: job.Namespace, Name: name}

err := wait.PollUntilContextTimeout(ctx, podGroupFetchPoll, podGroupFetchTimeout, true,
func(ctx context.Context) (bool, error) {
err := kubeClient.Get(ctx, key, pg)
if errors.IsNotFound(err) {
return false, nil
}
return err == nil, err
})
if err != nil {
return nil, fmt.Errorf("wait for PodGroup %s: %w", name, err)
}
return pg, nil
}

func waitForJobPods(
ctx context.Context, kubeClient runtimeClient.Client, job *batchv1.Job, expected int32,
) ([]*v1.Pod, error) {
var pods []*v1.Pod
err := wait.PollUntilContextTimeout(ctx, podGroupFetchPoll, podGroupFetchTimeout, true,
func(ctx context.Context) (bool, error) {
list := &v1.PodList{}
err := kubeClient.List(ctx, list,
runtimeClient.InNamespace(job.Namespace),
runtimeClient.MatchingLabels{JobNameLabel: job.Name},
)
if err != nil {
return false, err
}
if int32(len(list.Items)) < expected {
return false, nil
}
pods = make([]*v1.Pod, 0, len(list.Items))
for i := range list.Items {
pods = append(pods, &list.Items[i])
}
return true, nil
})
if err != nil {
return nil, fmt.Errorf("wait for %d pods of Job %s: %w", expected, job.Name, err)
}
return pods, nil
}

// PodGroupNameForJob returns the deterministic name the podgrouper uses for a Job-owned PodGroup.
func PodGroupNameForJob(job *batchv1.Job) string {
return fmt.Sprintf("%s-%s-%s", pgconstants.PodGroupNamePrefix, job.Name, job.UID)
}
85 changes: 16 additions & 69 deletions test/e2e/scale/kwok_job_creation.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,8 @@ package scale

import (
"context"
"fmt"
"sync"

. "github.com/onsi/gomega"
"go.uber.org/multierr"
"golang.org/x/exp/maps"
batchv1 "k8s.io/api/batch/v1"
v1 "k8s.io/api/core/v1"
"k8s.io/utils/ptr"
Expand All @@ -19,9 +15,6 @@ import (
"github.com/kai-scheduler/KAI-scheduler/pkg/apis/scheduling/v2alpha2"
testcontext "github.com/kai-scheduler/KAI-scheduler/test/e2e/modules/context"
"github.com/kai-scheduler/KAI-scheduler/test/e2e/modules/resources/rd"
"github.com/kai-scheduler/KAI-scheduler/test/e2e/modules/resources/rd/pod_group"
"github.com/kai-scheduler/KAI-scheduler/test/e2e/modules/resources/rd/queue"
"github.com/kai-scheduler/KAI-scheduler/test/e2e/modules/utils"
)

func createJobObjectForKwok(
Expand All @@ -30,74 +23,28 @@ func createJobObjectForKwok(
resources v1.ResourceRequirements,
extraLabels map[string]string,
) *batchv1.Job {
job := rd.CreateBatchJobObject(jobQueue, resources)
addKWOKTaintsAndAffinity(&job.Spec.Template.Spec)

maps.Copy(job.Spec.Template.ObjectMeta.Labels, extraLabels)

Expect(createObjectWithRetries(ctx, testCtx.ControllerClient, job)).To(Succeed())

job, _, _, err := rd.CreateDistributedBatchJob(ctx, testCtx.ControllerClient, jobQueue,
rd.DistributedBatchJobOptions{
Resources: resources,
ExtraLabels: extraLabels,
PodSpecMutator: addKWOKTaintsAndAffinity,
})
Expect(err).To(Succeed())
return job
}

// createDistributedJobForKwok creates one distributed job with podsPerDistributedJob batch jobs each with one pod
func createDistributedJobForKwok(
ctx context.Context, testCtx *testcontext.TestContext,
jobQueue *v2.Queue, resourcesPerPod v1.ResourceRequirements, numberOfTasks int,
extraLabels map[string]string, topologyConstraint *v2alpha2.TopologyConstraint,
) (*v2alpha2.PodGroup, []*v1.Pod, error) {
namespace := queue.GetConnectedNamespaceToQueue(jobQueue)
podGroup := pod_group.Create(
namespace, "distributed-job-"+utils.GenerateRandomK8sName(10), jobQueue.Name,
)
podGroup.Spec.MinMember = ptr.To(int32(numberOfTasks))
maps.Copy(podGroup.Labels, extraLabels)
if topologyConstraint != nil {
podGroup.Spec.TopologyConstraint = *topologyConstraint
}

err := createObjectWithRetries(ctx, testCtx.ControllerClient, podGroup)
if err != nil {
return nil, nil, err
}

var pods []*v1.Pod
var creationError error
podsLock := sync.Mutex{}
var wg sync.WaitGroup

for i := range numberOfTasks {
wg.Add(1)
go func(i int) {
defer wg.Done()

pod := rd.CreatePodObject(jobQueue, resourcesPerPod)
pod.Name = fmt.Sprintf("distributed-pod-%d-%s", i, utils.GenerateRandomK8sName(10))

if pod.Annotations == nil {
pod.Annotations = map[string]string{}
}
pod.Annotations[pod_group.PodGroupNameAnnotation] = podGroup.Name

maps.Copy(pod.Labels, extraLabels)
addKWOKTaintsAndAffinity(&pod.Spec)

err := createObjectWithRetries(ctx, testCtx.ControllerClient, pod)

podsLock.Lock()
if err != nil {
creationError = multierr.Append(creationError, err)
} else {
pods = append(pods, pod)
}
podsLock.Unlock()
}(i)
}
wg.Wait()

if creationError != nil {
return nil, nil, fmt.Errorf("failed to create some pods: %w", creationError)
}

return podGroup, pods, nil
_, pg, pods, err := rd.CreateDistributedBatchJob(ctx, testCtx.ControllerClient, jobQueue,
rd.DistributedBatchJobOptions{
Parallelism: ptr.To(int32(numberOfTasks)),
Resources: resourcesPerPod,
ExtraLabels: extraLabels,
TopologyConstraint: topologyConstraint,
PodSpecMutator: addKWOKTaintsAndAffinity,
})
return pg, pods, err
}
18 changes: 0 additions & 18 deletions test/e2e/scale/kwok_test_utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,24 +42,6 @@ var (
}
)

func createObjectWithRetries(ctx context.Context, kubeClient runtimeClient.Client, obj runtimeClient.Object) error {
key := runtimeClient.ObjectKeyFromObject(obj)
err := kubeClient.Get(ctx, key, obj)
if err == nil {
// object is not expected to exist in the cluster
return fmt.Errorf("object %v already exists in the cluster", key)
}

for i := 0; i < operationAttemptsRetries; i++ {
err = kubeClient.Create(ctx, obj)
if err == nil || errors.IsAlreadyExists(err) {
return nil
}
time.Sleep(retryInterval)
}
return err
}

func deleteObjectWithRetries(
ctx context.Context, kubeClient runtimeClient.Client,
obj runtimeClient.Object, opts ...runtimeClient.DeleteOption) error {
Expand Down
28 changes: 8 additions & 20 deletions test/e2e/suites/allocate/topology/topology_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ import (
"github.com/kai-scheduler/KAI-scheduler/test/e2e/modules/configurations/feature_flags"
testcontext "github.com/kai-scheduler/KAI-scheduler/test/e2e/modules/context"
"github.com/kai-scheduler/KAI-scheduler/test/e2e/modules/resources/rd"
"github.com/kai-scheduler/KAI-scheduler/test/e2e/modules/resources/rd/pod_group"
"github.com/kai-scheduler/KAI-scheduler/test/e2e/modules/resources/rd/queue"
"github.com/kai-scheduler/KAI-scheduler/test/e2e/modules/utils"
"github.com/kai-scheduler/KAI-scheduler/test/e2e/modules/wait"
Expand Down Expand Up @@ -260,24 +259,13 @@ var _ = Describe("Topology", Ordered, func() {

func createDistributedWorkload(ctx context.Context, testCtx *testcontext.TestContext,
podCount int, podResource v1.ResourceList, topologyConstraint v2alpha2.TopologyConstraint) []*v1.Pod {
namespace := queue.GetConnectedNamespaceToQueue(testCtx.Queues[0])
queueName := testCtx.Queues[0].Name

podGroup := pod_group.Create(namespace, "distributed-pod-group"+utils.GenerateRandomK8sName(10), queueName)
podGroup.Spec.MinMember = ptr.To(int32(podCount))
podGroup.Spec.TopologyConstraint = topologyConstraint

pods := []*v1.Pod{}
Expect(testCtx.ControllerClient.Create(ctx, podGroup)).To(Succeed())
for i := 0; i < podCount; i++ {
pod := rd.CreatePodObject(testCtx.Queues[0], v1.ResourceRequirements{Requests: podResource, Limits: podResource})
pod.Name = "distributed-pod-" + utils.GenerateRandomK8sName(10)
pod.Annotations[pod_group.PodGroupNameAnnotation] = podGroup.Name
pod.Labels[pod_group.PodGroupNameAnnotation] = podGroup.Name
_, err := rd.CreatePod(ctx, testCtx.KubeClientset, pod)
Expect(err).To(Succeed())
pods = append(pods, pod)
}

_, _, pods, err := rd.CreateDistributedBatchJob(ctx, testCtx.ControllerClient, testCtx.Queues[0],
rd.DistributedBatchJobOptions{
Parallelism: ptr.To(int32(podCount)),
NamePrefix: "distributed-" + utils.GenerateRandomK8sName(5) + "-",
Resources: v1.ResourceRequirements{Requests: podResource, Limits: podResource},
TopologyConstraint: &topologyConstraint,
})
Expect(err).To(Succeed())
return pods
}
Loading