Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
e72b69c
docs: add ClusterPlugin interface + Ray migration design spec
pingsutw May 29, 2026
e182c14
docs: add ClusterPlugin implementation plan
pingsutw May 29, 2026
b4b455a
feat(k8s): add ClusterPlugin interface for shared-cluster plugins
pingsutw May 29, 2026
8b2c71d
feat(k8s): add ClusterPlugin/ClusterResourceToWatch to PluginEntry
pingsutw May 29, 2026
a75f09e
feat(k8s): validate ClusterPlugin registration shape
pingsutw May 29, 2026
4ed7354
refactor(k8s): generalize addObjectMetadata for reuse by cluster manager
pingsutw May 29, 2026
fbafcf7
feat(k8s): add ClusterPluginManager state machine
pingsutw May 29, 2026
ddd5ddc
feat(executor): route ClusterPlugin entries to ClusterPluginManager
pingsutw May 29, 2026
7b43ccf
test(k8s): cover ClusterPluginManager cluster/job lifecycle
pingsutw May 29, 2026
28b8a08
feat(ray): migrate Ray plugin to ClusterPlugin with shared cluster
pingsutw May 29, 2026
3925429
test(ray): cover ClusterPlugin split (cluster/job/readiness/name)
pingsutw May 29, 2026
1b17caa
fix(k8s): guard against background job deletion in ClusterPluginManager
pingsutw May 29, 2026
b9f8f20
refactor(ray): remove dead submitter pod template helpers
pingsutw May 29, 2026
6d83315
chore(ray): add compile-time ClusterPlugin interface assertion
pingsutw May 29, 2026
4f0f40c
feat(executor): activate ray ClusterPlugin and fix CRD scheme registr…
pingsutw May 29, 2026
5a6ec95
fix(ray): strip run-scoped env from shared cluster pods
pingsutw May 29, 2026
453d878
fix(ray): inject per-job identity into shared-cluster RayJob runtime_env
pingsutw May 29, 2026
ccabe76
docs: remove internal ClusterPlugin planning docs from PR
pingsutw May 30, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
410 changes: 410 additions & 0 deletions executor/pkg/plugin/k8s/cluster_plugin_manager.go

Large diffs are not rendered by default.

253 changes: 253 additions & 0 deletions executor/pkg/plugin/k8s/cluster_plugin_manager_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,253 @@
package k8s

import (
"context"
"testing"
"time"

rayv1 "github.com/ray-project/kuberay/ray-operator/apis/ray/v1"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
k8serrors "k8s.io/apimachinery/pkg/api/errors"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
k8stypes "k8s.io/apimachinery/pkg/types"
"k8s.io/client-go/kubernetes/scheme"
"sigs.k8s.io/controller-runtime/pkg/client"
"sigs.k8s.io/controller-runtime/pkg/client/fake"

pluginsCore "github.com/flyteorg/flyte/v2/flyteplugins/go/tasks/pluginmachinery/core"
coremocks "github.com/flyteorg/flyte/v2/flyteplugins/go/tasks/pluginmachinery/core/mocks"
"github.com/flyteorg/flyte/v2/flyteplugins/go/tasks/pluginmachinery/k8s"
)

// fakeClusterPlugin is a hand-written k8s.ClusterPlugin used to drive the ClusterPluginManager state
// machine. The clusterReady and jobPhase toggles control IsClusterReady and GetJobPhase responses.
type fakeClusterPlugin struct {
clusterReady bool
jobPhase pluginsCore.Phase
}

var _ k8s.ClusterPlugin = &fakeClusterPlugin{}

func (f *fakeClusterPlugin) GetClusterName(_ context.Context, _ pluginsCore.TaskExecutionContext) (string, error) {
return "shared-cluster", nil
}

func rayClusterShell() *rayv1.RayCluster {
return &rayv1.RayCluster{
TypeMeta: metav1.TypeMeta{
Kind: "RayCluster",
APIVersion: rayv1.SchemeGroupVersion.String(),
},
}
}

func rayJobShell() *rayv1.RayJob {
return &rayv1.RayJob{
TypeMeta: metav1.TypeMeta{
Kind: "RayJob",
APIVersion: rayv1.SchemeGroupVersion.String(),
},
}
}

func (f *fakeClusterPlugin) BuildClusterResource(_ context.Context, _ pluginsCore.TaskExecutionContext) (client.Object, error) {
return rayClusterShell(), nil
}

func (f *fakeClusterPlugin) BuildClusterIdentityResource(_ context.Context, _ pluginsCore.TaskExecutionMetadata) (client.Object, error) {
return rayClusterShell(), nil
}

func (f *fakeClusterPlugin) IsClusterReady(_ context.Context, _ k8s.PluginContext, _ client.Object) (bool, error) {
return f.clusterReady, nil
}

func (f *fakeClusterPlugin) BuildJobResource(_ context.Context, _ pluginsCore.TaskExecutionContext, clusterName string) (client.Object, error) {
job := rayJobShell()
job.Spec = rayv1.RayJobSpec{
ClusterSelector: map[string]string{"ray.io/cluster": clusterName},
}
return job, nil
}

func (f *fakeClusterPlugin) BuildJobIdentityResource(_ context.Context, _ pluginsCore.TaskExecutionMetadata) (client.Object, error) {
return rayJobShell(), nil
}

func (f *fakeClusterPlugin) GetJobPhase(_ context.Context, _ k8s.PluginContext, _ client.Object) (pluginsCore.PhaseInfo, error) {
switch f.jobPhase {
case pluginsCore.PhaseSuccess:
return pluginsCore.PhaseInfoSuccess(nil), nil
case pluginsCore.PhaseRunning:
return pluginsCore.PhaseInfoRunning(pluginsCore.DefaultPhaseVersion, nil), nil
default:
return pluginsCore.PhaseInfoQueued(time.Now(), pluginsCore.DefaultPhaseVersion, "queued"), nil
}
}

func (f *fakeClusterPlugin) GetProperties() k8s.PluginProperties {
return k8s.PluginProperties{}
}

// newFakeKubeClient builds a pluginsCore.KubeClient backed by a controller-runtime fake client that is
// seeded with objs and aware of the ray CRD types.
func newFakeKubeClient(t *testing.T, objs ...client.Object) pluginsCore.KubeClient {
require.NoError(t, rayv1.AddToScheme(scheme.Scheme))
c := fake.NewClientBuilder().WithScheme(scheme.Scheme).WithObjects(objs...).Build()
kc := &coremocks.KubeClient{}
kc.EXPECT().GetClient().Return(c).Maybe()
kc.EXPECT().GetCache().Return(nil).Maybe()
return kc
}

// stateCapture records the last plugin state written by the manager.
type stateCapture struct {
state ClusterPluginState
}

// taskContextForState returns a TaskExecutionContext whose PluginStateReader returns the provided
// initial state and whose PluginStateWriter records the written state into the returned stateCapture.
func taskContextForState(t *testing.T, state ClusterPluginState) (*coremocks.TaskExecutionContext, *stateCapture) {
capture := &stateCapture{}

tCtx := &coremocks.TaskExecutionContext{}

reader := &coremocks.PluginStateReader{}
reader.EXPECT().Get(mock.Anything).RunAndReturn(func(v interface{}) (uint8, error) {
if ptr, ok := v.(*ClusterPluginState); ok {
*ptr = state
}
return 0, nil
}).Maybe()
tCtx.EXPECT().PluginStateReader().Return(reader).Maybe()

writer := &coremocks.PluginStateWriter{}
writer.EXPECT().Put(mock.Anything, mock.Anything).RunAndReturn(func(_ uint8, v interface{}) error {
if ptr, ok := v.(*ClusterPluginState); ok {
capture.state = *ptr
}
return nil
}).Maybe()
tCtx.EXPECT().PluginStateWriter().Return(writer).Maybe()

tID := &coremocks.TaskExecutionID{}
tID.EXPECT().GetGeneratedName().Return("job-name").Maybe()

md := &coremocks.TaskExecutionMetadata{}
md.EXPECT().GetNamespace().Return("ns").Maybe()
md.EXPECT().GetTaskExecutionID().Return(tID).Maybe()
md.EXPECT().GetOwnerReference().Return(metav1.OwnerReference{}).Maybe()
md.EXPECT().GetLabels().Return(map[string]string{}).Maybe()
md.EXPECT().GetAnnotations().Return(map[string]string{}).Maybe()
tCtx.EXPECT().TaskExecutionMetadata().Return(md).Maybe()

// DataStore / OutputWriter are only reached on the success path; add Maybe stubs so accidental
// access does not panic.
tCtx.EXPECT().OutputWriter().Return(nil).Maybe()
tCtx.EXPECT().DataStore().Return(nil).Maybe()

return tCtx, capture
}

func TestClusterPluginManager_NotStarted_CreatesClusterNoOwnerRef(t *testing.T) {
ctx := context.Background()
kc := newFakeKubeClient(t)
plugin := &fakeClusterPlugin{}
m := NewClusterPluginManager("test-id", plugin, kc)

tCtx, capture := taskContextForState(t, ClusterPluginState{Phase: ClusterPhaseNotStarted})
transition, err := m.Handle(ctx, tCtx)
require.NoError(t, err)
assert.Equal(t, pluginsCore.PhaseInitializing, transition.Info().Phase())

assert.Equal(t, ClusterPhaseClusterWait, capture.state.Phase)
assert.Equal(t, "shared-cluster", capture.state.ClusterName)

cluster := &rayv1.RayCluster{}
require.NoError(t, kc.GetClient().Get(ctx, k8stypes.NamespacedName{Namespace: "ns", Name: "shared-cluster"}, cluster))
assert.Empty(t, cluster.GetOwnerReferences())
assert.Empty(t, cluster.GetFinalizers())
}

func TestClusterPluginManager_ClusterWait_NotReady_StaysInitializing(t *testing.T) {
ctx := context.Background()
existingCluster := rayClusterShell()
existingCluster.SetNamespace("ns")
existingCluster.SetName("shared-cluster")
kc := newFakeKubeClient(t, existingCluster)

plugin := &fakeClusterPlugin{clusterReady: false}
m := NewClusterPluginManager("test-id", plugin, kc)

tCtx, _ := taskContextForState(t, ClusterPluginState{Phase: ClusterPhaseClusterWait, ClusterName: "shared-cluster"})
transition, err := m.Handle(ctx, tCtx)
require.NoError(t, err)
assert.Equal(t, pluginsCore.PhaseInitializing, transition.Info().Phase())

jobList := &rayv1.RayJobList{}
require.NoError(t, kc.GetClient().List(ctx, jobList))
assert.Empty(t, jobList.Items)
}

func TestClusterPluginManager_ClusterWait_Ready_CreatesJobWithSelector(t *testing.T) {
ctx := context.Background()
existingCluster := rayClusterShell()
existingCluster.SetNamespace("ns")
existingCluster.SetName("shared-cluster")
kc := newFakeKubeClient(t, existingCluster)

plugin := &fakeClusterPlugin{clusterReady: true}
m := NewClusterPluginManager("test-id", plugin, kc)

tCtx, capture := taskContextForState(t, ClusterPluginState{Phase: ClusterPhaseClusterWait, ClusterName: "shared-cluster"})
transition, err := m.Handle(ctx, tCtx)
require.NoError(t, err)
assert.Equal(t, pluginsCore.PhaseQueued, transition.Info().Phase())
assert.Equal(t, ClusterPhaseJobStarted, capture.state.Phase)

job := &rayv1.RayJob{}
require.NoError(t, kc.GetClient().Get(ctx, k8stypes.NamespacedName{Namespace: "ns", Name: "job-name"}, job))
assert.Equal(t, "shared-cluster", job.Spec.ClusterSelector["ray.io/cluster"])
}

func TestClusterPluginManager_JobStarted_MapsRunning(t *testing.T) {
ctx := context.Background()
existingJob := rayJobShell()
existingJob.SetNamespace("ns")
existingJob.SetName("job-name")
kc := newFakeKubeClient(t, existingJob)

plugin := &fakeClusterPlugin{jobPhase: pluginsCore.PhaseRunning}
m := NewClusterPluginManager("test-id", plugin, kc)

tCtx, _ := taskContextForState(t, ClusterPluginState{Phase: ClusterPhaseJobStarted, ClusterName: "shared-cluster"})
transition, err := m.Handle(ctx, tCtx)
require.NoError(t, err)
assert.Equal(t, pluginsCore.PhaseRunning, transition.Info().Phase())
}

func TestClusterPluginManager_Abort_DeletesJobNotCluster(t *testing.T) {
ctx := context.Background()
existingCluster := rayClusterShell()
existingCluster.SetNamespace("ns")
existingCluster.SetName("shared-cluster")
existingJob := rayJobShell()
existingJob.SetNamespace("ns")
existingJob.SetName("job-name")
kc := newFakeKubeClient(t, existingCluster, existingJob)

plugin := &fakeClusterPlugin{}
m := NewClusterPluginManager("test-id", plugin, kc)

tCtx, _ := taskContextForState(t, ClusterPluginState{Phase: ClusterPhaseJobStarted, ClusterName: "shared-cluster"})
require.NoError(t, m.Abort(ctx, tCtx))

job := &rayv1.RayJob{}
err := kc.GetClient().Get(ctx, k8stypes.NamespacedName{Namespace: "ns", Name: "job-name"}, job)
assert.True(t, k8serrors.IsNotFound(err), "expected job to be deleted, got err: %v", err)

cluster := &rayv1.RayCluster{}
require.NoError(t, kc.GetClient().Get(ctx, k8stypes.NamespacedName{Namespace: "ns", Name: "shared-cluster"}, cluster))
}
15 changes: 12 additions & 3 deletions executor/pkg/plugin/k8s/plugin_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,17 +76,26 @@ func (pm *PluginManager) GetProperties() pluginsCore.PluginProperties {
}
}

// addObjectMetadata stamps namespace/labels/annotations, the generated name, and (unless disabled)
// owner references and finalizers, on an object that IS owned by this task execution.
func (pm *PluginManager) addObjectMetadata(taskCtx pluginsCore.TaskExecutionMetadata, o client.Object, cfg *config.K8sPluginConfig) {
pm.addObjectMetadataWithName(taskCtx, o, cfg, taskCtx.GetTaskExecutionID().GetGeneratedName(), true)
}

// addObjectMetadataWithName is the general form. name overrides the object name; when injectOwnership
// is false, owner references and finalizers are never added (used for shared cluster resources that
// must outlive a single task execution).
func (pm *PluginManager) addObjectMetadataWithName(taskCtx pluginsCore.TaskExecutionMetadata, o client.Object, cfg *config.K8sPluginConfig, name string, injectOwnership bool) {
o.SetNamespace(taskCtx.GetNamespace())
o.SetAnnotations(pluginsUtils.UnionMaps(cfg.DefaultAnnotations, o.GetAnnotations(), pluginsUtils.CopyMap(taskCtx.GetAnnotations())))
o.SetLabels(pluginsUtils.UnionMaps(cfg.DefaultLabels, o.GetLabels(), pluginsUtils.CopyMap(taskCtx.GetLabels())))
o.SetName(taskCtx.GetTaskExecutionID().GetGeneratedName())
o.SetName(name)

if !pm.plugin.GetProperties().DisableInjectOwnerReferences && !cfg.DisableInjectOwnerReferences {
if injectOwnership && !pm.plugin.GetProperties().DisableInjectOwnerReferences && !cfg.DisableInjectOwnerReferences {
o.SetOwnerReferences([]metav1.OwnerReference{taskCtx.GetOwnerReference()})
}

if cfg.InjectFinalizer && !pm.plugin.GetProperties().DisableInjectFinalizer {
if injectOwnership && cfg.InjectFinalizer && !pm.plugin.GetProperties().DisableInjectFinalizer {
f := append(o.GetFinalizers(), "flyte/flytek8s")
o.SetFinalizers(f)
}
Expand Down
28 changes: 19 additions & 9 deletions executor/pkg/plugin/registry.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,29 +53,39 @@ func (r *Registry) Initialize(ctx context.Context) error {

// Load k8s plugins
for _, entry := range r.pluginRegistry.GetK8sPlugins() {
pm := executorK8s.NewPluginManager(
entry.ID,
entry.Plugin,
r.setupCtx.KubeClient(),
)
if err := pm.InitializeObjectEventWatcher(ctx); err != nil {
return fmt.Errorf("failed to initialize k8s object event watcher for plugin %s: %w", entry.ID, err)
var plugin pluginsCore.Plugin
if entry.ClusterPlugin != nil {
plugin = executorK8s.NewClusterPluginManager(
entry.ID,
entry.ClusterPlugin,
r.setupCtx.KubeClient(),
)
} else {
pm := executorK8s.NewPluginManager(
entry.ID,
entry.Plugin,
r.setupCtx.KubeClient(),
)
if err := pm.InitializeObjectEventWatcher(ctx); err != nil {
return fmt.Errorf("failed to initialize k8s object event watcher for plugin %s: %w", entry.ID, err)
}
plugin = pm
}

for _, taskType := range entry.RegisteredTaskTypes {
if existing, ok := r.plugins[taskType]; ok {
logger.Warnf(ctx, "Task type %q already registered by plugin %q, overwriting with %q",
taskType, existing.GetID(), entry.ID)
}
r.plugins[taskType] = pm
r.plugins[taskType] = plugin
}

if entry.IsDefault {
if r.defaultPlugin != nil {
logger.Warnf(ctx, "Multiple default plugins found, overwriting %q with %q",
r.defaultPlugin.GetID(), entry.ID)
}
r.defaultPlugin = pm
r.defaultPlugin = plugin
}

logger.Infof(ctx, "Registered k8s plugin [%s] for task types %v", entry.ID, entry.RegisteredTaskTypes)
Expand Down
6 changes: 6 additions & 0 deletions executor/setup.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,12 @@ import (
"net/http"
"os"

rayv1 "github.com/ray-project/kuberay/ray-operator/apis/ray/v1"
"k8s.io/apimachinery/pkg/runtime"
utilruntime "k8s.io/apimachinery/pkg/util/runtime"
"k8s.io/client-go/kubernetes"
clientgoscheme "k8s.io/client-go/kubernetes/scheme"
jobsetv1alpha2 "sigs.k8s.io/jobset/api/jobset/v1alpha2"
ctrl "sigs.k8s.io/controller-runtime"
"sigs.k8s.io/controller-runtime/pkg/healthz"
"sigs.k8s.io/controller-runtime/pkg/log/zap"
Expand All @@ -36,13 +38,17 @@ import (
// plugins with the global registry.
_ "github.com/flyteorg/flyte/v2/flyteplugins/go/tasks/plugins/k8s/clustered"
_ "github.com/flyteorg/flyte/v2/flyteplugins/go/tasks/plugins/k8s/pod"
_ "github.com/flyteorg/flyte/v2/flyteplugins/go/tasks/plugins/k8s/ray"
)

var scheme = runtime.NewScheme()

func init() {
utilruntime.Must(clientgoscheme.AddToScheme(scheme))
utilruntime.Must(flyteorgv1.AddToScheme(scheme))
// Register CRD types for plugins that use CRDs (must match the plugin imports below).
utilruntime.Must(rayv1.AddToScheme(scheme))
utilruntime.Must(jobsetv1alpha2.AddToScheme(scheme))
}

// Scheme returns the runtime.Scheme with executor CRDs registered.
Expand Down
Loading
Loading