Skip to content

Commit acae606

Browse files
committed
Support pod template
Signed-off-by: Yi Chen <[email protected]>
1 parent 0a9c591 commit acae606

File tree

8 files changed

+209
-26
lines changed

8 files changed

+209
-26
lines changed

go.mod

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ require (
1919
github.com/stretchr/testify v1.9.0
2020
go.uber.org/zap v1.27.0
2121
gocloud.dev v0.40.0
22+
golang.org/x/mod v0.20.0
2223
golang.org/x/net v0.30.0
2324
golang.org/x/time v0.7.0
2425
helm.sh/helm/v3 v3.16.2
@@ -30,6 +31,7 @@ require (
3031
k8s.io/utils v0.0.0-20240711033017-18e509b52bc8
3132
sigs.k8s.io/controller-runtime v0.17.5
3233
sigs.k8s.io/scheduler-plugins v0.29.8
34+
sigs.k8s.io/yaml v1.4.0
3335
volcano.sh/apis v1.9.0
3436
)
3537

@@ -229,7 +231,6 @@ require (
229231
sigs.k8s.io/kustomize/api v0.17.2 // indirect
230232
sigs.k8s.io/kustomize/kyaml v0.17.1 // indirect
231233
sigs.k8s.io/structured-merge-diff/v4 v4.4.1 // indirect
232-
sigs.k8s.io/yaml v1.4.0 // indirect
233234
)
234235

235236
replace (

internal/controller/sparkapplication/controller.go

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ package sparkapplication
1919
import (
2020
"context"
2121
"fmt"
22+
"os"
2223
"strconv"
2324
"time"
2425

@@ -718,6 +719,9 @@ func (r *Reconciler) submitSparkApplication(app *v1beta2.SparkApplication) error
718719
// Try submitting the application by running spark-submit.
719720
logger.Info("Running spark-submit for SparkApplication", "name", app.Name, "namespace", app.Namespace, "arguments", sparkSubmitArgs)
720721
submitted, err := runSparkSubmit(newSubmission(sparkSubmitArgs, app))
722+
if err := r.cleanUpPodTemplateFiles(app); err != nil {
723+
return fmt.Errorf("failed to clean up pod template files: %v", err)
724+
}
721725
if err != nil {
722726
r.recordSparkApplicationEvent(app)
723727
return fmt.Errorf("failed to run spark-submit: %v", err)
@@ -1228,3 +1232,18 @@ func (r *Reconciler) cleanUpOnTermination(_, newApp *v1beta2.SparkApplication) e
12281232
}
12291233
return nil
12301234
}
1235+
1236+
// cleanUpPodTemplateFiles cleans up the driver and executor pod template files.
1237+
func (r *Reconciler) cleanUpPodTemplateFiles(app *v1beta2.SparkApplication) error {
1238+
if app.Spec.Driver.Template == nil && app.Spec.Executor.Template == nil {
1239+
return nil
1240+
}
1241+
path := fmt.Sprintf("/tmp/spark/%s", app.Status.SubmissionID)
1242+
if err := os.RemoveAll(path); err != nil {
1243+
if !os.IsNotExist(err) {
1244+
return err
1245+
}
1246+
}
1247+
logger.V(1).Info("Deleted pod template files", "path", path)
1248+
return nil
1249+
}

internal/controller/sparkapplication/submission.go

Lines changed: 58 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -83,15 +83,17 @@ func buildSparkSubmitArgs(app *v1beta2.SparkApplication) ([]string, error) {
8383
submissionWaitAppCompletionOption,
8484
sparkConfOption,
8585
hadoopConfOption,
86+
driverPodTemplateOption,
8687
driverPodNameOption,
8788
driverConfOption,
88-
driverSecretOption,
8989
driverEnvOption,
90+
driverSecretOption,
9091
driverVolumeMountsOption,
92+
executorPodTemplateOption,
9193
executorConfOption,
94+
executorEnvOption,
9295
executorSecretOption,
9396
executorVolumeMountsOption,
94-
executorEnvOption,
9597
nodeSelectorOption,
9698
dynamicAllocationOption,
9799
proxyUserOption,
@@ -303,6 +305,12 @@ func driverConfOption(app *v1beta2.SparkApplication) ([]string, error) {
303305
property = fmt.Sprintf(common.SparkKubernetesDriverLabelTemplate, common.LabelLaunchedBySparkOperator)
304306
args = append(args, "--conf", fmt.Sprintf("%s=%s", property, "true"))
305307

308+
// If Spark version is less than 3.0.0 or driver pod template is not defined, then the driver pod needs to be mutated by the webhook.
309+
if util.CompareSemanticVersion(app.Spec.SparkVersion, "3.0.0") < 0 || app.Spec.Driver.Template == nil {
310+
property = fmt.Sprintf(common.SparkKubernetesDriverLabelTemplate, common.LabelMutatedBySparkOperator)
311+
args = append(args, "--conf", fmt.Sprintf("%s=%s", property, "true"))
312+
}
313+
306314
property = fmt.Sprintf(common.SparkKubernetesDriverLabelTemplate, common.LabelSubmissionID)
307315
args = append(args, "--conf", fmt.Sprintf("%s=%s", property, app.Status.SubmissionID))
308316

@@ -646,6 +654,12 @@ func executorConfOption(app *v1beta2.SparkApplication) ([]string, error) {
646654
property = fmt.Sprintf(common.SparkKubernetesExecutorLabelTemplate, common.LabelLaunchedBySparkOperator)
647655
args = append(args, "--conf", fmt.Sprintf("%s=%s", property, "true"))
648656

657+
// If Spark version is less than 3.0.0 or executor pod template is not defined, then the executor pods need to be mutated by the webhook.
658+
if util.CompareSemanticVersion(app.Spec.SparkVersion, "3.0.0") < 0 || app.Spec.Executor.Template == nil {
659+
property = fmt.Sprintf(common.SparkKubernetesExecutorLabelTemplate, common.LabelMutatedBySparkOperator)
660+
args = append(args, "--conf", fmt.Sprintf("%s=%s", property, "true"))
661+
}
662+
649663
property = fmt.Sprintf(common.SparkKubernetesExecutorLabelTemplate, common.LabelSubmissionID)
650664
args = append(args, "--conf", fmt.Sprintf("%s=%s", property, app.Status.SubmissionID))
651665

@@ -1022,3 +1036,45 @@ func mainApplicationFileOption(app *v1beta2.SparkApplication) ([]string, error)
10221036
func applicationOption(app *v1beta2.SparkApplication) ([]string, error) {
10231037
return app.Spec.Arguments, nil
10241038
}
1039+
1040+
// driverPodTemplateOption returns the driver pod template arguments.
1041+
func driverPodTemplateOption(app *v1beta2.SparkApplication) ([]string, error) {
1042+
if app.Spec.Driver.Template == nil {
1043+
return []string{}, nil
1044+
}
1045+
1046+
podTemplateFile := fmt.Sprintf("/tmp/spark/%s/driver-pod-template.yaml", app.Status.SubmissionID)
1047+
if err := util.WriteObjectToFile(app.Spec.Driver.Template, podTemplateFile); err != nil {
1048+
return []string{}, err
1049+
}
1050+
logger.V(1).Info("Created driver pod template file for SparkApplication", "name", app.Name, "namespace", app.Namespace, "file", podTemplateFile)
1051+
1052+
args := []string{
1053+
"--conf",
1054+
fmt.Sprintf("%s=%s", common.SparkKubernetesDriverPodTemplateFile, podTemplateFile),
1055+
"--conf",
1056+
fmt.Sprintf("%s=%s", common.SparkKubernetesDriverPodTemplateContainerName, common.SparkDriverContainerName),
1057+
}
1058+
return args, nil
1059+
}
1060+
1061+
// executorPodTemplateOption returns the executor pod template arguments.
1062+
func executorPodTemplateOption(app *v1beta2.SparkApplication) ([]string, error) {
1063+
if app.Spec.Executor.Template == nil {
1064+
return []string{}, nil
1065+
}
1066+
1067+
podTemplateFile := fmt.Sprintf("/tmp/spark/%s/executor-pod-template.yaml", app.Status.SubmissionID)
1068+
if err := util.WriteObjectToFile(app.Spec.Executor.Template, podTemplateFile); err != nil {
1069+
return []string{}, err
1070+
}
1071+
logger.V(1).Info("Created executor pod template file for SparkApplication", "name", app.Name, "namespace", app.Namespace, "file", podTemplateFile)
1072+
1073+
args := []string{
1074+
"--conf",
1075+
fmt.Sprintf("%s=%s", common.SparkKubernetesExecutorPodTemplateFile, podTemplateFile),
1076+
"--conf",
1077+
fmt.Sprintf("%s=%s", common.SparkKubernetesExecutorPodTemplateContainerName, common.Spark3DefaultExecutorContainerName),
1078+
}
1079+
return args, nil
1080+
}

internal/webhook/sparkapplication_defaulter.go

Lines changed: 0 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -83,32 +83,9 @@ func defaultSparkApplication(app *v1beta2.SparkApplication) {
8383
}
8484

8585
func defaultDriverSpec(app *v1beta2.SparkApplication) {
86-
if app.Spec.Driver.Cores == nil {
87-
if app.Spec.SparkConf == nil || app.Spec.SparkConf[common.SparkDriverCores] == "" {
88-
app.Spec.Driver.Cores = util.Int32Ptr(1)
89-
}
90-
}
91-
92-
if app.Spec.Driver.Memory == nil {
93-
if app.Spec.SparkConf == nil || app.Spec.SparkConf[common.SparkDriverMemory] == "" {
94-
app.Spec.Driver.Memory = util.StringPtr("1g")
95-
}
96-
}
9786
}
9887

9988
func defaultExecutorSpec(app *v1beta2.SparkApplication) {
100-
if app.Spec.Executor.Cores == nil {
101-
if app.Spec.SparkConf == nil || app.Spec.SparkConf[common.SparkExecutorCores] == "" {
102-
app.Spec.Executor.Cores = util.Int32Ptr(1)
103-
}
104-
}
105-
106-
if app.Spec.Executor.Memory == nil {
107-
if app.Spec.SparkConf == nil || app.Spec.SparkConf[common.SparkExecutorMemory] == "" {
108-
app.Spec.Executor.Memory = util.StringPtr("1g")
109-
}
110-
}
111-
11289
if app.Spec.Executor.Instances == nil {
11390
// Check whether dynamic allocation is enabled in application spec.
11491
enableDynamicAllocation := app.Spec.DynamicAllocation != nil && app.Spec.DynamicAllocation.Enabled

internal/webhook/sparkapplication_validator.go

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,10 @@ func (v *SparkApplicationValidator) ValidateDelete(ctx context.Context, obj runt
117117
func (v *SparkApplicationValidator) validateSpec(_ context.Context, app *v1beta2.SparkApplication) error {
118118
logger.V(1).Info("Validating SparkApplication spec", "name", app.Name, "namespace", app.Namespace, "state", util.GetApplicationState(app))
119119

120+
if err := v.validateSparkVersion(app); err != nil {
121+
return err
122+
}
123+
120124
if app.Spec.NodeSelector != nil && (app.Spec.Driver.NodeSelector != nil || app.Spec.Executor.NodeSelector != nil) {
121125
return fmt.Errorf("node selector cannot be defined at both SparkApplication and Driver/Executor")
122126
}
@@ -144,6 +148,16 @@ func (v *SparkApplicationValidator) validateSpec(_ context.Context, app *v1beta2
144148
return nil
145149
}
146150

151+
func (v *SparkApplicationValidator) validateSparkVersion(app *v1beta2.SparkApplication) error {
152+
// The pod template feature requires Spark version 3.0.0 or higher.
153+
if app.Spec.Driver.Template != nil || app.Spec.Executor.Template != nil {
154+
if util.CompareSemanticVersion(app.Spec.SparkVersion, "3.0.0") < 0 {
155+
return fmt.Errorf("pod template feature requires Spark version 3.0.0 or higher")
156+
}
157+
}
158+
return nil
159+
}
160+
147161
func (v *SparkApplicationValidator) validateResourceUsage(ctx context.Context, app *v1beta2.SparkApplication) error {
148162
logger.V(1).Info("Validating SparkApplication resource usage", "name", app.Name, "namespace", app.Namespace, "state", util.GetApplicationState(app))
149163

pkg/common/spark.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -307,6 +307,9 @@ const (
307307
// LabelLaunchedBySparkOperator is a label on Spark pods launched through the Spark Operator.
308308
LabelLaunchedBySparkOperator = LabelAnnotationPrefix + "launched-by-spark-operator"
309309

310+
// LabelMutatedBySparkOperator is a label on Spark pods that need to be mutated by webhook.
311+
LabelMutatedBySparkOperator = LabelAnnotationPrefix + "mutated-by-spark-operator"
312+
310313
// LabelSubmissionID is the label that records the submission ID of the current run of an application.
311314
LabelSubmissionID = LabelAnnotationPrefix + "submission-id"
312315

pkg/util/util.go

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,12 @@ package util
1919
import (
2020
"fmt"
2121
"os"
22+
"path/filepath"
2223
"strings"
2324

25+
"golang.org/x/mod/semver"
26+
"sigs.k8s.io/yaml"
27+
2428
"github.com/kubeflow/spark-operator/pkg/common"
2529
)
2630

@@ -77,3 +81,40 @@ func Int64Ptr(n int64) *int64 {
7781
func StringPtr(s string) *string {
7882
return &s
7983
}
84+
85+
// CompareSemanticVersion compares two semantic versions.
86+
func CompareSemanticVersion(v1, v2 string) int {
87+
// Add 'v' prefix if needed
88+
addPrefix := func(s string) string {
89+
if !strings.HasPrefix(s, "v") {
90+
return "v" + s
91+
}
92+
return s
93+
}
94+
return semver.Compare(addPrefix(v1), addPrefix(v2))
95+
}
96+
97+
// WriteObjectToFile marshals the given object into a YAML document and writes it to the given file.
98+
func WriteObjectToFile(obj interface{}, filePath string) error {
99+
if err := os.MkdirAll(filepath.Dir(filePath), 0755); err != nil {
100+
return err
101+
}
102+
103+
file, err := os.Create(filePath)
104+
if err != nil {
105+
return err
106+
}
107+
defer file.Close()
108+
109+
data, err := yaml.Marshal(obj)
110+
if err != nil {
111+
return err
112+
}
113+
114+
_, err = file.Write(data)
115+
if err != nil {
116+
return err
117+
}
118+
119+
return nil
120+
}

pkg/util/util_test.go

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@ import (
2121

2222
. "github.com/onsi/ginkgo/v2"
2323
. "github.com/onsi/gomega"
24+
corev1 "k8s.io/api/core/v1"
25+
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
2426

2527
"github.com/kubeflow/spark-operator/pkg/common"
2628
"github.com/kubeflow/spark-operator/pkg/util"
@@ -129,3 +131,73 @@ var _ = Describe("StringPtr", func() {
129131
Expect(util.StringPtr(s)).To(Equal(&s))
130132
})
131133
})
134+
135+
var _ = Describe("CompareSemanticVersions", func() {
136+
It("Should return 0 if the two versions are equal", func() {
137+
Expect(util.CompareSemanticVersion("1.2.3", "1.2.3"))
138+
Expect(util.CompareSemanticVersion("1.2.3", "v1.2.3")).To(Equal(0))
139+
})
140+
141+
It("Should return -1 if the first version is less than the second version", func() {
142+
Expect(util.CompareSemanticVersion("2.3.4", "2.4.5")).To(Equal(-1))
143+
Expect(util.CompareSemanticVersion("2.4.5", "2.4.8")).To(Equal(-1))
144+
Expect(util.CompareSemanticVersion("2.4.8", "3.5.2")).To(Equal(-1))
145+
})
146+
147+
It("Should return +1 if the first version is greater than the second version", func() {
148+
Expect(util.CompareSemanticVersion("2.4.5", "2.3.4")).To(Equal(1))
149+
Expect(util.CompareSemanticVersion("2.4.8", "2.4.5")).To(Equal(1))
150+
Expect(util.CompareSemanticVersion("3.5.2", "2.4.8")).To(Equal(1))
151+
})
152+
})
153+
154+
var _ = Describe("WriteObjectToFile", func() {
155+
It("Should write the object to the file", func() {
156+
podTemplate := &corev1.PodTemplateSpec{
157+
ObjectMeta: metav1.ObjectMeta{
158+
Name: "test-pod",
159+
Labels: map[string]string{
160+
"key1": "value1",
161+
"key2": "value2",
162+
},
163+
Annotations: map[string]string{
164+
"key3": "value3",
165+
"key4": "value4",
166+
},
167+
},
168+
Spec: corev1.PodSpec{
169+
Containers: []corev1.Container{
170+
{
171+
Name: "test-container",
172+
Image: "test-image",
173+
},
174+
},
175+
},
176+
}
177+
178+
expected := `metadata:
179+
annotations:
180+
key3: value3
181+
key4: value4
182+
creationTimestamp: null
183+
labels:
184+
key1: value1
185+
key2: value2
186+
name: test-pod
187+
spec:
188+
containers:
189+
- image: test-image
190+
name: test-container
191+
resources: {}
192+
`
193+
file := "pod-template.yaml"
194+
Expect(util.WriteObjectToFile(podTemplate, file)).To(Succeed())
195+
196+
data, err := os.ReadFile(file)
197+
Expect(err).NotTo(HaveOccurred())
198+
actual := string(data)
199+
200+
Expect(actual).To(Equal(expected))
201+
Expect(os.Remove(file)).NotTo(HaveOccurred())
202+
})
203+
})

0 commit comments

Comments
 (0)