diff --git a/cmd/apply/cmdapply.go b/cmd/apply/cmdapply.go index 759b43b6..cdc782b9 100644 --- a/cmd/apply/cmdapply.go +++ b/cmd/apply/cmdapply.go @@ -52,13 +52,15 @@ func GetApplyRunner(factory cmdutil.Factory, invFactory inventory.InventoryClien "Timeout threshold for waiting for all resources to reach the Current status.") cmd.Flags().BoolVar(&r.noPrune, "no-prune", r.noPrune, "If true, do not prune previously applied objects.") - cmd.Flags().StringVar(&r.prunePropagationPolicy, "prune-propagation-policy", - "Background", "Propagation policy for pruning") - cmd.Flags().DurationVar(&r.pruneTimeout, "prune-timeout", time.Duration(0), - "Timeout threshold for waiting for all pruned resources to be deleted") + cmd.Flags().StringVar(&r.deletePolicy, "delete-propagation-policy", + "Background", "Propagation policy for deletion") + cmd.Flags().DurationVar(&r.deleteTimeout, "delete-timeout", time.Duration(0), + "Timeout threshold for waiting for all deleted resources to complete deletion") cmd.Flags().StringVar(&r.inventoryPolicy, flagutils.InventoryPolicyFlag, flagutils.InventoryPolicyStrict, "It determines the behavior when the resources don't belong to current inventory. Available options "+ fmt.Sprintf("%q and %q.", flagutils.InventoryPolicyStrict, flagutils.InventoryPolicyAdopt)) + cmd.Flags().DurationVar(&r.timeout, "timeout", time.Duration(0), + "Timeout threshold for command execution") r.Command = cmd return r @@ -77,18 +79,30 @@ type ApplyRunner struct { invFactory inventory.InventoryClientFactory loader manifestreader.ManifestLoader - serverSideOptions common.ServerSideOptions - output string - period time.Duration - reconcileTimeout time.Duration - noPrune bool - prunePropagationPolicy string - pruneTimeout time.Duration - inventoryPolicy string + serverSideOptions common.ServerSideOptions + output string + period time.Duration + reconcileTimeout time.Duration + noPrune bool + deletePolicy string + deleteTimeout time.Duration + inventoryPolicy string + timeout time.Duration } func (r *ApplyRunner) RunE(cmd *cobra.Command, args []string) error { - prunePropPolicy, err := flagutils.ConvertPropagationPolicy(r.prunePropagationPolicy) + // If specified, cancel with timeout. + // Otherwise, cancel when completed to clean up timer. + ctx := cmd.Context() + var cancel func() + if r.timeout != 0 { + ctx, cancel = context.WithTimeout(ctx, r.timeout) + } else { + ctx, cancel = context.WithCancel(ctx) + } + defer cancel() + + deletePolicy, err := flagutils.ConvertPropagationPolicy(r.deletePolicy) if err != nil { return err } @@ -102,7 +116,7 @@ func (r *ApplyRunner) RunE(cmd *cobra.Command, args []string) error { // we do need status events event if we are not waiting for status. The // printers should be updated to handle this. var printStatusEvents bool - if r.reconcileTimeout != time.Duration(0) || r.pruneTimeout != time.Duration(0) { + if r.reconcileTimeout != time.Duration(0) || r.deleteTimeout != time.Duration(0) { printStatusEvents = true } @@ -147,18 +161,19 @@ func (r *ApplyRunner) RunE(cmd *cobra.Command, args []string) error { if err != nil { return err } - ch := a.Run(context.Background(), inv, objs, apply.Options{ + + ch := a.Run(ctx, inv, objs, apply.Options{ ServerSideOptions: r.serverSideOptions, PollInterval: r.period, ReconcileTimeout: r.reconcileTimeout, // If we are not waiting for status, tell the applier to not // emit the events. - EmitStatusEvents: printStatusEvents, - NoPrune: r.noPrune, - DryRunStrategy: common.DryRunNone, - PrunePropagationPolicy: prunePropPolicy, - PruneTimeout: r.pruneTimeout, - InventoryPolicy: inventoryPolicy, + EmitStatusEvents: printStatusEvents, + NoPrune: r.noPrune, + DryRunStrategy: common.DryRunNone, + DeletionPropagationPolicy: deletePolicy, + DeleteTimeout: r.deleteTimeout, + InventoryPolicy: inventoryPolicy, }) // The printer will print updates from the channel. It will block diff --git a/cmd/destroy/cmddestroy.go b/cmd/destroy/cmddestroy.go index 6935aa28..80b34b53 100644 --- a/cmd/destroy/cmddestroy.go +++ b/cmd/destroy/cmddestroy.go @@ -4,6 +4,7 @@ package destroy import ( + "context" "fmt" "strings" "time" @@ -46,6 +47,8 @@ func GetDestroyRunner(factory cmdutil.Factory, invFactory inventory.InventoryCli "Timeout threshold for waiting for all deleted resources to complete deletion") cmd.Flags().StringVar(&r.deletePropagationPolicy, "delete-propagation-policy", "Background", "Propagation policy for deletion") + cmd.Flags().DurationVar(&r.timeout, "timeout", time.Duration(0), + "Timeout threshold for command execution") r.Command = cmd return r @@ -70,9 +73,21 @@ type DestroyRunner struct { deleteTimeout time.Duration deletePropagationPolicy string inventoryPolicy string + timeout time.Duration } func (r *DestroyRunner) RunE(cmd *cobra.Command, args []string) error { + // If specified, cancel with timeout. + // Otherwise, cancel when completed to clean up timer. + ctx := cmd.Context() + var cancel func() + if r.timeout != 0 { + ctx, cancel = context.WithTimeout(ctx, r.timeout) + } else { + ctx, cancel = context.WithCancel(ctx) + } + defer cancel() + deletePropPolicy, err := flagutils.ConvertPropagationPolicy(r.deletePropagationPolicy) if err != nil { return err @@ -117,7 +132,8 @@ func (r *DestroyRunner) RunE(cmd *cobra.Command, args []string) error { // Run the destroyer. It will return a channel where we can receive updates // to keep track of progress and any issues. printStatusEvents := r.deleteTimeout != time.Duration(0) - ch := d.Run(inv, apply.DestroyerOptions{ + + ch := d.Run(ctx, inv, apply.DestroyerOptions{ DeleteTimeout: r.deleteTimeout, DeletePropagationPolicy: deletePropPolicy, InventoryPolicy: inventoryPolicy, diff --git a/cmd/preview/cmdpreview.go b/cmd/preview/cmdpreview.go index fe404a3a..4a622f86 100644 --- a/cmd/preview/cmdpreview.go +++ b/cmd/preview/cmdpreview.go @@ -7,6 +7,7 @@ import ( "context" "fmt" "strings" + "time" "github.com/spf13/cobra" "k8s.io/cli-runtime/pkg/genericclioptions" @@ -57,6 +58,8 @@ func GetPreviewRunner(factory cmdutil.Factory, invFactory inventory.InventoryCli cmd.Flags().StringVar(&r.inventoryPolicy, flagutils.InventoryPolicyFlag, flagutils.InventoryPolicyStrict, "It determines the behavior when the resources don't belong to current inventory. Available options "+ fmt.Sprintf("%q and %q.", flagutils.InventoryPolicyStrict, flagutils.InventoryPolicyAdopt)) + cmd.Flags().DurationVar(&r.timeout, "timeout", time.Duration(0), + "Timeout threshold for command execution") r.Command = cmd return r @@ -80,10 +83,22 @@ type PreviewRunner struct { serverSideOptions common.ServerSideOptions output string inventoryPolicy string + timeout time.Duration } // RunE is the function run from the cobra command. func (r *PreviewRunner) RunE(cmd *cobra.Command, args []string) error { + // If specified, cancel with timeout. + // Otherwise, cancel when completed to clean up timer. + ctx := cmd.Context() + var cancel func() + if r.timeout != 0 { + ctx, cancel = context.WithTimeout(ctx, r.timeout) + } else { + ctx, cancel = context.WithCancel(ctx) + } + defer cancel() + var ch <-chan event.Event drs := common.DryRunClient @@ -139,9 +154,6 @@ func (r *PreviewRunner) RunE(cmd *cobra.Command, args []string) error { return err } - // Create a context - ctx := context.Background() - // Run the applier. It will return a channel where we can receive updates // to keep track of progress and any issues. ch = a.Run(ctx, inv, objs, apply.Options{ @@ -156,7 +168,7 @@ func (r *PreviewRunner) RunE(cmd *cobra.Command, args []string) error { if err != nil { return err } - ch = d.Run(inv, apply.DestroyerOptions{ + ch = d.Run(ctx, inv, apply.DestroyerOptions{ InventoryPolicy: inventoryPolicy, DryRunStrategy: drs, }) diff --git a/cmd/status/cmdstatus.go b/cmd/status/cmdstatus.go index 3cbabe35..2df77379 100644 --- a/cmd/status/cmdstatus.go +++ b/cmd/status/cmdstatus.go @@ -72,6 +72,17 @@ type StatusRunner struct { // poller to compute status for each of the resources. One of the printer // implementations takes care of printing the output. func (r *StatusRunner) runE(cmd *cobra.Command, args []string) error { + // If specified, cancel with timeout. + // Otherwise, cancel when completed to clean up timer. + ctx := cmd.Context() + var cancel func() + if r.timeout != 0 { + ctx, cancel = context.WithTimeout(ctx, r.timeout) + } else { + ctx, cancel = context.WithCancel(ctx) + } + defer cancel() + _, err := common.DemandOneDirectory(args) if err != nil { return err @@ -125,17 +136,6 @@ func (r *StatusRunner) runE(cmd *cobra.Command, args []string) error { return fmt.Errorf("error creating printer: %w", err) } - // If the user has specified a timeout, we create a context with timeout, - // otherwise we create a context with cancel. - ctx := context.Background() - var cancel func() - if r.timeout != 0 { - ctx, cancel = context.WithTimeout(ctx, r.timeout) - } else { - ctx, cancel = context.WithCancel(ctx) - } - defer cancel() - // Choose the appropriate ObserverFunc based on the criteria for when // the command should exit. var cancelFunc collector.ObserverFunc diff --git a/cmd/status/cmdstatus_test.go b/cmd/status/cmdstatus_test.go index 38d012ae..bc0a12bf 100644 --- a/cmd/status/cmdstatus_test.go +++ b/cmd/status/cmdstatus_test.go @@ -236,12 +236,15 @@ deployment.apps/foo is InProgress: inProgress timeout: tc.timeout, } - cmd := &cobra.Command{} + cmd := &cobra.Command{ + RunE: runner.runE, + } cmd.SetIn(strings.NewReader(tc.input)) var buf bytes.Buffer cmd.SetOut(&buf) - err := runner.runE(cmd, []string{}) + // execute with cobra to handle setting the default context + err := cmd.Execute() if tc.expectedErrMsg != "" { if !assert.Error(t, err) { diff --git a/pkg/apply/applier.go b/pkg/apply/applier.go index 2d145f2f..da160c35 100644 --- a/pkg/apply/applier.go +++ b/pkg/apply/applier.go @@ -30,7 +30,7 @@ import ( // NewApplier returns a new Applier. func NewApplier(factory cmdutil.Factory, invClient inventory.InventoryClient, statusPoller poller.Poller) (*Applier, error) { - pruneOpts, err := prune.NewPruneOptions(factory, invClient) + pruneOpts, err := prune.NewPruner(factory, invClient) if err != nil { return nil, err } @@ -39,7 +39,7 @@ func NewApplier(factory cmdutil.Factory, invClient inventory.InventoryClient, st return nil, err } return &Applier{ - pruneOptions: pruneOpts, + pruner: pruneOpts, statusPoller: statusPoller, factory: factory, invClient: invClient, @@ -58,7 +58,7 @@ func NewApplier(factory cmdutil.Factory, invClient inventory.InventoryClient, st // parameters and/or the set of resources that needs to be applied to the // cluster, different sets of tasks might be needed. type Applier struct { - pruneOptions *prune.PruneOptions + pruner *prune.Pruner statusPoller poller.Poller factory cmdutil.Factory invClient inventory.InventoryClient @@ -67,8 +67,12 @@ type Applier struct { // prepareObjects returns the set of objects to apply and to prune or // an error if one occurred. -func (a *Applier) prepareObjects(localInv inventory.InventoryInfo, localObjs []*unstructured.Unstructured, - o Options) ([]*unstructured.Unstructured, []*unstructured.Unstructured, error) { +func (a *Applier) prepareObjects( + ctx context.Context, + localInv inventory.InventoryInfo, + localObjs []*unstructured.Unstructured, + opts Options, +) ([]*unstructured.Unstructured, []*unstructured.Unstructured, error) { if localInv == nil { return nil, nil, fmt.Errorf("the local inventory can't be nil") } @@ -99,9 +103,12 @@ func (a *Applier) prepareObjects(localInv inventory.InventoryInfo, localObjs []* } } } - pruneObjs, err := a.pruneOptions.GetPruneObjs(localInv, localObjs, prune.Options{ - DryRunStrategy: o.DryRunStrategy, - }) + pruneObjs, err := a.pruner.GetPruneObjs( + ctx, + localInv, + localObjs, + opts.PruneOptions(), + ) if err != nil { return nil, nil, err } @@ -117,10 +124,10 @@ func (a *Applier) prepareObjects(localInv inventory.InventoryInfo, localObjs []* // before all the given resources have been applied to the cluster. Any // cancellation or timeout will only affect how long we Wait for the // resources to become current. -func (a *Applier) Run(ctx context.Context, invInfo inventory.InventoryInfo, objects []*unstructured.Unstructured, options Options) <-chan event.Event { +func (a *Applier) Run(ctx context.Context, invInfo inventory.InventoryInfo, objects []*unstructured.Unstructured, opts Options) <-chan event.Event { klog.V(4).Infof("apply run for %d objects", len(objects)) eventChannel := make(chan event.Event) - setDefaults(&options) + setDefaults(&opts) go func() { defer close(eventChannel) @@ -144,7 +151,7 @@ func (a *Applier) Run(ctx context.Context, invInfo inventory.InventoryInfo, obje return } - applyObjs, pruneObjs, err := a.prepareObjects(invInfo, objects, options) + applyObjs, pruneObjs, err := a.prepareObjects(ctx, invInfo, objects, opts) if err != nil { handleError(eventChannel, err) return @@ -154,29 +161,21 @@ func (a *Applier) Run(ctx context.Context, invInfo inventory.InventoryInfo, obje // Fetch the queue (channel) of tasks that should be executed. klog.V(4).Infoln("applier building task queue...") taskBuilder := &solver.TaskQueueBuilder{ - PruneOptions: a.pruneOptions, - Factory: a.factory, - InfoHelper: a.infoHelper, - Mapper: mapper, - InvClient: a.invClient, - Destroy: false, - } - opts := solver.Options{ - ServerSideOptions: options.ServerSideOptions, - ReconcileTimeout: options.ReconcileTimeout, - Prune: !options.NoPrune, - DryRunStrategy: options.DryRunStrategy, - PrunePropagationPolicy: options.PrunePropagationPolicy, - PruneTimeout: options.PruneTimeout, - InventoryPolicy: options.InventoryPolicy, + Pruner: a.pruner, + Factory: a.factory, + InfoHelper: a.infoHelper, + Mapper: mapper, + InvClient: a.invClient, + Destroy: false, // DO NOT remove pruned resources from inventory } + solverOpts := opts.SolverOptions() // Build list of apply validation filters. applyFilters := []filter.ValidationFilter{ filter.InventoryPolicyApplyFilter{ Client: client, Mapper: mapper, Inv: invInfo, - InvPolicy: options.InventoryPolicy, + InvPolicy: opts.InventoryPolicy, }, } // Build list of prune validation filters. @@ -184,19 +183,21 @@ func (a *Applier) Run(ctx context.Context, invInfo inventory.InventoryInfo, obje filter.PreventRemoveFilter{}, filter.InventoryPolicyFilter{ Inv: invInfo, - InvPolicy: options.InventoryPolicy, + InvPolicy: opts.InventoryPolicy, }, filter.LocalNamespacesFilter{ LocalNamespaces: localNamespaces(invInfo, object.UnstructuredsToObjMetasOrDie(objects)), }, } - // Build the task queue by appending tasks in the proper order. - taskQueue, err := taskBuilder. - AppendInvAddTask(invInfo, applyObjs, options.DryRunStrategy). - AppendApplyWaitTasks(applyObjs, applyFilters, opts). - AppendPruneWaitTasks(pruneObjs, pruneFilters, opts). - AppendInvSetTask(invInfo, options.DryRunStrategy). - Build() + + // Build the ordered set of tasks to execute. + taskBuilder.AppendInvAddTask(invInfo, applyObjs, opts.DryRunStrategy) + taskBuilder.AppendApplyWaitTasks(applyObjs, applyFilters, solverOpts) + if !opts.NoPrune { + taskBuilder.AppendPruneWaitTasks(pruneObjs, pruneFilters, solverOpts) + } + taskBuilder.AppendInvSetTask(invInfo, opts.DryRunStrategy) + taskQueue, err := taskBuilder.Build() if err != nil { handleError(eventChannel, err) } @@ -213,11 +214,7 @@ func (a *Applier) Run(ctx context.Context, invInfo inventory.InventoryInfo, obje allIds := object.UnstructuredsToObjMetasOrDie(append(applyObjs, pruneObjs...)) runner := taskrunner.NewTaskStatusRunner(allIds, a.statusPoller) klog.V(4).Infoln("applier running TaskStatusRunner...") - err = runner.Run(ctx, taskQueue.ToChannel(), eventChannel, taskrunner.Options{ - PollInterval: options.PollInterval, - UseCache: true, - EmitStatusEvents: options.EmitStatusEvents, - }) + err = runner.Run(ctx, taskQueue.ToChannel(), eventChannel, opts.TaskRunnerOptions()) if err != nil { handleError(eventChannel, err) } @@ -250,28 +247,54 @@ type Options struct { // or if it is just talk and no action. DryRunStrategy common.DryRunStrategy - // PrunePropagationPolicy defines the deletion propagation policy + // DeletionPropagationPolicy defines the deletion propagation policy // that should be used for pruning. If this is not provided, the // default is to use the Background policy. - PrunePropagationPolicy metav1.DeletionPropagation + DeletionPropagationPolicy metav1.DeletionPropagation - // PruneTimeout defines whether we should wait for all resources + // DeleteTimeout defines whether we should wait for all resources // to be fully deleted after pruning, and if so, how long we should // wait. - PruneTimeout time.Duration + DeleteTimeout time.Duration // InventoryPolicy defines the inventory policy of apply. InventoryPolicy inventory.InventoryPolicy } +func (o Options) PruneOptions() prune.Options { + return prune.Options{ + DryRunStrategy: o.DryRunStrategy, + DeleteTimeout: o.DeleteTimeout, + } +} + +func (o Options) SolverOptions() solver.Options { + return solver.Options{ + ServerSideOptions: o.ServerSideOptions, + ReconcileTimeout: o.ReconcileTimeout, + DryRunStrategy: o.DryRunStrategy, + DeletePropagationPolicy: o.DeletionPropagationPolicy, + DeleteTimeout: o.DeleteTimeout, + InventoryPolicy: o.InventoryPolicy, + } +} + +func (o Options) TaskRunnerOptions() taskrunner.Options { + return taskrunner.Options{ + UseCache: true, + PollInterval: o.PollInterval, + EmitStatusEvents: o.EmitStatusEvents, + } +} + // setDefaults set the options to the default values if they // have not been provided. func setDefaults(o *Options) { if o.PollInterval == time.Duration(0) { o.PollInterval = poller.DefaultPollInterval } - if o.PrunePropagationPolicy == "" { - o.PrunePropagationPolicy = metav1.DeletePropagationBackground + if o.DeletionPropagationPolicy == "" { + o.DeletionPropagationPolicy = metav1.DeletePropagationBackground } } diff --git a/pkg/apply/applier_test.go b/pkg/apply/applier_test.go index 0e3a71f9..8f3bd47d 100644 --- a/pkg/apply/applier_test.go +++ b/pkg/apply/applier_test.go @@ -733,7 +733,7 @@ func TestReadAndPrepareObjects(t *testing.T) { } // Create applier with fake inventory client, and call prepareObjects applier := Applier{ - pruneOptions: &prune.PruneOptions{ + pruner: &prune.Pruner{ InvClient: fakeInvClient, Client: dynamicfake.NewSimpleDynamicClient(scheme.Scheme, objs...), Mapper: testrestmapper.TestOnlyStaticRESTMapper(scheme.Scheme, @@ -741,7 +741,12 @@ func TestReadAndPrepareObjects(t *testing.T) { }, invClient: fakeInvClient, } - applyObjs, pruneObjs, err := applier.prepareObjects(tc.inventory, tc.localObjs, Options{}) + applyObjs, pruneObjs, err := applier.prepareObjects( + context.TODO(), + tc.inventory, + tc.localObjs, + Options{}, + ) if tc.isError { assert.Error(t, err) return diff --git a/pkg/apply/destroyer.go b/pkg/apply/destroyer.go index b0b5b62a..ace7fcd4 100644 --- a/pkg/apply/destroyer.go +++ b/pkg/apply/destroyer.go @@ -30,12 +30,12 @@ import ( // handled by a separate printer with the KubectlPrinterAdapter bridging // between the two. func NewDestroyer(factory cmdutil.Factory, invClient inventory.InventoryClient, statusPoller poller.Poller) (*Destroyer, error) { - pruneOpts, err := prune.NewPruneOptions(factory, invClient) + pruner, err := prune.NewPruner(factory, invClient) if err != nil { return nil, fmt.Errorf("error setting up PruneOptions: %w", err) } return &Destroyer{ - pruneOptions: pruneOpts, + pruner: pruner, statusPoller: statusPoller, factory: factory, invClient: invClient, @@ -45,7 +45,7 @@ func NewDestroyer(factory cmdutil.Factory, invClient inventory.InventoryClient, // Destroyer performs the step of grabbing all the previous inventory objects and // prune them. This also deletes all the previous inventory objects type Destroyer struct { - pruneOptions *prune.PruneOptions + pruner *prune.Pruner statusPoller poller.Poller factory cmdutil.Factory invClient inventory.InventoryClient @@ -77,6 +77,32 @@ type DestroyerOptions struct { PollInterval time.Duration } +func (do DestroyerOptions) PruneOptions() prune.Options { + return prune.Options{ + DryRunStrategy: do.DryRunStrategy, + DeleteTimeout: do.DeleteTimeout, + DeletionPropagationPolicy: do.DeletePropagationPolicy, + // Always remove pruned resources from inventory when destroying. + Destroy: true, + } +} + +func (do DestroyerOptions) SolverOptions() solver.Options { + return solver.Options{ + DeleteTimeout: do.DeleteTimeout, + DryRunStrategy: do.DryRunStrategy, + DeletePropagationPolicy: do.DeletePropagationPolicy, + } +} + +func (do DestroyerOptions) TaskRunnerOptions() taskrunner.Options { + return taskrunner.Options{ + UseCache: true, + PollInterval: do.PollInterval, + EmitStatusEvents: do.EmitStatusEvents, + } +} + func setDestroyerDefaults(o *DestroyerOptions) { if o.PollInterval == time.Duration(0) { o.PollInterval = poller.DefaultPollInterval @@ -89,17 +115,20 @@ func setDestroyerDefaults(o *DestroyerOptions) { // Run performs the destroy step. Passes the inventory object. This // happens asynchronously on progress and any errors are reported // back on the event channel. -func (d *Destroyer) Run(inv inventory.InventoryInfo, options DestroyerOptions) <-chan event.Event { +func (d *Destroyer) Run(ctx context.Context, inv inventory.InventoryInfo, opts DestroyerOptions) <-chan event.Event { eventChannel := make(chan event.Event) - setDestroyerDefaults(&options) + setDestroyerDefaults(&opts) go func() { defer close(eventChannel) // Retrieve the objects to be deleted from the cluster. Second parameter is empty // because no local objects returns all inventory objects for deletion. emptyLocalObjs := []*unstructured.Unstructured{} - deleteObjs, err := d.pruneOptions.GetPruneObjs(inv, emptyLocalObjs, prune.Options{ - DryRunStrategy: options.DryRunStrategy, - }) + deleteObjs, err := d.pruner.GetPruneObjs( + ctx, + inv, + emptyLocalObjs, + opts.PruneOptions(), + ) if err != nil { handleError(eventChannel, err) return @@ -111,33 +140,30 @@ func (d *Destroyer) Run(inv inventory.InventoryInfo, options DestroyerOptions) < } klog.V(4).Infoln("destroyer building task queue...") taskBuilder := &solver.TaskQueueBuilder{ - PruneOptions: d.pruneOptions, - Factory: d.factory, - Mapper: mapper, - InvClient: d.invClient, - Destroy: true, - } - opts := solver.Options{ - Prune: true, - PruneTimeout: options.DeleteTimeout, - DryRunStrategy: options.DryRunStrategy, - PrunePropagationPolicy: options.DeletePropagationPolicy, + Pruner: d.pruner, + Factory: d.factory, + Mapper: mapper, + InvClient: d.invClient, + Destroy: opts.PruneOptions().Destroy, } + solverOpts := opts.SolverOptions() deleteFilters := []filter.ValidationFilter{ filter.PreventRemoveFilter{}, filter.InventoryPolicyFilter{ Inv: inv, - InvPolicy: options.InventoryPolicy, + InvPolicy: opts.InventoryPolicy, }, } + // Build the ordered set of tasks to execute. - taskQueue, err := taskBuilder. - AppendPruneWaitTasks(deleteObjs, deleteFilters, opts). - AppendDeleteInvTask(inv, options.DryRunStrategy). - Build() + // Destroyer always prunes + taskBuilder.AppendPruneWaitTasks(deleteObjs, deleteFilters, solverOpts) + taskBuilder.AppendDeleteInvTask(inv, opts.DryRunStrategy) + taskQueue, err := taskBuilder.Build() if err != nil { handleError(eventChannel, err) } + // Send event to inform the caller about the resources that // will be pruned. eventChannel <- event.Event{ @@ -152,11 +178,12 @@ func (d *Destroyer) Run(inv inventory.InventoryInfo, options DestroyerOptions) < runner := taskrunner.NewTaskStatusRunner(deleteIds, d.statusPoller) klog.V(4).Infoln("destroyer running TaskStatusRunner...") // TODO(seans): Make the poll interval configurable like the applier. - err = runner.Run(context.Background(), taskQueue.ToChannel(), eventChannel, taskrunner.Options{ - UseCache: true, - PollInterval: options.PollInterval, - EmitStatusEvents: options.EmitStatusEvents, - }) + err = runner.Run( + ctx, + taskQueue.ToChannel(), + eventChannel, + opts.TaskRunnerOptions(), + ) if err != nil { handleError(eventChannel, err) } diff --git a/pkg/apply/filter/current-uids-filter.go b/pkg/apply/filter/current-uids-filter.go index a73bf68c..13907ae5 100644 --- a/pkg/apply/filter/current-uids-filter.go +++ b/pkg/apply/filter/current-uids-filter.go @@ -4,6 +4,7 @@ package filter import ( + "context" "fmt" "k8s.io/apimachinery/pkg/apis/meta/v1/unstructured" @@ -26,7 +27,7 @@ func (cuf CurrentUIDFilter) Name() string { // because the it is a namespace that objects still reside in; otherwise // returns false. This filter should not be added to the list of filters // for "destroying", since every object is being deletet. Never returns an error. -func (cuf CurrentUIDFilter) Filter(obj *unstructured.Unstructured) (bool, string, error) { +func (cuf CurrentUIDFilter) Filter(ctx context.Context, obj *unstructured.Unstructured) (bool, string, error) { uid := string(obj.GetUID()) if cuf.CurrentUIDs.Has(uid) { reason := fmt.Sprintf("object removal prevented; UID just applied: %s", uid) diff --git a/pkg/apply/filter/current-uids-filter_test.go b/pkg/apply/filter/current-uids-filter_test.go index 40f1a939..e32f3e95 100644 --- a/pkg/apply/filter/current-uids-filter_test.go +++ b/pkg/apply/filter/current-uids-filter_test.go @@ -4,6 +4,7 @@ package filter import ( + "context" "testing" "k8s.io/apimachinery/pkg/types" @@ -50,7 +51,8 @@ func TestCurrentUIDFilter(t *testing.T) { } obj := defaultObj.DeepCopy() obj.SetUID(types.UID(tc.objUID)) - actual, reason, err := filter.Filter(obj) + ctx := context.TODO() + actual, reason, err := filter.Filter(ctx, obj) if err != nil { t.Fatalf("CurrentUIDFilter unexpected error (%s)", err) } diff --git a/pkg/apply/filter/filter.go b/pkg/apply/filter/filter.go index 298cda48..bb3e204d 100644 --- a/pkg/apply/filter/filter.go +++ b/pkg/apply/filter/filter.go @@ -4,6 +4,8 @@ package filter import ( + "context" + "k8s.io/apimachinery/pkg/apis/meta/v1/unstructured" ) @@ -17,5 +19,5 @@ type ValidationFilter interface { // Filter returns true if validation fails. If true a // reason string is included in the return. If an error happens // during filtering it is returned. - Filter(obj *unstructured.Unstructured) (bool, string, error) + Filter(ctx context.Context, obj *unstructured.Unstructured) (bool, string, error) } diff --git a/pkg/apply/filter/inventory-policy-apply-filter.go b/pkg/apply/filter/inventory-policy-apply-filter.go index aa7802ea..63c88ff4 100644 --- a/pkg/apply/filter/inventory-policy-apply-filter.go +++ b/pkg/apply/filter/inventory-policy-apply-filter.go @@ -34,12 +34,12 @@ func (ipaf InventoryPolicyApplyFilter) Name() string { // Filter returns true if the passed object should be filtered (NOT applied) and // a filter reason string; false otherwise. Returns an error if one occurred // during the filter calculation -func (ipaf InventoryPolicyApplyFilter) Filter(obj *unstructured.Unstructured) (bool, string, error) { +func (ipaf InventoryPolicyApplyFilter) Filter(ctx context.Context, obj *unstructured.Unstructured) (bool, string, error) { if obj == nil { return true, "missing object", nil } // Object must be retrieved from the cluster to get the inventory id. - clusterObj, err := ipaf.getObject(object.UnstructuredToObjMetaOrDie(obj)) + clusterObj, err := ipaf.getObject(ctx, object.UnstructuredToObjMetaOrDie(obj)) if err != nil { if apierrors.IsNotFound(err) { // This simply means the object hasn't been created yet. @@ -60,7 +60,7 @@ func (ipaf InventoryPolicyApplyFilter) Filter(obj *unstructured.Unstructured) (b } // getObject retrieves the passed object from the cluster, or an error if one occurred. -func (ipaf InventoryPolicyApplyFilter) getObject(id object.ObjMetadata) (*unstructured.Unstructured, error) { +func (ipaf InventoryPolicyApplyFilter) getObject(ctx context.Context, id object.ObjMetadata) (*unstructured.Unstructured, error) { mapping, err := ipaf.Mapper.RESTMapping(id.GroupKind) if err != nil { return nil, err @@ -69,5 +69,5 @@ func (ipaf InventoryPolicyApplyFilter) getObject(id object.ObjMetadata) (*unstru if err != nil { return nil, err } - return namespacedClient.Get(context.TODO(), id.Name, metav1.GetOptions{}) + return namespacedClient.Get(ctx, id.Name, metav1.GetOptions{}) } diff --git a/pkg/apply/filter/inventory-policy-apply-filter_test.go b/pkg/apply/filter/inventory-policy-apply-filter_test.go index e8935c1c..5dc71204 100644 --- a/pkg/apply/filter/inventory-policy-apply-filter_test.go +++ b/pkg/apply/filter/inventory-policy-apply-filter_test.go @@ -4,6 +4,7 @@ package filter import ( + "context" "testing" "k8s.io/apimachinery/pkg/api/meta/testrestmapper" @@ -103,7 +104,8 @@ func TestInventoryPolicyApplyFilter(t *testing.T) { Inv: inventory.WrapInventoryInfoObj(invObj), InvPolicy: tc.policy, } - actual, reason, err := filter.Filter(obj) + ctx := context.TODO() + actual, reason, err := filter.Filter(ctx, obj) if tc.isError != (err != nil) { t.Fatalf("Expected InventoryPolicyFilter error (%v), got (%v)", tc.isError, (err != nil)) } diff --git a/pkg/apply/filter/inventory-policy-filter.go b/pkg/apply/filter/inventory-policy-filter.go index 476cb36f..f021b4d3 100644 --- a/pkg/apply/filter/inventory-policy-filter.go +++ b/pkg/apply/filter/inventory-policy-filter.go @@ -4,6 +4,7 @@ package filter import ( + "context" "fmt" "k8s.io/apimachinery/pkg/apis/meta/v1/unstructured" @@ -26,7 +27,7 @@ func (ipf InventoryPolicyFilter) Name() string { // Filter returns true if the passed object should NOT be pruned (deleted) // because the "prevent remove" annotation is present; otherwise returns // false. Never returns an error. -func (ipf InventoryPolicyFilter) Filter(obj *unstructured.Unstructured) (bool, string, error) { +func (ipf InventoryPolicyFilter) Filter(ctx context.Context, obj *unstructured.Unstructured) (bool, string, error) { // Check the inventory id "match" and the adopt policy to determine // if an object should be pruned (deleted). if !inventory.CanPrune(ipf.Inv, obj, ipf.InvPolicy) { diff --git a/pkg/apply/filter/inventory-policy-filter_test.go b/pkg/apply/filter/inventory-policy-filter_test.go index 8f7378fc..947a02c8 100644 --- a/pkg/apply/filter/inventory-policy-filter_test.go +++ b/pkg/apply/filter/inventory-policy-filter_test.go @@ -4,6 +4,7 @@ package filter import ( + "context" "testing" "k8s.io/apimachinery/pkg/apis/meta/v1/unstructured" @@ -89,7 +90,8 @@ func TestInventoryPolicyFilter(t *testing.T) { } obj := defaultObj.DeepCopy() obj.SetAnnotations(objIDAnnotation) - actual, reason, err := filter.Filter(obj) + ctx := context.TODO() + actual, reason, err := filter.Filter(ctx, obj) if err != nil { t.Fatalf("InventoryPolicyFilter unexpected error (%s)", err) } diff --git a/pkg/apply/filter/local-namespaces-filter.go b/pkg/apply/filter/local-namespaces-filter.go index d9488164..ee795b9e 100644 --- a/pkg/apply/filter/local-namespaces-filter.go +++ b/pkg/apply/filter/local-namespaces-filter.go @@ -4,6 +4,7 @@ package filter import ( + "context" "fmt" "k8s.io/apimachinery/pkg/apis/meta/v1/unstructured" @@ -27,7 +28,7 @@ func (lnf LocalNamespacesFilter) Name() string { // because the it is a namespace that objects still reside in; otherwise // returns false. This filter should not be added to the list of filters // for "destroying", since every object is being delete. Never returns an error. -func (lnf LocalNamespacesFilter) Filter(obj *unstructured.Unstructured) (bool, string, error) { +func (lnf LocalNamespacesFilter) Filter(ctx context.Context, obj *unstructured.Unstructured) (bool, string, error) { id := object.UnstructuredToObjMetaOrDie(obj) if id.GroupKind == object.CoreV1Namespace.GroupKind() && lnf.LocalNamespaces.Has(id.Name) { diff --git a/pkg/apply/filter/local-namespaces-filter_test.go b/pkg/apply/filter/local-namespaces-filter_test.go index 0ccec89b..aaf38529 100644 --- a/pkg/apply/filter/local-namespaces-filter_test.go +++ b/pkg/apply/filter/local-namespaces-filter_test.go @@ -4,6 +4,7 @@ package filter import ( + "context" "testing" "k8s.io/apimachinery/pkg/apis/meta/v1/unstructured" @@ -50,7 +51,8 @@ func TestLocalNamespacesFilter(t *testing.T) { } namespace := testNamespace.DeepCopy() namespace.SetName(tc.namespace) - actual, reason, err := filter.Filter(namespace) + ctx := context.TODO() + actual, reason, err := filter.Filter(ctx, namespace) if err != nil { t.Fatalf("LocalNamespacesFilter unexpected error (%s)", err) } diff --git a/pkg/apply/filter/prevent-remove-filter.go b/pkg/apply/filter/prevent-remove-filter.go index 7c873059..0dce433f 100644 --- a/pkg/apply/filter/prevent-remove-filter.go +++ b/pkg/apply/filter/prevent-remove-filter.go @@ -4,6 +4,7 @@ package filter import ( + "context" "fmt" "k8s.io/apimachinery/pkg/apis/meta/v1/unstructured" @@ -24,7 +25,7 @@ func (prf PreventRemoveFilter) Name() string { // Filter returns true if the passed object should NOT be pruned (deleted) // because the "prevent remove" annotation is present; otherwise returns // false. Never returns an error. -func (prf PreventRemoveFilter) Filter(obj *unstructured.Unstructured) (bool, string, error) { +func (prf PreventRemoveFilter) Filter(ctx context.Context, obj *unstructured.Unstructured) (bool, string, error) { for annotation, value := range obj.GetAnnotations() { if common.NoDeletion(annotation, value) { reason := fmt.Sprintf("object removal prevented; delete annotation: %s/%s", annotation, value) diff --git a/pkg/apply/filter/prevent-remove-filter_test.go b/pkg/apply/filter/prevent-remove-filter_test.go index 742abd49..9a269373 100644 --- a/pkg/apply/filter/prevent-remove-filter_test.go +++ b/pkg/apply/filter/prevent-remove-filter_test.go @@ -4,6 +4,7 @@ package filter import ( + "context" "testing" "k8s.io/apimachinery/pkg/apis/meta/v1/unstructured" @@ -71,7 +72,8 @@ func TestPreventDeleteAnnotation(t *testing.T) { filter := PreventRemoveFilter{} obj := defaultObj.DeepCopy() obj.SetAnnotations(tc.annotations) - actual, reason, err := filter.Filter(obj) + ctx := context.TODO() + actual, reason, err := filter.Filter(ctx, obj) if err != nil { t.Fatalf("PreventRemoveFilter unexpected error (%s)", err) } diff --git a/pkg/apply/prune/prune.go b/pkg/apply/prune/prune.go index 47503ff6..4b2e3202 100644 --- a/pkg/apply/prune/prune.go +++ b/pkg/apply/prune/prune.go @@ -14,6 +14,7 @@ package prune import ( "context" "sort" + "time" "k8s.io/apimachinery/pkg/api/meta" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" @@ -29,18 +30,18 @@ import ( "sigs.k8s.io/cli-utils/pkg/ordering" ) -// PruneOptions encapsulates the necessary information to +// Pruner encapsulates the necessary information to // implement the prune functionality. -type PruneOptions struct { +type Pruner struct { InvClient inventory.InventoryClient Client dynamic.Interface Mapper meta.RESTMapper } -// NewPruneOptions returns a struct (PruneOptions) encapsulating the necessary +// NewPruner returns a struct (PruneOptions) encapsulating the necessary // information to run the prune. Returns an error if an error occurs // gathering this information. -func NewPruneOptions(factory util.Factory, invClient inventory.InventoryClient) (*PruneOptions, error) { +func NewPruner(factory util.Factory, invClient inventory.InventoryClient) (*Pruner, error) { // Client/Builder fields from the Factory. client, err := factory.DynamicClient() if err != nil { @@ -50,7 +51,7 @@ func NewPruneOptions(factory util.Factory, invClient inventory.InventoryClient) if err != nil { return nil, err } - return &PruneOptions{ + return &Pruner{ InvClient: invClient, Client: client, Mapper: mapper, @@ -64,7 +65,14 @@ type Options struct { // we should just print what would happen without actually doing it. DryRunStrategy common.DryRunStrategy - PropagationPolicy metav1.DeletionPropagation + // DeleteTimeout defines how long we should wait for resources to be fully + // deleted. If propegating deletion to dependencies, this timeout applies to + // all of the deletions collectively. + DeleteTimeout time.Duration + + // DeletionPropagationPolicy specifies how to deal with dependencies: + // Orphan, Background, or Foreground. + DeletionPropagationPolicy metav1.DeletionPropagation // True if we are destroying, which deletes the inventory object // as well (possibly) the inventory namespace. @@ -81,15 +89,17 @@ type Options struct { // automatically prune/delete). // // Parameters: +// taskContext - context for apply/prune // pruneObjs - objects to prune (delete) // pruneFilters - list of filters for deletion permission -// taskContext - task for apply/prune -// o - options for dry-run -func (po *PruneOptions) Prune(pruneObjs []*unstructured.Unstructured, - pruneFilters []filter.ValidationFilter, +// opts - options for dry-run +func (po *Pruner) Prune( taskContext *taskrunner.TaskContext, - o Options) error { - eventFactory := CreateEventFactory(o.Destroy) + pruneObjs []*unstructured.Unstructured, + pruneFilters []filter.ValidationFilter, + opts Options, +) error { + eventFactory := CreateEventFactory(opts.Destroy) // Iterate through objects to prune (delete). If an object is not pruned // and we need to keep it in the inventory, we must capture the prune failure. for _, pruneObj := range pruneObjs { @@ -101,7 +111,7 @@ func (po *PruneOptions) Prune(pruneObjs []*unstructured.Unstructured, var err error for _, filter := range pruneFilters { klog.V(6).Infof("prune filter %s: %s", filter.Name(), pruneID) - filtered, reason, err = filter.Filter(pruneObj) + filtered, reason, err = filter.Filter(taskContext.Context(), pruneObj) if err != nil { if klog.V(5).Enabled() { klog.Errorf("error during %s, (%s): %s", filter.Name(), pruneID, err) @@ -121,20 +131,9 @@ func (po *PruneOptions) Prune(pruneObjs []*unstructured.Unstructured, continue } // Filters passed--actually delete object if not dry run. - if !o.DryRunStrategy.ClientOrServerDryRun() { + if !opts.DryRunStrategy.ClientOrServerDryRun() { klog.V(4).Infof("prune object delete: %s", pruneID) - namespacedClient, err := po.namespacedClient(pruneID) - if err != nil { - if klog.V(4).Enabled() { - klog.Errorf("prune failed for %s (%s)", pruneID, err) - } - taskContext.EventChannel() <- eventFactory.CreateFailedEvent(pruneID, err) - taskContext.CapturePruneFailure(pruneID) - continue - } - err = namespacedClient.Delete(context.TODO(), pruneID.Name, metav1.DeleteOptions{ - PropagationPolicy: &o.PropagationPolicy, - }) + err = po.delete(taskContext, pruneID, opts) if err != nil { if klog.V(4).Enabled() { klog.Errorf("prune failed for %s (%s)", pruneID, err) @@ -149,21 +148,52 @@ func (po *PruneOptions) Prune(pruneObjs []*unstructured.Unstructured, return nil } +// delete an object with timeout +func (po *Pruner) delete( + taskContext *taskrunner.TaskContext, + objMeta object.ObjMetadata, + opts Options, +) error { + namespacedClient, err := po.namespacedClient(objMeta) + if err != nil { + return err + } + + // Cancel the delete when the delete timeout is reached, or when the parent + // context is cancelled, whichever comes first. + ctx, cancelFunc := context.WithTimeout(taskContext.Context(), opts.DeleteTimeout) + // Clean up on completion, whether timeout occurred or not. + // This should not affect the parent context. + defer cancelFunc() + + return namespacedClient.Delete( + ctx, + objMeta.Name, + metav1.DeleteOptions{ + PropagationPolicy: &opts.DeletionPropagationPolicy, + }, + ) +} + // GetPruneObjs calculates the set of prune objects, and retrieves them // from the cluster. Set of prune objects equals the set of inventory // objects minus the set of currently applied objects. Returns an error // if one occurs. -func (po *PruneOptions) GetPruneObjs(inv inventory.InventoryInfo, - localObjs []*unstructured.Unstructured, o Options) ([]*unstructured.Unstructured, error) { +func (po *Pruner) GetPruneObjs( + ctx context.Context, + inv inventory.InventoryInfo, + localObjs []*unstructured.Unstructured, + opts Options, +) ([]*unstructured.Unstructured, error) { localIds := object.UnstructuredsToObjMetasOrDie(localObjs) - prevInvIds, err := po.InvClient.GetClusterObjs(inv, o.DryRunStrategy) + prevInvIds, err := po.InvClient.GetClusterObjs(inv, opts.DryRunStrategy) if err != nil { return nil, err } pruneIds := object.SetDiff(prevInvIds, localIds) pruneObjs := []*unstructured.Unstructured{} for _, pruneID := range pruneIds { - pruneObj, err := po.GetObject(pruneID) + pruneObj, err := po.getObject(ctx, pruneID) if err != nil { return nil, err } @@ -173,17 +203,17 @@ func (po *PruneOptions) GetPruneObjs(inv inventory.InventoryInfo, return pruneObjs, nil } -// GetObject uses the passed object data to retrieve the object +// getObject uses the passed object data to retrieve the object // from the cluster (or an error if one occurs). -func (po *PruneOptions) GetObject(obj object.ObjMetadata) (*unstructured.Unstructured, error) { +func (po *Pruner) getObject(ctx context.Context, obj object.ObjMetadata) (*unstructured.Unstructured, error) { namespacedClient, err := po.namespacedClient(obj) if err != nil { return nil, err } - return namespacedClient.Get(context.TODO(), obj.Name, metav1.GetOptions{}) + return namespacedClient.Get(ctx, obj.Name, metav1.GetOptions{}) } -func (po *PruneOptions) namespacedClient(obj object.ObjMetadata) (dynamic.ResourceInterface, error) { +func (po *Pruner) namespacedClient(obj object.ObjMetadata) (dynamic.ResourceInterface, error) { mapping, err := po.Mapper.RESTMapping(obj.GroupKind) if err != nil { return nil, err diff --git a/pkg/apply/prune/prune_test.go b/pkg/apply/prune/prune_test.go index 2b619ab6..7581cd22 100644 --- a/pkg/apply/prune/prune_test.go +++ b/pkg/apply/prune/prune_test.go @@ -144,17 +144,17 @@ var preventDelete = &unstructured.Unstructured{ // Options with different dry-run values. var ( defaultOptions = Options{ - DryRunStrategy: common.DryRunNone, - PropagationPolicy: metav1.DeletePropagationBackground, + DryRunStrategy: common.DryRunNone, + DeletionPropagationPolicy: metav1.DeletePropagationBackground, } defaultOptionsDestroy = Options{ - DryRunStrategy: common.DryRunNone, - PropagationPolicy: metav1.DeletePropagationBackground, - Destroy: true, + DryRunStrategy: common.DryRunNone, + DeletionPropagationPolicy: metav1.DeletePropagationBackground, + Destroy: true, } clientDryRunOptions = Options{ - DryRunStrategy: common.DryRunClient, - PropagationPolicy: metav1.DeletePropagationBackground, + DryRunStrategy: common.DryRunClient, + DeletionPropagationPolicy: metav1.DeletePropagationBackground, } ) @@ -257,9 +257,9 @@ func TestPrune(t *testing.T) { "Server dry run still deleted event": { pruneObjs: []*unstructured.Unstructured{pod}, options: Options{ - DryRunStrategy: common.DryRunServer, - PropagationPolicy: metav1.DeletePropagationBackground, - Destroy: true, + DryRunStrategy: common.DryRunServer, + DeletionPropagationPolicy: metav1.DeletePropagationBackground, + Destroy: true, }, expectedEvents: []testutil.ExpEvent{ { @@ -386,7 +386,7 @@ func TestPrune(t *testing.T) { pruneIds, err := object.UnstructuredsToObjMetas(tc.pruneObjs) require.NoError(t, err) - po := PruneOptions{ + po := Pruner{ InvClient: inventory.NewFakeInventoryClient(pruneIds), Client: fake.NewSimpleDynamicClient(scheme.Scheme, objs...), Mapper: testrestmapper.TestOnlyStaticRESTMapper(scheme.Scheme, @@ -395,11 +395,11 @@ func TestPrune(t *testing.T) { // The event channel can not block; make sure its bigger than all // the events that can be put on it. eventChannel := make(chan event.Event, len(tc.pruneObjs)+1) - taskContext := taskrunner.NewTaskContext(eventChannel) + taskContext := taskrunner.NewTaskContext(context.TODO(), eventChannel) err = func() error { defer close(eventChannel) // Run the prune and validate. - return po.Prune(tc.pruneObjs, tc.pruneFilters, taskContext, tc.options) + return po.Prune(taskContext, tc.pruneObjs, tc.pruneFilters, tc.options) }() if err != nil { @@ -473,7 +473,7 @@ func TestPruneWithErrors(t *testing.T) { t.Run(name, func(t *testing.T) { pruneIds, err := object.UnstructuredsToObjMetas(tc.pruneObjs) require.NoError(t, err) - po := PruneOptions{ + po := Pruner{ InvClient: inventory.NewFakeInventoryClient(pruneIds), // Set up the fake dynamic client to recognize all objects, and the RESTMapper. Client: &fakeDynamicClient{ @@ -485,7 +485,7 @@ func TestPruneWithErrors(t *testing.T) { // The event channel can not block; make sure its bigger than all // the events that can be put on it. eventChannel := make(chan event.Event, len(tc.pruneObjs)) - taskContext := taskrunner.NewTaskContext(eventChannel) + taskContext := taskrunner.NewTaskContext(context.TODO(), eventChannel) err = func() error { defer close(eventChannel) var opts Options @@ -495,7 +495,7 @@ func TestPruneWithErrors(t *testing.T) { opts = defaultOptions } // Run the prune and validate. - return po.Prune(tc.pruneObjs, []filter.ValidationFilter{}, taskContext, opts) + return po.Prune(taskContext, tc.pruneObjs, []filter.ValidationFilter{}, opts) }() if err != nil { t.Fatalf("Unexpected error during Prune(): %#v", err) @@ -558,14 +558,19 @@ func TestGetPruneObjs(t *testing.T) { for _, obj := range tc.prevInventory { objs = append(objs, obj) } - po := PruneOptions{ + po := Pruner{ InvClient: inventory.NewFakeInventoryClient(object.UnstructuredsToObjMetasOrDie(tc.prevInventory)), Client: fake.NewSimpleDynamicClient(scheme.Scheme, objs...), Mapper: testrestmapper.TestOnlyStaticRESTMapper(scheme.Scheme, scheme.Scheme.PrioritizedVersionsAllGroups()...), } currentInventory := createInventoryInfo(tc.prevInventory...) - actualObjs, err := po.GetPruneObjs(currentInventory, tc.localObjs, Options{}) + actualObjs, err := po.GetPruneObjs( + context.TODO(), + currentInventory, + tc.localObjs, + Options{}, + ) if err != nil { t.Fatalf("unexpected error %s returned", err) } @@ -611,7 +616,7 @@ func TestPrune_PropagationPolicy(t *testing.T) { for name, tc := range testCases { t.Run(name, func(t *testing.T) { captureClient := &optionsCaptureNamespaceClient{} - po := PruneOptions{ + po := Pruner{ InvClient: inventory.NewFakeInventoryClient([]object.ObjMetadata{}), Client: &fakeDynamicClient{ resourceInterface: captureClient, @@ -621,10 +626,15 @@ func TestPrune_PropagationPolicy(t *testing.T) { } eventChannel := make(chan event.Event, 1) - taskContext := taskrunner.NewTaskContext(eventChannel) - err := po.Prune([]*unstructured.Unstructured{pdb}, []filter.ValidationFilter{}, taskContext, Options{ - PropagationPolicy: tc.propagationPolicy, - }) + taskContext := taskrunner.NewTaskContext(context.TODO(), eventChannel) + err := po.Prune( + taskContext, + []*unstructured.Unstructured{pdb}, + []filter.ValidationFilter{}, + Options{ + DeletionPropagationPolicy: tc.propagationPolicy, + }, + ) assert.NoError(t, err) require.NotNil(t, captureClient.options.PropagationPolicy) assert.Equal(t, tc.propagationPolicy, *captureClient.options.PropagationPolicy) diff --git a/pkg/apply/solver/solver.go b/pkg/apply/solver/solver.go index 8728dc0b..81e3d865 100644 --- a/pkg/apply/solver/solver.go +++ b/pkg/apply/solver/solver.go @@ -38,11 +38,11 @@ import ( const defaultWaitTimeout = 1 * time.Minute type TaskQueueBuilder struct { - PruneOptions *prune.PruneOptions - InfoHelper info.InfoHelper - Factory util.Factory - Mapper meta.RESTMapper - InvClient inventory.InventoryClient + Pruner *prune.Pruner + InfoHelper info.InfoHelper + Factory util.Factory + Mapper meta.RESTMapper + InvClient inventory.InventoryClient // True if we are destroying, which deletes the inventory object // as well (possibly) the inventory namespace. Destroy bool @@ -83,13 +83,12 @@ func (tq *TaskQueue) ToActionGroups() []event.ActionGroup { } type Options struct { - ServerSideOptions common.ServerSideOptions - ReconcileTimeout time.Duration - Prune bool - DryRunStrategy common.DryRunStrategy - PrunePropagationPolicy metav1.DeletionPropagation - PruneTimeout time.Duration - InventoryPolicy inventory.InventoryPolicy + ServerSideOptions common.ServerSideOptions + ReconcileTimeout time.Duration + DryRunStrategy common.DryRunStrategy + DeletePropagationPolicy metav1.DeletionPropagation + DeleteTimeout time.Duration + InventoryPolicy inventory.InventoryPolicy } // Build returns the queue of tasks that have been created. @@ -156,15 +155,18 @@ func (t *TaskQueueBuilder) AppendDeleteInvTask(inv inventory.InventoryInfo, dryR // AppendInvAddTask appends a task to the task queue to apply the passed objects // to the cluster. Returns a pointer to the Builder to chain function calls. -func (t *TaskQueueBuilder) AppendApplyTask(applyObjs []*unstructured.Unstructured, - applyFilters []filter.ValidationFilter, o Options) *TaskQueueBuilder { +func (t *TaskQueueBuilder) AppendApplyTask( + applyObjs []*unstructured.Unstructured, + applyFilters []filter.ValidationFilter, + opts Options, +) *TaskQueueBuilder { klog.V(2).Infof("adding apply task (%d objects)", len(applyObjs)) t.tasks = append(t.tasks, &task.ApplyTask{ TaskName: fmt.Sprintf("apply-%d", t.applyCounter), Objects: applyObjs, Filters: applyFilters, - ServerSideOptions: o.ServerSideOptions, - DryRunStrategy: o.DryRunStrategy, + ServerSideOptions: opts.ServerSideOptions, + DryRunStrategy: opts.DryRunStrategy, InfoHelper: t.InfoHelper, Factory: t.Factory, Mapper: t.Mapper, @@ -183,26 +185,31 @@ func (t *TaskQueueBuilder) AppendWaitTask(waitIds []object.ObjMetadata, conditio waitIds, condition, waitTimeout, - t.Mapper), - ) + t.Mapper, + )) t.waitCounter += 1 return t } // AppendInvAddTask appends a task to delete objects from the cluster to the task queue. // Returns a pointer to the Builder to chain function calls. -func (t *TaskQueueBuilder) AppendPruneTask(pruneObjs []*unstructured.Unstructured, - pruneFilters []filter.ValidationFilter, o Options) *TaskQueueBuilder { +func (t *TaskQueueBuilder) AppendPruneTask( + pruneObjs []*unstructured.Unstructured, + pruneFilters []filter.ValidationFilter, + opts Options, +) *TaskQueueBuilder { klog.V(2).Infof("adding prune task (%d objects)", len(pruneObjs)) t.tasks = append(t.tasks, &task.PruneTask{ - TaskName: fmt.Sprintf("prune-%d", t.pruneCounter), - Objects: pruneObjs, - Filters: pruneFilters, - PruneOptions: t.PruneOptions, - PropagationPolicy: o.PrunePropagationPolicy, - DryRunStrategy: o.DryRunStrategy, - Destroy: t.Destroy, + TaskName: fmt.Sprintf("prune-%d", t.pruneCounter), + Objects: pruneObjs, + Filters: pruneFilters, + Options: prune.Options{ + DeletionPropagationPolicy: opts.DeletePropagationPolicy, + DryRunStrategy: opts.DryRunStrategy, + Destroy: t.Destroy, + }, + PruneOptions: t.Pruner, }, ) t.pruneCounter += 1 @@ -212,18 +219,21 @@ func (t *TaskQueueBuilder) AppendPruneTask(pruneObjs []*unstructured.Unstructure // AppendApplyWaitTasks adds apply and wait tasks to the task queue, // depending on build variables (like dry-run) and resource types // (like CRD's). Returns a pointer to the Builder to chain function calls. -func (t *TaskQueueBuilder) AppendApplyWaitTasks(applyObjs []*unstructured.Unstructured, - applyFilters []filter.ValidationFilter, o Options) *TaskQueueBuilder { +func (t *TaskQueueBuilder) AppendApplyWaitTasks( + applyObjs []*unstructured.Unstructured, + applyFilters []filter.ValidationFilter, + opts Options, +) *TaskQueueBuilder { // Use the "depends-on" annotation to create a graph, ands sort the // objects to apply into sets using a topological sort. applySets, err := graph.SortObjs(applyObjs) if err != nil { t.err = err } - addWaitTask, waitTimeout := waitTaskTimeout(o.DryRunStrategy.ClientOrServerDryRun(), - len(applySets), o.ReconcileTimeout) + addWaitTask, waitTimeout := waitTaskTimeout(opts.DryRunStrategy.ClientOrServerDryRun(), + len(applySets), opts.ReconcileTimeout) for _, applySet := range applySets { - t.AppendApplyTask(applySet, applyFilters, o) + t.AppendApplyTask(applySet, applyFilters, opts) if addWaitTask { applyIds := object.UnstructuredsToObjMetasOrDie(applySet) t.AppendWaitTask(applyIds, taskrunner.AllCurrent, waitTimeout) @@ -235,23 +245,24 @@ func (t *TaskQueueBuilder) AppendApplyWaitTasks(applyObjs []*unstructured.Unstru // AppendPruneWaitTasks adds prune and wait tasks to the task queue // based on build variables (like dry-run). Returns a pointer to the // Builder to chain function calls. -func (t *TaskQueueBuilder) AppendPruneWaitTasks(pruneObjs []*unstructured.Unstructured, - pruneFilters []filter.ValidationFilter, o Options) *TaskQueueBuilder { - if o.Prune { - // Use the "depends-on" annotation to create a graph, ands sort the - // objects to prune into sets using a (reverse) topological sort. - pruneSets, err := graph.ReverseSortObjs(pruneObjs) - if err != nil { - t.err = err - } - addWaitTask, waitTimeout := waitTaskTimeout(o.DryRunStrategy.ClientOrServerDryRun(), - len(pruneSets), o.ReconcileTimeout) - for _, pruneSet := range pruneSets { - t.AppendPruneTask(pruneSet, pruneFilters, o) - if addWaitTask { - pruneIds := object.UnstructuredsToObjMetasOrDie(pruneSet) - t.AppendWaitTask(pruneIds, taskrunner.AllNotFound, waitTimeout) - } +func (t *TaskQueueBuilder) AppendPruneWaitTasks( + pruneObjs []*unstructured.Unstructured, + pruneFilters []filter.ValidationFilter, + opts Options, +) *TaskQueueBuilder { + // Use the "depends-on" annotation to create a graph, ands sort the + // objects to prune into sets using a (reverse) topological sort. + pruneSets, err := graph.ReverseSortObjs(pruneObjs) + if err != nil { + t.err = err + } + addWaitTask, waitTimeout := waitTaskTimeout(opts.DryRunStrategy.ClientOrServerDryRun(), + len(pruneSets), opts.ReconcileTimeout) + for _, pruneSet := range pruneSets { + t.AppendPruneTask(pruneSet, pruneFilters, opts) + if addWaitTask { + pruneIds := object.UnstructuredsToObjMetasOrDie(pruneSet) + t.AppendWaitTask(pruneIds, taskrunner.AllNotFound, waitTimeout) } } return t diff --git a/pkg/apply/solver/solver_test.go b/pkg/apply/solver/solver_test.go index 32ffa929..c73935fb 100644 --- a/pkg/apply/solver/solver_test.go +++ b/pkg/apply/solver/solver_test.go @@ -21,8 +21,8 @@ import ( ) var ( - pruneOptions = &prune.PruneOptions{} - resources = map[string]string{ + pruner = &prune.Pruner{} + resources = map[string]string{ "pod": ` kind: Pod apiVersion: v1 @@ -352,9 +352,9 @@ func TestTaskQueueBuilder_AppendApplyWaitTasks(t *testing.T) { applyIds := object.UnstructuredsToObjMetasOrDie(tc.applyObjs) fakeInvClient := inventory.NewFakeInventoryClient(applyIds) tqb := TaskQueueBuilder{ - PruneOptions: pruneOptions, - Mapper: testutil.NewFakeRESTMapper(), - InvClient: fakeInvClient, + Pruner: pruner, + Mapper: testutil.NewFakeRESTMapper(), + InvClient: fakeInvClient, } tq, err := tqb.AppendApplyWaitTasks(tc.applyObjs, []filter.ValidationFilter{}, tc.options).Build() if tc.isError { @@ -398,7 +398,7 @@ func TestTaskQueueBuilder_AppendPruneWaitTasks(t *testing.T) { }{ "no resources, no tasks": { pruneObjs: []*unstructured.Unstructured{}, - options: Options{Prune: true}, + options: Options{}, expectedTasks: []taskrunner.Task{}, isError: false, }, @@ -406,7 +406,7 @@ func TestTaskQueueBuilder_AppendPruneWaitTasks(t *testing.T) { pruneObjs: []*unstructured.Unstructured{ testutil.Unstructured(t, resources["default-pod"]), }, - options: Options{Prune: true}, + options: Options{}, expectedTasks: []taskrunner.Task{ &task.PruneTask{ TaskName: "prune-0", @@ -422,7 +422,7 @@ func TestTaskQueueBuilder_AppendPruneWaitTasks(t *testing.T) { testutil.Unstructured(t, resources["default-pod"]), testutil.Unstructured(t, resources["pod"]), }, - options: Options{Prune: true}, + options: Options{}, expectedTasks: []taskrunner.Task{ &task.PruneTask{ TaskName: "prune-0", @@ -440,7 +440,7 @@ func TestTaskQueueBuilder_AppendPruneWaitTasks(t *testing.T) { testutil.AddDependsOn(t, testutil.Unstructured(t, resources["secret"]))), testutil.Unstructured(t, resources["secret"]), }, - options: Options{Prune: true}, + options: Options{}, // Opposite ordering when pruning/deleting expectedTasks: []taskrunner.Task{ &task.PruneTask{ @@ -480,7 +480,6 @@ func TestTaskQueueBuilder_AppendPruneWaitTasks(t *testing.T) { options: Options{ ReconcileTimeout: time.Minute, DryRunStrategy: common.DryRunServer, - Prune: true, }, // No wait task, since it is dry run expectedTasks: []taskrunner.Task{ @@ -500,7 +499,7 @@ func TestTaskQueueBuilder_AppendPruneWaitTasks(t *testing.T) { testutil.Unstructured(t, resources["crd"]), testutil.Unstructured(t, resources["crontab2"]), }, - options: Options{Prune: true}, + options: Options{}, // Opposite ordering when pruning/deleting. expectedTasks: []taskrunner.Task{ &task.PruneTask{ @@ -543,7 +542,6 @@ func TestTaskQueueBuilder_AppendPruneWaitTasks(t *testing.T) { options: Options{ ReconcileTimeout: time.Minute, DryRunStrategy: common.DryRunClient, - Prune: true, }, expectedTasks: []taskrunner.Task{ &task.PruneTask{ @@ -568,7 +566,7 @@ func TestTaskQueueBuilder_AppendPruneWaitTasks(t *testing.T) { testutil.Unstructured(t, resources["pod"]), testutil.Unstructured(t, resources["secret"]), }, - options: Options{Prune: true}, + options: Options{}, expectedTasks: []taskrunner.Task{ &task.PruneTask{ TaskName: "prune-0", @@ -608,7 +606,7 @@ func TestTaskQueueBuilder_AppendPruneWaitTasks(t *testing.T) { testutil.Unstructured(t, resources["secret"], testutil.AddDependsOn(t, testutil.Unstructured(t, resources["deployment"]))), }, - options: Options{Prune: true}, + options: Options{}, expectedTasks: []taskrunner.Task{}, isError: true, }, @@ -619,9 +617,9 @@ func TestTaskQueueBuilder_AppendPruneWaitTasks(t *testing.T) { pruneIds := object.UnstructuredsToObjMetasOrDie(tc.pruneObjs) fakeInvClient := inventory.NewFakeInventoryClient(pruneIds) tqb := TaskQueueBuilder{ - PruneOptions: pruneOptions, - Mapper: testutil.NewFakeRESTMapper(), - InvClient: fakeInvClient, + Pruner: pruner, + Mapper: testutil.NewFakeRESTMapper(), + InvClient: fakeInvClient, } emptyPruneFilters := []filter.ValidationFilter{} tq, err := tqb.AppendPruneWaitTasks(tc.pruneObjs, emptyPruneFilters, tc.options).Build() diff --git a/pkg/apply/task/apply_task.go b/pkg/apply/task/apply_task.go index 597f5114..080d5ab9 100644 --- a/pkg/apply/task/apply_task.go +++ b/pkg/apply/task/apply_task.go @@ -113,7 +113,7 @@ func (a *ApplyTask) Start(taskContext *taskrunner.TaskContext) { for _, filter := range a.Filters { klog.V(6).Infof("apply filter %s: %s", filter.Name(), id) var reason string - filtered, reason, filterErr = filter.Filter(obj) + filtered, reason, filterErr = filter.Filter(taskContext.Context(), obj) if filterErr != nil { if klog.V(5).Enabled() { klog.Errorf("error during %s, (%s): %s", filter.Name(), id, filterErr) @@ -161,7 +161,7 @@ func (a *ApplyTask) Start(taskContext *taskrunner.TaskContext) { }() } -func newApplyOptions(eventChannel chan event.Event, serverSideOptions common.ServerSideOptions, +func newApplyOptions(eventChannel chan<- event.Event, serverSideOptions common.ServerSideOptions, strategy common.DryRunStrategy, factory util.Factory) (applyOptions, error) { discovery, err := factory.ToDiscoveryClient() if err != nil { @@ -207,8 +207,8 @@ func (a *ApplyTask) sendTaskResult(taskContext *taskrunner.TaskContext) { taskContext.TaskChannel() <- taskrunner.TaskResult{} } -// ClearTimeout is not supported by the ApplyTask. -func (a *ApplyTask) ClearTimeout() {} +// OnStatusEvent is not supported by ApplyTask. +func (a *ApplyTask) OnStatusEvent(taskContext *taskrunner.TaskContext, e event.StatusEvent) {} // createApplyEvent is a helper function to package an apply event for a single resource. func createApplyEvent(id object.ObjMetadata, operation event.ApplyEventOperation, resource *unstructured.Unstructured) event.Event { @@ -254,7 +254,7 @@ func isStreamError(err error) bool { return strings.Contains(err.Error(), "stream error: stream ID ") } -func clientSideApply(info *resource.Info, eventChannel chan event.Event, strategy common.DryRunStrategy, factory util.Factory) error { +func clientSideApply(info *resource.Info, eventChannel chan<- event.Event, strategy common.DryRunStrategy, factory util.Factory) error { ao, err := applyOptionsFactoryFunc(eventChannel, common.ServerSideOptions{ServerSideApply: false}, strategy, factory) if err != nil { return err diff --git a/pkg/apply/task/apply_task_test.go b/pkg/apply/task/apply_task_test.go index fae754c4..bac3309f 100644 --- a/pkg/apply/task/apply_task_test.go +++ b/pkg/apply/task/apply_task_test.go @@ -4,6 +4,7 @@ package task import ( + "context" "fmt" "strings" "sync" @@ -79,12 +80,12 @@ func TestApplyTask_BasicAppliedObjects(t *testing.T) { t.Run(tn, func(t *testing.T) { eventChannel := make(chan event.Event) defer close(eventChannel) - taskContext := taskrunner.NewTaskContext(eventChannel) + taskContext := taskrunner.NewTaskContext(context.TODO(), eventChannel) objs := toUnstructureds(tc.applied) oldAO := applyOptionsFactoryFunc - applyOptionsFactoryFunc = func(chan event.Event, common.ServerSideOptions, common.DryRunStrategy, util.Factory) (applyOptions, error) { + applyOptionsFactoryFunc = func(chan<- event.Event, common.ServerSideOptions, common.DryRunStrategy, util.Factory) (applyOptions, error) { return &fakeApplyOptions{}, nil } defer func() { applyOptionsFactoryFunc = oldAO }() @@ -164,12 +165,12 @@ func TestApplyTask_FetchGeneration(t *testing.T) { t.Run(tn, func(t *testing.T) { eventChannel := make(chan event.Event) defer close(eventChannel) - taskContext := taskrunner.NewTaskContext(eventChannel) + taskContext := taskrunner.NewTaskContext(context.TODO(), eventChannel) objs := toUnstructureds(tc.rss) oldAO := applyOptionsFactoryFunc - applyOptionsFactoryFunc = func(chan event.Event, common.ServerSideOptions, common.DryRunStrategy, util.Factory) (applyOptions, error) { + applyOptionsFactoryFunc = func(chan<- event.Event, common.ServerSideOptions, common.DryRunStrategy, util.Factory) (applyOptions, error) { return &fakeApplyOptions{}, nil } defer func() { applyOptionsFactoryFunc = oldAO }() @@ -275,7 +276,7 @@ func TestApplyTask_DryRun(t *testing.T) { drs := common.Strategies[i] t.Run(tn, func(t *testing.T) { eventChannel := make(chan event.Event) - taskContext := taskrunner.NewTaskContext(eventChannel) + taskContext := taskrunner.NewTaskContext(context.TODO(), eventChannel) restMapper := testutil.NewFakeRESTMapper(schema.GroupVersionKind{ Group: "apps", @@ -289,7 +290,7 @@ func TestApplyTask_DryRun(t *testing.T) { ao := &fakeApplyOptions{} oldAO := applyOptionsFactoryFunc - applyOptionsFactoryFunc = func(chan event.Event, common.ServerSideOptions, common.DryRunStrategy, util.Factory) (applyOptions, error) { + applyOptionsFactoryFunc = func(chan<- event.Event, common.ServerSideOptions, common.DryRunStrategy, util.Factory) (applyOptions, error) { return ao, nil } defer func() { applyOptionsFactoryFunc = oldAO }() @@ -409,7 +410,7 @@ func TestApplyTaskWithError(t *testing.T) { drs := common.DryRunNone t.Run(tn, func(t *testing.T) { eventChannel := make(chan event.Event) - taskContext := taskrunner.NewTaskContext(eventChannel) + taskContext := taskrunner.NewTaskContext(context.TODO(), eventChannel) restMapper := testutil.NewFakeRESTMapper(schema.GroupVersionKind{ Group: "apps", @@ -423,7 +424,7 @@ func TestApplyTaskWithError(t *testing.T) { ao := &fakeApplyOptions{} oldAO := applyOptionsFactoryFunc - applyOptionsFactoryFunc = func(chan event.Event, common.ServerSideOptions, common.DryRunStrategy, util.Factory) (applyOptions, error) { + applyOptionsFactoryFunc = func(chan<- event.Event, common.ServerSideOptions, common.DryRunStrategy, util.Factory) (applyOptions, error) { return ao, nil } defer func() { applyOptionsFactoryFunc = oldAO }() diff --git a/pkg/apply/task/delete_inv_task.go b/pkg/apply/task/delete_inv_task.go index 356fa6aa..15ed221d 100644 --- a/pkg/apply/task/delete_inv_task.go +++ b/pkg/apply/task/delete_inv_task.go @@ -49,5 +49,5 @@ func (i *DeleteInvTask) Start(taskContext *taskrunner.TaskContext) { }() } -// ClearTimeout is not supported by the DeleteInvTask. -func (i *DeleteInvTask) ClearTimeout() {} +// OnStatusEvent is not supported by DeleteInvTask. +func (i *DeleteInvTask) OnStatusEvent(taskContext *taskrunner.TaskContext, e event.StatusEvent) {} diff --git a/pkg/apply/task/delete_inv_task_test.go b/pkg/apply/task/delete_inv_task_test.go index ab34fb61..087d4baa 100644 --- a/pkg/apply/task/delete_inv_task_test.go +++ b/pkg/apply/task/delete_inv_task_test.go @@ -4,6 +4,7 @@ package task import ( + "context" "testing" apierrors "k8s.io/apimachinery/pkg/api/errors" @@ -39,7 +40,7 @@ func TestDeleteInvTask(t *testing.T) { client := inventory.NewFakeInventoryClient([]object.ObjMetadata{}) client.Err = tc.err eventChannel := make(chan event.Event) - context := taskrunner.NewTaskContext(eventChannel) + context := taskrunner.NewTaskContext(context.TODO(), eventChannel) task := DeleteInvTask{ TaskName: taskName, InvClient: client, diff --git a/pkg/apply/task/inv_add_task.go b/pkg/apply/task/inv_add_task.go index 75f0c560..5806187e 100644 --- a/pkg/apply/task/inv_add_task.go +++ b/pkg/apply/task/inv_add_task.go @@ -60,8 +60,8 @@ func (i *InvAddTask) Start(taskContext *taskrunner.TaskContext) { }() } -// ClearTimeout is not supported by the InvAddTask. -func (i *InvAddTask) ClearTimeout() {} +// OnStatusEvent is not supported by InvAddTask. +func (i *InvAddTask) OnStatusEvent(taskContext *taskrunner.TaskContext, e event.StatusEvent) {} // inventoryNamespaceInSet returns the the namespace the passed inventory // object will be applied to, or nil if this namespace object does not exist diff --git a/pkg/apply/task/inv_add_task_test.go b/pkg/apply/task/inv_add_task_test.go index c7753825..753f77a5 100644 --- a/pkg/apply/task/inv_add_task_test.go +++ b/pkg/apply/task/inv_add_task_test.go @@ -4,6 +4,7 @@ package task import ( + "context" "testing" "github.com/stretchr/testify/require" @@ -109,7 +110,7 @@ func TestInvAddTask(t *testing.T) { t.Run(name, func(t *testing.T) { client := inventory.NewFakeInventoryClient(tc.initialObjs) eventChannel := make(chan event.Event) - context := taskrunner.NewTaskContext(eventChannel) + context := taskrunner.NewTaskContext(context.TODO(), eventChannel) task := InvAddTask{ TaskName: taskName, InvClient: client, diff --git a/pkg/apply/task/inv_set_task.go b/pkg/apply/task/inv_set_task.go index 35e93362..b323e12a 100644 --- a/pkg/apply/task/inv_set_task.go +++ b/pkg/apply/task/inv_set_task.go @@ -63,5 +63,5 @@ func (i *InvSetTask) Start(taskContext *taskrunner.TaskContext) { }() } -// ClearTimeout is not supported by the InvSetTask. -func (i *InvSetTask) ClearTimeout() {} +// OnStatusEvent is not supported by InvSetTask. +func (i *InvSetTask) OnStatusEvent(taskContext *taskrunner.TaskContext, e event.StatusEvent) {} diff --git a/pkg/apply/task/inv_set_task_test.go b/pkg/apply/task/inv_set_task_test.go index c365ac09..ae8cd956 100644 --- a/pkg/apply/task/inv_set_task_test.go +++ b/pkg/apply/task/inv_set_task_test.go @@ -4,6 +4,7 @@ package task import ( + "context" "testing" "sigs.k8s.io/cli-utils/pkg/apply/event" @@ -105,7 +106,7 @@ func TestInvSetTask(t *testing.T) { t.Run(name, func(t *testing.T) { client := inventory.NewFakeInventoryClient([]object.ObjMetadata{}) eventChannel := make(chan event.Event) - context := taskrunner.NewTaskContext(eventChannel) + context := taskrunner.NewTaskContext(context.TODO(), eventChannel) prevInventory := make(map[object.ObjMetadata]bool, len(tc.prevInventory)) for _, prevInvID := range tc.prevInventory { prevInventory[prevInvID] = true diff --git a/pkg/apply/task/prune_task.go b/pkg/apply/task/prune_task.go index ee6db475..10d66072 100644 --- a/pkg/apply/task/prune_task.go +++ b/pkg/apply/task/prune_task.go @@ -4,14 +4,12 @@ package task import ( - metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/apis/meta/v1/unstructured" "k8s.io/klog/v2" "sigs.k8s.io/cli-utils/pkg/apply/event" "sigs.k8s.io/cli-utils/pkg/apply/filter" "sigs.k8s.io/cli-utils/pkg/apply/prune" "sigs.k8s.io/cli-utils/pkg/apply/taskrunner" - "sigs.k8s.io/cli-utils/pkg/common" "sigs.k8s.io/cli-utils/pkg/object" ) @@ -21,14 +19,10 @@ import ( type PruneTask struct { TaskName string - PruneOptions *prune.PruneOptions - Objects []*unstructured.Unstructured - Filters []filter.ValidationFilter - DryRunStrategy common.DryRunStrategy - PropagationPolicy metav1.DeletionPropagation - // True if we are destroying, which deletes the inventory object - // as well (possibly) the inventory namespace. - Destroy bool + Options prune.Options + PruneOptions *prune.Pruner + Objects []*unstructured.Unstructured + Filters []filter.ValidationFilter } func (p *PruneTask) Name() string { @@ -37,7 +31,7 @@ func (p *PruneTask) Name() string { func (p *PruneTask) Action() event.ResourceAction { action := event.PruneAction - if p.Destroy { + if p.Options.Destroy { action = event.DeleteAction } return action @@ -60,17 +54,15 @@ func (p *PruneTask) Start(taskContext *taskrunner.TaskContext) { CurrentUIDs: taskContext.AppliedResourceUIDs(), } p.Filters = append(p.Filters, uidFilter) - err := p.PruneOptions.Prune(p.Objects, - p.Filters, taskContext, prune.Options{ - DryRunStrategy: p.DryRunStrategy, - PropagationPolicy: p.PropagationPolicy, - Destroy: p.Destroy, - }) - taskContext.TaskChannel() <- taskrunner.TaskResult{ - Err: err, - } + err := p.PruneOptions.Prune( + taskContext, + p.Objects, + p.Filters, + p.Options, + ) + taskContext.TaskChannel() <- taskrunner.TaskResult{Err: err} }() } -// ClearTimeout is not supported by the PruneTask. -func (p *PruneTask) ClearTimeout() {} +// OnStatusEvent is not supported by PruneTask. +func (p *PruneTask) OnStatusEvent(taskContext *taskrunner.TaskContext, e event.StatusEvent) {} diff --git a/pkg/apply/task/resetmapper_task.go b/pkg/apply/task/resetmapper_task.go index 877b8f15..4b41ebc9 100644 --- a/pkg/apply/task/resetmapper_task.go +++ b/pkg/apply/task/resetmapper_task.go @@ -51,11 +51,5 @@ func extractDeferredDiscoveryRESTMapper(mapper meta.RESTMapper) (*restmapper.Def } func (r *ResetRESTMapperTask) sendTaskResult(taskContext *taskrunner.TaskContext, err error) { - taskContext.TaskChannel() <- taskrunner.TaskResult{ - Err: err, - } + taskContext.TaskChannel() <- taskrunner.TaskResult{Err: err} } - -// ClearTimeout doesn't do anything as ResetRESTMapperTask doesn't support -// timeouts. -func (r *ResetRESTMapperTask) ClearTimeout() {} diff --git a/pkg/apply/task/resetmapper_task_test.go b/pkg/apply/task/resetmapper_task_test.go index 5eb68100..f336a16d 100644 --- a/pkg/apply/task/resetmapper_task_test.go +++ b/pkg/apply/task/resetmapper_task_test.go @@ -4,6 +4,7 @@ package task import ( + "context" "testing" "github.com/stretchr/testify/assert" @@ -42,7 +43,7 @@ func TestResetRESTMapperTask(t *testing.T) { t.Run(tn, func(t *testing.T) { eventChannel := make(chan event.Event) defer close(eventChannel) - taskContext := taskrunner.NewTaskContext(eventChannel) + taskContext := taskrunner.NewTaskContext(context.TODO(), eventChannel) mapper, discoveryClient := tc.toRESTMapper() diff --git a/pkg/apply/task/send_event_task.go b/pkg/apply/task/send_event_task.go index b1f79d65..8d4bdab5 100644 --- a/pkg/apply/task/send_event_task.go +++ b/pkg/apply/task/send_event_task.go @@ -25,6 +25,5 @@ func (s *SendEventTask) Start(taskContext *taskrunner.TaskContext) { }() } -// ClearTimeout doesn't do anything as SendEventTask doesn't support -// timeouts. -func (s *SendEventTask) ClearTimeout() {} +// OnStatusEvent is not supported by SendEventTask. +func (s *SendEventTask) OnStatusEvent(taskContext *taskrunner.TaskContext, e event.StatusEvent) {} diff --git a/pkg/apply/taskrunner/collector.go b/pkg/apply/taskrunner/collector.go index f608f914..8158454a 100644 --- a/pkg/apply/taskrunner/collector.go +++ b/pkg/apply/taskrunner/collector.go @@ -4,51 +4,85 @@ package taskrunner import ( + "sync" + "sigs.k8s.io/cli-utils/pkg/kstatus/polling/event" "sigs.k8s.io/cli-utils/pkg/kstatus/status" "sigs.k8s.io/cli-utils/pkg/object" ) -// newResourceStatusCollector returns a new resourceStatusCollector -// that will keep track of the status of the provided resources. -func newResourceStatusCollector(identifiers []object.ObjMetadata) *resourceStatusCollector { - rm := make(map[object.ObjMetadata]resourceStatus) +// Condition is a type that defines the types of conditions +// which a WaitTask can use. +type Condition string - for _, obj := range identifiers { - rm[obj] = resourceStatus{ - Identifier: obj, - CurrentStatus: status.UnknownStatus, - } +const ( + // AllCurrent Condition means all the provided resources + // has reached (and remains in) the Current status. + AllCurrent Condition = "AllCurrent" + + // AllNotFound Condition means all the provided resources + // has reached the NotFound status, i.e. they are all deleted + // from the cluster. + AllNotFound Condition = "AllNotFound" +) + +// Meets returns true if the provided status meets the condition and +// false if it does not. +func (c Condition) Meets(s status.Status) bool { + switch c { + case AllCurrent: + return s == status.CurrentStatus + case AllNotFound: + return s == status.NotFoundStatus + default: + return false } - return &resourceStatusCollector{ - resourceMap: rm, +} + +type ResourceGeneration struct { + Identifier object.ObjMetadata + Generation int64 +} + +// NewResourceStatusCollector returns a new resourceStatusCollector +// that will keep track of the status of the provided resources. +func NewResourceStatusCollector() *ResourceStatusCollector { + return &ResourceStatusCollector{ + resourceMap: make(map[object.ObjMetadata]ResourceStatus), } } -// resourceStatusCollector keeps track of the latest seen status for all the +// ResourceStatusCollector keeps track of the latest seen status for all the // resources that is of interest during the operation. -type resourceStatusCollector struct { - resourceMap map[object.ObjMetadata]resourceStatus +type ResourceStatusCollector struct { + resourceMap map[object.ObjMetadata]ResourceStatus + // mu protects concurrent map access + mu sync.Mutex } // resoureStatus contains the latest status for a given // resource as identified by the Identifier. -type resourceStatus struct { - Identifier object.ObjMetadata +type ResourceStatus struct { CurrentStatus status.Status Message string Generation int64 } -// resourceStatus updates the collector with the latest -// seen status for the given resource. -func (a *resourceStatusCollector) resourceStatus(r *event.ResourceStatus) { - if ri, found := a.resourceMap[r.Identifier]; found { - ri.CurrentStatus = r.Status - ri.Message = r.Message - ri.Generation = getGeneration(r) - a.resourceMap[r.Identifier] = ri - } +// Put updates the collector with the specified status +func (a *ResourceStatusCollector) Put(id object.ObjMetadata, rs ResourceStatus) { + a.mu.Lock() + defer a.mu.Unlock() + a.resourceMap[id] = rs +} + +// PutEventStatus updates the collector with the latest status from an +// ResourceStatus event. +func (a *ResourceStatusCollector) PutEventStatus(rs *event.ResourceStatus) { + a.Put(rs.Identifier, ResourceStatus{ + CurrentStatus: rs.Status, + Message: rs.Message, + Generation: getGeneration(rs), + }) } // getGeneration looks up the value of the generation field in the @@ -61,9 +95,9 @@ func getGeneration(r *event.ResourceStatus) int64 { return r.Resource.GetGeneration() } -// conditionMet tests whether the provided Condition holds true for +// ConditionMet tests whether the provided Condition holds true for // all resources given by the list of Ids. -func (a *resourceStatusCollector) conditionMet(rwd []resourceWaitData, c Condition) bool { +func (a *ResourceStatusCollector) ConditionMet(rwd []ResourceGeneration, c Condition) bool { switch c { case AllCurrent: return a.allMatchStatus(rwd, status.CurrentStatus) @@ -74,15 +108,26 @@ func (a *resourceStatusCollector) conditionMet(rwd []resourceWaitData, c Conditi } } +// matchStatus returns the status of any resources with the specified +// identifiers that match the supplied status. +func (a *ResourceStatusCollector) Get(id object.ObjMetadata) ResourceStatus { + a.mu.Lock() + defer a.mu.Unlock() + rs, found := a.resourceMap[id] + if !found { + return ResourceStatus{ + CurrentStatus: status.UnknownStatus, + } + } + return rs +} + // allMatchStatus checks whether all resources given by the // Ids parameter has the provided status. -func (a *resourceStatusCollector) allMatchStatus(rwd []resourceWaitData, s status.Status) bool { +func (a *ResourceStatusCollector) allMatchStatus(rwd []ResourceGeneration, s status.Status) bool { for _, wd := range rwd { - ri, found := a.resourceMap[wd.identifier] - if !found { - return false - } - if ri.Generation < wd.generation || ri.CurrentStatus != s { + rs := a.Get(wd.Identifier) + if rs.Generation < wd.Generation || rs.CurrentStatus != s { return false } } @@ -91,13 +136,10 @@ func (a *resourceStatusCollector) allMatchStatus(rwd []resourceWaitData, s statu // noneMatchStatus checks whether none of the resources given // by the Ids parameters has the provided status. -func (a *resourceStatusCollector) noneMatchStatus(rwd []resourceWaitData, s status.Status) bool { +func (a *ResourceStatusCollector) noneMatchStatus(rwd []ResourceGeneration, s status.Status) bool { for _, wd := range rwd { - ri, found := a.resourceMap[wd.identifier] - if !found { - return false - } - if ri.Generation < wd.generation || ri.CurrentStatus == s { + rs := a.Get(wd.Identifier) + if rs.Generation < wd.Generation || rs.CurrentStatus == s { return false } } diff --git a/pkg/apply/taskrunner/collector_test.go b/pkg/apply/taskrunner/collector_test.go index 28df349b..311666ab 100644 --- a/pkg/apply/taskrunner/collector_test.go +++ b/pkg/apply/taskrunner/collector_test.go @@ -32,92 +32,86 @@ func TestCollector_ConditionMet(t *testing.T) { } testCases := map[string]struct { - collectorState map[object.ObjMetadata]resourceStatus - waitTaskData []resourceWaitData + collectorState map[object.ObjMetadata]ResourceStatus + waitTaskData []ResourceGeneration condition Condition expectedResult bool }{ "single resource with current status": { - collectorState: map[object.ObjMetadata]resourceStatus{ + collectorState: map[object.ObjMetadata]ResourceStatus{ identifiers["dep"]: { - Identifier: identifiers["dep"], CurrentStatus: status.CurrentStatus, Generation: int64(42), }, }, - waitTaskData: []resourceWaitData{ + waitTaskData: []ResourceGeneration{ { - identifier: identifiers["dep"], - generation: int64(42), + Identifier: identifiers["dep"], + Generation: int64(42), }, }, condition: AllCurrent, expectedResult: true, }, "single resource with current status and old generation": { - collectorState: map[object.ObjMetadata]resourceStatus{ + collectorState: map[object.ObjMetadata]ResourceStatus{ identifiers["dep"]: { - Identifier: identifiers["dep"], CurrentStatus: status.CurrentStatus, Generation: int64(41), }, }, - waitTaskData: []resourceWaitData{ + waitTaskData: []ResourceGeneration{ { - identifier: identifiers["dep"], - generation: int64(42), + Identifier: identifiers["dep"], + Generation: int64(42), }, }, condition: AllCurrent, expectedResult: false, }, "multiple resources not all current": { - collectorState: map[object.ObjMetadata]resourceStatus{ + collectorState: map[object.ObjMetadata]ResourceStatus{ identifiers["dep"]: { - Identifier: identifiers["dep"], CurrentStatus: status.CurrentStatus, Generation: int64(41), }, identifiers["custom"]: { - Identifier: identifiers["custom"], CurrentStatus: status.InProgressStatus, Generation: int64(0), }, }, - waitTaskData: []resourceWaitData{ + waitTaskData: []ResourceGeneration{ { - identifier: identifiers["dep"], - generation: int64(42), + Identifier: identifiers["dep"], + Generation: int64(42), }, { - identifier: identifiers["custom"], - generation: int64(0), + Identifier: identifiers["custom"], + Generation: int64(0), }, }, condition: AllCurrent, expectedResult: false, }, "multiple resources single with old generation": { - collectorState: map[object.ObjMetadata]resourceStatus{ + collectorState: map[object.ObjMetadata]ResourceStatus{ identifiers["dep"]: { - Identifier: identifiers["dep"], CurrentStatus: status.CurrentStatus, Generation: int64(42), }, identifiers["custom"]: { - Identifier: identifiers["custom"], CurrentStatus: status.CurrentStatus, Generation: int64(4), }, }, - waitTaskData: []resourceWaitData{ + waitTaskData: []ResourceGeneration{ { - identifier: identifiers["dep"], - generation: int64(42), + Identifier: identifiers["dep"], + Generation: int64(42), }, { - identifier: identifiers["custom"], - generation: int64(5), + Identifier: identifiers["custom"], + Generation: int64(5), }, }, condition: AllCurrent, @@ -127,10 +121,10 @@ func TestCollector_ConditionMet(t *testing.T) { for tn, tc := range testCases { t.Run(tn, func(t *testing.T) { - rsc := newResourceStatusCollector([]object.ObjMetadata{}) + rsc := NewResourceStatusCollector() rsc.resourceMap = tc.collectorState - res := rsc.conditionMet(tc.waitTaskData, tc.condition) + res := rsc.ConditionMet(tc.waitTaskData, tc.condition) assert.Equal(t, tc.expectedResult, res) }) diff --git a/pkg/apply/taskrunner/context.go b/pkg/apply/taskrunner/context.go index df2dfef6..a3f6cdb5 100644 --- a/pkg/apply/taskrunner/context.go +++ b/pkg/apply/taskrunner/context.go @@ -4,6 +4,8 @@ package taskrunner import ( + "context" + "k8s.io/apimachinery/pkg/types" "k8s.io/apimachinery/pkg/util/sets" "sigs.k8s.io/cli-utils/pkg/apply/event" @@ -11,23 +13,29 @@ import ( ) // NewTaskContext returns a new TaskContext -func NewTaskContext(eventChannel chan event.Event) *TaskContext { +func NewTaskContext(ctx context.Context, eventChannel chan event.Event) *TaskContext { return &TaskContext{ - taskChannel: make(chan TaskResult), - eventChannel: eventChannel, - appliedResources: make(map[object.ObjMetadata]applyInfo), - failedResources: make(map[object.ObjMetadata]struct{}), - pruneFailures: make(map[object.ObjMetadata]struct{}), + context: ctx, + taskChannel: make(chan TaskResult), + eventChannel: eventChannel, + resourceStatusCollector: NewResourceStatusCollector(), + appliedResources: make(map[object.ObjMetadata]applyInfo), + failedResources: make(map[object.ObjMetadata]struct{}), + pruneFailures: make(map[object.ObjMetadata]struct{}), } } // TaskContext defines a context that is passed between all // the tasks that is in a taskqueue. type TaskContext struct { + context context.Context + taskChannel chan TaskResult eventChannel chan event.Event + resourceStatusCollector *ResourceStatusCollector + appliedResources map[object.ObjMetadata]applyInfo // failedResources records the IDs of resources that are failed during applying. @@ -37,14 +45,22 @@ type TaskContext struct { pruneFailures map[object.ObjMetadata]struct{} } +func (tc *TaskContext) Context() context.Context { + return tc.context +} + func (tc *TaskContext) TaskChannel() chan TaskResult { return tc.taskChannel } -func (tc *TaskContext) EventChannel() chan event.Event { +func (tc *TaskContext) EventChannel() chan<- event.Event { return tc.eventChannel } +func (tc *TaskContext) ResourceStatusCollector() *ResourceStatusCollector { + return tc.resourceStatusCollector +} + // ResourceApplied updates the context with information about the // resource identified by the provided id. Currently, we keep information // about the generation of the resource after the apply operation completed. diff --git a/pkg/apply/taskrunner/main_test.go b/pkg/apply/taskrunner/main_test.go new file mode 100644 index 00000000..d3552342 --- /dev/null +++ b/pkg/apply/taskrunner/main_test.go @@ -0,0 +1,19 @@ +// Copyright 2021 The Kubernetes Authors. +// SPDX-License-Identifier: Apache-2.0 + +package taskrunner + +import ( + "os" + "testing" + + "k8s.io/klog/v2" +) + +// TestMain executes the tests for this package. +// Adds support for parsing logging flags. Example: +// go test sigs.k8s.io/cli-utils/pkg/apply/taskrunner -v -args -v=5 +func TestMain(m *testing.M) { + klog.InitFlags(nil) + os.Exit(m.Run()) +} diff --git a/pkg/apply/taskrunner/runner.go b/pkg/apply/taskrunner/runner.go index 803e760d..93881308 100644 --- a/pkg/apply/taskrunner/runner.go +++ b/pkg/apply/taskrunner/runner.go @@ -6,6 +6,7 @@ package taskrunner import ( "context" "fmt" + "reflect" "time" "sigs.k8s.io/cli-utils/pkg/apply/event" @@ -22,7 +23,7 @@ func NewTaskStatusRunner(identifiers []object.ObjMetadata, statusPoller poller.P identifiers: identifiers, statusPoller: statusPoller, - baseRunner: newBaseRunner(newResourceStatusCollector(identifiers)), + baseRunner: newBaseRunner(), } } @@ -47,33 +48,53 @@ type Options struct { // Run starts the execution of the taskqueue. It will start the // statusPoller and then pass the statusChannel to the baseRunner // that does most of the work. -func (tsr *taskStatusRunner) Run(ctx context.Context, taskQueue chan Task, - eventChannel chan event.Event, options Options) error { +func (tsr *taskStatusRunner) Run( + ctx context.Context, + taskQueue chan Task, + eventChannel chan event.Event, + opts Options, +) error { + // statusPoller gets its own context to ensure it is cancelled after the + // taskQueue is cancelled. statusCtx, cancelFunc := context.WithCancel(context.Background()) - statusChannel := tsr.statusPoller.Poll(statusCtx, tsr.identifiers, polling.Options{ - PollInterval: options.PollInterval, - UseCache: options.UseCache, - }) - o := baseOptions{ - emitStatusEvents: options.EmitStatusEvents, - } - err := tsr.baseRunner.run(ctx, taskQueue, statusChannel, eventChannel, o) - // cancel the statusPoller by cancelling the context. - cancelFunc() - // drain the statusChannel to make sure the lack of a consumer - // doesn't block the shutdown of the statusPoller. - for range statusChannel { - } - return err + // start polling in the background + statusChannel := tsr.statusPoller.Poll( + statusCtx, + tsr.identifiers, + polling.Options{ + PollInterval: opts.PollInterval, + UseCache: opts.UseCache, + }, + ) + + // defer draining, in case the runner panics + defer func() { + // cancel the statusPoller by cancelling the context. + cancelFunc() + // drain the statusChannel to make sure the lack of a consumer + // doesn't block the shutdown of the statusPoller. + for range statusChannel { + } + }() + + // execute the task queue + return tsr.baseRunner.run( + ctx, + taskQueue, + statusChannel, + eventChannel, + baseOptions{ + emitStatusEvents: opts.EmitStatusEvents, + }, + ) } // NewTaskRunner returns a new taskRunner. It can process taskqueues // that does not contain any wait tasks. func NewTaskRunner() *taskRunner { - collector := newResourceStatusCollector([]object.ObjMetadata{}) return &taskRunner{ - baseRunner: newBaseRunner(collector), + baseRunner: newBaseRunner(), } } @@ -87,22 +108,28 @@ type taskRunner struct { // Run starts the execution of the task queue. It delegates the // work to the baseRunner, but gives it as nil channel as the statusChannel. -func (tr *taskRunner) Run(ctx context.Context, taskQueue chan Task, - eventChannel chan event.Event) error { +func (tr *taskRunner) Run( + ctx context.Context, + taskQueue chan Task, + eventChannel chan event.Event, +) error { var nilStatusChannel chan pollevent.Event - o := baseOptions{ - // The taskRunner doesn't poll for status, so there are not - // statusEvents to emit. - emitStatusEvents: false, - } - return tr.baseRunner.run(ctx, taskQueue, nilStatusChannel, eventChannel, o) + return tr.baseRunner.run( + ctx, + taskQueue, + nilStatusChannel, + eventChannel, + baseOptions{ + // The taskRunner doesn't poll for status, so there are not + // statusEvents to emit. + emitStatusEvents: false, + }, + ) } -// newBaseRunner returns a new baseRunner using the given collector. -func newBaseRunner(collector *resourceStatusCollector) *baseRunner { - return &baseRunner{ - collector: collector, - } +// newBaseRunner returns a new baseRunner +func newBaseRunner() *baseRunner { + return &baseRunner{} } // baseRunner provides the basic task runner functionality. It needs @@ -111,9 +138,7 @@ func newBaseRunner(collector *resourceStatusCollector) *baseRunner { // cases where polling and waiting for status is not needed. // This is not meant to be used directly. It is used by the // taskRunner and the taskStatusRunner. -type baseRunner struct { - collector *resourceStatusCollector -} +type baseRunner struct{} type baseOptions struct { emitStatusEvents bool @@ -122,13 +147,21 @@ type baseOptions struct { // run is the main function that implements the processing of // tasks in the taskqueue. It sets up a loop where a single goroutine // will process events from three different channels. -func (b *baseRunner) run(ctx context.Context, taskQueue chan Task, - statusChannel <-chan pollevent.Event, eventChannel chan event.Event, - o baseOptions) error { +func (b *baseRunner) run( + ctx context.Context, + taskQueue chan Task, + statusChannel <-chan pollevent.Event, + eventChannel chan event.Event, + opts baseOptions, +) error { // wrap the context to allow task cancellation on error + ctx, cancel := context.WithCancel(ctx) + // always cancel to clean up resources + defer cancel() + // taskContext is passed into all tasks when they are started. It // provides access to the eventChannel and the taskChannel, and // also provides a way to pass data between tasks. - taskContext := NewTaskContext(eventChannel) + taskContext := NewTaskContext(ctx, eventChannel) // Find and start the first task in the queue. currentTask, done := b.nextTask(taskQueue, taskContext) @@ -155,7 +188,7 @@ func (b *baseRunner) run(ctx context.Context, taskQueue chan Task, // events are passed through to the eventChannel. This means // that listeners of the eventChannel will get updates on status // even while other tasks (like apply tasks) are running. - case statusEvent, ok := <-statusChannel: + case kStatusEvent, ok := <-statusChannel: // If the statusChannel has closed or we are preparing // to abort the task processing, we just ignore all // statusEvents. @@ -168,40 +201,41 @@ func (b *baseRunner) run(ctx context.Context, taskQueue chan Task, // An error event on the statusChannel means the StatusPoller // has encountered a problem so it can't continue. This means // the statusChannel will be closed soon. - if statusEvent.EventType == pollevent.ErrorEvent { + if kStatusEvent.EventType == pollevent.ErrorEvent { abort = true abortReason = fmt.Errorf("polling for status failed: %v", - statusEvent.Error) - // If the current task is a wait task, we just set it - // to complete so we can exit the loop as soon as possible. - completeIfWaitTask(currentTask, taskContext) + kStatusEvent.Error) + + // cancel any running tasks + cancel() continue } - if o.emitStatusEvents { + // convert types + statusEvent := event.StatusEvent{ + Identifier: kStatusEvent.Resource.Identifier, + PollResourceInfo: kStatusEvent.Resource, + Resource: kStatusEvent.Resource.Resource, + Error: kStatusEvent.Error, + } + + if opts.emitStatusEvents { // Forward all normal events to the eventChannel eventChannel <- event.Event{ - Type: event.StatusType, - StatusEvent: event.StatusEvent{ - Identifier: statusEvent.Resource.Identifier, - PollResourceInfo: statusEvent.Resource, - Resource: statusEvent.Resource.Resource, - Error: statusEvent.Error, - }, + Type: event.StatusType, + StatusEvent: statusEvent, } } // The collector needs to keep track of the latest status // for all resources so we can check whether wait task conditions // has been met. - b.collector.resourceStatus(statusEvent.Resource) - // If the current task is a wait task, we check whether - // the condition has been met. If so, we complete the task. - if wt, ok := currentTask.(*WaitTask); ok { - if wt.checkCondition(taskContext, b.collector) { - completeIfWaitTask(currentTask, taskContext) - } - } + taskContext.ResourceStatusCollector().PutEventStatus(kStatusEvent.Resource) + + // Send the event to the current task. + // This allows tasks to handle status updates from the status poller. + currentTask.OnStatusEvent(taskContext, statusEvent) + // A message on the taskChannel means that the current task // has either completed or failed. If it has failed, we return // the error. If the abort flag is true, which means something @@ -209,7 +243,6 @@ func (b *baseRunner) run(ctx context.Context, taskQueue chan Task, // finish, we exit. // If everything is ok, we fetch and start the next task. case msg := <-taskContext.TaskChannel(): - currentTask.ClearTimeout() taskContext.EventChannel() <- event.Event{ Type: event.ActionGroupType, ActionGroupEvent: event.ActionGroupEvent{ @@ -219,7 +252,7 @@ func (b *baseRunner) run(ctx context.Context, taskQueue chan Task, }, } if msg.Err != nil { - b.amendTimeoutError(msg.Err) + b.amendTimeoutError(taskContext, msg.Err) return msg.Err } if abort { @@ -237,40 +270,27 @@ func (b *baseRunner) run(ctx context.Context, taskQueue chan Task, case <-doneCh: doneCh = nil // Set doneCh to nil so we don't enter a busy loop. abort = true - completeIfWaitTask(currentTask, taskContext) } } } -func (b *baseRunner) amendTimeoutError(err error) { +func (b *baseRunner) amendTimeoutError(taskContext *TaskContext, err error) { if timeoutErr, ok := err.(*TimeoutError); ok { var timedOutResources []TimedOutResource for _, id := range timeoutErr.Identifiers { - ls, found := b.collector.resourceMap[id] - if !found { - continue - } - if timeoutErr.Condition.Meets(ls.CurrentStatus) { - continue + rs := taskContext.ResourceStatusCollector().Get(id) + if !timeoutErr.Condition.Meets(rs.CurrentStatus) { + timedOutResources = append(timedOutResources, TimedOutResource{ + Identifier: id, + Status: rs.CurrentStatus, + Message: rs.Message, + }) } - timedOutResources = append(timedOutResources, TimedOutResource{ - Identifier: id, - Status: ls.CurrentStatus, - Message: ls.Message, - }) } timeoutErr.TimedOutResources = timedOutResources } } -// completeIfWaitTask checks if the current task is a wait task. If so, -// we invoke the complete function to complete it. -func completeIfWaitTask(currentTask Task, taskContext *TaskContext) { - if wt, ok := currentTask.(*WaitTask); ok { - wt.complete(taskContext) - } -} - // nextTask fetches the latest task from the taskQueue and // starts it. If the taskQueue is empty, it the second // return value will be true. @@ -296,20 +316,8 @@ func (b *baseRunner) nextTask(taskQueue chan Task, }, } - switch st := tsk.(type) { - case *WaitTask: - // The wait tasks need to be handled specifically here. Before - // starting a new wait task, we check if the condition is already - // met. Without this check, a task might end up waiting for - // status events when the condition is in fact already met. - if st.checkCondition(taskContext, b.collector) { - st.startAndComplete(taskContext) - } else { - st.Start(taskContext) - } - default: - tsk.Start(taskContext) - } + tsk.Start(taskContext) + return tsk, false } @@ -349,6 +357,21 @@ func (te TimeoutError) Error() string { te.Timeout.Seconds(), len(te.Identifiers), te.Condition) } +// Is satisfies the Is interface from errors.Is +func (te *TimeoutError) Is(err error) bool { + if err == nil { + return te == nil + } + bTe, ok := err.(TimeoutError) + if !ok { + return false + } + return object.SetEquals(te.Identifiers, bTe.Identifiers) && + te.Timeout == bTe.Timeout && + te.Condition == bTe.Condition && + reflect.DeepEqual(te.TimedOutResources, bTe.TimedOutResources) +} + // IsTimeoutError checks whether a given error is // a TimeoutError. func IsTimeoutError(err error) (*TimeoutError, bool) { diff --git a/pkg/apply/taskrunner/runner_test.go b/pkg/apply/taskrunner/runner_test.go index dc1fc107..42ac05ec 100644 --- a/pkg/apply/taskrunner/runner_test.go +++ b/pkg/apply/taskrunner/runner_test.go @@ -11,6 +11,7 @@ import ( "time" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "k8s.io/apimachinery/pkg/runtime/schema" "sigs.k8s.io/cli-utils/pkg/apply/event" pollevent "sigs.k8s.io/cli-utils/pkg/kstatus/polling/event" @@ -40,7 +41,6 @@ var ( func TestBaseRunner(t *testing.T) { testCases := map[string]struct { - identifiers []object.ObjMetadata tasks []Task statusEventsDelay time.Duration statusEvents []pollevent.Event @@ -49,7 +49,6 @@ func TestBaseRunner(t *testing.T) { expectedTimedOutResources []TimedOutResource }{ "wait task runs until condition is met": { - identifiers: []object.ObjMetadata{depID, cmID}, tasks: []Task{ &fakeApplyTask{ resultEvent: event.Event{ @@ -97,7 +96,6 @@ func TestBaseRunner(t *testing.T) { }, }, "wait task times out eventually": { - identifiers: []object.ObjMetadata{depID, cmID}, tasks: []Task{ NewWaitTask("wait", []object.ObjMetadata{depID, cmID}, AllCurrent, 2*time.Second, testutil.NewFakeRESTMapper()), @@ -124,7 +122,6 @@ func TestBaseRunner(t *testing.T) { }, }, "tasks run in order": { - identifiers: []object.ObjMetadata{}, tasks: []Task{ &fakeApplyTask{ resultEvent: event.Event{ @@ -172,7 +169,7 @@ func TestBaseRunner(t *testing.T) { for tn, tc := range testCases { t.Run(tn, func(t *testing.T) { - runner := newBaseRunner(newResourceStatusCollector(tc.identifiers)) + runner := newBaseRunner() eventChannel := make(chan event.Event) taskQueue := make(chan Task, len(tc.tasks)) for _, tsk := range tc.tasks { @@ -239,16 +236,15 @@ func TestBaseRunnerCancellation(t *testing.T) { testError := fmt.Errorf("this is a test error") testCases := map[string]struct { - identifiers []object.ObjMetadata tasks []Task statusEventsDelay time.Duration statusEvents []pollevent.Event contextTimeout time.Duration + contextCancel time.Duration expectedError error expectedEventTypes []event.Type }{ - "cancellation while custom task is running": { - identifiers: []object.ObjMetadata{depID}, + "timeout while custom task is running": { tasks: []Task{ &fakeApplyTask{ resultEvent: event.Event{ @@ -264,17 +260,22 @@ func TestBaseRunnerCancellation(t *testing.T) { }, }, contextTimeout: 2 * time.Second, + expectedError: context.DeadlineExceeded, expectedEventTypes: []event.Type{ event.ActionGroupType, event.ApplyType, event.ActionGroupType, }, }, - "cancellation while wait task is running": { - identifiers: []object.ObjMetadata{depID}, + "timeout while wait task is running": { tasks: []Task{ - NewWaitTask("wait", []object.ObjMetadata{depID}, AllCurrent, - 20*time.Second, testutil.NewFakeRESTMapper()), + NewWaitTask( + "wait", + []object.ObjMetadata{depID}, + AllCurrent, + 20*time.Second, + testutil.NewFakeRESTMapper(), + ), &fakeApplyTask{ resultEvent: event.Event{ Type: event.PruneType, @@ -283,13 +284,36 @@ func TestBaseRunnerCancellation(t *testing.T) { }, }, contextTimeout: 2 * time.Second, + expectedError: context.DeadlineExceeded, + expectedEventTypes: []event.Type{ + event.ActionGroupType, + event.ActionGroupType, + }, + }, + "cancel while wait task is running": { + tasks: []Task{ + NewWaitTask( + "wait", + []object.ObjMetadata{depID}, + AllCurrent, + 20*time.Second, + testutil.NewFakeRESTMapper(), + ), + &fakeApplyTask{ + resultEvent: event.Event{ + Type: event.PruneType, + }, + duration: 2 * time.Second, + }, + }, + contextCancel: 2 * time.Second, + expectedError: context.Canceled, expectedEventTypes: []event.Type{ event.ActionGroupType, event.ActionGroupType, }, }, "error while custom task is running": { - identifiers: []object.ObjMetadata{depID}, tasks: []Task{ &fakeApplyTask{ resultEvent: event.Event{ @@ -314,7 +338,6 @@ func TestBaseRunnerCancellation(t *testing.T) { }, }, "error from status poller while wait task is running": { - identifiers: []object.ObjMetadata{depID}, tasks: []Task{ NewWaitTask("wait", []object.ObjMetadata{depID}, AllCurrent, 20*time.Second, testutil.NewFakeRESTMapper()), @@ -333,7 +356,7 @@ func TestBaseRunnerCancellation(t *testing.T) { }, }, contextTimeout: 30 * time.Second, - expectedError: testError, + expectedError: context.Canceled, expectedEventTypes: []event.Type{ event.ActionGroupType, event.ActionGroupType, @@ -343,7 +366,7 @@ func TestBaseRunnerCancellation(t *testing.T) { for tn, tc := range testCases { t.Run(tn, func(t *testing.T) { - runner := newBaseRunner(newResourceStatusCollector(tc.identifiers)) + runner := newBaseRunner() eventChannel := make(chan event.Event) taskQueue := make(chan Task, len(tc.tasks)) @@ -376,20 +399,34 @@ func TestBaseRunnerCancellation(t *testing.T) { } }() - ctx, cancel := context.WithTimeout(context.Background(), tc.contextTimeout) - defer cancel() - err := runner.run(ctx, taskQueue, statusChannel, eventChannel, - baseOptions{emitStatusEvents: false}) + ctx := context.Background() + var cancel context.CancelFunc + if tc.contextTimeout > 0 { + ctx, cancel = context.WithTimeout(ctx, tc.contextTimeout) + defer cancel() + } else if tc.contextCancel > 0 { + ctx, cancel = context.WithCancel(ctx) + go func() { + time.Sleep(tc.contextCancel) + cancel() + }() + defer cancel() + } + + err := runner.run( + ctx, + taskQueue, + statusChannel, + eventChannel, + baseOptions{emitStatusEvents: false}, + ) close(statusChannel) close(eventChannel) wg.Wait() - if tc.expectedError == nil && err != nil { - t.Errorf("expected no error, but got %v", err) - } - - if tc.expectedError != nil && err == nil { - t.Errorf("expected error %v, but didn't get one", tc.expectedError) + if tc.expectedError != nil { + require.Error(t, err) + assert.Equal(t, tc.expectedError, err) } if want, got := len(tc.expectedEventTypes), len(events); want != got { @@ -427,12 +464,20 @@ func (f *fakeApplyTask) Identifiers() []object.ObjMetadata { func (f *fakeApplyTask) Start(taskContext *TaskContext) { go func() { - <-time.NewTimer(f.duration).C - taskContext.EventChannel() <- f.resultEvent - taskContext.TaskChannel() <- TaskResult{ - Err: f.err, + var err error + ctx := taskContext.Context() + timer := time.NewTimer(f.duration) + select { + case <-ctx.Done(): + // context cancel/timeout + err = ctx.Err() + case <-timer.C: + // task duration timeout + err = f.err } + taskContext.EventChannel() <- f.resultEvent + taskContext.TaskChannel() <- TaskResult{Err: err} }() } -func (f *fakeApplyTask) ClearTimeout() {} +func (f *fakeApplyTask) OnStatusEvent(taskContext *TaskContext, e event.StatusEvent) {} diff --git a/pkg/apply/taskrunner/task.go b/pkg/apply/taskrunner/task.go index f8f06ff0..bba1fa3b 100644 --- a/pkg/apply/taskrunner/task.go +++ b/pkg/apply/taskrunner/task.go @@ -4,8 +4,11 @@ package taskrunner import ( + "context" + "errors" "fmt" "reflect" + "sync" "time" v1 "k8s.io/apiextensions-apiserver/pkg/apis/apiextensions/v1" @@ -14,7 +17,6 @@ import ( "k8s.io/client-go/restmapper" "k8s.io/klog/v2" "sigs.k8s.io/cli-utils/pkg/apply/event" - "sigs.k8s.io/cli-utils/pkg/kstatus/status" "sigs.k8s.io/cli-utils/pkg/object" ) @@ -24,25 +26,19 @@ type Task interface { Name() string Action() event.ResourceAction Identifiers() []object.ObjMetadata - Start(taskContext *TaskContext) - ClearTimeout() + Start(*TaskContext) + OnStatusEvent(*TaskContext, event.StatusEvent) } // NewWaitTask creates a new wait task where we will wait until // the resources specifies by ids all meet the specified condition. func NewWaitTask(name string, ids []object.ObjMetadata, cond Condition, timeout time.Duration, mapper meta.RESTMapper) *WaitTask { - // Create the token channel and only add one item. - tokenChannel := make(chan struct{}, 1) - tokenChannel <- struct{}{} - return &WaitTask{ name: name, Ids: ids, Condition: cond, Timeout: timeout, - - mapper: mapper, - token: tokenChannel, + mapper: mapper, } } @@ -66,17 +62,9 @@ type WaitTask struct { mapper meta.RESTMapper - // cancelFunc is a function that will cancel the timeout timer - // on the task. - cancelFunc func() - - // token is a channel that is provided a single item when the - // task is created. Goroutines are only allowed to write to the - // taskChannel if they are able to get the item from the channel. - // This makes sure that the task only results in one message on the - // taskChannel, even if the condition is met and the task times out - // at the same time. - token chan struct{} + // once ensures errorCh only receives one error + once sync.Once + errorCh chan error } func (w *WaitTask) Name() string { @@ -95,52 +83,92 @@ func (w *WaitTask) Identifiers() []object.ObjMetadata { // setting up the timeout timer. func (w *WaitTask) Start(taskContext *TaskContext) { klog.V(2).Infof("starting wait task (%d objects)", len(w.Ids)) - w.setTimer(taskContext) -} -// setTimer creates the timer with the timeout value taken from -// the WaitTask struct. Once the timer expires, it will send -// a message on the TaskChannel provided in the taskContext. -func (w *WaitTask) setTimer(taskContext *TaskContext) { - timer := time.NewTimer(w.Timeout) + // reset task + w.once = sync.Once{} + w.errorCh = make(chan error) + + // WaitTask gets its own context to diferentiate between WaitTask timeout + // and TaskQueue timeout. First timeout wins! + taskCtx := context.Background() + var taskCancel func() + if w.Timeout != 0 { + klog.V(5).Infof("wait task timeout: %s", w.Timeout) + taskCtx, taskCancel = context.WithTimeout(taskCtx, w.Timeout) + } else { + taskCtx, taskCancel = context.WithCancel(taskCtx) + } + + // Wrap the parent context to ensure it's done. + ctx, cancel := context.WithCancel(taskContext.Context()) + + // wait until parent timeout or cancel go func() { - // TODO(mortent): See if there is a better way to do this. This - // solution will cause the goroutine to hang forever if the - // Timeout is cancelled. - <-timer.C - select { - // We only send the taskResult if no one has gotten - // to the token first. - case <-w.token: - taskContext.TaskChannel() <- TaskResult{ - Err: &TimeoutError{ - Identifiers: w.Ids, - Timeout: w.Timeout, - Condition: w.Condition, - }, - } - default: - return + defer func() { + // cancel contexts to free up resources + taskCancel() + cancel() + }() + + <-ctx.Done() + klog.V(3).Info("wait task parent context done") + w.stop(ctx.Err()) + }() + + // wait until task timeout or cancel + go func() { + defer func() { + // cancel contexts to free up resources + taskCancel() + cancel() + }() + + <-taskCtx.Done() + klog.V(3).Info("wait task context done") + w.stop(w.unwrapTaskTimeout(taskCtx.Err())) + }() + + // wait until complete (optional error on errorCh) + go func() { + defer func() { + // cancel contexts to free up resources + taskCancel() + cancel() + }() + + err := <-w.errorCh + klog.V(3).Info("wait task completed") + taskContext.TaskChannel() <- TaskResult{ + Err: err, } }() - w.cancelFunc = func() { - timer.Stop() + + if w.conditionMet(taskContext) { + klog.V(3).Info("wait condition met, stopping task early") + w.stop(nil) } } -// checkCondition checks whether the condition set in the task +func (w *WaitTask) OnStatusEvent(taskContext *TaskContext, _ event.StatusEvent) { + if w.conditionMet(taskContext) { + klog.V(3).Info("wait condition met, stopping task") + w.stop(nil) + } +} + +// conditionMet checks whether the condition set in the task // is currently met given the status of resources in the collector. -func (w *WaitTask) checkCondition(taskContext *TaskContext, coll *resourceStatusCollector) bool { - rwd := w.computeResourceWaitData(taskContext) - return coll.conditionMet(rwd, w.Condition) +func (w *WaitTask) conditionMet(taskContext *TaskContext) bool { + rwd := w.resourcesToWaitFor(taskContext) + return taskContext.ResourceStatusCollector().ConditionMet(rwd, w.Condition) } -// computeResourceWaitData creates a slice of resourceWaitData for +// resourcesToWaitFor creates a slice of ResourceGeneration for // the resources that is relevant to this wait task. The objective is // to match each resource with the generation seen after the resource // was applied. -func (w *WaitTask) computeResourceWaitData(taskContext *TaskContext) []resourceWaitData { - var rwd []resourceWaitData +func (w *WaitTask) resourcesToWaitFor(taskContext *TaskContext) []ResourceGeneration { + var rwd []ResourceGeneration for _, id := range w.Ids { // Skip checking condition for resources which have failed // to apply or failed to prune/delete (depending on wait condition). @@ -150,90 +178,44 @@ func (w *WaitTask) computeResourceWaitData(taskContext *TaskContext) []resourceW continue } gen, _ := taskContext.ResourceGeneration(id) - rwd = append(rwd, resourceWaitData{ - identifier: id, - generation: gen, + rwd = append(rwd, ResourceGeneration{ + Identifier: id, + Generation: gen, }) } return rwd } -// startAndComplete is invoked when the condition is already -// met when the task should be started. In this case there is no -// need to start a timer. So it just sets the cancelFunc and then -// completes the task. -func (w *WaitTask) startAndComplete(taskContext *TaskContext) { - w.cancelFunc = func() {} - w.complete(taskContext) -} - -// complete is invoked by the taskrunner when all the conditions -// for the task has been met, or something has failed so the task -// need to be stopped. -func (w *WaitTask) complete(taskContext *TaskContext) { - var err error - for _, obj := range w.Ids { - if (obj.GroupKind.Group == v1.SchemeGroupVersion.Group || - obj.GroupKind.Group == v1beta1.SchemeGroupVersion.Group) && - obj.GroupKind.Kind == "CustomResourceDefinition" { - ddRESTMapper, err := extractDeferredDiscoveryRESTMapper(w.mapper) - if err == nil { - ddRESTMapper.Reset() - // We only need to reset once. - break +// stop resets the RESTMapper if any Ids are CRDs and sends the error to the +// errorCh, once per task start. +func (w *WaitTask) stop(err error) { + w.once.Do(func() { + klog.V(3).Info("wait task complete") + for _, obj := range w.Ids { + if (obj.GroupKind.Group == v1.SchemeGroupVersion.Group || + obj.GroupKind.Group == v1beta1.SchemeGroupVersion.Group) && + obj.GroupKind.Kind == "CustomResourceDefinition" { + ddRESTMapper, err := extractDeferredDiscoveryRESTMapper(w.mapper) + if err == nil { + ddRESTMapper.Reset() + // We only need to reset once. + break + } } - continue } - } - select { - // Only do something if we can get the token. - case <-w.token: - go func() { - taskContext.TaskChannel() <- TaskResult{ - Err: err, - } - }() - default: - return - } + w.errorCh <- err + }) } -// ClearTimeout cancels the timeout for the wait task. -func (w *WaitTask) ClearTimeout() { - w.cancelFunc() -} - -type resourceWaitData struct { - identifier object.ObjMetadata - generation int64 -} - -// Condition is a type that defines the types of conditions -// which a WaitTask can use. -type Condition string - -const ( - // AllCurrent Condition means all the provided resources - // has reached (and remains in) the Current status. - AllCurrent Condition = "AllCurrent" - - // AllNotFound Condition means all the provided resources - // has reached the NotFound status, i.e. they are all deleted - // from the cluster. - AllNotFound Condition = "AllNotFound" -) - -// Meets returns true if the provided status meets the condition and -// false if it does not. -func (c Condition) Meets(s status.Status) bool { - switch c { - case AllCurrent: - return s == status.CurrentStatus - case AllNotFound: - return s == status.NotFoundStatus - default: - return false +func (w *WaitTask) unwrapTaskTimeout(err error) error { + if errors.Is(err, context.DeadlineExceeded) { + err = &TimeoutError{ + Identifiers: w.Ids, + Timeout: w.Timeout, + Condition: w.Condition, + } } + return err } // extractDeferredDiscoveryRESTMapper unwraps the provided RESTMapper diff --git a/pkg/apply/taskrunner/task_test.go b/pkg/apply/taskrunner/task_test.go index bdde8504..978f08f2 100644 --- a/pkg/apply/taskrunner/task_test.go +++ b/pkg/apply/taskrunner/task_test.go @@ -4,21 +4,31 @@ package taskrunner import ( + "context" "sync" "testing" "time" + "github.com/stretchr/testify/require" + "k8s.io/klog/v2" "sigs.k8s.io/cli-utils/pkg/apply/event" + kstatus "sigs.k8s.io/cli-utils/pkg/kstatus/status" "sigs.k8s.io/cli-utils/pkg/object" "sigs.k8s.io/cli-utils/pkg/testutil" ) -func TestWaitTask_TimeoutTriggered(t *testing.T) { - task := NewWaitTask("wait", []object.ObjMetadata{}, AllCurrent, - 2*time.Second, testutil.NewFakeRESTMapper()) +func TestWaitTask_TaskTimeout(t *testing.T) { + // ensure conditions are not met, or task with exit early + task := NewWaitTask( + "wait", + []object.ObjMetadata{depID}, + AllCurrent, + 2*time.Second, + testutil.NewFakeRESTMapper(), + ) eventChannel := make(chan event.Event) - taskContext := NewTaskContext(eventChannel) + taskContext := NewTaskContext(context.TODO(), eventChannel) defer close(eventChannel) task.Start(taskContext) @@ -30,53 +40,165 @@ func TestWaitTask_TimeoutTriggered(t *testing.T) { if _, ok := IsTimeoutError(res.Err); !ok { t.Errorf("expected timeout error, but got %v", res.Err) } - return + expected := &TimeoutError{ + Identifiers: []object.ObjMetadata{depID}, + Timeout: 2 * time.Second, + Condition: AllCurrent, + } + require.Equal(t, res.Err.Error(), expected.Error()) case <-timer.C: t.Errorf("expected timeout to trigger, but it didn't") } } -func TestWaitTask_TimeoutCancelled(t *testing.T) { - task := NewWaitTask("wait", []object.ObjMetadata{}, AllCurrent, - 2*time.Second, testutil.NewFakeRESTMapper()) +func TestWaitTask_ContextCancelled(t *testing.T) { + // ensure conditions are not met, or task with exit early + task := NewWaitTask( + "wait", + []object.ObjMetadata{depID}, + AllCurrent, + 2*time.Second, + testutil.NewFakeRESTMapper(), + ) + + ctx, cancel := context.WithCancel(context.Background()) + eventChannel := make(chan event.Event) + taskContext := NewTaskContext(ctx, eventChannel) + defer close(eventChannel) + defer cancel() + + task.Start(taskContext) + + go func() { + time.Sleep(1 * time.Second) + cancel() + }() + + timer := time.NewTimer(3 * time.Second) + + select { + case res := <-taskContext.TaskChannel(): + require.ErrorIs(t, res.Err, context.Canceled) + case <-timer.C: + t.Errorf("unexpected timeout") + } +} +func TestWaitTask_ContextTimeout(t *testing.T) { + // ensure conditions are not met, or task with exit early + task := NewWaitTask( + "wait", + []object.ObjMetadata{depID}, + AllCurrent, + 2*time.Second, + testutil.NewFakeRESTMapper(), + ) + + ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) eventChannel := make(chan event.Event) - taskContext := NewTaskContext(eventChannel) + taskContext := NewTaskContext(ctx, eventChannel) defer close(eventChannel) + defer cancel() task.Start(taskContext) - task.ClearTimeout() + timer := time.NewTimer(3 * time.Second) select { case res := <-taskContext.TaskChannel(): - t.Errorf("didn't expect timeout error, but got %v", res.Err) + require.ErrorIs(t, res.Err, context.DeadlineExceeded) case <-timer.C: - return + t.Errorf("unexpected timeout") } } -func TestWaitTask_SingleTaskResult(t *testing.T) { - task := NewWaitTask("wait", []object.ObjMetadata{}, AllCurrent, - 2*time.Second, testutil.NewFakeRESTMapper()) +// TestWaitTask_OnStatusEvent tests that OnStatusEvent with the right status in +// the ResourceStatusCollector triggers a TaskResult on the TaskChannel. +func TestWaitTask_OnStatusEvent(t *testing.T) { + // ensure conditions are not met, or task with exit early. + task := NewWaitTask( + "wait", + []object.ObjMetadata{depID}, + AllCurrent, + 0*time.Second, + testutil.NewFakeRESTMapper(), + ) eventChannel := make(chan event.Event) - taskContext := NewTaskContext(eventChannel) + taskContext := NewTaskContext(context.TODO(), eventChannel) taskContext.taskChannel = make(chan TaskResult, 10) defer close(eventChannel) - var completeWg sync.WaitGroup + task.Start(taskContext) - for i := 0; i < 10; i++ { - completeWg.Add(1) - go func() { - defer completeWg.Done() - task.complete(taskContext) - }() + go func() { + klog.V(5).Infof("status event %d", 1) + taskContext.ResourceStatusCollector().Put(depID, ResourceStatus{ + CurrentStatus: kstatus.CurrentStatus, + Generation: 1, + }) + task.OnStatusEvent(taskContext, event.StatusEvent{}) + klog.V(5).Infof("status event %d handled", 1) + }() + + timer := time.NewTimer(4 * time.Second) + + select { + case res := <-taskContext.TaskChannel(): + require.NoError(t, res.Err) + case <-timer.C: + t.Errorf("unexpected timeout") } - completeWg.Wait() +} + +// TestWaitTask_SingleTaskResult tests that WaitTask can handle more than one +// call to OnStatusEvent and still only send one result on the TaskChannel. +func TestWaitTask_SingleTaskResult(t *testing.T) { + // ensure conditions are not met, or task with exit early. + task := NewWaitTask( + "wait", + []object.ObjMetadata{depID}, + AllCurrent, + 0*time.Second, + testutil.NewFakeRESTMapper(), + ) - <-taskContext.TaskChannel() + eventChannel := make(chan event.Event) + taskContext := NewTaskContext(context.TODO(), eventChannel) + taskContext.taskChannel = make(chan TaskResult, 10) + defer close(eventChannel) + + task.Start(taskContext) + + var completeWg sync.WaitGroup + completeWg.Add(1) + go func() { + defer completeWg.Done() + klog.V(5).Info("waiting for task result") + res := <-taskContext.TaskChannel() + klog.V(5).Infof("received task result: %v", res.Err) + require.NoError(t, res.Err) + }() + completeWg.Add(4) + go func() { + for i := 0; i < 4; i++ { + index := i + go func() { + defer completeWg.Done() + time.Sleep(time.Duration(index) * time.Second) + klog.V(5).Infof("status event %d", index) + if index > 2 { + taskContext.ResourceStatusCollector().Put(depID, ResourceStatus{ + CurrentStatus: kstatus.CurrentStatus, + Generation: 1, + }) + } + task.OnStatusEvent(taskContext, event.StatusEvent{}) + klog.V(5).Infof("status event %d handled", index) + }() + } + }() + completeWg.Wait() timer := time.NewTimer(4 * time.Second) diff --git a/test/e2e/apply_and_destroy_test.go b/test/e2e/apply_and_destroy_test.go index 96c4d297..2c7ba3c0 100644 --- a/test/e2e/apply_and_destroy_test.go +++ b/test/e2e/apply_and_destroy_test.go @@ -55,9 +55,11 @@ func applyAndDestroyTest(c client.Client, invConfig InventoryConfig, inventoryNa By("Destroy resources") destroyer := invConfig.DestroyerFactoryFunc() + // TODO: test timeout/cancel behavior + ctx := context.TODO() destroyInv := createInventoryInfo(invConfig, inventoryName, namespaceName, inventoryID) options := apply.DestroyerOptions{InventoryPolicy: inventory.AdoptIfNoInventory} - destroyerEvents := runCollectNoErr(destroyer.Run(destroyInv, options)) + destroyerEvents := runCollectNoErr(destroyer.Run(ctx, destroyInv, options)) err = testutil.VerifyEvents([]testutil.ExpEvent{ { EventType: event.DeleteType, diff --git a/test/e2e/crd_test.go b/test/e2e/crd_test.go index 81758426..b7622579 100644 --- a/test/e2e/crd_test.go +++ b/test/e2e/crd_test.go @@ -87,9 +87,11 @@ func crdTest(_ client.Client, invConfig InventoryConfig, inventoryName, namespac Expect(err).ToNot(HaveOccurred()) By("destroy the resources, including the crd") + // TODO: test timeout/cancel behavior + ctx := context.TODO() destroyer := invConfig.DestroyerFactoryFunc() options := apply.DestroyerOptions{InventoryPolicy: inventory.AdoptIfNoInventory} - destroyerEvents := runCollectNoErr(destroyer.Run(inv, options)) + destroyerEvents := runCollectNoErr(destroyer.Run(ctx, inv, options)) err = testutil.VerifyEvents([]testutil.ExpEvent{ { // Initial event diff --git a/test/e2e/depends_on_test.go b/test/e2e/depends_on_test.go index 153cd3c5..1ffd05a8 100644 --- a/test/e2e/depends_on_test.go +++ b/test/e2e/depends_on_test.go @@ -109,9 +109,11 @@ func dependsOnTest(_ client.Client, invConfig InventoryConfig, inventoryName, na Expect(err).ToNot(HaveOccurred()) By("destroy resources in opposite order") + // TODO: test timeout/cancel behavior + ctx := context.TODO() destroyer := invConfig.DestroyerFactoryFunc() options := apply.DestroyerOptions{InventoryPolicy: inventory.AdoptIfNoInventory} - destroyerEvents := runCollectNoErr(destroyer.Run(inv, options)) + destroyerEvents := runCollectNoErr(destroyer.Run(ctx, inv, options)) err = testutil.VerifyEvents([]testutil.ExpEvent{ { // Initial event