Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 7 additions & 5 deletions flyteplugins/go/tasks/plugins/k8s/spark/spark.go
Original file line number Diff line number Diff line change
Expand Up @@ -145,16 +145,19 @@ func serviceAccountName(metadata pluginsCore.TaskExecutionMetadata) string {
func createSparkPodSpec(
taskCtx pluginsCore.TaskExecutionContext,
podSpec *v1.PodSpec,
objectMeta *metav1.ObjectMeta,
container *v1.Container,
k8sPod *core.K8SPod,
) *sparkOp.SparkPodSpec {

annotations := pluginsUtils.UnionMaps(
config.GetK8sPluginConfig().DefaultAnnotations,
objectMeta.GetAnnotations(),
pluginsUtils.CopyMap(taskCtx.TaskExecutionMetadata().GetAnnotations()),
)
labels := pluginsUtils.UnionMaps(
config.GetK8sPluginConfig().DefaultLabels,
objectMeta.GetLabels(),
pluginsUtils.CopyMap(taskCtx.TaskExecutionMetadata().GetLabels()),
)
if k8sPod.GetMetadata().GetAnnotations() != nil {
Expand Down Expand Up @@ -195,7 +198,7 @@ type driverSpec struct {
func createDriverSpec(ctx context.Context, taskCtx pluginsCore.TaskExecutionContext, sparkConfig map[string]string, sparkJob *plugins.SparkJob) (*driverSpec, error) {
// Spark driver pods should always run as non-interruptible
nonInterruptibleTaskCtx := flytek8s.NewPluginTaskExecutionContext(taskCtx, flytek8s.WithInterruptible(false))
podSpec, _, primaryContainerName, err := flytek8s.ToK8sPodSpec(ctx, nonInterruptibleTaskCtx)
podSpec, objectMeta, primaryContainerName, err := flytek8s.ToK8sPodSpec(ctx, nonInterruptibleTaskCtx)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -226,8 +229,7 @@ func createDriverSpec(ctx context.Context, taskCtx pluginsCore.TaskExecutionCont
if err != nil {
return nil, err
}
sparkPodSpec := createSparkPodSpec(nonInterruptibleTaskCtx, podSpec, primaryContainer, driverPod)

sparkPodSpec := createSparkPodSpec(nonInterruptibleTaskCtx, podSpec, objectMeta, primaryContainer, driverPod)
spec := driverSpec{
&sparkOp.DriverSpec{
SparkPodSpec: *sparkPodSpec,
Expand All @@ -250,7 +252,7 @@ type executorSpec struct {
}

func createExecutorSpec(ctx context.Context, taskCtx pluginsCore.TaskExecutionContext, sparkConfig map[string]string, sparkJob *plugins.SparkJob) (*executorSpec, error) {
podSpec, _, primaryContainerName, err := flytek8s.ToK8sPodSpec(ctx, taskCtx)
podSpec, objectMeta, primaryContainerName, err := flytek8s.ToK8sPodSpec(ctx, taskCtx)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -280,7 +282,7 @@ func createExecutorSpec(ctx context.Context, taskCtx pluginsCore.TaskExecutionCo
if err != nil {
return nil, err
}
sparkPodSpec := createSparkPodSpec(taskCtx, podSpec, primaryContainer, sparkJob.GetExecutorPod())
sparkPodSpec := createSparkPodSpec(taskCtx, podSpec, objectMeta, primaryContainer, sparkJob.GetExecutorPod())
serviceAccountName := serviceAccountName(taskCtx.TaskExecutionMetadata())
spec := executorSpec{
primaryContainer,
Expand Down
23 changes: 14 additions & 9 deletions flyteplugins/go/tasks/plugins/k8s/spark/spark_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -467,7 +467,7 @@ func dummySparkTaskTemplateDriverExecutor(id string, sparkConf map[string]string
}
}

func dummySparkTaskTemplatePod(id string, sparkConf map[string]string, podSpec *corev1.PodSpec) *core.TaskTemplate {
func dummySparkTaskTemplatePod(id string, sparkConf map[string]string, podSpec *corev1.PodSpec, podMetadata *core.K8SObjectMetadata) *core.TaskTemplate {
// add driver/executor pod below
sparkJob := dummySparkCustomObj(sparkConf)
sparkJobJSON, err := utils.MarshalToString(sparkJob)
Expand All @@ -492,7 +492,8 @@ func dummySparkTaskTemplatePod(id string, sparkConf map[string]string, podSpec *
Type: "k8s_pod",
Target: &core.TaskTemplate_K8SPod{
K8SPod: &core.K8SPod{
PodSpec: podSpecPb,
Metadata: podMetadata,
PodSpec: podSpecPb,
},
},
Config: map[string]string{
Expand Down Expand Up @@ -949,8 +950,12 @@ func TestBuildResourcePodTemplate(t *testing.T) {
podSpec := dummyPodSpec()
podSpec.Tolerations = append(podSpec.Tolerations, extraToleration)
podSpec.NodeSelector = map[string]string{"x/custom": "foo"}
taskTemplate := dummySparkTaskTemplatePod("blah-1", dummySparkConf, podSpec)
taskTemplate.GetK8SPod()
podMetadata := &core.K8SObjectMetadata{
Annotations: map[string]string{"annotation-2": "val2"},
Labels: map[string]string{"label-2": "val2"},
}

taskTemplate := dummySparkTaskTemplatePod("blah-1", dummySparkConf, podSpec, podMetadata)
sparkResourceHandler := sparkResourceHandler{}

taskCtx := dummySparkTaskContext(taskTemplate, true, k8s.PluginState{})
Expand Down Expand Up @@ -980,8 +985,8 @@ func TestBuildResourcePodTemplate(t *testing.T) {
assert.Equal(t, sparkApplicationFile, *sparkApp.Spec.MainApplicationFile)

// Driver
assert.Equal(t, utils.UnionMaps(defaultConfig.DefaultAnnotations, map[string]string{"annotation-1": "val1"}), sparkApp.Spec.Driver.Annotations)
assert.Equal(t, utils.UnionMaps(defaultConfig.DefaultLabels, map[string]string{"label-1": "val1"}), sparkApp.Spec.Driver.Labels)
assert.Equal(t, utils.UnionMaps(defaultConfig.DefaultAnnotations, map[string]string{"annotation-1": "val1", "annotation-2": "val2"}), sparkApp.Spec.Driver.Annotations)
assert.Equal(t, utils.UnionMaps(defaultConfig.DefaultLabels, map[string]string{"label-1": "val1", "label-2": "val2"}), sparkApp.Spec.Driver.Labels)
assert.Equal(t, len(findEnvVarByName(sparkApp.Spec.Driver.Env, "FLYTE_MAX_ATTEMPTS").Value), 1)
assert.Equal(t, defaultConfig.DefaultEnvVars["foo"], findEnvVarByName(sparkApp.Spec.Driver.Env, "foo").Value)
assert.Equal(t, defaultConfig.DefaultEnvVars["fooEnv"], findEnvVarByName(sparkApp.Spec.Driver.Env, "fooEnv").Value)
Expand Down Expand Up @@ -1018,8 +1023,8 @@ func TestBuildResourcePodTemplate(t *testing.T) {
assert.Equal(t, dummySparkConf["spark.driver.memory"], *sparkApp.Spec.Driver.Memory)

// Executor
assert.Equal(t, utils.UnionMaps(defaultConfig.DefaultAnnotations, map[string]string{"annotation-1": "val1"}), sparkApp.Spec.Executor.Annotations)
assert.Equal(t, utils.UnionMaps(defaultConfig.DefaultLabels, map[string]string{"label-1": "val1"}), sparkApp.Spec.Executor.Labels)
assert.Equal(t, utils.UnionMaps(defaultConfig.DefaultAnnotations, map[string]string{"annotation-1": "val1", "annotation-2": "val2"}), sparkApp.Spec.Executor.Annotations)
assert.Equal(t, utils.UnionMaps(defaultConfig.DefaultLabels, map[string]string{"label-1": "val1", "label-2": "val2"}), sparkApp.Spec.Executor.Labels)
assert.Equal(t, defaultConfig.DefaultEnvVars["foo"], findEnvVarByName(sparkApp.Spec.Executor.Env, "foo").Value)
assert.Equal(t, defaultConfig.DefaultEnvVars["fooEnv"], findEnvVarByName(sparkApp.Spec.Executor.Env, "fooEnv").Value)
assert.Equal(t, findEnvVarByName(dummyEnvVarsWithSecretRef, "SECRET"), findEnvVarByName(sparkApp.Spec.Executor.Env, "SECRET"))
Expand Down Expand Up @@ -1065,7 +1070,7 @@ func TestBuildResourcePriorityClassName(t *testing.T) {
const priorityClassName = "high-priority"
podSpec := dummyPodSpec()
podSpec.PriorityClassName = priorityClassName
taskTemplate := dummySparkTaskTemplatePod("blah-1", dummySparkConf, podSpec)
taskTemplate := dummySparkTaskTemplatePod("blah-1", dummySparkConf, podSpec, nil)

sparkResourceHandler := sparkResourceHandler{}
taskCtx := dummySparkTaskContext(taskTemplate, true, k8s.PluginState{})
Expand Down
Loading