@@ -90,8 +90,11 @@ func (s *setup) Run(ctx context.Context) (err error) {
90
90
closeProgress (err )
91
91
}()
92
92
93
- isDistro , _ := isDebianOrUbuntu ()
94
- if ! isDistro {
93
+ osid , osversion , err := getOSRelease ()
94
+ if err != nil {
95
+ return err
96
+ }
97
+ if osid != "debian" && osid != "ubuntu" {
95
98
return errors .Errorf ("NVIDIA setup is currently only supported on Debian/Ubuntu" )
96
99
}
97
100
@@ -131,7 +134,7 @@ func (s *setup) Run(ctx context.Context) (err error) {
131
134
return err
132
135
}
133
136
134
- if err := installPackages (ctx , dv , pw , dgst ); err != nil {
137
+ if err := installPackages (ctx , osid , osversion , dv , pw , dgst ); err != nil {
135
138
return err
136
139
}
137
140
@@ -167,8 +170,20 @@ func run(ctx context.Context, args []string, pw progress.Writer, dgst digest.Dig
167
170
return cmd .Run ()
168
171
}
169
172
170
- func installPackages (ctx context.Context , dv string , pw progress.Writer , dgst digest.Digest ) error {
171
- const aptDistro = "ubuntu2404"
173
+ func installPackages (ctx context.Context , osid string , osversion string , dv string , pw progress.Writer , dgst digest.Digest ) error {
174
+ aptDistro := "ubuntu2404"
175
+ switch osid {
176
+ case "debian" :
177
+ if osversion == "" {
178
+ aptDistro = "debian12"
179
+ } else {
180
+ aptDistro = "debian" + osversion
181
+ }
182
+ case "ubuntu" :
183
+ if osversion != "" {
184
+ aptDistro = "ubuntu" + strings .ReplaceAll (osversion , "." , "" )
185
+ }
186
+ }
172
187
173
188
var arch string
174
189
switch runtime .GOARCH {
@@ -274,36 +289,33 @@ func hasNvidiaDevices() (bool, error) {
274
289
return found , nil
275
290
}
276
291
277
- func getOSID () (string , error ) {
292
+ func getOSRelease () (string , string , error ) {
278
293
file , err := os .Open ("/etc/os-release" )
279
294
if err != nil {
280
- return "" , err
295
+ return "" , "" , err
281
296
}
282
297
defer file .Close ()
283
298
299
+ var id , versionID string
284
300
scanner := bufio .NewScanner (file )
285
301
for scanner .Scan () {
286
302
line := scanner .Text ()
287
303
if strings .HasPrefix (line , "ID=" ) {
288
- id := strings .TrimPrefix (line , "ID=" )
289
- return strings .Trim (id , `"` ), nil // Remove potential quotes
304
+ id = strings .Trim (strings .TrimPrefix (line , "ID=" ), `"` ) // Remove potential quotes
305
+ } else if strings .HasPrefix (line , "VERSION_ID=" ) {
306
+ versionID = strings .Trim (strings .TrimPrefix (line , "VERSION_ID=" ), `"` )
290
307
}
291
308
}
292
309
293
310
if err := scanner .Err (); err != nil {
294
- return "" , err
311
+ return "" , "" , err
295
312
}
296
313
297
- return "" , errors .Errorf ("ID not found in /etc/os-release" )
298
- }
299
-
300
- func isDebianOrUbuntu () (bool , error ) {
301
- id , err := getOSID ()
302
- if err != nil {
303
- return false , err
314
+ if id == "" {
315
+ return "" , "" , errors .Errorf ("ID not found in /etc/os-release" )
304
316
}
305
317
306
- return id == "debian" || id == "ubuntu" , nil
318
+ return id , versionID , nil
307
319
}
308
320
309
321
func hasWSLGPU () bool {
0 commit comments