Skip to content

Commit 9c9e4d9

Browse files
authored
fix(clustered): correctly reconcile JobSet phase during restarts (#7517)
* fix Signed-off-by: M. Adil Fayyaz <62440954+AdilFayyaz@users.noreply.github.com> * nit Signed-off-by: M. Adil Fayyaz <62440954+AdilFayyaz@users.noreply.github.com> * updated docstr Signed-off-by: M. Adil Fayyaz <62440954+AdilFayyaz@users.noreply.github.com> * add max attempts Signed-off-by: M. Adil Fayyaz <62440954+AdilFayyaz@users.noreply.github.com> --------- Signed-off-by: M. Adil Fayyaz <62440954+AdilFayyaz@users.noreply.github.com>
1 parent c10f8d9 commit 9c9e4d9

5 files changed

Lines changed: 538 additions & 37 deletions

File tree

flyteplugins/go/tasks/plugins/k8s/clustered/clustered_test.go

Lines changed: 309 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package clustered
22

33
import (
44
"context"
5+
"errors"
56
"testing"
67
"time"
78

@@ -19,6 +20,7 @@ import (
1920
pluginsCore "github.com/flyteorg/flyte/v2/flyteplugins/go/tasks/pluginmachinery/core"
2021
coreMocks "github.com/flyteorg/flyte/v2/flyteplugins/go/tasks/pluginmachinery/core/mocks"
2122
pluginIOMocks "github.com/flyteorg/flyte/v2/flyteplugins/go/tasks/pluginmachinery/io/mocks"
23+
plugink8s "github.com/flyteorg/flyte/v2/flyteplugins/go/tasks/pluginmachinery/k8s"
2224
k8smocks "github.com/flyteorg/flyte/v2/flyteplugins/go/tasks/pluginmachinery/k8s/mocks"
2325
"github.com/flyteorg/flyte/v2/flyteplugins/go/tasks/pluginmachinery/utils"
2426
"github.com/flyteorg/flyte/v2/gen/go/flyteidl2/core"
@@ -225,6 +227,8 @@ func TestInjectTorchRunEnv_Static(t *testing.T) {
225227
assert.Equal(t, "8", envMap["NPROC_PER_NODE"])
226228
assert.Equal(t, "29500", envMap["MASTER_PORT"])
227229
assert.Equal(t, "static", envMap["RDZV_BACKEND"])
230+
// No failure policy set → budget defaults to 0 (every failure is terminal).
231+
assert.Equal(t, "0", envMap["JOBSET_MAX_RESTARTS"])
228232

229233
// Downward API env vars should be present.
230234
names := make(map[string]bool)
@@ -233,9 +237,28 @@ func TestInjectTorchRunEnv_Static(t *testing.T) {
233237
}
234238
assert.True(t, names["JOBSET_NAME"])
235239
assert.True(t, names["JOBSET_RESTART_ATTEMPT"])
240+
assert.True(t, names["JOBSET_MAX_RESTARTS"])
236241
assert.True(t, names["POD_NAMESPACE"])
237242
}
238243

244+
func TestInjectTorchRunEnv_MaxRestarts(t *testing.T) {
245+
spec := &clusteredpb.ClusteredTaskSpec{
246+
Replicas: 2,
247+
NprocPerNode: 4,
248+
FailurePolicy: &clusteredpb.ClusterFailurePolicy{MaxRestarts: 3},
249+
}
250+
container := &corev1.Container{}
251+
injectTorchRunEnv(container, spec)
252+
253+
for _, e := range container.Env {
254+
if e.Name == "JOBSET_MAX_RESTARTS" {
255+
assert.Equal(t, "3", e.Value)
256+
return
257+
}
258+
}
259+
t.Fatal("JOBSET_MAX_RESTARTS not found")
260+
}
261+
239262
func TestInjectTorchRunEnv_C10D(t *testing.T) {
240263
spec := &clusteredpb.ClusteredTaskSpec{
241264
Replicas: 2,
@@ -336,6 +359,15 @@ func emptyK8sReader() client.Reader {
336359
}
337360

338361
func dummyPluginCtx(taskTemplate *core.TaskTemplate, k8sReader client.Reader) *k8smocks.PluginContext {
362+
return dummyPluginCtxWithState(taskTemplate, k8sReader, plugink8s.PluginState{}, nil)
363+
}
364+
365+
func dummyPluginCtxWithState(
366+
taskTemplate *core.TaskTemplate,
367+
k8sReader client.Reader,
368+
pluginState plugink8s.PluginState,
369+
pluginStateErr error,
370+
) *k8smocks.PluginContext {
339371
pCtx := &k8smocks.PluginContext{}
340372

341373
taskReader := &coreMocks.TaskReader{}
@@ -358,7 +390,15 @@ func dummyPluginCtx(taskTemplate *core.TaskTemplate, k8sReader client.Reader) *k
358390
pCtx.EXPECT().TaskExecutionMetadata().Return(meta)
359391

360392
pluginStateReader := &coreMocks.PluginStateReader{}
361-
pluginStateReader.EXPECT().Get(mock.Anything).Return(uint8(0), nil)
393+
pluginStateReader.EXPECT().Get(mock.Anything).RunAndReturn(func(t interface{}) (uint8, error) {
394+
if pluginStateErr != nil {
395+
return 0, pluginStateErr
396+
}
397+
if s, ok := t.(*plugink8s.PluginState); ok {
398+
*s = pluginState
399+
}
400+
return 0, nil
401+
})
362402
pCtx.EXPECT().PluginStateReader().Return(pluginStateReader)
363403

364404
return pCtx
@@ -482,7 +522,7 @@ func TestGetTaskPhase_FastFail_Worker0Failed(t *testing.T) {
482522

483523
pod := &corev1.Pod{
484524
ObjectMeta: metav1.ObjectMeta{
485-
Name: rank0PodName(testJobName),
525+
Name: rank0PodName(testJobName) + "-abc12",
486526
Namespace: testNS,
487527
},
488528
Status: corev1.PodStatus{
@@ -518,7 +558,7 @@ func TestGetTaskPhase_MaintenanceRetry_SystemFailure(t *testing.T) {
518558

519559
pod := &corev1.Pod{
520560
ObjectMeta: metav1.ObjectMeta{
521-
Name: rank0PodName(testJobName),
561+
Name: rank0PodName(testJobName) + "-abc12",
522562
Namespace: testNS,
523563
},
524564
Status: corev1.PodStatus{
@@ -542,6 +582,272 @@ func TestGetTaskPhase_MaintenanceRetry_SystemFailure(t *testing.T) {
542582
assert.Equal(t, core.ExecutionError_SYSTEM, phase.Err().GetKind())
543583
}
544584

585+
func TestFindRank0Pod_SuffixedAndDeterministic(t *testing.T) {
586+
oldFailed := &corev1.Pod{
587+
ObjectMeta: metav1.ObjectMeta{
588+
Name: rank0PodName(testJobName) + "-aaaa1",
589+
Namespace: testNS,
590+
CreationTimestamp: metav1.NewTime(time.Now().Add(-2 * time.Minute)),
591+
},
592+
Status: corev1.PodStatus{Phase: corev1.PodFailed},
593+
}
594+
newRunning := &corev1.Pod{
595+
ObjectMeta: metav1.ObjectMeta{
596+
Name: rank0PodName(testJobName) + "-bbbb2",
597+
Namespace: testNS,
598+
CreationTimestamp: metav1.NewTime(time.Now()),
599+
},
600+
Status: corev1.PodStatus{Phase: corev1.PodRunning},
601+
}
602+
otherPod := &corev1.Pod{
603+
ObjectMeta: metav1.ObjectMeta{
604+
Name: testJobName + "-workers-0-1-ccccc",
605+
Namespace: testNS,
606+
},
607+
Status: corev1.PodStatus{Phase: corev1.PodRunning},
608+
}
609+
fakeClient := fake.NewClientBuilder().WithScheme(k8sscheme.Scheme).WithObjects(oldFailed, newRunning, otherPod).Build()
610+
611+
pCtx := &k8smocks.PluginContext{}
612+
pCtx.EXPECT().K8sReader().Return(fakeClient)
613+
614+
js := makeJobSet("", "", false)
615+
pod := findRank0Pod(context.Background(), pCtx, js)
616+
assert.NotNil(t, pod)
617+
assert.Equal(t, newRunning.Name, pod.Name)
618+
}
619+
620+
func TestGetTaskPhase_FastFail_FailedWithBudgetRemainingReturnsRunning(t *testing.T) {
621+
js := makeJobSet("", "", false)
622+
js.Spec.FailurePolicy = &jobsetv1alpha2.FailurePolicy{MaxRestarts: 2}
623+
js.Status.Restarts = 1
624+
js.Status.ReplicatedJobsStatus = []jobsetv1alpha2.ReplicatedJobStatus{
625+
{Name: workersReplicatedJobName, Failed: 1, Active: 1},
626+
}
627+
js.Status.Conditions = []metav1.Condition{
628+
{Type: "SomeActiveCondition", Status: metav1.ConditionTrue, LastTransitionTime: metav1.NewTime(time.Now())},
629+
}
630+
631+
pod := &corev1.Pod{
632+
ObjectMeta: metav1.ObjectMeta{Name: rank0PodName(testJobName) + "-abc12", Namespace: testNS},
633+
Status: corev1.PodStatus{
634+
Phase: corev1.PodFailed,
635+
ContainerStatuses: []corev1.ContainerStatus{
636+
{Name: "primary", State: corev1.ContainerState{Terminated: &corev1.ContainerStateTerminated{ExitCode: 1}}},
637+
},
638+
},
639+
}
640+
fakeClient := fake.NewClientBuilder().WithScheme(k8sscheme.Scheme).WithObjects(pod).Build()
641+
642+
spec := &clusteredpb.ClusteredTaskSpec{
643+
Replicas: 2,
644+
NprocPerNode: 1,
645+
FailurePolicy: &clusteredpb.ClusterFailurePolicy{MaxRestarts: 2},
646+
}
647+
pCtx := dummyPluginCtx(buildTaskTemplate(spec), fakeClient)
648+
649+
handler := clusteredResourceHandler{}
650+
phase, err := handler.GetTaskPhase(context.Background(), pCtx, js)
651+
assert.NoError(t, err)
652+
assert.Equal(t, pluginsCore.PhaseRunning, phase.Phase())
653+
}
654+
655+
func TestGetTaskPhase_FastFail_FailedWithBudgetExhaustedReturnsRetryableFailure(t *testing.T) {
656+
js := makeJobSet("", "", false)
657+
js.Spec.FailurePolicy = &jobsetv1alpha2.FailurePolicy{MaxRestarts: 1}
658+
js.Status.Restarts = 1
659+
js.Status.ReplicatedJobsStatus = []jobsetv1alpha2.ReplicatedJobStatus{
660+
{Name: workersReplicatedJobName, Failed: 1, Active: 1},
661+
}
662+
js.Status.Conditions = []metav1.Condition{
663+
{Type: "SomeActiveCondition", Status: metav1.ConditionTrue, LastTransitionTime: metav1.NewTime(time.Now())},
664+
}
665+
666+
pod := &corev1.Pod{
667+
ObjectMeta: metav1.ObjectMeta{Name: rank0PodName(testJobName) + "-abc12", Namespace: testNS},
668+
Status: corev1.PodStatus{
669+
Phase: corev1.PodFailed,
670+
ContainerStatuses: []corev1.ContainerStatus{
671+
{Name: "primary", State: corev1.ContainerState{Terminated: &corev1.ContainerStateTerminated{ExitCode: 1}}},
672+
},
673+
},
674+
}
675+
fakeClient := fake.NewClientBuilder().WithScheme(k8sscheme.Scheme).WithObjects(pod).Build()
676+
677+
spec := &clusteredpb.ClusteredTaskSpec{
678+
Replicas: 2,
679+
NprocPerNode: 1,
680+
FailurePolicy: &clusteredpb.ClusterFailurePolicy{MaxRestarts: 1},
681+
}
682+
pCtx := dummyPluginCtx(buildTaskTemplate(spec), fakeClient)
683+
684+
handler := clusteredResourceHandler{}
685+
phase, err := handler.GetTaskPhase(context.Background(), pCtx, js)
686+
assert.NoError(t, err)
687+
assert.Equal(t, pluginsCore.PhaseRetryableFailure, phase.Phase())
688+
}
689+
690+
func TestGetTaskPhase_FastFail_PendingImagePullRegardlessBudget(t *testing.T) {
691+
js := makeJobSet("", "", false)
692+
js.Spec.FailurePolicy = &jobsetv1alpha2.FailurePolicy{MaxRestarts: 3}
693+
js.Status.Restarts = 1
694+
js.Status.Conditions = []metav1.Condition{
695+
{Type: "SomeActiveCondition", Status: metav1.ConditionTrue, LastTransitionTime: metav1.NewTime(time.Now())},
696+
}
697+
698+
oldTransition := metav1.NewTime(time.Now().Add(-24 * time.Hour))
699+
pod := &corev1.Pod{
700+
ObjectMeta: metav1.ObjectMeta{Name: rank0PodName(testJobName) + "-abc12", Namespace: testNS},
701+
Status: corev1.PodStatus{
702+
Phase: corev1.PodPending,
703+
Conditions: []corev1.PodCondition{
704+
{Type: corev1.PodReady, Status: corev1.ConditionFalse, Reason: "ContainersNotReady", LastTransitionTime: oldTransition},
705+
},
706+
ContainerStatuses: []corev1.ContainerStatus{
707+
{
708+
Name: "primary",
709+
Ready: false,
710+
State: corev1.ContainerState{
711+
Waiting: &corev1.ContainerStateWaiting{
712+
Reason: "ImagePullBackOff",
713+
Message: "Back-off pulling image",
714+
},
715+
},
716+
},
717+
},
718+
},
719+
}
720+
fakeClient := fake.NewClientBuilder().WithScheme(k8sscheme.Scheme).WithObjects(pod).Build()
721+
722+
spec := &clusteredpb.ClusteredTaskSpec{
723+
Replicas: 2,
724+
NprocPerNode: 1,
725+
FailurePolicy: &clusteredpb.ClusterFailurePolicy{MaxRestarts: 3},
726+
}
727+
pCtx := dummyPluginCtx(buildTaskTemplate(spec), fakeClient)
728+
729+
handler := clusteredResourceHandler{}
730+
phase, err := handler.GetTaskPhase(context.Background(), pCtx, js)
731+
assert.NoError(t, err)
732+
assert.Equal(t, pluginsCore.PhaseRetryableFailure, phase.Phase())
733+
}
734+
735+
func TestGetTaskPhase_NoCondition_ZeroBudgetFailureFastFails(t *testing.T) {
736+
// maxRestarts == 0 and a worker has failed, but the JobSet controller has not yet
737+
// written any condition. hasJobSetStarted must still treat this as started so the
738+
// failure is surfaced via maybeFastFailWorker0 instead of falling back to Initializing.
739+
js := makeJobSet("", "", false)
740+
js.Spec.FailurePolicy = &jobsetv1alpha2.FailurePolicy{MaxRestarts: 0}
741+
js.Status.ReplicatedJobsStatus = []jobsetv1alpha2.ReplicatedJobStatus{
742+
{Name: workersReplicatedJobName, Failed: 1},
743+
}
744+
745+
pod := &corev1.Pod{
746+
ObjectMeta: metav1.ObjectMeta{Name: rank0PodName(testJobName) + "-abc12", Namespace: testNS},
747+
Status: corev1.PodStatus{
748+
Phase: corev1.PodFailed,
749+
ContainerStatuses: []corev1.ContainerStatus{
750+
{Name: "primary", State: corev1.ContainerState{Terminated: &corev1.ContainerStateTerminated{ExitCode: 1}}},
751+
},
752+
},
753+
}
754+
fakeClient := fake.NewClientBuilder().WithScheme(k8sscheme.Scheme).WithObjects(pod).Build()
755+
756+
spec := &clusteredpb.ClusteredTaskSpec{
757+
Replicas: 2,
758+
NprocPerNode: 1,
759+
FailurePolicy: &clusteredpb.ClusterFailurePolicy{MaxRestarts: 0},
760+
}
761+
pCtx := dummyPluginCtx(buildTaskTemplate(spec), fakeClient)
762+
763+
handler := clusteredResourceHandler{}
764+
phase, err := handler.GetTaskPhase(context.Background(), pCtx, js)
765+
assert.NoError(t, err)
766+
assert.Equal(t, pluginsCore.PhaseRetryableFailure, phase.Phase())
767+
}
768+
769+
func TestGetTaskPhase_RestartingCondition_ReportsRunningWithAttempt(t *testing.T) {
770+
js := makeJobSet("", "", false)
771+
js.Spec.FailurePolicy = &jobsetv1alpha2.FailurePolicy{MaxRestarts: 1}
772+
js.Status.Restarts = 1
773+
js.Status.ReplicatedJobsStatus = []jobsetv1alpha2.ReplicatedJobStatus{
774+
{Name: workersReplicatedJobName, Failed: 1},
775+
}
776+
js.Status.Conditions = []metav1.Condition{
777+
{Type: jobSetRestartingConditionType, Status: metav1.ConditionTrue, LastTransitionTime: metav1.NewTime(time.Now())},
778+
}
779+
780+
pod := &corev1.Pod{
781+
ObjectMeta: metav1.ObjectMeta{Name: rank0PodName(testJobName) + "-abc12", Namespace: testNS},
782+
Status: corev1.PodStatus{
783+
Phase: corev1.PodFailed,
784+
ContainerStatuses: []corev1.ContainerStatus{
785+
{Name: "primary", State: corev1.ContainerState{Terminated: &corev1.ContainerStateTerminated{ExitCode: 1}}},
786+
},
787+
},
788+
}
789+
fakeClient := fake.NewClientBuilder().WithScheme(k8sscheme.Scheme).WithObjects(pod).Build()
790+
pCtx := dummyPluginCtx(buildTaskTemplate(&clusteredpb.ClusteredTaskSpec{Replicas: 2, NprocPerNode: 1}), fakeClient)
791+
792+
handler := clusteredResourceHandler{}
793+
phase, err := handler.GetTaskPhase(context.Background(), pCtx, js)
794+
assert.NoError(t, err)
795+
assert.Equal(t, pluginsCore.PhaseRunning, phase.Phase())
796+
assert.Contains(t, phase.Reason(), "restart in progress (attempt 1)")
797+
}
798+
799+
func TestGetTaskPhase_NoTrueConditionWithRestarts_ReportsRunning(t *testing.T) {
800+
js := makeJobSet("", "", false)
801+
js.Status.Restarts = 1
802+
js.Status.Conditions = []metav1.Condition{
803+
{Type: string(jobsetv1alpha2.JobSetSuspended), Status: metav1.ConditionFalse, LastTransitionTime: metav1.NewTime(time.Now())},
804+
{Type: string(jobsetv1alpha2.JobSetCompleted), Status: metav1.ConditionFalse, LastTransitionTime: metav1.NewTime(time.Now())},
805+
}
806+
807+
pCtx := dummyPluginCtx(buildTaskTemplate(&clusteredpb.ClusteredTaskSpec{Replicas: 2, NprocPerNode: 1}), emptyK8sReader())
808+
809+
handler := clusteredResourceHandler{}
810+
phase, err := handler.GetTaskPhase(context.Background(), pCtx, js)
811+
assert.NoError(t, err)
812+
assert.Equal(t, pluginsCore.PhaseRunning, phase.Phase())
813+
assert.Contains(t, phase.Reason(), "restart attempt 1")
814+
}
815+
816+
func TestGetTaskPhase_NoConditionWithPriorRunningState_ReportsRunning(t *testing.T) {
817+
js := makeJobSet("", "", false)
818+
pCtx := dummyPluginCtxWithState(
819+
buildTaskTemplate(&clusteredpb.ClusteredTaskSpec{Replicas: 2, NprocPerNode: 1}),
820+
emptyK8sReader(),
821+
plugink8s.PluginState{Phase: pluginsCore.PhaseRunning, PhaseVersion: 1},
822+
nil,
823+
)
824+
825+
handler := clusteredResourceHandler{}
826+
phase, err := handler.GetTaskPhase(context.Background(), pCtx, js)
827+
assert.NoError(t, err)
828+
assert.Equal(t, pluginsCore.PhaseRunning, phase.Phase())
829+
}
830+
831+
func TestGetTaskPhase_NoTrueCondition_StateReadErrorFallsBackToStatus(t *testing.T) {
832+
js := makeJobSet("", "", false)
833+
js.Status.Restarts = 1
834+
js.Status.Conditions = []metav1.Condition{
835+
{Type: string(jobsetv1alpha2.JobSetCompleted), Status: metav1.ConditionFalse, LastTransitionTime: metav1.NewTime(time.Now())},
836+
}
837+
838+
pCtx := dummyPluginCtxWithState(
839+
buildTaskTemplate(&clusteredpb.ClusteredTaskSpec{Replicas: 2, NprocPerNode: 1}),
840+
emptyK8sReader(),
841+
plugink8s.PluginState{},
842+
errors.New("state read failed"),
843+
)
844+
845+
handler := clusteredResourceHandler{}
846+
phase, err := handler.GetTaskPhase(context.Background(), pCtx, js)
847+
assert.NoError(t, err)
848+
assert.Equal(t, pluginsCore.PhaseRunning, phase.Phase())
849+
}
850+
545851
func TestGetTaskPhase_LogContext(t *testing.T) {
546852
const primaryContainer = "primary"
547853
const sidecarContainer = "sidecar"

0 commit comments

Comments
 (0)