diff --git a/CHANGELOG.md b/CHANGELOG.md index e9487a8ff..563d0d300 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -24,6 +24,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/). - Rebuilt the `crd-upgrader` hook image on `alpine:3.20` instead of `ubi9/ubi-minimal`. Image size drops from ~165 MB to ~67 MB uncompressed (~60% reduction), shrinking cold-pull latency on ephemeral CI runners. The image is also reused by the `topology-migration` and `post-delete` hook jobs as a generic `kubectl + bash` toolbox, so bash is preserved on the runtime image. [#1404](https://github.com/kai-scheduler/KAI-Scheduler/issues/1404) ### Fixed +- Streaming snapshot JSON directly into the zip writer to avoid OOM on large clusters. The `/get-snapshot` endpoint previously buffered the entire JSON payload in memory (~3x the data size); it now streams per-element, reducing peak memory to ~1x. [#1564](https://github.com/kai-scheduler/KAI-Scheduler/pull/1564) - Fixed `additionalImagePullSecrets` in Config CR rendering as `map[name:...]` instead of plain strings by extracting `.name` from `global.imagePullSecrets` objects. Also propagated `global.imagePullSecrets` to all Helm hook jobs (`crd-upgrader`, `topology-migration`, `post-delete-cleanup`) - Added `global.nodeSelector`, `global.tolerations`, `global.affinity`, `global.securityContext` support to the post-delete job hook. - Fixed Helm template writing `imagesPullSecret` (string) instead of `additionalImagePullSecrets` (array) in Config CR, causing image pull secrets to be silently ignored. Added backward-compatible deprecated `imagesPullSecret` field to CRD schema. [#942](https://github.com/kai-scheduler/KAI-Scheduler/issues/942) diff --git a/pkg/scheduler/plugins/snapshot/snapshot.go b/pkg/scheduler/plugins/snapshot/snapshot.go index 5b55816b7..9e5bccc97 100644 --- a/pkg/scheduler/plugins/snapshot/snapshot.go +++ b/pkg/scheduler/plugins/snapshot/snapshot.go @@ -9,7 +9,6 @@ import ( "encoding/json" "io" "net/http" - "strings" v1 "k8s.io/api/core/v1" resourceapi "k8s.io/api/resource/v1" @@ -70,6 +69,18 @@ type snapshotPlugin struct { session *framework.Session } +type jsonStream struct { + writer io.Writer + encoder *json.Encoder +} + +type jsonObjectWriter struct { + stream *jsonStream + wroteField bool +} + +type jsonFieldWriter func(*jsonObjectWriter) error + func (sp *snapshotPlugin) Name() string { return "snapshot" } @@ -83,154 +94,227 @@ func (sp *snapshotPlugin) OnSessionOpen(ssn *framework.Session) { func (sp *snapshotPlugin) OnSessionClose(ssn *framework.Session) {} func (sp *snapshotPlugin) serveSnapshot(writer http.ResponseWriter, request *http.Request) { - rawObjects := &RawKubernetesObjects{} - var err error + writer.Header().Set("Content-Disposition", "attachment; filename=snapshot.zip") + writer.Header().Set("Content-Type", "application/zip") - dataLister := sp.session.Cache.GetDataLister() + zipWriter := zip.NewWriter(writer) + defer func() { + if err := zipWriter.Close(); err != nil { + log.InfraLogger.Errorf("Error closing snapshot zip: %v", err) + } + }() - rawObjects.Pods, err = dataLister.ListPods() + jsonWriter, err := zipWriter.Create(SnapshotFileName) if err != nil { - log.InfraLogger.Errorf("Error getting raw pods: %v", err) - rawObjects.Pods = []*v1.Pod{} + http.Error(writer, err.Error(), http.StatusInternalServerError) + return } - rawObjects.Nodes, err = dataLister.ListNodes() - if err != nil { - log.InfraLogger.Errorf("Error getting raw nodes: %v", err) - rawObjects.Nodes = []*v1.Node{} + if err := sp.writeSnapshot(request.Context(), jsonWriter); err != nil { + log.InfraLogger.Errorf("Error writing snapshot: %v", err) } +} - rawObjects.Queues, err = dataLister.ListQueues() - if err != nil { - log.InfraLogger.Errorf("Error getting raw queues: %v", err) - rawObjects.Queues = []*enginev2.Queue{} - } +func (sp *snapshotPlugin) writeSnapshot(ctx context.Context, writer io.Writer) error { + stream := newJSONStream(writer) + + return stream.writeObject(func(object *jsonObjectWriter) error { + return writeFields( + object, + valueField("config", sp.session.Config), + valueField("schedulerParams", &sp.session.SchedulerParams), + streamedField("rawObjects", sp.writeRawObjects), + streamedField("discovery", func(stream *jsonStream) error { + return sp.writeDiscoverySnapshot(ctx, stream) + }), + ) + }) +} - rawObjects.PodGroups, err = dataLister.ListPodGroups() - if err != nil { - log.InfraLogger.Errorf("Error getting raw pod groups: %v", err) - rawObjects.PodGroups = []*enginev2alpha2.PodGroup{} - } +func (sp *snapshotPlugin) writeRawObjects(stream *jsonStream) error { + dataLister := sp.session.Cache.GetDataLister() - rawObjects.BindRequests, err = dataLister.ListBindRequests() - if err != nil { - log.InfraLogger.Errorf("Error getting raw bind requests: %v", err) - rawObjects.BindRequests = []*schedulingv1alpha2.BindRequest{} - } + return stream.writeObject(func(object *jsonObjectWriter) error { + return writeFields( + object, + listedSliceField("pods", dataLister.ListPods, "Error getting raw pods"), + listedSliceField("nodes", dataLister.ListNodes, "Error getting raw nodes"), + listedSliceField("queues", dataLister.ListQueues, "Error getting raw queues"), + listedSliceField("podGroups", dataLister.ListPodGroups, "Error getting raw pod groups"), + listedSliceField("bindRequests", dataLister.ListBindRequests, "Error getting raw bind requests"), + listedSliceField("priorityClasses", dataLister.ListPriorityClasses, "Error getting raw priority classes"), + listedSliceField("configMaps", dataLister.ListConfigMaps, "Error getting raw config maps"), + listedSliceField("persistentVolumes", dataLister.ListPersistentVolumes, "Error getting raw persistent volumes"), + listedSliceField("persistentVolumeClaims", dataLister.ListPersistentVolumeClaims, + "Error getting raw persistent volume claims"), + listedSliceField("csiStorageCapacities", dataLister.ListCSIStorageCapacities, + "Error getting raw CSI storage capacities"), + listedSliceField("storageClasses", dataLister.ListStorageClasses, "Error getting raw storage classes"), + listedSliceField("csiDrivers", dataLister.ListCSIDrivers, "Error getting raw CSI drivers"), + listedSliceField("resourceClaims", dataLister.ListResourceClaims, "Error getting raw resource claims"), + listedSliceField("resourceSlices", dataLister.ListResourceSlices, "Error getting raw resource slices"), + listedSliceField("deviceClasses", dataLister.ListDeviceClasses, "Error getting raw device classes"), + listedSliceField("topologies", dataLister.ListTopologies, "Error getting raw topologies"), + ) + }) +} - rawObjects.PriorityClasses, err = dataLister.ListPriorityClasses() - if err != nil { - log.InfraLogger.Errorf("Error getting raw priority classes: %v", err) - rawObjects.PriorityClasses = []*v14.PriorityClass{} - } +func (sp *snapshotPlugin) writeDiscoverySnapshot(ctx context.Context, stream *jsonStream) error { + discoverySnapshot := sp.getDiscoverySnapshot(ctx) - rawObjects.ConfigMaps, err = dataLister.ListConfigMaps() - if err != nil { - log.InfraLogger.Errorf("Error getting raw config maps: %v", err) - rawObjects.ConfigMaps = []*v1.ConfigMap{} - } + return stream.writeObject(func(object *jsonObjectWriter) error { + return writeFields( + object, + valueField("serverVersion", discoverySnapshot.ServerVersion), + sliceField("resources", discoverySnapshot.Resources), + ) + }) +} - rawObjects.PersistentVolumes, err = dataLister.ListPersistentVolumes() +func (sp *snapshotPlugin) getDiscoverySnapshot(ctx context.Context) *DiscoverySnapshot { + discoverySnapshot := &DiscoverySnapshot{} + discoveryClient := sp.session.Cache.KubeClient().Discovery() + var err error + + discoverySnapshot.ServerVersion, err = getServerVersion(ctx, discoveryClient) if err != nil { - log.InfraLogger.Errorf("Error getting raw persistent volumes: %v", err) - rawObjects.PersistentVolumes = []*v1.PersistentVolume{} + log.InfraLogger.V(2).Warnf("Failed to snapshot server version: %v", err) + discoverySnapshot.ServerVersion = nil } - rawObjects.PersistentVolumeClaims, err = dataLister.ListPersistentVolumeClaims() + _, discoverySnapshot.Resources, err = discoveryClient.ServerGroupsAndResources() if err != nil { - log.InfraLogger.Errorf("Error getting raw persistent volume claims: %v", err) - rawObjects.PersistentVolumeClaims = []*v1.PersistentVolumeClaim{} + log.InfraLogger.V(2).Warnf("Failed to snapshot server resources: %v", err) + discoverySnapshot.Resources = nil } - rawObjects.CSIStorageCapacities, err = dataLister.ListCSIStorageCapacities() - if err != nil { - log.InfraLogger.Errorf("Error getting raw CSI storage capacities: %v", err) - rawObjects.CSIStorageCapacities = []*storage.CSIStorageCapacity{} + return discoverySnapshot +} + +func newJSONStream(writer io.Writer) *jsonStream { + return &jsonStream{ + writer: writer, + encoder: json.NewEncoder(writer), } +} - rawObjects.StorageClasses, err = dataLister.ListStorageClasses() - if err != nil { - log.InfraLogger.Errorf("Error getting raw storage classes: %v", err) - rawObjects.StorageClasses = []*storage.StorageClass{} +func (js *jsonStream) writeRaw(value string) error { + _, err := io.WriteString(js.writer, value) + return err +} + +func (js *jsonStream) writeValue(value any) error { + return js.encoder.Encode(value) +} + +func (js *jsonStream) writeObject(fn func(*jsonObjectWriter) error) error { + object := &jsonObjectWriter{stream: js} + if err := js.writeRaw("{"); err != nil { + return err + } + if err := fn(object); err != nil { + return err } + return js.writeRaw("}") +} - rawObjects.CSIDrivers, err = dataLister.ListCSIDrivers() - if err != nil { - log.InfraLogger.Errorf("Error getting raw CSI drivers: %v", err) - rawObjects.CSIDrivers = []*storage.CSIDriver{} +func (object *jsonObjectWriter) writeFieldName(fieldName string) error { + if object.wroteField { + if err := object.stream.writeRaw(","); err != nil { + return err + } } + object.wroteField = true - rawObjects.Topologies, err = dataLister.ListTopologies() + quotedFieldName, err := json.Marshal(fieldName) if err != nil { - log.InfraLogger.Errorf("Error getting raw topologies: %v", err) - rawObjects.Topologies = []*kaiv1alpha1.Topology{} + return err + } + if _, err := object.stream.writer.Write(quotedFieldName); err != nil { + return err } + return object.stream.writeRaw(":") +} - rawObjects.ResourceClaims, err = dataLister.ListResourceClaims() - if err != nil { - log.InfraLogger.Errorf("Error getting raw resource claims: %v", err) - rawObjects.ResourceClaims = []*resourceapi.ResourceClaim{} +func (object *jsonObjectWriter) writeField(fieldName string, value any) error { + if err := object.writeFieldName(fieldName); err != nil { + return err } + return object.stream.writeValue(value) +} - rawObjects.ResourceSlices, err = dataLister.ListResourceSlices() - if err != nil { - log.InfraLogger.Errorf("Error getting raw resource slices: %v", err) - rawObjects.ResourceSlices = []*resourceapi.ResourceSlice{} +func writeFields(object *jsonObjectWriter, fields ...jsonFieldWriter) error { + for _, field := range fields { + if err := field(object); err != nil { + return err + } } + return nil +} - rawObjects.DeviceClasses, err = dataLister.ListDeviceClasses() - if err != nil { - log.InfraLogger.Errorf("Error getting raw device classes: %v", err) - rawObjects.DeviceClasses = []*resourceapi.DeviceClass{} +func valueField(fieldName string, value any) jsonFieldWriter { + return func(object *jsonObjectWriter) error { + return object.writeField(fieldName, value) } +} - discoverySnapshot := &DiscoverySnapshot{} - discoveryClient := sp.session.Cache.KubeClient().Discovery() - discoverySnapshot.ServerVersion, err = getServerVersion(request.Context(), discoveryClient) - if err != nil { - log.InfraLogger.V(2).Warnf("Failed to snapshot server version: %v", err) - discoverySnapshot.ServerVersion = nil +func streamedField(fieldName string, writeValue func(*jsonStream) error) jsonFieldWriter { + return func(object *jsonObjectWriter) error { + if err := object.writeFieldName(fieldName); err != nil { + return err + } + return writeValue(object.stream) } +} - _, discoverySnapshot.Resources, err = discoveryClient.ServerGroupsAndResources() - if err != nil { - log.InfraLogger.V(2).Warnf("Failed to snapshot server resources: %v", err) - discoverySnapshot.Resources = nil +func listedSliceField[T any](fieldName string, list func() ([]T, error), errorMessage string) jsonFieldWriter { + return func(object *jsonObjectWriter) error { + return writeListedSlice(object, fieldName, list, errorMessage) } +} - snapshotAndConfig := Snapshot{ - Config: sp.session.Config, - SchedulerParams: &sp.session.SchedulerParams, - RawObjects: rawObjects, - Discovery: discoverySnapshot, +func sliceField[T any](fieldName string, values []T) jsonFieldWriter { + return func(object *jsonObjectWriter) error { + return writeSliceField(object, fieldName, values) } - jsonBytes, err := json.Marshal(snapshotAndConfig) +} + +func writeListedSlice[T any](object *jsonObjectWriter, fieldName string, list func() ([]T, error), errorMessage string) error { + values, err := list() if err != nil { - http.Error(writer, err.Error(), http.StatusInternalServerError) - return + log.InfraLogger.Errorf("%s: %v", errorMessage, err) + values = []T{} } - writer.Header().Set("Content-Disposition", "attachment; filename=snapshot.zip") - writer.Header().Set("Content-Type", "application/zip") + return writeSliceField(object, fieldName, values) +} - zipWriter := zip.NewWriter(writer) - jsonWriter, err := zipWriter.Create(SnapshotFileName) - if err != nil { - http.Error(writer, err.Error(), http.StatusInternalServerError) - return +func writeSliceField[T any](object *jsonObjectWriter, fieldName string, values []T) error { + if err := object.writeFieldName(fieldName); err != nil { + return err } + return writeSlice(object.stream, values) +} - _, err = io.Copy(jsonWriter, strings.NewReader(string(jsonBytes))) - if err != nil { - http.Error(writer, err.Error(), http.StatusInternalServerError) - return +func writeSlice[T any](stream *jsonStream, values []T) error { + if values == nil { + return stream.writeValue(values) } - err = zipWriter.Close() - if err != nil { - http.Error(writer, err.Error(), http.StatusInternalServerError) - return + if err := stream.writeRaw("["); err != nil { + return err + } + for index, value := range values { + if index > 0 { + if err := stream.writeRaw(","); err != nil { + return err + } + } + if err := stream.writeValue(value); err != nil { + return err + } } + return stream.writeRaw("]") } func New(_ framework.PluginArguments) framework.Plugin {