Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/).
- Rebuilt the `crd-upgrader` hook image on `alpine:3.20` instead of `ubi9/ubi-minimal`. Image size drops from ~165 MB to ~67 MB uncompressed (~60% reduction), shrinking cold-pull latency on ephemeral CI runners. The image is also reused by the `topology-migration` and `post-delete` hook jobs as a generic `kubectl + bash` toolbox, so bash is preserved on the runtime image. [#1404](https://github.com/kai-scheduler/KAI-Scheduler/issues/1404)

### Fixed
- Streaming snapshot JSON directly into the zip writer to avoid OOM on large clusters. The `/get-snapshot` endpoint previously buffered the entire JSON payload in memory (~3x the data size); it now streams per-element, reducing peak memory to ~1x. [#1564](https://github.com/kai-scheduler/KAI-Scheduler/pull/1564)
- Fixed `additionalImagePullSecrets` in Config CR rendering as `map[name:...]` instead of plain strings by extracting `.name` from `global.imagePullSecrets` objects. Also propagated `global.imagePullSecrets` to all Helm hook jobs (`crd-upgrader`, `topology-migration`, `post-delete-cleanup`)
- Added `global.nodeSelector`, `global.tolerations`, `global.affinity`, `global.securityContext` support to the post-delete job hook.
- Fixed Helm template writing `imagesPullSecret` (string) instead of `additionalImagePullSecrets` (array) in Config CR, causing image pull secrets to be silently ignored. Added backward-compatible deprecated `imagesPullSecret` field to CRD schema. [#942](https://github.com/kai-scheduler/KAI-Scheduler/issues/942)
Expand Down
288 changes: 186 additions & 102 deletions pkg/scheduler/plugins/snapshot/snapshot.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ import (
"encoding/json"
"io"
"net/http"
"strings"

v1 "k8s.io/api/core/v1"
resourceapi "k8s.io/api/resource/v1"
Expand Down Expand Up @@ -70,6 +69,18 @@ type snapshotPlugin struct {
session *framework.Session
}

type jsonStream struct {
writer io.Writer
encoder *json.Encoder
}

type jsonObjectWriter struct {
stream *jsonStream
wroteField bool
}

type jsonFieldWriter func(*jsonObjectWriter) error

func (sp *snapshotPlugin) Name() string {
return "snapshot"
}
Expand All @@ -83,154 +94,227 @@ func (sp *snapshotPlugin) OnSessionOpen(ssn *framework.Session) {
func (sp *snapshotPlugin) OnSessionClose(ssn *framework.Session) {}

func (sp *snapshotPlugin) serveSnapshot(writer http.ResponseWriter, request *http.Request) {
rawObjects := &RawKubernetesObjects{}
var err error
writer.Header().Set("Content-Disposition", "attachment; filename=snapshot.zip")
writer.Header().Set("Content-Type", "application/zip")

dataLister := sp.session.Cache.GetDataLister()
zipWriter := zip.NewWriter(writer)
defer func() {
if err := zipWriter.Close(); err != nil {
log.InfraLogger.Errorf("Error closing snapshot zip: %v", err)
}
}()

rawObjects.Pods, err = dataLister.ListPods()
jsonWriter, err := zipWriter.Create(SnapshotFileName)
if err != nil {
log.InfraLogger.Errorf("Error getting raw pods: %v", err)
rawObjects.Pods = []*v1.Pod{}
http.Error(writer, err.Error(), http.StatusInternalServerError)
return
}

rawObjects.Nodes, err = dataLister.ListNodes()
if err != nil {
log.InfraLogger.Errorf("Error getting raw nodes: %v", err)
rawObjects.Nodes = []*v1.Node{}
if err := sp.writeSnapshot(request.Context(), jsonWriter); err != nil {
log.InfraLogger.Errorf("Error writing snapshot: %v", err)
}
}

rawObjects.Queues, err = dataLister.ListQueues()
if err != nil {
log.InfraLogger.Errorf("Error getting raw queues: %v", err)
rawObjects.Queues = []*enginev2.Queue{}
}
func (sp *snapshotPlugin) writeSnapshot(ctx context.Context, writer io.Writer) error {
stream := newJSONStream(writer)

return stream.writeObject(func(object *jsonObjectWriter) error {
return writeFields(
object,
valueField("config", sp.session.Config),
valueField("schedulerParams", &sp.session.SchedulerParams),
streamedField("rawObjects", sp.writeRawObjects),
streamedField("discovery", func(stream *jsonStream) error {
return sp.writeDiscoverySnapshot(ctx, stream)
}),
)
})
}

rawObjects.PodGroups, err = dataLister.ListPodGroups()
if err != nil {
log.InfraLogger.Errorf("Error getting raw pod groups: %v", err)
rawObjects.PodGroups = []*enginev2alpha2.PodGroup{}
}
func (sp *snapshotPlugin) writeRawObjects(stream *jsonStream) error {
dataLister := sp.session.Cache.GetDataLister()

rawObjects.BindRequests, err = dataLister.ListBindRequests()
if err != nil {
log.InfraLogger.Errorf("Error getting raw bind requests: %v", err)
rawObjects.BindRequests = []*schedulingv1alpha2.BindRequest{}
}
return stream.writeObject(func(object *jsonObjectWriter) error {
return writeFields(
object,
listedSliceField("pods", dataLister.ListPods, "Error getting raw pods"),
listedSliceField("nodes", dataLister.ListNodes, "Error getting raw nodes"),
listedSliceField("queues", dataLister.ListQueues, "Error getting raw queues"),
listedSliceField("podGroups", dataLister.ListPodGroups, "Error getting raw pod groups"),
listedSliceField("bindRequests", dataLister.ListBindRequests, "Error getting raw bind requests"),
listedSliceField("priorityClasses", dataLister.ListPriorityClasses, "Error getting raw priority classes"),
listedSliceField("configMaps", dataLister.ListConfigMaps, "Error getting raw config maps"),
listedSliceField("persistentVolumes", dataLister.ListPersistentVolumes, "Error getting raw persistent volumes"),
listedSliceField("persistentVolumeClaims", dataLister.ListPersistentVolumeClaims,
"Error getting raw persistent volume claims"),
listedSliceField("csiStorageCapacities", dataLister.ListCSIStorageCapacities,
"Error getting raw CSI storage capacities"),
listedSliceField("storageClasses", dataLister.ListStorageClasses, "Error getting raw storage classes"),
listedSliceField("csiDrivers", dataLister.ListCSIDrivers, "Error getting raw CSI drivers"),
listedSliceField("resourceClaims", dataLister.ListResourceClaims, "Error getting raw resource claims"),
listedSliceField("resourceSlices", dataLister.ListResourceSlices, "Error getting raw resource slices"),
listedSliceField("deviceClasses", dataLister.ListDeviceClasses, "Error getting raw device classes"),
listedSliceField("topologies", dataLister.ListTopologies, "Error getting raw topologies"),
)
})
}

rawObjects.PriorityClasses, err = dataLister.ListPriorityClasses()
if err != nil {
log.InfraLogger.Errorf("Error getting raw priority classes: %v", err)
rawObjects.PriorityClasses = []*v14.PriorityClass{}
}
func (sp *snapshotPlugin) writeDiscoverySnapshot(ctx context.Context, stream *jsonStream) error {
discoverySnapshot := sp.getDiscoverySnapshot(ctx)

rawObjects.ConfigMaps, err = dataLister.ListConfigMaps()
if err != nil {
log.InfraLogger.Errorf("Error getting raw config maps: %v", err)
rawObjects.ConfigMaps = []*v1.ConfigMap{}
}
return stream.writeObject(func(object *jsonObjectWriter) error {
return writeFields(
object,
valueField("serverVersion", discoverySnapshot.ServerVersion),
sliceField("resources", discoverySnapshot.Resources),
)
})
}

rawObjects.PersistentVolumes, err = dataLister.ListPersistentVolumes()
func (sp *snapshotPlugin) getDiscoverySnapshot(ctx context.Context) *DiscoverySnapshot {
discoverySnapshot := &DiscoverySnapshot{}
discoveryClient := sp.session.Cache.KubeClient().Discovery()
var err error

discoverySnapshot.ServerVersion, err = getServerVersion(ctx, discoveryClient)
if err != nil {
log.InfraLogger.Errorf("Error getting raw persistent volumes: %v", err)
rawObjects.PersistentVolumes = []*v1.PersistentVolume{}
log.InfraLogger.V(2).Warnf("Failed to snapshot server version: %v", err)
discoverySnapshot.ServerVersion = nil
}

rawObjects.PersistentVolumeClaims, err = dataLister.ListPersistentVolumeClaims()
_, discoverySnapshot.Resources, err = discoveryClient.ServerGroupsAndResources()
if err != nil {
log.InfraLogger.Errorf("Error getting raw persistent volume claims: %v", err)
rawObjects.PersistentVolumeClaims = []*v1.PersistentVolumeClaim{}
log.InfraLogger.V(2).Warnf("Failed to snapshot server resources: %v", err)
discoverySnapshot.Resources = nil
Comment thread
enoodle marked this conversation as resolved.
}

rawObjects.CSIStorageCapacities, err = dataLister.ListCSIStorageCapacities()
if err != nil {
log.InfraLogger.Errorf("Error getting raw CSI storage capacities: %v", err)
rawObjects.CSIStorageCapacities = []*storage.CSIStorageCapacity{}
return discoverySnapshot
}

func newJSONStream(writer io.Writer) *jsonStream {
return &jsonStream{
writer: writer,
encoder: json.NewEncoder(writer),
}
}

rawObjects.StorageClasses, err = dataLister.ListStorageClasses()
if err != nil {
log.InfraLogger.Errorf("Error getting raw storage classes: %v", err)
rawObjects.StorageClasses = []*storage.StorageClass{}
func (js *jsonStream) writeRaw(value string) error {
_, err := io.WriteString(js.writer, value)
return err
}

func (js *jsonStream) writeValue(value any) error {
return js.encoder.Encode(value)
}

func (js *jsonStream) writeObject(fn func(*jsonObjectWriter) error) error {
object := &jsonObjectWriter{stream: js}
if err := js.writeRaw("{"); err != nil {
return err
}
if err := fn(object); err != nil {
return err
}
return js.writeRaw("}")
}

rawObjects.CSIDrivers, err = dataLister.ListCSIDrivers()
if err != nil {
log.InfraLogger.Errorf("Error getting raw CSI drivers: %v", err)
rawObjects.CSIDrivers = []*storage.CSIDriver{}
func (object *jsonObjectWriter) writeFieldName(fieldName string) error {
if object.wroteField {
if err := object.stream.writeRaw(","); err != nil {
return err
}
}
object.wroteField = true

rawObjects.Topologies, err = dataLister.ListTopologies()
quotedFieldName, err := json.Marshal(fieldName)
if err != nil {
log.InfraLogger.Errorf("Error getting raw topologies: %v", err)
rawObjects.Topologies = []*kaiv1alpha1.Topology{}
return err
}
if _, err := object.stream.writer.Write(quotedFieldName); err != nil {
return err
}
return object.stream.writeRaw(":")
}

rawObjects.ResourceClaims, err = dataLister.ListResourceClaims()
if err != nil {
log.InfraLogger.Errorf("Error getting raw resource claims: %v", err)
rawObjects.ResourceClaims = []*resourceapi.ResourceClaim{}
func (object *jsonObjectWriter) writeField(fieldName string, value any) error {
if err := object.writeFieldName(fieldName); err != nil {
return err
}
return object.stream.writeValue(value)
}

rawObjects.ResourceSlices, err = dataLister.ListResourceSlices()
if err != nil {
log.InfraLogger.Errorf("Error getting raw resource slices: %v", err)
rawObjects.ResourceSlices = []*resourceapi.ResourceSlice{}
func writeFields(object *jsonObjectWriter, fields ...jsonFieldWriter) error {
for _, field := range fields {
if err := field(object); err != nil {
return err
}
}
return nil
}

rawObjects.DeviceClasses, err = dataLister.ListDeviceClasses()
if err != nil {
log.InfraLogger.Errorf("Error getting raw device classes: %v", err)
rawObjects.DeviceClasses = []*resourceapi.DeviceClass{}
func valueField(fieldName string, value any) jsonFieldWriter {
return func(object *jsonObjectWriter) error {
return object.writeField(fieldName, value)
}
}

discoverySnapshot := &DiscoverySnapshot{}
discoveryClient := sp.session.Cache.KubeClient().Discovery()
discoverySnapshot.ServerVersion, err = getServerVersion(request.Context(), discoveryClient)
if err != nil {
log.InfraLogger.V(2).Warnf("Failed to snapshot server version: %v", err)
discoverySnapshot.ServerVersion = nil
func streamedField(fieldName string, writeValue func(*jsonStream) error) jsonFieldWriter {
return func(object *jsonObjectWriter) error {
if err := object.writeFieldName(fieldName); err != nil {
return err
}
return writeValue(object.stream)
}
}

_, discoverySnapshot.Resources, err = discoveryClient.ServerGroupsAndResources()
if err != nil {
log.InfraLogger.V(2).Warnf("Failed to snapshot server resources: %v", err)
discoverySnapshot.Resources = nil
func listedSliceField[T any](fieldName string, list func() ([]T, error), errorMessage string) jsonFieldWriter {
return func(object *jsonObjectWriter) error {
return writeListedSlice(object, fieldName, list, errorMessage)
}
}

snapshotAndConfig := Snapshot{
Config: sp.session.Config,
SchedulerParams: &sp.session.SchedulerParams,
RawObjects: rawObjects,
Discovery: discoverySnapshot,
func sliceField[T any](fieldName string, values []T) jsonFieldWriter {
return func(object *jsonObjectWriter) error {
return writeSliceField(object, fieldName, values)
}
jsonBytes, err := json.Marshal(snapshotAndConfig)
}

func writeListedSlice[T any](object *jsonObjectWriter, fieldName string, list func() ([]T, error), errorMessage string) error {
values, err := list()
if err != nil {
http.Error(writer, err.Error(), http.StatusInternalServerError)
return
log.InfraLogger.Errorf("%s: %v", errorMessage, err)
values = []T{}
}
Comment thread
enoodle marked this conversation as resolved.

writer.Header().Set("Content-Disposition", "attachment; filename=snapshot.zip")
writer.Header().Set("Content-Type", "application/zip")
return writeSliceField(object, fieldName, values)
}

zipWriter := zip.NewWriter(writer)
jsonWriter, err := zipWriter.Create(SnapshotFileName)
if err != nil {
http.Error(writer, err.Error(), http.StatusInternalServerError)
return
func writeSliceField[T any](object *jsonObjectWriter, fieldName string, values []T) error {
if err := object.writeFieldName(fieldName); err != nil {
return err
}
return writeSlice(object.stream, values)
}

_, err = io.Copy(jsonWriter, strings.NewReader(string(jsonBytes)))
if err != nil {
http.Error(writer, err.Error(), http.StatusInternalServerError)
return
func writeSlice[T any](stream *jsonStream, values []T) error {
if values == nil {
return stream.writeValue(values)
}

err = zipWriter.Close()
if err != nil {
http.Error(writer, err.Error(), http.StatusInternalServerError)
return
if err := stream.writeRaw("["); err != nil {
return err
}
for index, value := range values {
if index > 0 {
if err := stream.writeRaw(","); err != nil {
return err
}
}
if err := stream.writeValue(value); err != nil {
return err
}
}
return stream.writeRaw("]")
}

func New(_ framework.PluginArguments) framework.Plugin {
Expand Down
Loading