Skip to content

Commit 2499dd2

Browse files
rohanKanojiapraveenkumar
authored andcommitted
refactor (shell) : Windows Shell detection uses gopsutil (#4588)
+ Update `shell_windows.go` to use detectShellByCheckingProcessTree instead of relying on SHELL environment variable. + Remove hardcoded check from detectShellByCheckingProcessTree for shell types, use already present supportedShell slice. Signed-off-by: Rohan Kumar <[email protected]>
1 parent 644f23f commit 2499dd2

File tree

4 files changed

+131
-108
lines changed

4 files changed

+131
-108
lines changed

pkg/os/shell/shell.go

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"errors"
55
"fmt"
66
"os"
7+
"slices"
78
"strings"
89

910
"github.com/shirou/gopsutil/v4/process"
@@ -208,7 +209,9 @@ func detectShellByCheckingProcessTree(p AbstractProcess) string {
208209
if err != nil {
209210
return ""
210211
}
211-
if processName == "zsh" || processName == "bash" || processName == "fish" {
212+
if slices.ContainsFunc(supportedShell, func(listElem string) bool {
213+
return strings.HasPrefix(processName, listElem)
214+
}) {
212215
return processName
213216
}
214217
p, err = p.Parent()

pkg/os/shell/shell_unix.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,10 @@ import (
88
"path/filepath"
99
)
1010

11+
var (
12+
supportedShell = []string{"bash", "zsh", "fish"}
13+
)
14+
1115
// detect detects user's current shell.
1216
func detect() (string, error) {
1317
detectedShell := detectShellByCheckingProcessTree(currentProcessSupplier())

pkg/os/shell/shell_windows.go

Lines changed: 7 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -1,64 +1,18 @@
11
package shell
22

33
import (
4-
"fmt"
5-
"math"
6-
"os"
7-
"path/filepath"
4+
"slices"
85
"sort"
96
"strconv"
107
"strings"
11-
"syscall"
12-
"unsafe"
138

149
"github.com/crc-org/crc/v2/pkg/crc/logging"
1510
)
1611

1712
var (
18-
supportedShell = []string{"cmd", "powershell", "bash", "zsh", "fish"}
13+
supportedShell = []string{"cmd", "powershell", "wsl", "bash", "zsh", "fish"}
1914
)
2015

21-
// re-implementation of private function in https://github.com/golang/go/blob/master/src/syscall/syscall_windows.go
22-
func getProcessEntry(pid uint32) (pe *syscall.ProcessEntry32, err error) {
23-
snapshot, err := syscall.CreateToolhelp32Snapshot(syscall.TH32CS_SNAPPROCESS, 0)
24-
if err != nil {
25-
return nil, err
26-
}
27-
defer func() {
28-
_ = syscall.CloseHandle(syscall.Handle(snapshot))
29-
}()
30-
31-
var processEntry syscall.ProcessEntry32
32-
processEntry.Size = uint32(unsafe.Sizeof(processEntry))
33-
err = syscall.Process32First(snapshot, &processEntry)
34-
if err != nil {
35-
return nil, err
36-
}
37-
38-
for {
39-
if processEntry.ProcessID == pid {
40-
pe = &processEntry
41-
return
42-
}
43-
44-
err = syscall.Process32Next(snapshot, &processEntry)
45-
if err != nil {
46-
return nil, err
47-
}
48-
}
49-
}
50-
51-
// getNameAndItsPpid returns the exe file name its parent process id.
52-
func getNameAndItsPpid(pid uint32) (exefile string, parentid uint32, err error) {
53-
pe, err := getProcessEntry(pid)
54-
if err != nil {
55-
return "", 0, err
56-
}
57-
58-
name := syscall.UTF16ToString(pe.ExeFile[:])
59-
return name, pe.ParentProcessID, nil
60-
}
61-
6216
func shellType(shell string, defaultShell string) string {
6317
switch {
6418
case strings.Contains(strings.ToLower(shell), "powershell"):
@@ -69,39 +23,15 @@ func shellType(shell string, defaultShell string) string {
6923
return "cmd"
7024
case strings.Contains(strings.ToLower(shell), "wsl"):
7125
return detectShellByInvokingCommand("bash", "wsl", []string{"-e", "bash", "-c", "ps -ao pid=,comm="})
72-
case filepath.IsAbs(shell) && strings.Contains(strings.ToLower(shell), "bash"):
26+
case strings.Contains(strings.ToLower(shell), "bash"):
7327
return "bash"
7428
default:
7529
return defaultShell
7630
}
7731
}
7832

7933
func detect() (string, error) {
80-
shell := os.Getenv("SHELL")
81-
82-
if shell == "" {
83-
pid := os.Getppid()
84-
if pid < 0 || pid > math.MaxUint32 {
85-
return "", fmt.Errorf("integer overflow for pid: %v", pid)
86-
}
87-
shell, shellppid, err := getNameAndItsPpid(uint32(pid))
88-
if err != nil {
89-
return "cmd", err // defaulting to cmd
90-
}
91-
shell = shellType(shell, "")
92-
if shell == "" {
93-
shell, _, err := getNameAndItsPpid(shellppid)
94-
if err != nil {
95-
return "cmd", err // defaulting to cmd
96-
}
97-
return shellType(shell, "cmd"), nil
98-
}
99-
return shell, nil
100-
}
101-
102-
if os.Getenv("__fish_bin_dir") != "" {
103-
return "fish", nil
104-
}
34+
shell := detectShellByCheckingProcessTree(currentProcessSupplier())
10535

10636
return shellType(shell, "cmd"), nil
10737
}
@@ -163,9 +93,9 @@ func inspectProcessOutputForRecentlyUsedShell(psCommandOutput string) string {
16393
lines := strings.Split(psCommandOutput, "\n")
16494
for _, line := range lines {
16595
lineParts := strings.Split(strings.TrimSpace(line), " ")
166-
if len(lineParts) == 2 && (strings.Contains(lineParts[1], "zsh") ||
167-
strings.Contains(lineParts[1], "bash") ||
168-
strings.Contains(lineParts[1], "fish")) {
96+
if len(lineParts) == 2 && slices.ContainsFunc(supportedShell, func(listElem string) bool {
97+
return strings.HasPrefix(lineParts[1], listElem)
98+
}) {
16999
parsedProcessID, err := strconv.Atoi(lineParts[0])
170100
if err == nil {
171101
processOutputs = append(processOutputs, ProcessOutput{

pkg/os/shell/shell_windows_test.go

Lines changed: 116 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,51 +1,137 @@
11
package shell
22

33
import (
4-
"math"
5-
"os"
64
"testing"
75

86
"github.com/stretchr/testify/assert"
97
)
108

11-
func TestDetect(t *testing.T) {
12-
defer func(shell string) { os.Setenv("SHELL", shell) }(os.Getenv("SHELL"))
13-
os.Setenv("SHELL", "")
14-
15-
shell, err := detect()
9+
func TestDetect_WhenUnknownShell_ThenDefaultToCmdShell(t *testing.T) {
10+
tests := []struct {
11+
name string
12+
processTree []MockedProcess
13+
expectedShellType string
14+
}{
15+
{
16+
"failure to get process details for given pid",
17+
[]MockedProcess{},
18+
"",
19+
},
20+
{
21+
"failure while getting name of process",
22+
[]MockedProcess{
23+
{
24+
name: "crc.exe",
25+
},
26+
{
27+
nameGetFails: true,
28+
},
29+
},
30+
"",
31+
},
32+
{
33+
"failure while getting ppid of process",
34+
[]MockedProcess{
35+
{
36+
name: "crc.exe",
37+
},
38+
{
39+
parentGetFails: true,
40+
},
41+
},
42+
"",
43+
},
44+
{
45+
"failure when no shell process in process tree",
46+
[]MockedProcess{
47+
{
48+
name: "crc.exe",
49+
},
50+
{
51+
name: "unknown.exe",
52+
},
53+
},
54+
"",
55+
},
56+
}
57+
for _, tt := range tests {
58+
t.Run(tt.name, func(t *testing.T) {
59+
// Given
60+
currentProcessSupplier = func() AbstractProcess {
61+
return createNewMockProcessTreeFrom(tt.processTree)
62+
}
1663

17-
assert.Contains(t, supportedShell, shell)
18-
assert.NoError(t, err)
19-
}
64+
// When
65+
shell, err := detect()
2066

21-
func TestGetNameAndItsPpidOfCurrent(t *testing.T) {
22-
pid := os.Getpid()
23-
if pid < 0 || pid > math.MaxUint32 {
24-
assert.Fail(t, "integer overflow detected")
25-
}
26-
shell, shellppid, err := getNameAndItsPpid(uint32(pid))
27-
assert.Equal(t, "shell.test.exe", shell)
28-
ppid := os.Getppid()
29-
if ppid < 0 || ppid > math.MaxUint32 {
30-
assert.Fail(t, "integer overflow detected")
67+
// Then
68+
assert.NoError(t, err)
69+
assert.Equal(t, "cmd", shell)
70+
})
3171
}
32-
assert.Equal(t, uint32(ppid), shellppid)
33-
assert.NoError(t, err)
3472
}
3573

36-
func TestGetNameAndItsPpidOfParent(t *testing.T) {
37-
pid := os.Getppid()
38-
if pid < 0 || pid > math.MaxUint32 {
39-
assert.Fail(t, "integer overflow detected")
74+
func TestDetect_GivenProcessTree_ThenReturnShellProcessWithCorrespondingParentPID(t *testing.T) {
75+
tests := []struct {
76+
name string
77+
processTree []MockedProcess
78+
expectedShellType string
79+
}{
80+
{
81+
"bash shell, then detect bash shell",
82+
[]MockedProcess{
83+
{
84+
name: "crc.exe",
85+
},
86+
{
87+
name: "bash.exe",
88+
},
89+
},
90+
"bash",
91+
},
92+
{
93+
"powershell, then detect powershell",
94+
[]MockedProcess{
95+
{
96+
name: "crc.exe",
97+
},
98+
{
99+
name: "powershell.exe",
100+
},
101+
},
102+
"powershell",
103+
},
104+
{
105+
"cmd shell, then detect fish shell",
106+
[]MockedProcess{
107+
{
108+
name: "crc.exe",
109+
},
110+
{
111+
name: "cmd.exe",
112+
},
113+
},
114+
"cmd",
115+
},
40116
}
41-
shell, _, err := getNameAndItsPpid(uint32(pid))
117+
for _, tt := range tests {
118+
t.Run(tt.name, func(t *testing.T) {
119+
// Given
120+
currentProcessSupplier = func() AbstractProcess {
121+
return createNewMockProcessTreeFrom(tt.processTree)
122+
}
123+
// When
124+
shell, err := detect()
42125

43-
assert.Equal(t, "go.exe", shell)
44-
assert.NoError(t, err)
126+
// Then
127+
assert.Equal(t, tt.expectedShellType, shell)
128+
assert.NoError(t, err)
129+
})
130+
}
45131
}
46132

47133
func TestSupportedShells(t *testing.T) {
48-
assert.Equal(t, []string{"cmd", "powershell", "bash", "zsh", "fish"}, supportedShell)
134+
assert.Equal(t, []string{"cmd", "powershell", "wsl", "bash", "zsh", "fish"}, supportedShell)
49135
}
50136

51137
func TestShellType(t *testing.T) {

0 commit comments

Comments
 (0)