@@ -2,6 +2,7 @@ package clustered
22
33import (
44 "context"
5+ "errors"
56 "testing"
67 "time"
78
@@ -19,6 +20,7 @@ import (
1920 pluginsCore "github.com/flyteorg/flyte/v2/flyteplugins/go/tasks/pluginmachinery/core"
2021 coreMocks "github.com/flyteorg/flyte/v2/flyteplugins/go/tasks/pluginmachinery/core/mocks"
2122 pluginIOMocks "github.com/flyteorg/flyte/v2/flyteplugins/go/tasks/pluginmachinery/io/mocks"
23+ plugink8s "github.com/flyteorg/flyte/v2/flyteplugins/go/tasks/pluginmachinery/k8s"
2224 k8smocks "github.com/flyteorg/flyte/v2/flyteplugins/go/tasks/pluginmachinery/k8s/mocks"
2325 "github.com/flyteorg/flyte/v2/flyteplugins/go/tasks/pluginmachinery/utils"
2426 "github.com/flyteorg/flyte/v2/gen/go/flyteidl2/core"
@@ -225,6 +227,8 @@ func TestInjectTorchRunEnv_Static(t *testing.T) {
225227 assert .Equal (t , "8" , envMap ["NPROC_PER_NODE" ])
226228 assert .Equal (t , "29500" , envMap ["MASTER_PORT" ])
227229 assert .Equal (t , "static" , envMap ["RDZV_BACKEND" ])
230+ // No failure policy set → budget defaults to 0 (every failure is terminal).
231+ assert .Equal (t , "0" , envMap ["JOBSET_MAX_RESTARTS" ])
228232
229233 // Downward API env vars should be present.
230234 names := make (map [string ]bool )
@@ -233,9 +237,28 @@ func TestInjectTorchRunEnv_Static(t *testing.T) {
233237 }
234238 assert .True (t , names ["JOBSET_NAME" ])
235239 assert .True (t , names ["JOBSET_RESTART_ATTEMPT" ])
240+ assert .True (t , names ["JOBSET_MAX_RESTARTS" ])
236241 assert .True (t , names ["POD_NAMESPACE" ])
237242}
238243
244+ func TestInjectTorchRunEnv_MaxRestarts (t * testing.T ) {
245+ spec := & clusteredpb.ClusteredTaskSpec {
246+ Replicas : 2 ,
247+ NprocPerNode : 4 ,
248+ FailurePolicy : & clusteredpb.ClusterFailurePolicy {MaxRestarts : 3 },
249+ }
250+ container := & corev1.Container {}
251+ injectTorchRunEnv (container , spec )
252+
253+ for _ , e := range container .Env {
254+ if e .Name == "JOBSET_MAX_RESTARTS" {
255+ assert .Equal (t , "3" , e .Value )
256+ return
257+ }
258+ }
259+ t .Fatal ("JOBSET_MAX_RESTARTS not found" )
260+ }
261+
239262func TestInjectTorchRunEnv_C10D (t * testing.T ) {
240263 spec := & clusteredpb.ClusteredTaskSpec {
241264 Replicas : 2 ,
@@ -336,6 +359,15 @@ func emptyK8sReader() client.Reader {
336359}
337360
338361func dummyPluginCtx (taskTemplate * core.TaskTemplate , k8sReader client.Reader ) * k8smocks.PluginContext {
362+ return dummyPluginCtxWithState (taskTemplate , k8sReader , plugink8s.PluginState {}, nil )
363+ }
364+
365+ func dummyPluginCtxWithState (
366+ taskTemplate * core.TaskTemplate ,
367+ k8sReader client.Reader ,
368+ pluginState plugink8s.PluginState ,
369+ pluginStateErr error ,
370+ ) * k8smocks.PluginContext {
339371 pCtx := & k8smocks.PluginContext {}
340372
341373 taskReader := & coreMocks.TaskReader {}
@@ -358,7 +390,15 @@ func dummyPluginCtx(taskTemplate *core.TaskTemplate, k8sReader client.Reader) *k
358390 pCtx .EXPECT ().TaskExecutionMetadata ().Return (meta )
359391
360392 pluginStateReader := & coreMocks.PluginStateReader {}
361- pluginStateReader .EXPECT ().Get (mock .Anything ).Return (uint8 (0 ), nil )
393+ pluginStateReader .EXPECT ().Get (mock .Anything ).RunAndReturn (func (t interface {}) (uint8 , error ) {
394+ if pluginStateErr != nil {
395+ return 0 , pluginStateErr
396+ }
397+ if s , ok := t .(* plugink8s.PluginState ); ok {
398+ * s = pluginState
399+ }
400+ return 0 , nil
401+ })
362402 pCtx .EXPECT ().PluginStateReader ().Return (pluginStateReader )
363403
364404 return pCtx
@@ -482,7 +522,7 @@ func TestGetTaskPhase_FastFail_Worker0Failed(t *testing.T) {
482522
483523 pod := & corev1.Pod {
484524 ObjectMeta : metav1.ObjectMeta {
485- Name : rank0PodName (testJobName ),
525+ Name : rank0PodName (testJobName ) + "-abc12" ,
486526 Namespace : testNS ,
487527 },
488528 Status : corev1.PodStatus {
@@ -518,7 +558,7 @@ func TestGetTaskPhase_MaintenanceRetry_SystemFailure(t *testing.T) {
518558
519559 pod := & corev1.Pod {
520560 ObjectMeta : metav1.ObjectMeta {
521- Name : rank0PodName (testJobName ),
561+ Name : rank0PodName (testJobName ) + "-abc12" ,
522562 Namespace : testNS ,
523563 },
524564 Status : corev1.PodStatus {
@@ -542,6 +582,272 @@ func TestGetTaskPhase_MaintenanceRetry_SystemFailure(t *testing.T) {
542582 assert .Equal (t , core .ExecutionError_SYSTEM , phase .Err ().GetKind ())
543583}
544584
585+ func TestFindRank0Pod_SuffixedAndDeterministic (t * testing.T ) {
586+ oldFailed := & corev1.Pod {
587+ ObjectMeta : metav1.ObjectMeta {
588+ Name : rank0PodName (testJobName ) + "-aaaa1" ,
589+ Namespace : testNS ,
590+ CreationTimestamp : metav1 .NewTime (time .Now ().Add (- 2 * time .Minute )),
591+ },
592+ Status : corev1.PodStatus {Phase : corev1 .PodFailed },
593+ }
594+ newRunning := & corev1.Pod {
595+ ObjectMeta : metav1.ObjectMeta {
596+ Name : rank0PodName (testJobName ) + "-bbbb2" ,
597+ Namespace : testNS ,
598+ CreationTimestamp : metav1 .NewTime (time .Now ()),
599+ },
600+ Status : corev1.PodStatus {Phase : corev1 .PodRunning },
601+ }
602+ otherPod := & corev1.Pod {
603+ ObjectMeta : metav1.ObjectMeta {
604+ Name : testJobName + "-workers-0-1-ccccc" ,
605+ Namespace : testNS ,
606+ },
607+ Status : corev1.PodStatus {Phase : corev1 .PodRunning },
608+ }
609+ fakeClient := fake .NewClientBuilder ().WithScheme (k8sscheme .Scheme ).WithObjects (oldFailed , newRunning , otherPod ).Build ()
610+
611+ pCtx := & k8smocks.PluginContext {}
612+ pCtx .EXPECT ().K8sReader ().Return (fakeClient )
613+
614+ js := makeJobSet ("" , "" , false )
615+ pod := findRank0Pod (context .Background (), pCtx , js )
616+ assert .NotNil (t , pod )
617+ assert .Equal (t , newRunning .Name , pod .Name )
618+ }
619+
620+ func TestGetTaskPhase_FastFail_FailedWithBudgetRemainingReturnsRunning (t * testing.T ) {
621+ js := makeJobSet ("" , "" , false )
622+ js .Spec .FailurePolicy = & jobsetv1alpha2.FailurePolicy {MaxRestarts : 2 }
623+ js .Status .Restarts = 1
624+ js .Status .ReplicatedJobsStatus = []jobsetv1alpha2.ReplicatedJobStatus {
625+ {Name : workersReplicatedJobName , Failed : 1 , Active : 1 },
626+ }
627+ js .Status .Conditions = []metav1.Condition {
628+ {Type : "SomeActiveCondition" , Status : metav1 .ConditionTrue , LastTransitionTime : metav1 .NewTime (time .Now ())},
629+ }
630+
631+ pod := & corev1.Pod {
632+ ObjectMeta : metav1.ObjectMeta {Name : rank0PodName (testJobName ) + "-abc12" , Namespace : testNS },
633+ Status : corev1.PodStatus {
634+ Phase : corev1 .PodFailed ,
635+ ContainerStatuses : []corev1.ContainerStatus {
636+ {Name : "primary" , State : corev1.ContainerState {Terminated : & corev1.ContainerStateTerminated {ExitCode : 1 }}},
637+ },
638+ },
639+ }
640+ fakeClient := fake .NewClientBuilder ().WithScheme (k8sscheme .Scheme ).WithObjects (pod ).Build ()
641+
642+ spec := & clusteredpb.ClusteredTaskSpec {
643+ Replicas : 2 ,
644+ NprocPerNode : 1 ,
645+ FailurePolicy : & clusteredpb.ClusterFailurePolicy {MaxRestarts : 2 },
646+ }
647+ pCtx := dummyPluginCtx (buildTaskTemplate (spec ), fakeClient )
648+
649+ handler := clusteredResourceHandler {}
650+ phase , err := handler .GetTaskPhase (context .Background (), pCtx , js )
651+ assert .NoError (t , err )
652+ assert .Equal (t , pluginsCore .PhaseRunning , phase .Phase ())
653+ }
654+
655+ func TestGetTaskPhase_FastFail_FailedWithBudgetExhaustedReturnsRetryableFailure (t * testing.T ) {
656+ js := makeJobSet ("" , "" , false )
657+ js .Spec .FailurePolicy = & jobsetv1alpha2.FailurePolicy {MaxRestarts : 1 }
658+ js .Status .Restarts = 1
659+ js .Status .ReplicatedJobsStatus = []jobsetv1alpha2.ReplicatedJobStatus {
660+ {Name : workersReplicatedJobName , Failed : 1 , Active : 1 },
661+ }
662+ js .Status .Conditions = []metav1.Condition {
663+ {Type : "SomeActiveCondition" , Status : metav1 .ConditionTrue , LastTransitionTime : metav1 .NewTime (time .Now ())},
664+ }
665+
666+ pod := & corev1.Pod {
667+ ObjectMeta : metav1.ObjectMeta {Name : rank0PodName (testJobName ) + "-abc12" , Namespace : testNS },
668+ Status : corev1.PodStatus {
669+ Phase : corev1 .PodFailed ,
670+ ContainerStatuses : []corev1.ContainerStatus {
671+ {Name : "primary" , State : corev1.ContainerState {Terminated : & corev1.ContainerStateTerminated {ExitCode : 1 }}},
672+ },
673+ },
674+ }
675+ fakeClient := fake .NewClientBuilder ().WithScheme (k8sscheme .Scheme ).WithObjects (pod ).Build ()
676+
677+ spec := & clusteredpb.ClusteredTaskSpec {
678+ Replicas : 2 ,
679+ NprocPerNode : 1 ,
680+ FailurePolicy : & clusteredpb.ClusterFailurePolicy {MaxRestarts : 1 },
681+ }
682+ pCtx := dummyPluginCtx (buildTaskTemplate (spec ), fakeClient )
683+
684+ handler := clusteredResourceHandler {}
685+ phase , err := handler .GetTaskPhase (context .Background (), pCtx , js )
686+ assert .NoError (t , err )
687+ assert .Equal (t , pluginsCore .PhaseRetryableFailure , phase .Phase ())
688+ }
689+
690+ func TestGetTaskPhase_FastFail_PendingImagePullRegardlessBudget (t * testing.T ) {
691+ js := makeJobSet ("" , "" , false )
692+ js .Spec .FailurePolicy = & jobsetv1alpha2.FailurePolicy {MaxRestarts : 3 }
693+ js .Status .Restarts = 1
694+ js .Status .Conditions = []metav1.Condition {
695+ {Type : "SomeActiveCondition" , Status : metav1 .ConditionTrue , LastTransitionTime : metav1 .NewTime (time .Now ())},
696+ }
697+
698+ oldTransition := metav1 .NewTime (time .Now ().Add (- 24 * time .Hour ))
699+ pod := & corev1.Pod {
700+ ObjectMeta : metav1.ObjectMeta {Name : rank0PodName (testJobName ) + "-abc12" , Namespace : testNS },
701+ Status : corev1.PodStatus {
702+ Phase : corev1 .PodPending ,
703+ Conditions : []corev1.PodCondition {
704+ {Type : corev1 .PodReady , Status : corev1 .ConditionFalse , Reason : "ContainersNotReady" , LastTransitionTime : oldTransition },
705+ },
706+ ContainerStatuses : []corev1.ContainerStatus {
707+ {
708+ Name : "primary" ,
709+ Ready : false ,
710+ State : corev1.ContainerState {
711+ Waiting : & corev1.ContainerStateWaiting {
712+ Reason : "ImagePullBackOff" ,
713+ Message : "Back-off pulling image" ,
714+ },
715+ },
716+ },
717+ },
718+ },
719+ }
720+ fakeClient := fake .NewClientBuilder ().WithScheme (k8sscheme .Scheme ).WithObjects (pod ).Build ()
721+
722+ spec := & clusteredpb.ClusteredTaskSpec {
723+ Replicas : 2 ,
724+ NprocPerNode : 1 ,
725+ FailurePolicy : & clusteredpb.ClusterFailurePolicy {MaxRestarts : 3 },
726+ }
727+ pCtx := dummyPluginCtx (buildTaskTemplate (spec ), fakeClient )
728+
729+ handler := clusteredResourceHandler {}
730+ phase , err := handler .GetTaskPhase (context .Background (), pCtx , js )
731+ assert .NoError (t , err )
732+ assert .Equal (t , pluginsCore .PhaseRetryableFailure , phase .Phase ())
733+ }
734+
735+ func TestGetTaskPhase_NoCondition_ZeroBudgetFailureFastFails (t * testing.T ) {
736+ // maxRestarts == 0 and a worker has failed, but the JobSet controller has not yet
737+ // written any condition. hasJobSetStarted must still treat this as started so the
738+ // failure is surfaced via maybeFastFailWorker0 instead of falling back to Initializing.
739+ js := makeJobSet ("" , "" , false )
740+ js .Spec .FailurePolicy = & jobsetv1alpha2.FailurePolicy {MaxRestarts : 0 }
741+ js .Status .ReplicatedJobsStatus = []jobsetv1alpha2.ReplicatedJobStatus {
742+ {Name : workersReplicatedJobName , Failed : 1 },
743+ }
744+
745+ pod := & corev1.Pod {
746+ ObjectMeta : metav1.ObjectMeta {Name : rank0PodName (testJobName ) + "-abc12" , Namespace : testNS },
747+ Status : corev1.PodStatus {
748+ Phase : corev1 .PodFailed ,
749+ ContainerStatuses : []corev1.ContainerStatus {
750+ {Name : "primary" , State : corev1.ContainerState {Terminated : & corev1.ContainerStateTerminated {ExitCode : 1 }}},
751+ },
752+ },
753+ }
754+ fakeClient := fake .NewClientBuilder ().WithScheme (k8sscheme .Scheme ).WithObjects (pod ).Build ()
755+
756+ spec := & clusteredpb.ClusteredTaskSpec {
757+ Replicas : 2 ,
758+ NprocPerNode : 1 ,
759+ FailurePolicy : & clusteredpb.ClusterFailurePolicy {MaxRestarts : 0 },
760+ }
761+ pCtx := dummyPluginCtx (buildTaskTemplate (spec ), fakeClient )
762+
763+ handler := clusteredResourceHandler {}
764+ phase , err := handler .GetTaskPhase (context .Background (), pCtx , js )
765+ assert .NoError (t , err )
766+ assert .Equal (t , pluginsCore .PhaseRetryableFailure , phase .Phase ())
767+ }
768+
769+ func TestGetTaskPhase_RestartingCondition_ReportsRunningWithAttempt (t * testing.T ) {
770+ js := makeJobSet ("" , "" , false )
771+ js .Spec .FailurePolicy = & jobsetv1alpha2.FailurePolicy {MaxRestarts : 1 }
772+ js .Status .Restarts = 1
773+ js .Status .ReplicatedJobsStatus = []jobsetv1alpha2.ReplicatedJobStatus {
774+ {Name : workersReplicatedJobName , Failed : 1 },
775+ }
776+ js .Status .Conditions = []metav1.Condition {
777+ {Type : jobSetRestartingConditionType , Status : metav1 .ConditionTrue , LastTransitionTime : metav1 .NewTime (time .Now ())},
778+ }
779+
780+ pod := & corev1.Pod {
781+ ObjectMeta : metav1.ObjectMeta {Name : rank0PodName (testJobName ) + "-abc12" , Namespace : testNS },
782+ Status : corev1.PodStatus {
783+ Phase : corev1 .PodFailed ,
784+ ContainerStatuses : []corev1.ContainerStatus {
785+ {Name : "primary" , State : corev1.ContainerState {Terminated : & corev1.ContainerStateTerminated {ExitCode : 1 }}},
786+ },
787+ },
788+ }
789+ fakeClient := fake .NewClientBuilder ().WithScheme (k8sscheme .Scheme ).WithObjects (pod ).Build ()
790+ pCtx := dummyPluginCtx (buildTaskTemplate (& clusteredpb.ClusteredTaskSpec {Replicas : 2 , NprocPerNode : 1 }), fakeClient )
791+
792+ handler := clusteredResourceHandler {}
793+ phase , err := handler .GetTaskPhase (context .Background (), pCtx , js )
794+ assert .NoError (t , err )
795+ assert .Equal (t , pluginsCore .PhaseRunning , phase .Phase ())
796+ assert .Contains (t , phase .Reason (), "restart in progress (attempt 1)" )
797+ }
798+
799+ func TestGetTaskPhase_NoTrueConditionWithRestarts_ReportsRunning (t * testing.T ) {
800+ js := makeJobSet ("" , "" , false )
801+ js .Status .Restarts = 1
802+ js .Status .Conditions = []metav1.Condition {
803+ {Type : string (jobsetv1alpha2 .JobSetSuspended ), Status : metav1 .ConditionFalse , LastTransitionTime : metav1 .NewTime (time .Now ())},
804+ {Type : string (jobsetv1alpha2 .JobSetCompleted ), Status : metav1 .ConditionFalse , LastTransitionTime : metav1 .NewTime (time .Now ())},
805+ }
806+
807+ pCtx := dummyPluginCtx (buildTaskTemplate (& clusteredpb.ClusteredTaskSpec {Replicas : 2 , NprocPerNode : 1 }), emptyK8sReader ())
808+
809+ handler := clusteredResourceHandler {}
810+ phase , err := handler .GetTaskPhase (context .Background (), pCtx , js )
811+ assert .NoError (t , err )
812+ assert .Equal (t , pluginsCore .PhaseRunning , phase .Phase ())
813+ assert .Contains (t , phase .Reason (), "restart attempt 1" )
814+ }
815+
816+ func TestGetTaskPhase_NoConditionWithPriorRunningState_ReportsRunning (t * testing.T ) {
817+ js := makeJobSet ("" , "" , false )
818+ pCtx := dummyPluginCtxWithState (
819+ buildTaskTemplate (& clusteredpb.ClusteredTaskSpec {Replicas : 2 , NprocPerNode : 1 }),
820+ emptyK8sReader (),
821+ plugink8s.PluginState {Phase : pluginsCore .PhaseRunning , PhaseVersion : 1 },
822+ nil ,
823+ )
824+
825+ handler := clusteredResourceHandler {}
826+ phase , err := handler .GetTaskPhase (context .Background (), pCtx , js )
827+ assert .NoError (t , err )
828+ assert .Equal (t , pluginsCore .PhaseRunning , phase .Phase ())
829+ }
830+
831+ func TestGetTaskPhase_NoTrueCondition_StateReadErrorFallsBackToStatus (t * testing.T ) {
832+ js := makeJobSet ("" , "" , false )
833+ js .Status .Restarts = 1
834+ js .Status .Conditions = []metav1.Condition {
835+ {Type : string (jobsetv1alpha2 .JobSetCompleted ), Status : metav1 .ConditionFalse , LastTransitionTime : metav1 .NewTime (time .Now ())},
836+ }
837+
838+ pCtx := dummyPluginCtxWithState (
839+ buildTaskTemplate (& clusteredpb.ClusteredTaskSpec {Replicas : 2 , NprocPerNode : 1 }),
840+ emptyK8sReader (),
841+ plugink8s.PluginState {},
842+ errors .New ("state read failed" ),
843+ )
844+
845+ handler := clusteredResourceHandler {}
846+ phase , err := handler .GetTaskPhase (context .Background (), pCtx , js )
847+ assert .NoError (t , err )
848+ assert .Equal (t , pluginsCore .PhaseRunning , phase .Phase ())
849+ }
850+
545851func TestGetTaskPhase_LogContext (t * testing.T ) {
546852 const primaryContainer = "primary"
547853 const sidecarContainer = "sidecar"
0 commit comments