diff --git a/flyteplugins/go/tasks/plugins/k8s/clustered/build.go b/flyteplugins/go/tasks/plugins/k8s/clustered/build.go index 2c847f8994..331c25f8f8 100644 --- a/flyteplugins/go/tasks/plugins/k8s/clustered/build.go +++ b/flyteplugins/go/tasks/plugins/k8s/clustered/build.go @@ -13,6 +13,7 @@ import ( flyteerr "github.com/flyteorg/flyte/v2/flyteplugins/go/tasks/errors" pluginsCore "github.com/flyteorg/flyte/v2/flyteplugins/go/tasks/pluginmachinery/core" "github.com/flyteorg/flyte/v2/flyteplugins/go/tasks/pluginmachinery/flytek8s" + "github.com/flyteorg/flyte/v2/flyteplugins/go/tasks/pluginmachinery/flytek8s/config" "github.com/flyteorg/flyte/v2/flyteplugins/go/tasks/pluginmachinery/utils" clusteredpb "github.com/flyteorg/flyte/v2/gen/go/flyteidl2/plugins" ) @@ -45,6 +46,18 @@ func (clusteredResourceHandler) BuildResource(ctx context.Context, taskCtx plugi podSpec = applyInterconnect(ctx, spec.GetInterconnect(), podSpec) + // Propagate the node-execution labels/annotations onto the pod template. The plugin + // manager's addObjectMetadata only stamps these (incl. execution-id/node-id) on the + // top-level JobSet, and the JobSet controller does not copy arbitrary parent labels + // down to child pods. Without this, child pods lack execution-id/node-id and the + // node-execution-scoped K8sReader.List in getLogContext returns nothing, so no + // LogContext reaches the UI. Mirrors ray's buildWorkerPodTemplate. + cfg := config.GetK8sPluginConfig() + objectMeta.Labels = utils.UnionMaps(cfg.DefaultLabels, objectMeta.Labels, + utils.CopyMap(taskCtx.TaskExecutionMetadata().GetLabels())) + objectMeta.Annotations = utils.UnionMaps(cfg.DefaultAnnotations, objectMeta.Annotations, + utils.CopyMap(taskCtx.TaskExecutionMetadata().GetAnnotations())) + // The SDK is responsible for setting container.Command to the entrypoint module // (python -m flyte.distributed._entrypoint) at serde time. The plugin stays // module-path-agnostic so SDK renames do not require a backend release. diff --git a/flyteplugins/go/tasks/plugins/k8s/clustered/clustered_test.go b/flyteplugins/go/tasks/plugins/k8s/clustered/clustered_test.go index e0212260cb..a73a4762c2 100644 --- a/flyteplugins/go/tasks/plugins/k8s/clustered/clustered_test.go +++ b/flyteplugins/go/tasks/plugins/k8s/clustered/clustered_test.go @@ -12,8 +12,9 @@ import ( "k8s.io/apimachinery/pkg/api/resource" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" k8sscheme "k8s.io/client-go/kubernetes/scheme" - jobsetv1alpha2 "sigs.k8s.io/jobset/api/jobset/v1alpha2" + "sigs.k8s.io/controller-runtime/pkg/client" "sigs.k8s.io/controller-runtime/pkg/client/fake" + jobsetv1alpha2 "sigs.k8s.io/jobset/api/jobset/v1alpha2" pluginsCore "github.com/flyteorg/flyte/v2/flyteplugins/go/tasks/pluginmachinery/core" coreMocks "github.com/flyteorg/flyte/v2/flyteplugins/go/tasks/pluginmachinery/core/mocks" @@ -99,8 +100,8 @@ func dummyTaskCtx(taskTemplate *core.TaskTemplate) *coreMocks.TaskExecutionConte meta := &coreMocks.TaskExecutionMetadata{} meta.EXPECT().GetTaskExecutionID().Return(tID) meta.EXPECT().GetNamespace().Return(testNS) - meta.EXPECT().GetAnnotations().Return(map[string]string{}) - meta.EXPECT().GetLabels().Return(map[string]string{}) + meta.EXPECT().GetAnnotations().Return(map[string]string{"flyte.org/test-annotation": "av"}) + meta.EXPECT().GetLabels().Return(map[string]string{"execution-id": "my-exec", "node-id": "n1"}) meta.EXPECT().GetOwnerReference().Return(metav1.OwnerReference{Kind: "node", Name: "n1"}) meta.EXPECT().IsInterruptible().Return(false) meta.EXPECT().GetOverrides().Return(overrides) @@ -154,6 +155,14 @@ func TestBuildResource_HappyPath(t *testing.T) { assert.Equal(t, int32(4), *jobSpec.Completions) assert.Equal(t, batchv1.IndexedCompletion, *jobSpec.CompletionMode) assert.Equal(t, int32(0), *jobSpec.BackoffLimit) + + // The node-execution labels/annotations must be propagated onto the pod template so + // JobSet child pods carry execution-id/node-id; otherwise the node-execution-scoped + // K8sReader.List in getLogContext returns nothing and no logs reach the UI. + podMeta := jobSpec.Template.ObjectMeta + assert.Equal(t, "my-exec", podMeta.Labels["execution-id"]) + assert.Equal(t, "n1", podMeta.Labels["node-id"]) + assert.Equal(t, "av", podMeta.Annotations["flyte.org/test-annotation"]) } func TestBuildResource_PrimaryContainerPreserved(t *testing.T) { @@ -320,13 +329,21 @@ func makeJobSet(condType jobsetv1alpha2.JobSetConditionType, status metav1.Condi return js } -func dummyPluginCtx(taskTemplate *core.TaskTemplate) *k8smocks.PluginContext { +// emptyK8sReader returns a fake client with no objects, for tests that don't +// exercise pod inspection (getLogContext just yields an empty pod list -> nil LogContext). +func emptyK8sReader() client.Reader { + return fake.NewClientBuilder().WithScheme(k8sscheme.Scheme).Build() +} + +func dummyPluginCtx(taskTemplate *core.TaskTemplate, k8sReader client.Reader) *k8smocks.PluginContext { pCtx := &k8smocks.PluginContext{} taskReader := &coreMocks.TaskReader{} taskReader.EXPECT().Read(mock.Anything).Return(taskTemplate, nil) pCtx.EXPECT().TaskReader().Return(taskReader) + pCtx.EXPECT().K8sReader().Return(k8sReader) + tID := &coreMocks.TaskExecutionID{} tID.EXPECT().GetID().Return(&core.TaskExecutionIdentifier{ NodeExecutionId: &core.NodeExecutionIdentifier{ @@ -352,7 +369,7 @@ func TestGetTaskPhase_Initializing(t *testing.T) { js := makeJobSet("", "", suspend) spec := &clusteredpb.ClusteredTaskSpec{Replicas: 2, NprocPerNode: 1} - pCtx := dummyPluginCtx(buildTaskTemplate(spec)) + pCtx := dummyPluginCtx(buildTaskTemplate(spec), emptyK8sReader()) handler := clusteredResourceHandler{} phase, err := handler.GetTaskPhase(context.Background(), pCtx, js) @@ -364,7 +381,7 @@ func TestGetTaskPhase_Success(t *testing.T) { js := makeJobSet(jobsetv1alpha2.JobSetCompleted, metav1.ConditionTrue, false) spec := &clusteredpb.ClusteredTaskSpec{Replicas: 2, NprocPerNode: 1} - pCtx := dummyPluginCtx(buildTaskTemplate(spec)) + pCtx := dummyPluginCtx(buildTaskTemplate(spec), emptyK8sReader()) handler := clusteredResourceHandler{} phase, err := handler.GetTaskPhase(context.Background(), pCtx, js) @@ -376,7 +393,7 @@ func TestGetTaskPhase_Failure(t *testing.T) { js := makeJobSet(jobsetv1alpha2.JobSetFailed, metav1.ConditionTrue, false) spec := &clusteredpb.ClusteredTaskSpec{Replicas: 2, NprocPerNode: 1} - pCtx := dummyPluginCtx(buildTaskTemplate(spec)) + pCtx := dummyPluginCtx(buildTaskTemplate(spec), emptyK8sReader()) handler := clusteredResourceHandler{} phase, err := handler.GetTaskPhase(context.Background(), pCtx, js) @@ -397,7 +414,7 @@ func TestGetTaskPhase_Running(t *testing.T) { } spec := &clusteredpb.ClusteredTaskSpec{Replicas: 2, NprocPerNode: 1} - pCtx := dummyPluginCtx(buildTaskTemplate(spec)) + pCtx := dummyPluginCtx(buildTaskTemplate(spec), emptyK8sReader()) handler := clusteredResourceHandler{} phase, err := handler.GetTaskPhase(context.Background(), pCtx, js) @@ -424,7 +441,7 @@ func TestGetTaskPhase_FastFail_NoJobsFailed(t *testing.T) { } spec := &clusteredpb.ClusteredTaskSpec{Replicas: 2, NprocPerNode: 1} - pCtx := dummyPluginCtx(buildTaskTemplate(spec)) + pCtx := dummyPluginCtx(buildTaskTemplate(spec), emptyK8sReader()) handler := clusteredResourceHandler{} phase, err := handler.GetTaskPhase(context.Background(), pCtx, js) @@ -442,7 +459,7 @@ func TestGetTaskPhase_MaintenanceRetry_FlagFalse(t *testing.T) { NprocPerNode: 1, FailurePolicy: &clusteredpb.ClusterFailurePolicy{RestartOnHostMaintenance: false}, } - pCtx := dummyPluginCtx(buildTaskTemplate(spec)) + pCtx := dummyPluginCtx(buildTaskTemplate(spec), emptyK8sReader()) handler := clusteredResourceHandler{} phase, err := handler.GetTaskPhase(context.Background(), pCtx, js) @@ -484,8 +501,7 @@ func TestGetTaskPhase_FastFail_Worker0Failed(t *testing.T) { fakeClient := fake.NewClientBuilder().WithScheme(k8sscheme.Scheme).WithObjects(pod).Build() spec := &clusteredpb.ClusteredTaskSpec{Replicas: 2, NprocPerNode: 1} - pCtx := dummyPluginCtx(buildTaskTemplate(spec)) - pCtx.EXPECT().K8sReader().Return(fakeClient) + pCtx := dummyPluginCtx(buildTaskTemplate(spec), fakeClient) handler := clusteredResourceHandler{} phase, err := handler.GetTaskPhase(context.Background(), pCtx, js) @@ -517,8 +533,7 @@ func TestGetTaskPhase_MaintenanceRetry_SystemFailure(t *testing.T) { NprocPerNode: 1, FailurePolicy: &clusteredpb.ClusterFailurePolicy{RestartOnHostMaintenance: true}, } - pCtx := dummyPluginCtx(buildTaskTemplate(spec)) - pCtx.EXPECT().K8sReader().Return(fakeClient) + pCtx := dummyPluginCtx(buildTaskTemplate(spec), fakeClient) handler := clusteredResourceHandler{} phase, err := handler.GetTaskPhase(context.Background(), pCtx, js) @@ -527,6 +542,108 @@ func TestGetTaskPhase_MaintenanceRetry_SystemFailure(t *testing.T) { assert.Equal(t, core.ExecutionError_SYSTEM, phase.Err().GetKind()) } +func TestGetTaskPhase_LogContext(t *testing.T) { + const primaryContainer = "primary" + const sidecarContainer = "sidecar" + + // mkPod builds a realistic JobSet child pod: a primary container plus a sidecar, + // with matching container statuses so BuildPodLogContext produces real container + // contexts. Pending pods carry no statuses. + mkPod := func(name string, phase corev1.PodPhase) *corev1.Pod { + pod := &corev1.Pod{ + ObjectMeta: metav1.ObjectMeta{ + Name: name, + Namespace: testNS, + }, + Spec: corev1.PodSpec{ + Containers: []corev1.Container{{Name: primaryContainer}, {Name: sidecarContainer}}, + }, + Status: corev1.PodStatus{Phase: phase}, + } + if phase == corev1.PodRunning { + running := corev1.ContainerState{Running: &corev1.ContainerStateRunning{StartedAt: metav1.NewTime(time.Now())}} + pod.Status.ContainerStatuses = []corev1.ContainerStatus{ + {Name: primaryContainer, State: running}, + {Name: sidecarContainer, State: running}, + } + } + return pod + } + + // jobSet annotates the authoritative primary container name at build time. + makeRunningJobSet := func() *jobsetv1alpha2.JobSet { + js := makeJobSet("", "", false) + js.Annotations = map[string]string{primaryContainerAnnotation: primaryContainer} + js.Status.Conditions = []metav1.Condition{ + {Type: "SomeActiveCondition", Status: metav1.ConditionTrue, LastTransitionTime: metav1.NewTime(time.Now())}, + } + return js + } + + // Real JobSet pods carry a random suffix after the "-workers--" stem. + rank0 := rank0PodName(testJobName) + "-x1y2z" + rank1 := testJobName + "-workers-0-1-a9b8c" + rank2 := testJobName + "-workers-0-2-pppp" + + t.Run("primary pod and container resolved from live pods", func(t *testing.T) { + js := makeRunningJobSet() + fakeClient := fake.NewClientBuilder().WithScheme(k8sscheme.Scheme). + WithObjects( + mkPod(rank0, corev1.PodRunning), + mkPod(rank1, corev1.PodRunning), + mkPod(rank2, corev1.PodPending), + ).Build() + + spec := &clusteredpb.ClusteredTaskSpec{Replicas: 2, NprocPerNode: 1} + pCtx := dummyPluginCtx(buildTaskTemplate(spec), fakeClient) + + handler := clusteredResourceHandler{} + phase, err := handler.GetTaskPhase(context.Background(), pCtx, js) + assert.NoError(t, err) + assert.Equal(t, pluginsCore.PhaseRunning, phase.Phase()) + + lc := phase.Info().LogContext + assert.NotNil(t, lc) + assert.Equal(t, rank0, lc.PrimaryPodName) + // Pending pod is excluded → only the two running pods remain. + assert.Len(t, lc.Pods, 2) + names := []string{lc.Pods[0].GetPodName(), lc.Pods[1].GetPodName()} + assert.Contains(t, names, rank0) + assert.Contains(t, names, rank1) + + // Each pod's primary container comes from the JobSet annotation (not the + // sidecar / first container), and container contexts are populated. + for _, p := range lc.Pods { + assert.Equal(t, primaryContainer, p.GetPrimaryContainerName()) + assert.GreaterOrEqual(t, len(p.GetContainers()), 1) + } + }) + + t.Run("primary falls back when rank-0 pod is pending", func(t *testing.T) { + js := makeRunningJobSet() + fakeClient := fake.NewClientBuilder().WithScheme(k8sscheme.Scheme). + WithObjects( + mkPod(rank0, corev1.PodPending), + mkPod(rank1, corev1.PodRunning), + ).Build() + + spec := &clusteredpb.ClusteredTaskSpec{Replicas: 2, NprocPerNode: 1} + pCtx := dummyPluginCtx(buildTaskTemplate(spec), fakeClient) + + handler := clusteredResourceHandler{} + phase, err := handler.GetTaskPhase(context.Background(), pCtx, js) + assert.NoError(t, err) + + lc := phase.Info().LogContext + assert.NotNil(t, lc) + // rank-0 is pending and excluded → PrimaryPodName must still reference an + // included pod so downstream log streaming can resolve it. + assert.Len(t, lc.Pods, 1) + assert.Equal(t, rank1, lc.PrimaryPodName) + assert.Equal(t, lc.Pods[0].GetPodName(), lc.PrimaryPodName) + }) +} + // --- IsTerminal / GetCompletionTime --- func TestIsTerminal(t *testing.T) { diff --git a/flyteplugins/go/tasks/plugins/k8s/clustered/logs.go b/flyteplugins/go/tasks/plugins/k8s/clustered/logs.go index af3df4f4b3..e4d2fd91f7 100644 --- a/flyteplugins/go/tasks/plugins/k8s/clustered/logs.go +++ b/flyteplugins/go/tasks/plugins/k8s/clustered/logs.go @@ -3,13 +3,17 @@ package clustered import ( "context" "fmt" + "strings" "time" + v1 "k8s.io/api/core/v1" jobsetv1alpha2 "sigs.k8s.io/jobset/api/jobset/v1alpha2" "github.com/flyteorg/flyte/v2/flyteplugins/go/tasks/logs" + "github.com/flyteorg/flyte/v2/flyteplugins/go/tasks/pluginmachinery/flytek8s" "github.com/flyteorg/flyte/v2/flyteplugins/go/tasks/pluginmachinery/k8s" "github.com/flyteorg/flyte/v2/flyteplugins/go/tasks/pluginmachinery/tasklog" + "github.com/flyteorg/flyte/v2/flytestdlib/logger" "github.com/flyteorg/flyte/v2/gen/go/flyteidl2/core" ) @@ -76,3 +80,55 @@ func getTaskLogs(ctx context.Context, pluginContext k8s.PluginContext, jobSet *j } return taskLogs, nil } + +// getLogContext builds the structured LogContext from the JobSet's live child pods. +// +// Unlike getTaskLogs (which synthesizes templated URIs from *predicted* pod names and +// requires a pod-log template to be configured in cluster config), this uses the *real* +// pods — actual names (including the Job-assigned random suffix), namespace, primary +// container, and per-container names + process timestamps — so the console can fetch +// logs natively regardless +// of log-template config. Best-effort: returns nil on list error or when no pods are +// ready yet, leaving the templated Logs path as the fallback. +func getLogContext(ctx context.Context, pluginContext k8s.PluginContext, jobSet *jobsetv1alpha2.JobSet) *core.LogContext { + // The plugin's K8sReader already scopes List calls to this node execution's + // namespace and execution-id/node-id labels, so no extra filters are needed. + podList := &v1.PodList{} + if err := pluginContext.K8sReader().List(ctx, podList); err != nil { + logger.Warnf(ctx, "failed to list pods for JobSet %s/%s log context: %v", jobSet.Namespace, jobSet.Name, err) + return nil + } + + // rank0PodName returns "-workers-0-0"; the real pod carries an additional + // random suffix, so match on prefix to identify the primary (rank-0) pod. + primaryPrefix := rank0PodName(jobSet.Name) + // The authoritative primary container name is stored on the JobSet at build time + // (see build.go). Child pods don't carry the annotations BuildPodLogContext infers + // from, so set it explicitly to avoid resolving to the wrong container (e.g. a sidecar). + primaryContainerName := jobSet.Annotations[primaryContainerAnnotation] + logCtx := &core.LogContext{Pods: make([]*core.PodLogContext, 0, len(podList.Items))} + for i := range podList.Items { + pod := &podList.Items[i] + // Pending pods have no logs yet and no container statuses to build contexts from. + if pod.Status.Phase == v1.PodPending { + continue + } + if strings.HasPrefix(pod.Name, primaryPrefix) { + logCtx.PrimaryPodName = pod.Name + } + podLogCtx := flytek8s.BuildPodLogContext(pod) + if primaryContainerName != "" { + podLogCtx.PrimaryContainerName = primaryContainerName + } + logCtx.Pods = append(logCtx.Pods, podLogCtx) + } + if len(logCtx.Pods) == 0 { + return nil + } + // Guarantee PrimaryPodName references a pod in Pods: if rank-0 was pending/absent, + // fall back to the first included pod so downstream log streaming can resolve it. + if logCtx.PrimaryPodName == "" { + logCtx.PrimaryPodName = logCtx.Pods[0].GetPodName() + } + return logCtx +} diff --git a/flyteplugins/go/tasks/plugins/k8s/clustered/phase.go b/flyteplugins/go/tasks/plugins/k8s/clustered/phase.go index 343456c387..1ce1209fe6 100644 --- a/flyteplugins/go/tasks/plugins/k8s/clustered/phase.go +++ b/flyteplugins/go/tasks/plugins/k8s/clustered/phase.go @@ -45,6 +45,7 @@ func (clusteredResourceHandler) GetTaskPhase(ctx context.Context, pluginContext } taskInfo := pluginsCore.TaskInfo{ Logs: taskLogs, + LogContext: getLogContext(ctx, pluginContext, jobSet), OccurredAt: &occurredAt, CustomInfo: statusDetails, }