diff --git a/flyteplugins/go/tasks/plugins/k8s/spark/spark.go b/flyteplugins/go/tasks/plugins/k8s/spark/spark.go index 6bb746c576..d98f689552 100644 --- a/flyteplugins/go/tasks/plugins/k8s/spark/spark.go +++ b/flyteplugins/go/tasks/plugins/k8s/spark/spark.go @@ -145,16 +145,19 @@ func serviceAccountName(metadata pluginsCore.TaskExecutionMetadata) string { func createSparkPodSpec( taskCtx pluginsCore.TaskExecutionContext, podSpec *v1.PodSpec, + objectMeta *metav1.ObjectMeta, container *v1.Container, k8sPod *core.K8SPod, ) *sparkOp.SparkPodSpec { annotations := pluginsUtils.UnionMaps( config.GetK8sPluginConfig().DefaultAnnotations, + objectMeta.GetAnnotations(), pluginsUtils.CopyMap(taskCtx.TaskExecutionMetadata().GetAnnotations()), ) labels := pluginsUtils.UnionMaps( config.GetK8sPluginConfig().DefaultLabels, + objectMeta.GetLabels(), pluginsUtils.CopyMap(taskCtx.TaskExecutionMetadata().GetLabels()), ) if k8sPod.GetMetadata().GetAnnotations() != nil { @@ -195,7 +198,7 @@ type driverSpec struct { func createDriverSpec(ctx context.Context, taskCtx pluginsCore.TaskExecutionContext, sparkConfig map[string]string, sparkJob *plugins.SparkJob) (*driverSpec, error) { // Spark driver pods should always run as non-interruptible nonInterruptibleTaskCtx := flytek8s.NewPluginTaskExecutionContext(taskCtx, flytek8s.WithInterruptible(false)) - podSpec, _, primaryContainerName, err := flytek8s.ToK8sPodSpec(ctx, nonInterruptibleTaskCtx) + podSpec, objectMeta, primaryContainerName, err := flytek8s.ToK8sPodSpec(ctx, nonInterruptibleTaskCtx) if err != nil { return nil, err } @@ -226,8 +229,7 @@ func createDriverSpec(ctx context.Context, taskCtx pluginsCore.TaskExecutionCont if err != nil { return nil, err } - sparkPodSpec := createSparkPodSpec(nonInterruptibleTaskCtx, podSpec, primaryContainer, driverPod) - + sparkPodSpec := createSparkPodSpec(nonInterruptibleTaskCtx, podSpec, objectMeta, primaryContainer, driverPod) spec := driverSpec{ &sparkOp.DriverSpec{ SparkPodSpec: *sparkPodSpec, @@ -250,7 +252,7 @@ type executorSpec struct { } func createExecutorSpec(ctx context.Context, taskCtx pluginsCore.TaskExecutionContext, sparkConfig map[string]string, sparkJob *plugins.SparkJob) (*executorSpec, error) { - podSpec, _, primaryContainerName, err := flytek8s.ToK8sPodSpec(ctx, taskCtx) + podSpec, objectMeta, primaryContainerName, err := flytek8s.ToK8sPodSpec(ctx, taskCtx) if err != nil { return nil, err } @@ -280,7 +282,7 @@ func createExecutorSpec(ctx context.Context, taskCtx pluginsCore.TaskExecutionCo if err != nil { return nil, err } - sparkPodSpec := createSparkPodSpec(taskCtx, podSpec, primaryContainer, sparkJob.GetExecutorPod()) + sparkPodSpec := createSparkPodSpec(taskCtx, podSpec, objectMeta, primaryContainer, sparkJob.GetExecutorPod()) serviceAccountName := serviceAccountName(taskCtx.TaskExecutionMetadata()) spec := executorSpec{ primaryContainer, diff --git a/flyteplugins/go/tasks/plugins/k8s/spark/spark_test.go b/flyteplugins/go/tasks/plugins/k8s/spark/spark_test.go index 886396be09..678b5def88 100644 --- a/flyteplugins/go/tasks/plugins/k8s/spark/spark_test.go +++ b/flyteplugins/go/tasks/plugins/k8s/spark/spark_test.go @@ -467,7 +467,7 @@ func dummySparkTaskTemplateDriverExecutor(id string, sparkConf map[string]string } } -func dummySparkTaskTemplatePod(id string, sparkConf map[string]string, podSpec *corev1.PodSpec) *core.TaskTemplate { +func dummySparkTaskTemplatePod(id string, sparkConf map[string]string, podSpec *corev1.PodSpec, podMetadata *core.K8SObjectMetadata) *core.TaskTemplate { // add driver/executor pod below sparkJob := dummySparkCustomObj(sparkConf) sparkJobJSON, err := utils.MarshalToString(sparkJob) @@ -492,7 +492,8 @@ func dummySparkTaskTemplatePod(id string, sparkConf map[string]string, podSpec * Type: "k8s_pod", Target: &core.TaskTemplate_K8SPod{ K8SPod: &core.K8SPod{ - PodSpec: podSpecPb, + Metadata: podMetadata, + PodSpec: podSpecPb, }, }, Config: map[string]string{ @@ -949,8 +950,12 @@ func TestBuildResourcePodTemplate(t *testing.T) { podSpec := dummyPodSpec() podSpec.Tolerations = append(podSpec.Tolerations, extraToleration) podSpec.NodeSelector = map[string]string{"x/custom": "foo"} - taskTemplate := dummySparkTaskTemplatePod("blah-1", dummySparkConf, podSpec) - taskTemplate.GetK8SPod() + podMetadata := &core.K8SObjectMetadata{ + Annotations: map[string]string{"annotation-2": "val2"}, + Labels: map[string]string{"label-2": "val2"}, + } + + taskTemplate := dummySparkTaskTemplatePod("blah-1", dummySparkConf, podSpec, podMetadata) sparkResourceHandler := sparkResourceHandler{} taskCtx := dummySparkTaskContext(taskTemplate, true, k8s.PluginState{}) @@ -980,8 +985,8 @@ func TestBuildResourcePodTemplate(t *testing.T) { assert.Equal(t, sparkApplicationFile, *sparkApp.Spec.MainApplicationFile) // Driver - assert.Equal(t, utils.UnionMaps(defaultConfig.DefaultAnnotations, map[string]string{"annotation-1": "val1"}), sparkApp.Spec.Driver.Annotations) - assert.Equal(t, utils.UnionMaps(defaultConfig.DefaultLabels, map[string]string{"label-1": "val1"}), sparkApp.Spec.Driver.Labels) + assert.Equal(t, utils.UnionMaps(defaultConfig.DefaultAnnotations, map[string]string{"annotation-1": "val1", "annotation-2": "val2"}), sparkApp.Spec.Driver.Annotations) + assert.Equal(t, utils.UnionMaps(defaultConfig.DefaultLabels, map[string]string{"label-1": "val1", "label-2": "val2"}), sparkApp.Spec.Driver.Labels) assert.Equal(t, len(findEnvVarByName(sparkApp.Spec.Driver.Env, "FLYTE_MAX_ATTEMPTS").Value), 1) assert.Equal(t, defaultConfig.DefaultEnvVars["foo"], findEnvVarByName(sparkApp.Spec.Driver.Env, "foo").Value) assert.Equal(t, defaultConfig.DefaultEnvVars["fooEnv"], findEnvVarByName(sparkApp.Spec.Driver.Env, "fooEnv").Value) @@ -1018,8 +1023,8 @@ func TestBuildResourcePodTemplate(t *testing.T) { assert.Equal(t, dummySparkConf["spark.driver.memory"], *sparkApp.Spec.Driver.Memory) // Executor - assert.Equal(t, utils.UnionMaps(defaultConfig.DefaultAnnotations, map[string]string{"annotation-1": "val1"}), sparkApp.Spec.Executor.Annotations) - assert.Equal(t, utils.UnionMaps(defaultConfig.DefaultLabels, map[string]string{"label-1": "val1"}), sparkApp.Spec.Executor.Labels) + assert.Equal(t, utils.UnionMaps(defaultConfig.DefaultAnnotations, map[string]string{"annotation-1": "val1", "annotation-2": "val2"}), sparkApp.Spec.Executor.Annotations) + assert.Equal(t, utils.UnionMaps(defaultConfig.DefaultLabels, map[string]string{"label-1": "val1", "label-2": "val2"}), sparkApp.Spec.Executor.Labels) assert.Equal(t, defaultConfig.DefaultEnvVars["foo"], findEnvVarByName(sparkApp.Spec.Executor.Env, "foo").Value) assert.Equal(t, defaultConfig.DefaultEnvVars["fooEnv"], findEnvVarByName(sparkApp.Spec.Executor.Env, "fooEnv").Value) assert.Equal(t, findEnvVarByName(dummyEnvVarsWithSecretRef, "SECRET"), findEnvVarByName(sparkApp.Spec.Executor.Env, "SECRET")) @@ -1065,7 +1070,7 @@ func TestBuildResourcePriorityClassName(t *testing.T) { const priorityClassName = "high-priority" podSpec := dummyPodSpec() podSpec.PriorityClassName = priorityClassName - taskTemplate := dummySparkTaskTemplatePod("blah-1", dummySparkConf, podSpec) + taskTemplate := dummySparkTaskTemplatePod("blah-1", dummySparkConf, podSpec, nil) sparkResourceHandler := sparkResourceHandler{} taskCtx := dummySparkTaskContext(taskTemplate, true, k8s.PluginState{})