Skip to content

Commit 7dc9b56

Browse files
authored
Apply pod template labels and annotations to Spark applications (#7514)
* Propagate annotations and labels * Fix build * Include labels and metadata in TestBuildResourcePodTemplate * Autoformat
1 parent fb49a7f commit 7dc9b56

2 files changed

Lines changed: 21 additions & 14 deletions

File tree

flyteplugins/go/tasks/plugins/k8s/spark/spark.go

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -145,16 +145,19 @@ func serviceAccountName(metadata pluginsCore.TaskExecutionMetadata) string {
145145
func createSparkPodSpec(
146146
taskCtx pluginsCore.TaskExecutionContext,
147147
podSpec *v1.PodSpec,
148+
objectMeta *metav1.ObjectMeta,
148149
container *v1.Container,
149150
k8sPod *core.K8SPod,
150151
) *sparkOp.SparkPodSpec {
151152

152153
annotations := pluginsUtils.UnionMaps(
153154
config.GetK8sPluginConfig().DefaultAnnotations,
155+
objectMeta.GetAnnotations(),
154156
pluginsUtils.CopyMap(taskCtx.TaskExecutionMetadata().GetAnnotations()),
155157
)
156158
labels := pluginsUtils.UnionMaps(
157159
config.GetK8sPluginConfig().DefaultLabels,
160+
objectMeta.GetLabels(),
158161
pluginsUtils.CopyMap(taskCtx.TaskExecutionMetadata().GetLabels()),
159162
)
160163
if k8sPod.GetMetadata().GetAnnotations() != nil {
@@ -195,7 +198,7 @@ type driverSpec struct {
195198
func createDriverSpec(ctx context.Context, taskCtx pluginsCore.TaskExecutionContext, sparkConfig map[string]string, sparkJob *plugins.SparkJob) (*driverSpec, error) {
196199
// Spark driver pods should always run as non-interruptible
197200
nonInterruptibleTaskCtx := flytek8s.NewPluginTaskExecutionContext(taskCtx, flytek8s.WithInterruptible(false))
198-
podSpec, _, primaryContainerName, err := flytek8s.ToK8sPodSpec(ctx, nonInterruptibleTaskCtx)
201+
podSpec, objectMeta, primaryContainerName, err := flytek8s.ToK8sPodSpec(ctx, nonInterruptibleTaskCtx)
199202
if err != nil {
200203
return nil, err
201204
}
@@ -226,8 +229,7 @@ func createDriverSpec(ctx context.Context, taskCtx pluginsCore.TaskExecutionCont
226229
if err != nil {
227230
return nil, err
228231
}
229-
sparkPodSpec := createSparkPodSpec(nonInterruptibleTaskCtx, podSpec, primaryContainer, driverPod)
230-
232+
sparkPodSpec := createSparkPodSpec(nonInterruptibleTaskCtx, podSpec, objectMeta, primaryContainer, driverPod)
231233
spec := driverSpec{
232234
&sparkOp.DriverSpec{
233235
SparkPodSpec: *sparkPodSpec,
@@ -250,7 +252,7 @@ type executorSpec struct {
250252
}
251253

252254
func createExecutorSpec(ctx context.Context, taskCtx pluginsCore.TaskExecutionContext, sparkConfig map[string]string, sparkJob *plugins.SparkJob) (*executorSpec, error) {
253-
podSpec, _, primaryContainerName, err := flytek8s.ToK8sPodSpec(ctx, taskCtx)
255+
podSpec, objectMeta, primaryContainerName, err := flytek8s.ToK8sPodSpec(ctx, taskCtx)
254256
if err != nil {
255257
return nil, err
256258
}
@@ -280,7 +282,7 @@ func createExecutorSpec(ctx context.Context, taskCtx pluginsCore.TaskExecutionCo
280282
if err != nil {
281283
return nil, err
282284
}
283-
sparkPodSpec := createSparkPodSpec(taskCtx, podSpec, primaryContainer, sparkJob.GetExecutorPod())
285+
sparkPodSpec := createSparkPodSpec(taskCtx, podSpec, objectMeta, primaryContainer, sparkJob.GetExecutorPod())
284286
serviceAccountName := serviceAccountName(taskCtx.TaskExecutionMetadata())
285287
spec := executorSpec{
286288
primaryContainer,

flyteplugins/go/tasks/plugins/k8s/spark/spark_test.go

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -467,7 +467,7 @@ func dummySparkTaskTemplateDriverExecutor(id string, sparkConf map[string]string
467467
}
468468
}
469469

470-
func dummySparkTaskTemplatePod(id string, sparkConf map[string]string, podSpec *corev1.PodSpec) *core.TaskTemplate {
470+
func dummySparkTaskTemplatePod(id string, sparkConf map[string]string, podSpec *corev1.PodSpec, podMetadata *core.K8SObjectMetadata) *core.TaskTemplate {
471471
// add driver/executor pod below
472472
sparkJob := dummySparkCustomObj(sparkConf)
473473
sparkJobJSON, err := utils.MarshalToString(sparkJob)
@@ -492,7 +492,8 @@ func dummySparkTaskTemplatePod(id string, sparkConf map[string]string, podSpec *
492492
Type: "k8s_pod",
493493
Target: &core.TaskTemplate_K8SPod{
494494
K8SPod: &core.K8SPod{
495-
PodSpec: podSpecPb,
495+
Metadata: podMetadata,
496+
PodSpec: podSpecPb,
496497
},
497498
},
498499
Config: map[string]string{
@@ -949,8 +950,12 @@ func TestBuildResourcePodTemplate(t *testing.T) {
949950
podSpec := dummyPodSpec()
950951
podSpec.Tolerations = append(podSpec.Tolerations, extraToleration)
951952
podSpec.NodeSelector = map[string]string{"x/custom": "foo"}
952-
taskTemplate := dummySparkTaskTemplatePod("blah-1", dummySparkConf, podSpec)
953-
taskTemplate.GetK8SPod()
953+
podMetadata := &core.K8SObjectMetadata{
954+
Annotations: map[string]string{"annotation-2": "val2"},
955+
Labels: map[string]string{"label-2": "val2"},
956+
}
957+
958+
taskTemplate := dummySparkTaskTemplatePod("blah-1", dummySparkConf, podSpec, podMetadata)
954959
sparkResourceHandler := sparkResourceHandler{}
955960

956961
taskCtx := dummySparkTaskContext(taskTemplate, true, k8s.PluginState{})
@@ -980,8 +985,8 @@ func TestBuildResourcePodTemplate(t *testing.T) {
980985
assert.Equal(t, sparkApplicationFile, *sparkApp.Spec.MainApplicationFile)
981986

982987
// Driver
983-
assert.Equal(t, utils.UnionMaps(defaultConfig.DefaultAnnotations, map[string]string{"annotation-1": "val1"}), sparkApp.Spec.Driver.Annotations)
984-
assert.Equal(t, utils.UnionMaps(defaultConfig.DefaultLabels, map[string]string{"label-1": "val1"}), sparkApp.Spec.Driver.Labels)
988+
assert.Equal(t, utils.UnionMaps(defaultConfig.DefaultAnnotations, map[string]string{"annotation-1": "val1", "annotation-2": "val2"}), sparkApp.Spec.Driver.Annotations)
989+
assert.Equal(t, utils.UnionMaps(defaultConfig.DefaultLabels, map[string]string{"label-1": "val1", "label-2": "val2"}), sparkApp.Spec.Driver.Labels)
985990
assert.Equal(t, len(findEnvVarByName(sparkApp.Spec.Driver.Env, "FLYTE_MAX_ATTEMPTS").Value), 1)
986991
assert.Equal(t, defaultConfig.DefaultEnvVars["foo"], findEnvVarByName(sparkApp.Spec.Driver.Env, "foo").Value)
987992
assert.Equal(t, defaultConfig.DefaultEnvVars["fooEnv"], findEnvVarByName(sparkApp.Spec.Driver.Env, "fooEnv").Value)
@@ -1018,8 +1023,8 @@ func TestBuildResourcePodTemplate(t *testing.T) {
10181023
assert.Equal(t, dummySparkConf["spark.driver.memory"], *sparkApp.Spec.Driver.Memory)
10191024

10201025
// Executor
1021-
assert.Equal(t, utils.UnionMaps(defaultConfig.DefaultAnnotations, map[string]string{"annotation-1": "val1"}), sparkApp.Spec.Executor.Annotations)
1022-
assert.Equal(t, utils.UnionMaps(defaultConfig.DefaultLabels, map[string]string{"label-1": "val1"}), sparkApp.Spec.Executor.Labels)
1026+
assert.Equal(t, utils.UnionMaps(defaultConfig.DefaultAnnotations, map[string]string{"annotation-1": "val1", "annotation-2": "val2"}), sparkApp.Spec.Executor.Annotations)
1027+
assert.Equal(t, utils.UnionMaps(defaultConfig.DefaultLabels, map[string]string{"label-1": "val1", "label-2": "val2"}), sparkApp.Spec.Executor.Labels)
10231028
assert.Equal(t, defaultConfig.DefaultEnvVars["foo"], findEnvVarByName(sparkApp.Spec.Executor.Env, "foo").Value)
10241029
assert.Equal(t, defaultConfig.DefaultEnvVars["fooEnv"], findEnvVarByName(sparkApp.Spec.Executor.Env, "fooEnv").Value)
10251030
assert.Equal(t, findEnvVarByName(dummyEnvVarsWithSecretRef, "SECRET"), findEnvVarByName(sparkApp.Spec.Executor.Env, "SECRET"))
@@ -1065,7 +1070,7 @@ func TestBuildResourcePriorityClassName(t *testing.T) {
10651070
const priorityClassName = "high-priority"
10661071
podSpec := dummyPodSpec()
10671072
podSpec.PriorityClassName = priorityClassName
1068-
taskTemplate := dummySparkTaskTemplatePod("blah-1", dummySparkConf, podSpec)
1073+
taskTemplate := dummySparkTaskTemplatePod("blah-1", dummySparkConf, podSpec, nil)
10691074

10701075
sparkResourceHandler := sparkResourceHandler{}
10711076
taskCtx := dummySparkTaskContext(taskTemplate, true, k8s.PluginState{})

0 commit comments

Comments
 (0)