Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

use cmd context in sparkctl #2447

Merged
merged 1 commit into from
Feb 20, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 10 additions & 8 deletions cmd/sparkctl/app/create.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@ var createCmd = &cobra.Command{
Short: "Create a SparkApplication object",
Long: `Create a SparkApplication from a given YAML file storing the application specification.`,
Run: func(cmd *cobra.Command, args []string) {
ctx := cmd.Context()

if From != "" && len(args) != 1 {
fmt.Fprintln(os.Stderr, "must specify the name of a ScheduledSparkApplication")
return
Expand All @@ -80,11 +82,11 @@ var createCmd = &cobra.Command{
}

if From != "" {
if err := createFromScheduledSparkApplication(args[0], kubeClient, crdClient); err != nil {
if err := createFromScheduledSparkApplication(ctx, args[0], kubeClient, crdClient); err != nil {
fmt.Fprintf(os.Stderr, "%v\n", err)
}
} else {
if err := createFromYaml(args[0], kubeClient, crdClient); err != nil {
if err := createFromYaml(ctx, args[0], kubeClient, crdClient); err != nil {
fmt.Fprintf(os.Stderr, "%v\n", err)
}
}
Expand Down Expand Up @@ -114,20 +116,20 @@ func init() {
"the name of ScheduledSparkApplication from which a forced SparkApplication run is created")
}

func createFromYaml(yamlFile string, kubeClient clientset.Interface, crdClient crdclientset.Interface) error {
func createFromYaml(ctx context.Context, yamlFile string, kubeClient clientset.Interface, crdClient crdclientset.Interface) error {
app, err := loadFromYAML(yamlFile)
if err != nil {
return fmt.Errorf("failed to read a SparkApplication from %s: %v", yamlFile, err)
}

if err := createSparkApplication(app, kubeClient, crdClient); err != nil {
if err := createSparkApplication(ctx, app, kubeClient, crdClient); err != nil {
return fmt.Errorf("failed to create SparkApplication %s: %v", app.Name, err)
}

return nil
}

func createFromScheduledSparkApplication(name string, kubeClient clientset.Interface, crdClient crdclientset.Interface) error {
func createFromScheduledSparkApplication(ctx context.Context, name string, kubeClient clientset.Interface, crdClient crdclientset.Interface) error {
sapp, err := crdClient.SparkoperatorV1beta2().ScheduledSparkApplications(Namespace).Get(context.TODO(), From, metav1.GetOptions{})
if err != nil {
return fmt.Errorf("failed to get ScheduledSparkApplication %s: %v", From, err)
Expand All @@ -149,14 +151,14 @@ func createFromScheduledSparkApplication(name string, kubeClient clientset.Inter
Spec: *sapp.Spec.Template.DeepCopy(),
}

if err := createSparkApplication(app, kubeClient, crdClient); err != nil {
if err := createSparkApplication(ctx, app, kubeClient, crdClient); err != nil {
return fmt.Errorf("failed to create SparkApplication %s: %v", app.Name, err)
}

return nil
}

func createSparkApplication(app *v1beta2.SparkApplication, kubeClient clientset.Interface, crdClient crdclientset.Interface) error {
func createSparkApplication(ctx context.Context, app *v1beta2.SparkApplication, kubeClient clientset.Interface, crdClient crdclientset.Interface) error {
if DeleteIfExists {
if err := deleteSparkApplication(app.Name, crdClient); err != nil {
return err
Expand Down Expand Up @@ -190,7 +192,7 @@ func createSparkApplication(app *v1beta2.SparkApplication, kubeClient clientset.
fmt.Printf("SparkApplication \"%s\" created\n", app.Name)

if LogsEnabled {
if err := doLog(app.Name, true, kubeClient, crdClient); err != nil {
if err := doLog(ctx, app.Name, true, kubeClient, crdClient); err != nil {
return nil
}
}
Expand Down
27 changes: 16 additions & 11 deletions cmd/sparkctl/app/log.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ var logCommand = &cobra.Command{
Short: "log is a sub-command of sparkctl that fetches logs of a Spark application.",
Long: ``,
Run: func(cmd *cobra.Command, args []string) {
ctx := cmd.Context()

if len(args) != 1 {
fmt.Fprintln(os.Stderr, "must specify a SparkApplication name")
return
Expand All @@ -56,7 +58,7 @@ var logCommand = &cobra.Command{
return
}

if err := doLog(args[0], FollowLogs, kubeClientset, crdClientset); err != nil {
if err := doLog(ctx, args[0], FollowLogs, kubeClientset, crdClientset); err != nil {
fmt.Fprintf(os.Stderr, "failed to get driver logs of SparkApplication %s: %v\n", args[0], err)
}
},
Expand All @@ -69,13 +71,14 @@ func init() {
}

func doLog(
ctx context.Context,
name string,
followLogs bool,
kubeClient clientset.Interface,
crdClient crdclientset.Interface) error {
timeout := 30 * time.Second

podNameChannel := getPodNameChannel(name, crdClient)
podNameChannel := getPodNameChannel(ctx, name, crdClient)
var podName string

select {
Expand All @@ -84,7 +87,7 @@ func doLog(
return fmt.Errorf("not found pod name")
}

waitLogsChannel := waitForLogsFromPodChannel(podName, kubeClient, crdClient)
waitLogsChannel := waitForLogsFromPodChannel(ctx, podName, kubeClient, crdClient)

select {
case <-waitLogsChannel:
Expand All @@ -93,19 +96,20 @@ func doLog(
}

if followLogs {
return streamLogs(os.Stdout, kubeClient, podName)
return streamLogs(ctx, os.Stdout, kubeClient, podName)
}
return printLogs(os.Stdout, kubeClient, podName)
return printLogs(ctx, os.Stdout, kubeClient, podName)
}

func getPodNameChannel(
ctx context.Context,
sparkApplicationName string,
crdClient crdclientset.Interface) chan string {
channel := make(chan string, 1)
go func() {
for {
app, _ := crdClient.SparkoperatorV1beta2().SparkApplications(Namespace).Get(
context.TODO(),
ctx,
sparkApplicationName,
metav1.GetOptions{})

Expand All @@ -119,13 +123,14 @@ func getPodNameChannel(
}

func waitForLogsFromPodChannel(
ctx context.Context,
podName string,
kubeClient clientset.Interface,
_ crdclientset.Interface) chan bool {
channel := make(chan bool, 1)
go func() {
for {
_, err := kubeClient.CoreV1().Pods(Namespace).GetLogs(podName, &corev1.PodLogOptions{}).Do(context.TODO()).Raw()
_, err := kubeClient.CoreV1().Pods(Namespace).GetLogs(podName, &corev1.PodLogOptions{}).Do(ctx).Raw()

if err == nil {
channel <- true
Expand All @@ -137,8 +142,8 @@ func waitForLogsFromPodChannel(
}

// printLogs is a one time operation that prints the fetched logs of the given pod.
func printLogs(out io.Writer, kubeClientset clientset.Interface, podName string) error {
rawLogs, err := kubeClientset.CoreV1().Pods(Namespace).GetLogs(podName, &corev1.PodLogOptions{}).Do(context.TODO()).Raw()
func printLogs(ctx context.Context, out io.Writer, kubeClientset clientset.Interface, podName string) error {
rawLogs, err := kubeClientset.CoreV1().Pods(Namespace).GetLogs(podName, &corev1.PodLogOptions{}).Do(ctx).Raw()
if err != nil {
return err
}
Expand All @@ -147,9 +152,9 @@ func printLogs(out io.Writer, kubeClientset clientset.Interface, podName string)
}

// streamLogs streams the logs of the given pod until there are no more logs available.
func streamLogs(out io.Writer, kubeClientset clientset.Interface, podName string) error {
func streamLogs(ctx context.Context, out io.Writer, kubeClientset clientset.Interface, podName string) error {
request := kubeClientset.CoreV1().Pods(Namespace).GetLogs(podName, &corev1.PodLogOptions{Follow: true})
reader, err := request.Stream(context.TODO())
reader, err := request.Stream(ctx)
if err != nil {
return err
}
Expand Down
5 changes: 4 additions & 1 deletion cmd/sparkctl/app/root.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ limitations under the License.
package app

import (
"context"
"fmt"
"os"

Expand Down Expand Up @@ -52,7 +53,9 @@ func init() {
}

func Execute() {
if err := rootCmd.Execute(); err != nil {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
if err := rootCmd.ExecuteContext(ctx); err != nil {
fmt.Fprintf(os.Stderr, "%v", err)
}
}