Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions flyteplugins/go/tasks/plugins/k8s/clustered/build.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
flyteerr "github.com/flyteorg/flyte/v2/flyteplugins/go/tasks/errors"
pluginsCore "github.com/flyteorg/flyte/v2/flyteplugins/go/tasks/pluginmachinery/core"
"github.com/flyteorg/flyte/v2/flyteplugins/go/tasks/pluginmachinery/flytek8s"
"github.com/flyteorg/flyte/v2/flyteplugins/go/tasks/pluginmachinery/flytek8s/config"
"github.com/flyteorg/flyte/v2/flyteplugins/go/tasks/pluginmachinery/utils"
clusteredpb "github.com/flyteorg/flyte/v2/gen/go/flyteidl2/plugins"
)
Expand Down Expand Up @@ -45,6 +46,18 @@ func (clusteredResourceHandler) BuildResource(ctx context.Context, taskCtx plugi

podSpec = applyInterconnect(ctx, spec.GetInterconnect(), podSpec)

// Propagate the node-execution labels/annotations onto the pod template. The plugin
// manager's addObjectMetadata only stamps these (incl. execution-id/node-id) on the
// top-level JobSet, and the JobSet controller does not copy arbitrary parent labels
// down to child pods. Without this, child pods lack execution-id/node-id and the
// node-execution-scoped K8sReader.List in getLogContext returns nothing, so no
// LogContext reaches the UI. Mirrors ray's buildWorkerPodTemplate.
cfg := config.GetK8sPluginConfig()
objectMeta.Labels = utils.UnionMaps(cfg.DefaultLabels, objectMeta.Labels,
utils.CopyMap(taskCtx.TaskExecutionMetadata().GetLabels()))
objectMeta.Annotations = utils.UnionMaps(cfg.DefaultAnnotations, objectMeta.Annotations,
utils.CopyMap(taskCtx.TaskExecutionMetadata().GetAnnotations()))

// The SDK is responsible for setting container.Command to the entrypoint module
// (python -m flyte.distributed._entrypoint) at serde time. The plugin stays
// module-path-agnostic so SDK renames do not require a backend release.
Expand Down
145 changes: 131 additions & 14 deletions flyteplugins/go/tasks/plugins/k8s/clustered/clustered_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,9 @@ import (
"k8s.io/apimachinery/pkg/api/resource"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
k8sscheme "k8s.io/client-go/kubernetes/scheme"
jobsetv1alpha2 "sigs.k8s.io/jobset/api/jobset/v1alpha2"
"sigs.k8s.io/controller-runtime/pkg/client"
"sigs.k8s.io/controller-runtime/pkg/client/fake"
jobsetv1alpha2 "sigs.k8s.io/jobset/api/jobset/v1alpha2"

pluginsCore "github.com/flyteorg/flyte/v2/flyteplugins/go/tasks/pluginmachinery/core"
coreMocks "github.com/flyteorg/flyte/v2/flyteplugins/go/tasks/pluginmachinery/core/mocks"
Expand Down Expand Up @@ -99,8 +100,8 @@ func dummyTaskCtx(taskTemplate *core.TaskTemplate) *coreMocks.TaskExecutionConte
meta := &coreMocks.TaskExecutionMetadata{}
meta.EXPECT().GetTaskExecutionID().Return(tID)
meta.EXPECT().GetNamespace().Return(testNS)
meta.EXPECT().GetAnnotations().Return(map[string]string{})
meta.EXPECT().GetLabels().Return(map[string]string{})
meta.EXPECT().GetAnnotations().Return(map[string]string{"flyte.org/test-annotation": "av"})
meta.EXPECT().GetLabels().Return(map[string]string{"execution-id": "my-exec", "node-id": "n1"})
meta.EXPECT().GetOwnerReference().Return(metav1.OwnerReference{Kind: "node", Name: "n1"})
meta.EXPECT().IsInterruptible().Return(false)
meta.EXPECT().GetOverrides().Return(overrides)
Expand Down Expand Up @@ -154,6 +155,14 @@ func TestBuildResource_HappyPath(t *testing.T) {
assert.Equal(t, int32(4), *jobSpec.Completions)
assert.Equal(t, batchv1.IndexedCompletion, *jobSpec.CompletionMode)
assert.Equal(t, int32(0), *jobSpec.BackoffLimit)

// The node-execution labels/annotations must be propagated onto the pod template so
// JobSet child pods carry execution-id/node-id; otherwise the node-execution-scoped
// K8sReader.List in getLogContext returns nothing and no logs reach the UI.
podMeta := jobSpec.Template.ObjectMeta
assert.Equal(t, "my-exec", podMeta.Labels["execution-id"])
assert.Equal(t, "n1", podMeta.Labels["node-id"])
assert.Equal(t, "av", podMeta.Annotations["flyte.org/test-annotation"])
}

func TestBuildResource_PrimaryContainerPreserved(t *testing.T) {
Expand Down Expand Up @@ -320,13 +329,21 @@ func makeJobSet(condType jobsetv1alpha2.JobSetConditionType, status metav1.Condi
return js
}

func dummyPluginCtx(taskTemplate *core.TaskTemplate) *k8smocks.PluginContext {
// emptyK8sReader returns a fake client with no objects, for tests that don't
// exercise pod inspection (getLogContext just yields an empty pod list -> nil LogContext).
func emptyK8sReader() client.Reader {
return fake.NewClientBuilder().WithScheme(k8sscheme.Scheme).Build()
}

func dummyPluginCtx(taskTemplate *core.TaskTemplate, k8sReader client.Reader) *k8smocks.PluginContext {
pCtx := &k8smocks.PluginContext{}

taskReader := &coreMocks.TaskReader{}
taskReader.EXPECT().Read(mock.Anything).Return(taskTemplate, nil)
pCtx.EXPECT().TaskReader().Return(taskReader)

pCtx.EXPECT().K8sReader().Return(k8sReader)

tID := &coreMocks.TaskExecutionID{}
tID.EXPECT().GetID().Return(&core.TaskExecutionIdentifier{
NodeExecutionId: &core.NodeExecutionIdentifier{
Expand All @@ -352,7 +369,7 @@ func TestGetTaskPhase_Initializing(t *testing.T) {
js := makeJobSet("", "", suspend)

spec := &clusteredpb.ClusteredTaskSpec{Replicas: 2, NprocPerNode: 1}
pCtx := dummyPluginCtx(buildTaskTemplate(spec))
pCtx := dummyPluginCtx(buildTaskTemplate(spec), emptyK8sReader())

handler := clusteredResourceHandler{}
phase, err := handler.GetTaskPhase(context.Background(), pCtx, js)
Expand All @@ -364,7 +381,7 @@ func TestGetTaskPhase_Success(t *testing.T) {
js := makeJobSet(jobsetv1alpha2.JobSetCompleted, metav1.ConditionTrue, false)

spec := &clusteredpb.ClusteredTaskSpec{Replicas: 2, NprocPerNode: 1}
pCtx := dummyPluginCtx(buildTaskTemplate(spec))
pCtx := dummyPluginCtx(buildTaskTemplate(spec), emptyK8sReader())

handler := clusteredResourceHandler{}
phase, err := handler.GetTaskPhase(context.Background(), pCtx, js)
Expand All @@ -376,7 +393,7 @@ func TestGetTaskPhase_Failure(t *testing.T) {
js := makeJobSet(jobsetv1alpha2.JobSetFailed, metav1.ConditionTrue, false)

spec := &clusteredpb.ClusteredTaskSpec{Replicas: 2, NprocPerNode: 1}
pCtx := dummyPluginCtx(buildTaskTemplate(spec))
pCtx := dummyPluginCtx(buildTaskTemplate(spec), emptyK8sReader())

handler := clusteredResourceHandler{}
phase, err := handler.GetTaskPhase(context.Background(), pCtx, js)
Expand All @@ -397,7 +414,7 @@ func TestGetTaskPhase_Running(t *testing.T) {
}

spec := &clusteredpb.ClusteredTaskSpec{Replicas: 2, NprocPerNode: 1}
pCtx := dummyPluginCtx(buildTaskTemplate(spec))
pCtx := dummyPluginCtx(buildTaskTemplate(spec), emptyK8sReader())

handler := clusteredResourceHandler{}
phase, err := handler.GetTaskPhase(context.Background(), pCtx, js)
Expand All @@ -424,7 +441,7 @@ func TestGetTaskPhase_FastFail_NoJobsFailed(t *testing.T) {
}

spec := &clusteredpb.ClusteredTaskSpec{Replicas: 2, NprocPerNode: 1}
pCtx := dummyPluginCtx(buildTaskTemplate(spec))
pCtx := dummyPluginCtx(buildTaskTemplate(spec), emptyK8sReader())

handler := clusteredResourceHandler{}
phase, err := handler.GetTaskPhase(context.Background(), pCtx, js)
Expand All @@ -442,7 +459,7 @@ func TestGetTaskPhase_MaintenanceRetry_FlagFalse(t *testing.T) {
NprocPerNode: 1,
FailurePolicy: &clusteredpb.ClusterFailurePolicy{RestartOnHostMaintenance: false},
}
pCtx := dummyPluginCtx(buildTaskTemplate(spec))
pCtx := dummyPluginCtx(buildTaskTemplate(spec), emptyK8sReader())

handler := clusteredResourceHandler{}
phase, err := handler.GetTaskPhase(context.Background(), pCtx, js)
Expand Down Expand Up @@ -484,8 +501,7 @@ func TestGetTaskPhase_FastFail_Worker0Failed(t *testing.T) {
fakeClient := fake.NewClientBuilder().WithScheme(k8sscheme.Scheme).WithObjects(pod).Build()

spec := &clusteredpb.ClusteredTaskSpec{Replicas: 2, NprocPerNode: 1}
pCtx := dummyPluginCtx(buildTaskTemplate(spec))
pCtx.EXPECT().K8sReader().Return(fakeClient)
pCtx := dummyPluginCtx(buildTaskTemplate(spec), fakeClient)

handler := clusteredResourceHandler{}
phase, err := handler.GetTaskPhase(context.Background(), pCtx, js)
Expand Down Expand Up @@ -517,8 +533,7 @@ func TestGetTaskPhase_MaintenanceRetry_SystemFailure(t *testing.T) {
NprocPerNode: 1,
FailurePolicy: &clusteredpb.ClusterFailurePolicy{RestartOnHostMaintenance: true},
}
pCtx := dummyPluginCtx(buildTaskTemplate(spec))
pCtx.EXPECT().K8sReader().Return(fakeClient)
pCtx := dummyPluginCtx(buildTaskTemplate(spec), fakeClient)

handler := clusteredResourceHandler{}
phase, err := handler.GetTaskPhase(context.Background(), pCtx, js)
Expand All @@ -527,6 +542,108 @@ func TestGetTaskPhase_MaintenanceRetry_SystemFailure(t *testing.T) {
assert.Equal(t, core.ExecutionError_SYSTEM, phase.Err().GetKind())
}

func TestGetTaskPhase_LogContext(t *testing.T) {
const primaryContainer = "primary"
const sidecarContainer = "sidecar"

// mkPod builds a realistic JobSet child pod: a primary container plus a sidecar,
// with matching container statuses so BuildPodLogContext produces real container
// contexts. Pending pods carry no statuses.
mkPod := func(name string, phase corev1.PodPhase) *corev1.Pod {
pod := &corev1.Pod{
ObjectMeta: metav1.ObjectMeta{
Name: name,
Namespace: testNS,
},
Spec: corev1.PodSpec{
Containers: []corev1.Container{{Name: primaryContainer}, {Name: sidecarContainer}},
},
Status: corev1.PodStatus{Phase: phase},
}
Comment thread
AdilFayyaz marked this conversation as resolved.
if phase == corev1.PodRunning {
running := corev1.ContainerState{Running: &corev1.ContainerStateRunning{StartedAt: metav1.NewTime(time.Now())}}
pod.Status.ContainerStatuses = []corev1.ContainerStatus{
{Name: primaryContainer, State: running},
{Name: sidecarContainer, State: running},
}
}
return pod
}

// jobSet annotates the authoritative primary container name at build time.
makeRunningJobSet := func() *jobsetv1alpha2.JobSet {
js := makeJobSet("", "", false)
js.Annotations = map[string]string{primaryContainerAnnotation: primaryContainer}
js.Status.Conditions = []metav1.Condition{
{Type: "SomeActiveCondition", Status: metav1.ConditionTrue, LastTransitionTime: metav1.NewTime(time.Now())},
}
return js
}
Comment thread
AdilFayyaz marked this conversation as resolved.

// Real JobSet pods carry a random suffix after the "<jobset>-workers-<job>-<idx>" stem.
rank0 := rank0PodName(testJobName) + "-x1y2z"
rank1 := testJobName + "-workers-0-1-a9b8c"
rank2 := testJobName + "-workers-0-2-pppp"

t.Run("primary pod and container resolved from live pods", func(t *testing.T) {
js := makeRunningJobSet()
fakeClient := fake.NewClientBuilder().WithScheme(k8sscheme.Scheme).
WithObjects(
mkPod(rank0, corev1.PodRunning),
mkPod(rank1, corev1.PodRunning),
mkPod(rank2, corev1.PodPending),
).Build()

spec := &clusteredpb.ClusteredTaskSpec{Replicas: 2, NprocPerNode: 1}
pCtx := dummyPluginCtx(buildTaskTemplate(spec), fakeClient)

handler := clusteredResourceHandler{}
phase, err := handler.GetTaskPhase(context.Background(), pCtx, js)
assert.NoError(t, err)
assert.Equal(t, pluginsCore.PhaseRunning, phase.Phase())

lc := phase.Info().LogContext
assert.NotNil(t, lc)
assert.Equal(t, rank0, lc.PrimaryPodName)
// Pending pod is excluded → only the two running pods remain.
assert.Len(t, lc.Pods, 2)
names := []string{lc.Pods[0].GetPodName(), lc.Pods[1].GetPodName()}
assert.Contains(t, names, rank0)
assert.Contains(t, names, rank1)

// Each pod's primary container comes from the JobSet annotation (not the
// sidecar / first container), and container contexts are populated.
for _, p := range lc.Pods {
assert.Equal(t, primaryContainer, p.GetPrimaryContainerName())
assert.GreaterOrEqual(t, len(p.GetContainers()), 1)
}
})

t.Run("primary falls back when rank-0 pod is pending", func(t *testing.T) {
js := makeRunningJobSet()
fakeClient := fake.NewClientBuilder().WithScheme(k8sscheme.Scheme).
WithObjects(
mkPod(rank0, corev1.PodPending),
mkPod(rank1, corev1.PodRunning),
).Build()

spec := &clusteredpb.ClusteredTaskSpec{Replicas: 2, NprocPerNode: 1}
pCtx := dummyPluginCtx(buildTaskTemplate(spec), fakeClient)

handler := clusteredResourceHandler{}
phase, err := handler.GetTaskPhase(context.Background(), pCtx, js)
assert.NoError(t, err)

lc := phase.Info().LogContext
assert.NotNil(t, lc)
// rank-0 is pending and excluded → PrimaryPodName must still reference an
// included pod so downstream log streaming can resolve it.
assert.Len(t, lc.Pods, 1)
assert.Equal(t, rank1, lc.PrimaryPodName)
assert.Equal(t, lc.Pods[0].GetPodName(), lc.PrimaryPodName)
})
}

// --- IsTerminal / GetCompletionTime ---

func TestIsTerminal(t *testing.T) {
Expand Down
56 changes: 56 additions & 0 deletions flyteplugins/go/tasks/plugins/k8s/clustered/logs.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,17 @@ package clustered
import (
"context"
"fmt"
"strings"
"time"

v1 "k8s.io/api/core/v1"
jobsetv1alpha2 "sigs.k8s.io/jobset/api/jobset/v1alpha2"

"github.com/flyteorg/flyte/v2/flyteplugins/go/tasks/logs"
"github.com/flyteorg/flyte/v2/flyteplugins/go/tasks/pluginmachinery/flytek8s"
"github.com/flyteorg/flyte/v2/flyteplugins/go/tasks/pluginmachinery/k8s"
"github.com/flyteorg/flyte/v2/flyteplugins/go/tasks/pluginmachinery/tasklog"
"github.com/flyteorg/flyte/v2/flytestdlib/logger"
"github.com/flyteorg/flyte/v2/gen/go/flyteidl2/core"
)

Expand Down Expand Up @@ -76,3 +80,55 @@ func getTaskLogs(ctx context.Context, pluginContext k8s.PluginContext, jobSet *j
}
return taskLogs, nil
}

// getLogContext builds the structured LogContext from the JobSet's live child pods.
//
// Unlike getTaskLogs (which synthesizes templated URIs from *predicted* pod names and
// requires a pod-log template to be configured in cluster config), this uses the *real*
// pods — actual names (including the Job-assigned random suffix), namespace, primary
// container, and per-container names + process timestamps — so the console can fetch
// logs natively regardless
// of log-template config. Best-effort: returns nil on list error or when no pods are
Comment thread
AdilFayyaz marked this conversation as resolved.
// ready yet, leaving the templated Logs path as the fallback.
func getLogContext(ctx context.Context, pluginContext k8s.PluginContext, jobSet *jobsetv1alpha2.JobSet) *core.LogContext {
// The plugin's K8sReader already scopes List calls to this node execution's
// namespace and execution-id/node-id labels, so no extra filters are needed.
Comment thread
pingsutw marked this conversation as resolved.
podList := &v1.PodList{}
if err := pluginContext.K8sReader().List(ctx, podList); err != nil {
logger.Warnf(ctx, "failed to list pods for JobSet %s/%s log context: %v", jobSet.Namespace, jobSet.Name, err)
Comment thread
AdilFayyaz marked this conversation as resolved.
return nil
}

// rank0PodName returns "<jobset>-workers-0-0"; the real pod carries an additional
// random suffix, so match on prefix to identify the primary (rank-0) pod.
primaryPrefix := rank0PodName(jobSet.Name)
// The authoritative primary container name is stored on the JobSet at build time
// (see build.go). Child pods don't carry the annotations BuildPodLogContext infers
// from, so set it explicitly to avoid resolving to the wrong container (e.g. a sidecar).
primaryContainerName := jobSet.Annotations[primaryContainerAnnotation]
logCtx := &core.LogContext{Pods: make([]*core.PodLogContext, 0, len(podList.Items))}
for i := range podList.Items {
pod := &podList.Items[i]
// Pending pods have no logs yet and no container statuses to build contexts from.
if pod.Status.Phase == v1.PodPending {
continue
}
Comment thread
pingsutw marked this conversation as resolved.
if strings.HasPrefix(pod.Name, primaryPrefix) {
logCtx.PrimaryPodName = pod.Name
}
podLogCtx := flytek8s.BuildPodLogContext(pod)
if primaryContainerName != "" {
podLogCtx.PrimaryContainerName = primaryContainerName
}
logCtx.Pods = append(logCtx.Pods, podLogCtx)
}
if len(logCtx.Pods) == 0 {
return nil
}
// Guarantee PrimaryPodName references a pod in Pods: if rank-0 was pending/absent,
// fall back to the first included pod so downstream log streaming can resolve it.
if logCtx.PrimaryPodName == "" {
logCtx.PrimaryPodName = logCtx.Pods[0].GetPodName()
}
return logCtx
Comment thread
AdilFayyaz marked this conversation as resolved.
}
1 change: 1 addition & 0 deletions flyteplugins/go/tasks/plugins/k8s/clustered/phase.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ func (clusteredResourceHandler) GetTaskPhase(ctx context.Context, pluginContext
}
taskInfo := pluginsCore.TaskInfo{
Logs: taskLogs,
LogContext: getLogContext(ctx, pluginContext, jobSet),
OccurredAt: &occurredAt,
CustomInfo: statusDetails,
}
Expand Down
Loading