@@ -28,6 +28,7 @@ import (
2828 "time"
2929
3030 "k8s.io/klog/v2"
31+ "k8s.io/utils/strings/slices"
3132
3233 "sigs.k8s.io/gcp-compute-persistent-disk-csi-driver/pkg/common"
3334 "sigs.k8s.io/gcp-compute-persistent-disk-csi-driver/pkg/deviceutils"
@@ -98,7 +99,7 @@ func init() {
9899 // Use V(4) for general debug information logging
99100 // Use V(5) for GCE Cloud Provider Call informational logging
100101 // Use V(6) for extra repeated/polling information
101- enumFlag (& computeEnvironment , "compute-environment" , allowedComputeEnvironment , "Operating compute environment" )
102+ stringEnumFlag (& computeEnvironment , "compute-environment" , allowedComputeEnvironment , "Operating compute environment" )
102103 urlFlag (& computeEndpoint , "compute-endpoint" , "Compute endpoint" )
103104 klog .InitFlags (flag .CommandLine )
104105 flag .Set ("logtostderr" , "true" )
@@ -175,23 +176,23 @@ func handle() {
175176 identityServer := driver .NewIdentityServer (gceDriver )
176177
177178 // Initialize requisite zones
178- fallbackRequisiteZones := strings . Split (* fallbackRequisiteZonesFlag , "," )
179+ fallbackRequisiteZones := parseCSVFlag (* fallbackRequisiteZonesFlag )
179180
180181 // Initialize multi-zone disk types
181- multiZoneVolumeHandleDiskTypes := strings . Split (* multiZoneVolumeHandleDiskTypesFlag , "," )
182+ multiZoneVolumeHandleDiskTypes := parseCSVFlag (* multiZoneVolumeHandleDiskTypesFlag )
182183 multiZoneVolumeHandleConfig := driver.MultiZoneVolumeHandleConfig {
183184 Enable : * multiZoneVolumeHandleEnableFlag ,
184185 DiskTypes : multiZoneVolumeHandleDiskTypes ,
185186 }
186187
187188 // Initialize waitForAttach config
188- useInstanceAPIOnWaitForAttachDiskTypes := strings . Split (* useInstanceAPIOnWaitForAttachDiskTypesFlag , "," )
189+ useInstanceAPIOnWaitForAttachDiskTypes := parseCSVFlag (* useInstanceAPIOnWaitForAttachDiskTypesFlag )
189190 waitForAttachConfig := gce.WaitForAttachConfig {
190191 UseInstancesAPIForDiskTypes : useInstanceAPIOnWaitForAttachDiskTypes ,
191192 }
192193
193194 // Initialize listVolumes config
194- instancesListFilters := strings . Split (* instancesListFiltersFlag , "," )
195+ instancesListFilters := parseCSVFlag (* instancesListFiltersFlag )
195196 listInstancesConfig := gce.ListInstancesConfig {
196197 Filters : instancesListFilters ,
197198 }
@@ -252,18 +253,48 @@ func handle() {
252253 gceDriver .Run (* endpoint , * grpcLogCharCap , * enableOtelTracing )
253254}
254255
255- func enumFlag (target * gce.Environment , name string , allowedComputeEnvironment []gce.Environment , usage string ) {
256+ func notEmpty (v string ) bool {
257+ return v != ""
258+ }
259+
260+ func parseCSVFlag (list string ) []string {
261+ return slices .Filter (nil , strings .Split (list , "," ), notEmpty )
262+ }
263+
264+ type enumConverter [T any ] interface {
265+ convert (v string ) (T , error )
266+ eq (a , b T ) bool
267+ }
268+
269+ type stringConverter [T ~ string ] struct {}
270+
271+ func (s stringConverter [T ]) convert (v string ) (T , error ) {
272+ return T (v ), nil
273+ }
274+
275+ func (s stringConverter [T ]) eq (a , b T ) bool {
276+ return a == b
277+ }
278+
279+ func stringEnumFlag [T ~ string ](target * T , name string , allowed []T , usage string ) {
280+ enumFlag (target , name , stringConverter [T ]{}, allowed , usage )
281+ }
282+
283+ func enumFlag [T any ](target * T , name string , converter enumConverter [T ], allowed []T , usage string ) {
256284 flag .Func (name , usage , func (flagValue string ) error {
257- for _ , allowedValue := range allowedComputeEnvironment {
258- if gce .Environment (flagValue ) == allowedValue {
259- * target = gce .Environment (flagValue )
285+ tValue , err := converter .convert (flagValue )
286+ if err != nil {
287+ return err
288+ }
289+ for _ , allowedValue := range allowed {
290+ if converter .eq (allowedValue , tValue ) {
291+ * target = tValue
260292 return nil
261293 }
262294 }
263295 errMsg := fmt .Sprintf (`must be one of %v` , allowedComputeEnvironment )
264296 return errors .New (errMsg )
265297 })
266-
267298}
268299
269300func urlFlag (target * * url.URL , name string , usage string ) {
0 commit comments