Skip to content

Commit 6390586

Browse files
committed
contrib(nvidia): match right apt repo based on os release
Signed-off-by: CrazyMax <[email protected]>
1 parent 7fbda52 commit 6390586

File tree

1 file changed

+30
-18
lines changed

1 file changed

+30
-18
lines changed

contrib/cdisetup/nvidia/nvidia.go

Lines changed: 30 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -90,8 +90,11 @@ func (s *setup) Run(ctx context.Context) (err error) {
9090
closeProgress(err)
9191
}()
9292

93-
isDistro, _ := isDebianOrUbuntu()
94-
if !isDistro {
93+
osid, osversion, err := getOSRelease()
94+
if err != nil {
95+
return err
96+
}
97+
if osid != "debian" && osid != "ubuntu" {
9598
return errors.Errorf("NVIDIA setup is currently only supported on Debian/Ubuntu")
9699
}
97100

@@ -131,7 +134,7 @@ func (s *setup) Run(ctx context.Context) (err error) {
131134
return err
132135
}
133136

134-
if err := installPackages(ctx, dv, pw, dgst); err != nil {
137+
if err := installPackages(ctx, osid, osversion, dv, pw, dgst); err != nil {
135138
return err
136139
}
137140

@@ -167,8 +170,20 @@ func run(ctx context.Context, args []string, pw progress.Writer, dgst digest.Dig
167170
return cmd.Run()
168171
}
169172

170-
func installPackages(ctx context.Context, dv string, pw progress.Writer, dgst digest.Digest) error {
171-
const aptDistro = "ubuntu2404"
173+
func installPackages(ctx context.Context, osid string, osversion string, dv string, pw progress.Writer, dgst digest.Digest) error {
174+
aptDistro := "ubuntu2404"
175+
switch osid {
176+
case "debian":
177+
if osversion == "" {
178+
aptDistro = "debian12"
179+
} else {
180+
aptDistro = "debian" + osversion
181+
}
182+
case "ubuntu":
183+
if osversion != "" {
184+
aptDistro = "ubuntu" + strings.ReplaceAll(osversion, ".", "")
185+
}
186+
}
172187

173188
var arch string
174189
switch runtime.GOARCH {
@@ -274,36 +289,33 @@ func hasNvidiaDevices() (bool, error) {
274289
return found, nil
275290
}
276291

277-
func getOSID() (string, error) {
292+
func getOSRelease() (string, string, error) {
278293
file, err := os.Open("/etc/os-release")
279294
if err != nil {
280-
return "", err
295+
return "", "", err
281296
}
282297
defer file.Close()
283298

299+
var id, versionID string
284300
scanner := bufio.NewScanner(file)
285301
for scanner.Scan() {
286302
line := scanner.Text()
287303
if strings.HasPrefix(line, "ID=") {
288-
id := strings.TrimPrefix(line, "ID=")
289-
return strings.Trim(id, `"`), nil // Remove potential quotes
304+
id = strings.Trim(strings.TrimPrefix(line, "ID="), `"`) // Remove potential quotes
305+
} else if strings.HasPrefix(line, "VERSION_ID=") {
306+
versionID = strings.Trim(strings.TrimPrefix(line, "VERSION_ID="), `"`)
290307
}
291308
}
292309

293310
if err := scanner.Err(); err != nil {
294-
return "", err
311+
return "", "", err
295312
}
296313

297-
return "", errors.Errorf("ID not found in /etc/os-release")
298-
}
299-
300-
func isDebianOrUbuntu() (bool, error) {
301-
id, err := getOSID()
302-
if err != nil {
303-
return false, err
314+
if id == "" {
315+
return "", "", errors.Errorf("ID not found in /etc/os-release")
304316
}
305317

306-
return id == "debian" || id == "ubuntu", nil
318+
return id, versionID, nil
307319
}
308320

309321
func hasWSLGPU() bool {

0 commit comments

Comments
 (0)