Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
134 changes: 122 additions & 12 deletions flyteplugins/go/tasks/plugins/k8s/clustered/clustered_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,9 @@ import (
"k8s.io/apimachinery/pkg/api/resource"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
k8sscheme "k8s.io/client-go/kubernetes/scheme"
jobsetv1alpha2 "sigs.k8s.io/jobset/api/jobset/v1alpha2"
"sigs.k8s.io/controller-runtime/pkg/client"
"sigs.k8s.io/controller-runtime/pkg/client/fake"
jobsetv1alpha2 "sigs.k8s.io/jobset/api/jobset/v1alpha2"

pluginsCore "github.com/flyteorg/flyte/v2/flyteplugins/go/tasks/pluginmachinery/core"
coreMocks "github.com/flyteorg/flyte/v2/flyteplugins/go/tasks/pluginmachinery/core/mocks"
Expand Down Expand Up @@ -320,13 +321,21 @@ func makeJobSet(condType jobsetv1alpha2.JobSetConditionType, status metav1.Condi
return js
}

func dummyPluginCtx(taskTemplate *core.TaskTemplate) *k8smocks.PluginContext {
// emptyK8sReader returns a fake client with no objects, for tests that don't
// exercise pod inspection (getLogContext just yields an empty pod list -> nil LogContext).
func emptyK8sReader() client.Reader {
return fake.NewClientBuilder().WithScheme(k8sscheme.Scheme).Build()
}

func dummyPluginCtx(taskTemplate *core.TaskTemplate, k8sReader client.Reader) *k8smocks.PluginContext {
pCtx := &k8smocks.PluginContext{}

taskReader := &coreMocks.TaskReader{}
taskReader.EXPECT().Read(mock.Anything).Return(taskTemplate, nil)
pCtx.EXPECT().TaskReader().Return(taskReader)

pCtx.EXPECT().K8sReader().Return(k8sReader)

tID := &coreMocks.TaskExecutionID{}
tID.EXPECT().GetID().Return(&core.TaskExecutionIdentifier{
NodeExecutionId: &core.NodeExecutionIdentifier{
Expand All @@ -352,7 +361,7 @@ func TestGetTaskPhase_Initializing(t *testing.T) {
js := makeJobSet("", "", suspend)

spec := &clusteredpb.ClusteredTaskSpec{Replicas: 2, NprocPerNode: 1}
pCtx := dummyPluginCtx(buildTaskTemplate(spec))
pCtx := dummyPluginCtx(buildTaskTemplate(spec), emptyK8sReader())

handler := clusteredResourceHandler{}
phase, err := handler.GetTaskPhase(context.Background(), pCtx, js)
Expand All @@ -364,7 +373,7 @@ func TestGetTaskPhase_Success(t *testing.T) {
js := makeJobSet(jobsetv1alpha2.JobSetCompleted, metav1.ConditionTrue, false)

spec := &clusteredpb.ClusteredTaskSpec{Replicas: 2, NprocPerNode: 1}
pCtx := dummyPluginCtx(buildTaskTemplate(spec))
pCtx := dummyPluginCtx(buildTaskTemplate(spec), emptyK8sReader())

handler := clusteredResourceHandler{}
phase, err := handler.GetTaskPhase(context.Background(), pCtx, js)
Expand All @@ -376,7 +385,7 @@ func TestGetTaskPhase_Failure(t *testing.T) {
js := makeJobSet(jobsetv1alpha2.JobSetFailed, metav1.ConditionTrue, false)

spec := &clusteredpb.ClusteredTaskSpec{Replicas: 2, NprocPerNode: 1}
pCtx := dummyPluginCtx(buildTaskTemplate(spec))
pCtx := dummyPluginCtx(buildTaskTemplate(spec), emptyK8sReader())

handler := clusteredResourceHandler{}
phase, err := handler.GetTaskPhase(context.Background(), pCtx, js)
Expand All @@ -397,7 +406,7 @@ func TestGetTaskPhase_Running(t *testing.T) {
}

spec := &clusteredpb.ClusteredTaskSpec{Replicas: 2, NprocPerNode: 1}
pCtx := dummyPluginCtx(buildTaskTemplate(spec))
pCtx := dummyPluginCtx(buildTaskTemplate(spec), emptyK8sReader())

handler := clusteredResourceHandler{}
phase, err := handler.GetTaskPhase(context.Background(), pCtx, js)
Expand All @@ -424,7 +433,7 @@ func TestGetTaskPhase_FastFail_NoJobsFailed(t *testing.T) {
}

spec := &clusteredpb.ClusteredTaskSpec{Replicas: 2, NprocPerNode: 1}
pCtx := dummyPluginCtx(buildTaskTemplate(spec))
pCtx := dummyPluginCtx(buildTaskTemplate(spec), emptyK8sReader())

handler := clusteredResourceHandler{}
phase, err := handler.GetTaskPhase(context.Background(), pCtx, js)
Expand All @@ -442,7 +451,7 @@ func TestGetTaskPhase_MaintenanceRetry_FlagFalse(t *testing.T) {
NprocPerNode: 1,
FailurePolicy: &clusteredpb.ClusterFailurePolicy{RestartOnHostMaintenance: false},
}
pCtx := dummyPluginCtx(buildTaskTemplate(spec))
pCtx := dummyPluginCtx(buildTaskTemplate(spec), emptyK8sReader())

handler := clusteredResourceHandler{}
phase, err := handler.GetTaskPhase(context.Background(), pCtx, js)
Expand Down Expand Up @@ -484,8 +493,7 @@ func TestGetTaskPhase_FastFail_Worker0Failed(t *testing.T) {
fakeClient := fake.NewClientBuilder().WithScheme(k8sscheme.Scheme).WithObjects(pod).Build()

spec := &clusteredpb.ClusteredTaskSpec{Replicas: 2, NprocPerNode: 1}
pCtx := dummyPluginCtx(buildTaskTemplate(spec))
pCtx.EXPECT().K8sReader().Return(fakeClient)
pCtx := dummyPluginCtx(buildTaskTemplate(spec), fakeClient)

handler := clusteredResourceHandler{}
phase, err := handler.GetTaskPhase(context.Background(), pCtx, js)
Expand Down Expand Up @@ -517,8 +525,7 @@ func TestGetTaskPhase_MaintenanceRetry_SystemFailure(t *testing.T) {
NprocPerNode: 1,
FailurePolicy: &clusteredpb.ClusterFailurePolicy{RestartOnHostMaintenance: true},
}
pCtx := dummyPluginCtx(buildTaskTemplate(spec))
pCtx.EXPECT().K8sReader().Return(fakeClient)
pCtx := dummyPluginCtx(buildTaskTemplate(spec), fakeClient)

handler := clusteredResourceHandler{}
phase, err := handler.GetTaskPhase(context.Background(), pCtx, js)
Expand All @@ -527,6 +534,109 @@ func TestGetTaskPhase_MaintenanceRetry_SystemFailure(t *testing.T) {
assert.Equal(t, core.ExecutionError_SYSTEM, phase.Err().GetKind())
}

func TestGetTaskPhase_LogContext(t *testing.T) {
const primaryContainer = "primary"
const sidecarContainer = "sidecar"

// mkPod builds a realistic JobSet child pod: a primary container plus a sidecar,
// with matching container statuses so BuildPodLogContext produces real container
// contexts. Pending pods carry no statuses.
mkPod := func(name string, phase corev1.PodPhase) *corev1.Pod {
pod := &corev1.Pod{
ObjectMeta: metav1.ObjectMeta{
Name: name,
Namespace: testNS,
Labels: map[string]string{jobSetNameLabel: testJobName},
},
Spec: corev1.PodSpec{
Containers: []corev1.Container{{Name: primaryContainer}, {Name: sidecarContainer}},
},
Status: corev1.PodStatus{Phase: phase},
}
Comment thread
AdilFayyaz marked this conversation as resolved.
if phase == corev1.PodRunning {
running := corev1.ContainerState{Running: &corev1.ContainerStateRunning{StartedAt: metav1.NewTime(time.Now())}}
pod.Status.ContainerStatuses = []corev1.ContainerStatus{
{Name: primaryContainer, State: running},
{Name: sidecarContainer, State: running},
}
}
return pod
}
Comment thread
AdilFayyaz marked this conversation as resolved.

// jobSet annotates the authoritative primary container name at build time.
makeRunningJobSet := func() *jobsetv1alpha2.JobSet {
js := makeJobSet("", "", false)
js.Annotations = map[string]string{primaryContainerAnnotation: primaryContainer}
js.Status.Conditions = []metav1.Condition{
{Type: "SomeActiveCondition", Status: metav1.ConditionTrue, LastTransitionTime: metav1.NewTime(time.Now())},
}
return js
}

// Real JobSet pods carry a random suffix after the "<jobset>-workers-<job>-<idx>" stem.
rank0 := rank0PodName(testJobName) + "-x1y2z"
rank1 := testJobName + "-workers-0-1-a9b8c"
rank2 := testJobName + "-workers-0-2-pppp"

t.Run("primary pod and container resolved from live pods", func(t *testing.T) {
js := makeRunningJobSet()
fakeClient := fake.NewClientBuilder().WithScheme(k8sscheme.Scheme).
WithObjects(
mkPod(rank0, corev1.PodRunning),
mkPod(rank1, corev1.PodRunning),
mkPod(rank2, corev1.PodPending),
).Build()

spec := &clusteredpb.ClusteredTaskSpec{Replicas: 2, NprocPerNode: 1}
pCtx := dummyPluginCtx(buildTaskTemplate(spec), fakeClient)

handler := clusteredResourceHandler{}
phase, err := handler.GetTaskPhase(context.Background(), pCtx, js)
assert.NoError(t, err)
assert.Equal(t, pluginsCore.PhaseRunning, phase.Phase())

lc := phase.Info().LogContext
assert.NotNil(t, lc)
assert.Equal(t, rank0, lc.PrimaryPodName)
// Pending pod is excluded → only the two running pods remain.
assert.Len(t, lc.Pods, 2)
names := []string{lc.Pods[0].GetPodName(), lc.Pods[1].GetPodName()}
assert.Contains(t, names, rank0)
assert.Contains(t, names, rank1)

// Each pod's primary container comes from the JobSet annotation (not the
// sidecar / first container), and container contexts are populated.
for _, p := range lc.Pods {
assert.Equal(t, primaryContainer, p.GetPrimaryContainerName())
assert.GreaterOrEqual(t, len(p.GetContainers()), 1)
}
})

t.Run("primary falls back when rank-0 pod is pending", func(t *testing.T) {
js := makeRunningJobSet()
fakeClient := fake.NewClientBuilder().WithScheme(k8sscheme.Scheme).
WithObjects(
mkPod(rank0, corev1.PodPending),
mkPod(rank1, corev1.PodRunning),
).Build()

spec := &clusteredpb.ClusteredTaskSpec{Replicas: 2, NprocPerNode: 1}
pCtx := dummyPluginCtx(buildTaskTemplate(spec), fakeClient)

handler := clusteredResourceHandler{}
phase, err := handler.GetTaskPhase(context.Background(), pCtx, js)
assert.NoError(t, err)

lc := phase.Info().LogContext
assert.NotNil(t, lc)
// rank-0 is pending and excluded → PrimaryPodName must still reference an
// included pod so downstream log streaming can resolve it.
assert.Len(t, lc.Pods, 1)
assert.Equal(t, rank1, lc.PrimaryPodName)
assert.Equal(t, lc.Pods[0].GetPodName(), lc.PrimaryPodName)
})
}

// --- IsTerminal / GetCompletionTime ---

func TestIsTerminal(t *testing.T) {
Expand Down
61 changes: 61 additions & 0 deletions flyteplugins/go/tasks/plugins/k8s/clustered/logs.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,25 @@ package clustered
import (
"context"
"fmt"
"strings"
"time"

v1 "k8s.io/api/core/v1"
"sigs.k8s.io/controller-runtime/pkg/client"
jobsetv1alpha2 "sigs.k8s.io/jobset/api/jobset/v1alpha2"

"github.com/flyteorg/flyte/v2/flyteplugins/go/tasks/logs"
"github.com/flyteorg/flyte/v2/flyteplugins/go/tasks/pluginmachinery/flytek8s"
"github.com/flyteorg/flyte/v2/flyteplugins/go/tasks/pluginmachinery/k8s"
"github.com/flyteorg/flyte/v2/flyteplugins/go/tasks/pluginmachinery/tasklog"
"github.com/flyteorg/flyte/v2/flytestdlib/logger"
"github.com/flyteorg/flyte/v2/gen/go/flyteidl2/core"
)

// jobSetNameLabel is stamped by the JobSet controller on every child pod, so we
// can list a JobSet's pods without depending on predicted pod names.
const jobSetNameLabel = "jobset.sigs.k8s.io/jobset-name"

// getTaskLogs synthesizes per-rank log URLs.
//
// JobSet pod-name pattern: <jobsetName>-<replicatedJob>-<jobIdx>-<podIdx>.
Expand Down Expand Up @@ -76,3 +85,55 @@ func getTaskLogs(ctx context.Context, pluginContext k8s.PluginContext, jobSet *j
}
return taskLogs, nil
}

// getLogContext builds the structured LogContext from the JobSet's live child pods.
//
// Unlike getTaskLogs (which synthesizes templated URIs from *predicted* pod names and
// requires a pod-log template to be configured in cluster config), this uses the *real*
// pods — actual names (including the Job-assigned random suffix), namespace, primary
// container, and per-container IDs — so the console can fetch logs natively regardless
// of log-template config. Best-effort: returns nil on list error or when no pods are
Comment thread
AdilFayyaz marked this conversation as resolved.
// ready yet, leaving the templated Logs path as the fallback.
func getLogContext(ctx context.Context, pluginContext k8s.PluginContext, jobSet *jobsetv1alpha2.JobSet) *core.LogContext {
podList := &v1.PodList{}
if err := pluginContext.K8sReader().List(ctx, podList,
client.InNamespace(jobSet.Namespace),
Comment thread
AdilFayyaz marked this conversation as resolved.
Outdated
client.MatchingLabels{jobSetNameLabel: jobSet.Name},
); err != nil {
logger.Warnf(ctx, "failed to list pods for JobSet %s/%s log context: %v", jobSet.Namespace, jobSet.Name, err)
Comment thread
AdilFayyaz marked this conversation as resolved.
return nil
}

// rank0PodName returns "<jobset>-workers-0-0"; the real pod carries an additional
// random suffix, so match on prefix to identify the primary (rank-0) pod.
primaryPrefix := rank0PodName(jobSet.Name)
// The authoritative primary container name is stored on the JobSet at build time
// (see build.go). Child pods don't carry the annotations BuildPodLogContext infers
// from, so set it explicitly to avoid resolving to the wrong container (e.g. a sidecar).
primaryContainerName := jobSet.Annotations[primaryContainerAnnotation]
logCtx := &core.LogContext{Pods: make([]*core.PodLogContext, 0, len(podList.Items))}
for i := range podList.Items {
pod := &podList.Items[i]
// Pending pods have no logs yet and no container IDs to address them by.
Comment thread
AdilFayyaz marked this conversation as resolved.
Outdated
if pod.Status.Phase == v1.PodPending {
continue
}
Comment thread
pingsutw marked this conversation as resolved.
if strings.HasPrefix(pod.Name, primaryPrefix) {
logCtx.PrimaryPodName = pod.Name
}
podLogCtx := flytek8s.BuildPodLogContext(pod)
if primaryContainerName != "" {
podLogCtx.PrimaryContainerName = primaryContainerName
}
logCtx.Pods = append(logCtx.Pods, podLogCtx)
}
if len(logCtx.Pods) == 0 {
return nil
}
// Guarantee PrimaryPodName references a pod in Pods: if rank-0 was pending/absent,
// fall back to the first included pod so downstream log streaming can resolve it.
if logCtx.PrimaryPodName == "" {
logCtx.PrimaryPodName = logCtx.Pods[0].GetPodName()
}
return logCtx
Comment thread
AdilFayyaz marked this conversation as resolved.
}
1 change: 1 addition & 0 deletions flyteplugins/go/tasks/plugins/k8s/clustered/phase.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ func (clusteredResourceHandler) GetTaskPhase(ctx context.Context, pluginContext
}
taskInfo := pluginsCore.TaskInfo{
Logs: taskLogs,
LogContext: getLogContext(ctx, pluginContext, jobSet),
OccurredAt: &occurredAt,
CustomInfo: statusDetails,
}
Expand Down
Loading