diff --git a/pkg/scheduler/framework/statement.go b/pkg/scheduler/framework/statement.go index 9960820a4..a9527582a 100644 --- a/pkg/scheduler/framework/statement.go +++ b/pkg/scheduler/framework/statement.go @@ -35,6 +35,7 @@ import ( type Statement struct { operations []Operation + validOps []bool ssn *Session sessionID string } @@ -57,6 +58,7 @@ func (s *Statement) Rollback(cp Checkpoint) error { } s.operations = s.operations[:cp] + s.validOps = s.validOps[:cp] return nil } @@ -116,6 +118,7 @@ func (s *Statement) Evict(reclaimeeTask *pod_info.PodInfo, message string, }, }, ) + s.validOps = append(s.validOps, true) reclaimeeTask.IsVirtualStatus = true log.InfraLogger.V(6).Infof("Statement evicted task: <%v/%v> from node: <%v>", @@ -284,6 +287,7 @@ func (s *Statement) Pipeline(task *pod_info.PodInfo, hostname string, updateTask return s.unpipeline(task, previousNode, previousStatus, previousGpuGroup, previousResourceClaimInfo, previousIsVirtualStatus) }, }) + s.validOps = append(s.validOps, true) task.IsVirtualStatus = true log.InfraLogger.V(6).Infof( @@ -347,6 +351,7 @@ func (s *Statement) Allocate(task *pod_info.PodInfo, hostname string) error { }, }, ) + s.validOps = append(s.validOps, true) task.IsVirtualStatus = true log.InfraLogger.V(6).Infof( @@ -504,18 +509,22 @@ func (s *Statement) ConvertAllAllocatedToPipelined(jobID common_info.PodGroupID) } var newOperations []Operation - for _, op := range s.operations { + var newValidOps []bool + for i, op := range s.operations { if !(op.TaskInfo().Job == jobID && op.Name() == allocate) { newOperations = append(newOperations, op) + newValidOps = append(newValidOps, s.validOps[i]) } } s.operations = newOperations + s.validOps = newValidOps return nil } func (s *Statement) clearOperations() { - s.operations = []Operation{} + s.operations = s.operations[:0] + s.validOps = s.validOps[:0] } func (s *Statement) Discard() { @@ -638,6 +647,8 @@ func (s *Statement) undoOperation(index int) error { reverseOperation: redoOperation, }, ) + s.validOps = append(s.validOps, true) + s.toggleValidOp(index) return err } @@ -650,13 +661,14 @@ func (s *Statement) cleanupFailedAllocation(task *pod_info.PodInfo, node *node_i } func (s *Statement) operationValid(i int) bool { - for undoIndex, operation := range s.operations { - if operation.Name() != undo { - continue - } - if operation.(undoOperation).operationIndex == i { - return !s.operationValid(undoIndex) - } + return s.validOps[i] +} + +// toggleValidOp flips the validity of operation at index and propagates through +// any undo chain: if op[index] is itself an undoOperation, its target is toggled too. +func (s *Statement) toggleValidOp(index int) { + s.validOps[index] = !s.validOps[index] + if undoOp, ok := s.operations[index].(undoOperation); ok { + s.toggleValidOp(undoOp.operationIndex) } - return true }