@@ -467,7 +467,7 @@ func dummySparkTaskTemplateDriverExecutor(id string, sparkConf map[string]string
467467 }
468468}
469469
470- func dummySparkTaskTemplatePod (id string , sparkConf map [string ]string , podSpec * corev1.PodSpec ) * core.TaskTemplate {
470+ func dummySparkTaskTemplatePod (id string , sparkConf map [string ]string , podSpec * corev1.PodSpec , podMetadata * core. K8SObjectMetadata ) * core.TaskTemplate {
471471 // add driver/executor pod below
472472 sparkJob := dummySparkCustomObj (sparkConf )
473473 sparkJobJSON , err := utils .MarshalToString (sparkJob )
@@ -492,7 +492,8 @@ func dummySparkTaskTemplatePod(id string, sparkConf map[string]string, podSpec *
492492 Type : "k8s_pod" ,
493493 Target : & core.TaskTemplate_K8SPod {
494494 K8SPod : & core.K8SPod {
495- PodSpec : podSpecPb ,
495+ Metadata : podMetadata ,
496+ PodSpec : podSpecPb ,
496497 },
497498 },
498499 Config : map [string ]string {
@@ -949,8 +950,12 @@ func TestBuildResourcePodTemplate(t *testing.T) {
949950 podSpec := dummyPodSpec ()
950951 podSpec .Tolerations = append (podSpec .Tolerations , extraToleration )
951952 podSpec .NodeSelector = map [string ]string {"x/custom" : "foo" }
952- taskTemplate := dummySparkTaskTemplatePod ("blah-1" , dummySparkConf , podSpec )
953- taskTemplate .GetK8SPod ()
953+ podMetadata := & core.K8SObjectMetadata {
954+ Annotations : map [string ]string {"annotation-2" : "val2" },
955+ Labels : map [string ]string {"label-2" : "val2" },
956+ }
957+
958+ taskTemplate := dummySparkTaskTemplatePod ("blah-1" , dummySparkConf , podSpec , podMetadata )
954959 sparkResourceHandler := sparkResourceHandler {}
955960
956961 taskCtx := dummySparkTaskContext (taskTemplate , true , k8s.PluginState {})
@@ -980,8 +985,8 @@ func TestBuildResourcePodTemplate(t *testing.T) {
980985 assert .Equal (t , sparkApplicationFile , * sparkApp .Spec .MainApplicationFile )
981986
982987 // Driver
983- assert .Equal (t , utils .UnionMaps (defaultConfig .DefaultAnnotations , map [string ]string {"annotation-1" : "val1" }), sparkApp .Spec .Driver .Annotations )
984- assert .Equal (t , utils .UnionMaps (defaultConfig .DefaultLabels , map [string ]string {"label-1" : "val1" }), sparkApp .Spec .Driver .Labels )
988+ assert .Equal (t , utils .UnionMaps (defaultConfig .DefaultAnnotations , map [string ]string {"annotation-1" : "val1" , "annotation-2" : "val2" }), sparkApp .Spec .Driver .Annotations )
989+ assert .Equal (t , utils .UnionMaps (defaultConfig .DefaultLabels , map [string ]string {"label-1" : "val1" , "label-2" : "val2" }), sparkApp .Spec .Driver .Labels )
985990 assert .Equal (t , len (findEnvVarByName (sparkApp .Spec .Driver .Env , "FLYTE_MAX_ATTEMPTS" ).Value ), 1 )
986991 assert .Equal (t , defaultConfig .DefaultEnvVars ["foo" ], findEnvVarByName (sparkApp .Spec .Driver .Env , "foo" ).Value )
987992 assert .Equal (t , defaultConfig .DefaultEnvVars ["fooEnv" ], findEnvVarByName (sparkApp .Spec .Driver .Env , "fooEnv" ).Value )
@@ -1018,8 +1023,8 @@ func TestBuildResourcePodTemplate(t *testing.T) {
10181023 assert .Equal (t , dummySparkConf ["spark.driver.memory" ], * sparkApp .Spec .Driver .Memory )
10191024
10201025 // Executor
1021- assert .Equal (t , utils .UnionMaps (defaultConfig .DefaultAnnotations , map [string ]string {"annotation-1" : "val1" }), sparkApp .Spec .Executor .Annotations )
1022- assert .Equal (t , utils .UnionMaps (defaultConfig .DefaultLabels , map [string ]string {"label-1" : "val1" }), sparkApp .Spec .Executor .Labels )
1026+ assert .Equal (t , utils .UnionMaps (defaultConfig .DefaultAnnotations , map [string ]string {"annotation-1" : "val1" , "annotation-2" : "val2" }), sparkApp .Spec .Executor .Annotations )
1027+ assert .Equal (t , utils .UnionMaps (defaultConfig .DefaultLabels , map [string ]string {"label-1" : "val1" , "label-2" : "val2" }), sparkApp .Spec .Executor .Labels )
10231028 assert .Equal (t , defaultConfig .DefaultEnvVars ["foo" ], findEnvVarByName (sparkApp .Spec .Executor .Env , "foo" ).Value )
10241029 assert .Equal (t , defaultConfig .DefaultEnvVars ["fooEnv" ], findEnvVarByName (sparkApp .Spec .Executor .Env , "fooEnv" ).Value )
10251030 assert .Equal (t , findEnvVarByName (dummyEnvVarsWithSecretRef , "SECRET" ), findEnvVarByName (sparkApp .Spec .Executor .Env , "SECRET" ))
@@ -1065,7 +1070,7 @@ func TestBuildResourcePriorityClassName(t *testing.T) {
10651070 const priorityClassName = "high-priority"
10661071 podSpec := dummyPodSpec ()
10671072 podSpec .PriorityClassName = priorityClassName
1068- taskTemplate := dummySparkTaskTemplatePod ("blah-1" , dummySparkConf , podSpec )
1073+ taskTemplate := dummySparkTaskTemplatePod ("blah-1" , dummySparkConf , podSpec , nil )
10691074
10701075 sparkResourceHandler := sparkResourceHandler {}
10711076 taskCtx := dummySparkTaskContext (taskTemplate , true , k8s.PluginState {})
0 commit comments