Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 67 additions & 12 deletions flyteplugins/go/tasks/plugins/k8s/clustered/clustered_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,9 @@ import (
"k8s.io/apimachinery/pkg/api/resource"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
k8sscheme "k8s.io/client-go/kubernetes/scheme"
jobsetv1alpha2 "sigs.k8s.io/jobset/api/jobset/v1alpha2"
"sigs.k8s.io/controller-runtime/pkg/client"
"sigs.k8s.io/controller-runtime/pkg/client/fake"
jobsetv1alpha2 "sigs.k8s.io/jobset/api/jobset/v1alpha2"

pluginsCore "github.com/flyteorg/flyte/v2/flyteplugins/go/tasks/pluginmachinery/core"
coreMocks "github.com/flyteorg/flyte/v2/flyteplugins/go/tasks/pluginmachinery/core/mocks"
Expand Down Expand Up @@ -320,13 +321,21 @@ func makeJobSet(condType jobsetv1alpha2.JobSetConditionType, status metav1.Condi
return js
}

func dummyPluginCtx(taskTemplate *core.TaskTemplate) *k8smocks.PluginContext {
// emptyK8sReader returns a fake client with no objects, for tests that don't
// exercise pod inspection (getLogContext just yields an empty pod list -> nil LogContext).
func emptyK8sReader() client.Reader {
return fake.NewClientBuilder().WithScheme(k8sscheme.Scheme).Build()
}

func dummyPluginCtx(taskTemplate *core.TaskTemplate, k8sReader client.Reader) *k8smocks.PluginContext {
pCtx := &k8smocks.PluginContext{}

taskReader := &coreMocks.TaskReader{}
taskReader.EXPECT().Read(mock.Anything).Return(taskTemplate, nil)
pCtx.EXPECT().TaskReader().Return(taskReader)

pCtx.EXPECT().K8sReader().Return(k8sReader)

tID := &coreMocks.TaskExecutionID{}
tID.EXPECT().GetID().Return(&core.TaskExecutionIdentifier{
NodeExecutionId: &core.NodeExecutionIdentifier{
Expand All @@ -352,7 +361,7 @@ func TestGetTaskPhase_Initializing(t *testing.T) {
js := makeJobSet("", "", suspend)

spec := &clusteredpb.ClusteredTaskSpec{Replicas: 2, NprocPerNode: 1}
pCtx := dummyPluginCtx(buildTaskTemplate(spec))
pCtx := dummyPluginCtx(buildTaskTemplate(spec), emptyK8sReader())

handler := clusteredResourceHandler{}
phase, err := handler.GetTaskPhase(context.Background(), pCtx, js)
Expand All @@ -364,7 +373,7 @@ func TestGetTaskPhase_Success(t *testing.T) {
js := makeJobSet(jobsetv1alpha2.JobSetCompleted, metav1.ConditionTrue, false)

spec := &clusteredpb.ClusteredTaskSpec{Replicas: 2, NprocPerNode: 1}
pCtx := dummyPluginCtx(buildTaskTemplate(spec))
pCtx := dummyPluginCtx(buildTaskTemplate(spec), emptyK8sReader())

handler := clusteredResourceHandler{}
phase, err := handler.GetTaskPhase(context.Background(), pCtx, js)
Expand All @@ -376,7 +385,7 @@ func TestGetTaskPhase_Failure(t *testing.T) {
js := makeJobSet(jobsetv1alpha2.JobSetFailed, metav1.ConditionTrue, false)

spec := &clusteredpb.ClusteredTaskSpec{Replicas: 2, NprocPerNode: 1}
pCtx := dummyPluginCtx(buildTaskTemplate(spec))
pCtx := dummyPluginCtx(buildTaskTemplate(spec), emptyK8sReader())

handler := clusteredResourceHandler{}
phase, err := handler.GetTaskPhase(context.Background(), pCtx, js)
Expand All @@ -397,7 +406,7 @@ func TestGetTaskPhase_Running(t *testing.T) {
}

spec := &clusteredpb.ClusteredTaskSpec{Replicas: 2, NprocPerNode: 1}
pCtx := dummyPluginCtx(buildTaskTemplate(spec))
pCtx := dummyPluginCtx(buildTaskTemplate(spec), emptyK8sReader())

handler := clusteredResourceHandler{}
phase, err := handler.GetTaskPhase(context.Background(), pCtx, js)
Expand All @@ -424,7 +433,7 @@ func TestGetTaskPhase_FastFail_NoJobsFailed(t *testing.T) {
}

spec := &clusteredpb.ClusteredTaskSpec{Replicas: 2, NprocPerNode: 1}
pCtx := dummyPluginCtx(buildTaskTemplate(spec))
pCtx := dummyPluginCtx(buildTaskTemplate(spec), emptyK8sReader())

handler := clusteredResourceHandler{}
phase, err := handler.GetTaskPhase(context.Background(), pCtx, js)
Expand All @@ -442,7 +451,7 @@ func TestGetTaskPhase_MaintenanceRetry_FlagFalse(t *testing.T) {
NprocPerNode: 1,
FailurePolicy: &clusteredpb.ClusterFailurePolicy{RestartOnHostMaintenance: false},
}
pCtx := dummyPluginCtx(buildTaskTemplate(spec))
pCtx := dummyPluginCtx(buildTaskTemplate(spec), emptyK8sReader())

handler := clusteredResourceHandler{}
phase, err := handler.GetTaskPhase(context.Background(), pCtx, js)
Expand Down Expand Up @@ -484,8 +493,7 @@ func TestGetTaskPhase_FastFail_Worker0Failed(t *testing.T) {
fakeClient := fake.NewClientBuilder().WithScheme(k8sscheme.Scheme).WithObjects(pod).Build()

spec := &clusteredpb.ClusteredTaskSpec{Replicas: 2, NprocPerNode: 1}
pCtx := dummyPluginCtx(buildTaskTemplate(spec))
pCtx.EXPECT().K8sReader().Return(fakeClient)
pCtx := dummyPluginCtx(buildTaskTemplate(spec), fakeClient)

handler := clusteredResourceHandler{}
phase, err := handler.GetTaskPhase(context.Background(), pCtx, js)
Expand Down Expand Up @@ -517,8 +525,7 @@ func TestGetTaskPhase_MaintenanceRetry_SystemFailure(t *testing.T) {
NprocPerNode: 1,
FailurePolicy: &clusteredpb.ClusterFailurePolicy{RestartOnHostMaintenance: true},
}
pCtx := dummyPluginCtx(buildTaskTemplate(spec))
pCtx.EXPECT().K8sReader().Return(fakeClient)
pCtx := dummyPluginCtx(buildTaskTemplate(spec), fakeClient)

handler := clusteredResourceHandler{}
phase, err := handler.GetTaskPhase(context.Background(), pCtx, js)
Expand All @@ -527,6 +534,54 @@ func TestGetTaskPhase_MaintenanceRetry_SystemFailure(t *testing.T) {
assert.Equal(t, core.ExecutionError_SYSTEM, phase.Err().GetKind())
}

func TestGetTaskPhase_LogContext(t *testing.T) {
// A running JobSet with live worker pods → LogContext is built from the real pods:
// the rank-0 pod is marked primary (matched by name prefix despite its random
// suffix), and Pending pods are excluded.
js := makeJobSet("", "", false)
js.Status.Conditions = []metav1.Condition{
{Type: "SomeActiveCondition", Status: metav1.ConditionTrue, LastTransitionTime: metav1.NewTime(time.Now())},
}

mkPod := func(name string, phase corev1.PodPhase) *corev1.Pod {
return &corev1.Pod{
ObjectMeta: metav1.ObjectMeta{
Name: name,
Namespace: testNS,
Labels: map[string]string{jobSetNameLabel: testJobName},
},
Status: corev1.PodStatus{Phase: phase},
}
Comment thread
AdilFayyaz marked this conversation as resolved.
}
Comment thread
AdilFayyaz marked this conversation as resolved.
// Real JobSet pods carry a random suffix after the "<jobset>-workers-<job>-<idx>" stem.
rank0 := rank0PodName(testJobName) + "-x1y2z"
rank1 := testJobName + "-workers-0-1-a9b8c"
pending := testJobName + "-workers-0-2-pppp"
fakeClient := fake.NewClientBuilder().WithScheme(k8sscheme.Scheme).
WithObjects(
mkPod(rank0, corev1.PodRunning),
mkPod(rank1, corev1.PodRunning),
mkPod(pending, corev1.PodPending),
).Build()

spec := &clusteredpb.ClusteredTaskSpec{Replicas: 2, NprocPerNode: 1}
pCtx := dummyPluginCtx(buildTaskTemplate(spec), fakeClient)

handler := clusteredResourceHandler{}
phase, err := handler.GetTaskPhase(context.Background(), pCtx, js)
assert.NoError(t, err)
assert.Equal(t, pluginsCore.PhaseRunning, phase.Phase())

lc := phase.Info().LogContext
assert.NotNil(t, lc)
assert.Equal(t, rank0, lc.PrimaryPodName)
// Pending pod is excluded → only the two running pods remain.
assert.Len(t, lc.Pods, 2)
names := []string{lc.Pods[0].GetPodName(), lc.Pods[1].GetPodName()}
assert.Contains(t, names, rank0)
assert.Contains(t, names, rank1)
Comment thread
AdilFayyaz marked this conversation as resolved.
Outdated
}

// --- IsTerminal / GetCompletionTime ---

func TestIsTerminal(t *testing.T) {
Expand Down
48 changes: 48 additions & 0 deletions flyteplugins/go/tasks/plugins/k8s/clustered/logs.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,25 @@ package clustered
import (
"context"
"fmt"
"strings"
"time"

v1 "k8s.io/api/core/v1"
"sigs.k8s.io/controller-runtime/pkg/client"
jobsetv1alpha2 "sigs.k8s.io/jobset/api/jobset/v1alpha2"

"github.com/flyteorg/flyte/v2/flyteplugins/go/tasks/logs"
"github.com/flyteorg/flyte/v2/flyteplugins/go/tasks/pluginmachinery/flytek8s"
"github.com/flyteorg/flyte/v2/flyteplugins/go/tasks/pluginmachinery/k8s"
"github.com/flyteorg/flyte/v2/flyteplugins/go/tasks/pluginmachinery/tasklog"
"github.com/flyteorg/flyte/v2/flytestdlib/logger"
"github.com/flyteorg/flyte/v2/gen/go/flyteidl2/core"
)

// jobSetNameLabel is stamped by the JobSet controller on every child pod, so we
// can list a JobSet's pods without depending on predicted pod names.
const jobSetNameLabel = "jobset.sigs.k8s.io/jobset-name"

// getTaskLogs synthesizes per-rank log URLs.
//
// JobSet pod-name pattern: <jobsetName>-<replicatedJob>-<jobIdx>-<podIdx>.
Expand Down Expand Up @@ -76,3 +85,42 @@ func getTaskLogs(ctx context.Context, pluginContext k8s.PluginContext, jobSet *j
}
return taskLogs, nil
}

// getLogContext builds the structured LogContext from the JobSet's live child pods.
//
// Unlike getTaskLogs (which synthesizes templated URIs from *predicted* pod names and
// requires a pod-log template to be configured in cluster config), this uses the *real*
// pods — actual names (including the Job-assigned random suffix), namespace, primary
// container, and per-container IDs — so the console can fetch logs natively regardless
// of log-template config. Best-effort: returns nil on list error or when no pods are
Comment thread
AdilFayyaz marked this conversation as resolved.
// ready yet, leaving the templated Logs path as the fallback.
func getLogContext(ctx context.Context, pluginContext k8s.PluginContext, jobSet *jobsetv1alpha2.JobSet) *core.LogContext {
podList := &v1.PodList{}
if err := pluginContext.K8sReader().List(ctx, podList,
client.InNamespace(jobSet.Namespace),
Comment thread
AdilFayyaz marked this conversation as resolved.
Outdated
client.MatchingLabels{jobSetNameLabel: jobSet.Name},
); err != nil {
logger.Warnf(ctx, "failed to list pods for JobSet %s/%s log context: %v", jobSet.Namespace, jobSet.Name, err)
Comment thread
AdilFayyaz marked this conversation as resolved.
return nil
}

// rank0PodName returns "<jobset>-workers-0-0"; the real pod carries an additional
// random suffix, so match on prefix to identify the primary (rank-0) pod.
primaryPrefix := rank0PodName(jobSet.Name)
logCtx := &core.LogContext{Pods: make([]*core.PodLogContext, 0, len(podList.Items))}
for i := range podList.Items {
pod := &podList.Items[i]
// Pending pods have no logs yet and no container IDs to address them by.
Comment thread
AdilFayyaz marked this conversation as resolved.
Outdated
if pod.Status.Phase == v1.PodPending {
continue
}
Comment thread
pingsutw marked this conversation as resolved.
if strings.HasPrefix(pod.Name, primaryPrefix) {
logCtx.PrimaryPodName = pod.Name
}
logCtx.Pods = append(logCtx.Pods, flytek8s.BuildPodLogContext(pod))
}
if len(logCtx.Pods) == 0 {
return nil
}
return logCtx
Comment thread
AdilFayyaz marked this conversation as resolved.
}
1 change: 1 addition & 0 deletions flyteplugins/go/tasks/plugins/k8s/clustered/phase.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ func (clusteredResourceHandler) GetTaskPhase(ctx context.Context, pluginContext
}
taskInfo := pluginsCore.TaskInfo{
Logs: taskLogs,
LogContext: getLogContext(ctx, pluginContext, jobSet),
OccurredAt: &occurredAt,
CustomInfo: statusDetails,
}
Expand Down
Loading