diff --git a/internal/certconfig/certconfig.go b/internal/certconfig/certconfig.go index feddeedb..7e9a19cb 100644 --- a/internal/certconfig/certconfig.go +++ b/internal/certconfig/certconfig.go @@ -3,6 +3,11 @@ package certconfig import ( "context" "log" + + "github.com/AikidoSec/safechain-internals/internal/certconfig/docker" + "github.com/AikidoSec/safechain-internals/internal/certconfig/firefox" + "github.com/AikidoSec/safechain-internals/internal/certconfig/node" + "github.com/AikidoSec/safechain-internals/internal/certconfig/pip" ) type Configurator interface { @@ -18,9 +23,10 @@ type CertConfig struct { func New() *CertConfig { return &CertConfig{ configurators: []Configurator{ - newNodeConfigurator(), - newPipConfigurator(), - newFirefoxConfigurator(), + node.New(), + pip.New(), + firefox.New(), + docker.New(), }, } } diff --git a/internal/certconfig/common.go b/internal/certconfig/common.go deleted file mode 100644 index 88a9f782..00000000 --- a/internal/certconfig/common.go +++ /dev/null @@ -1,154 +0,0 @@ -package certconfig - -import ( - "context" - "fmt" - "os" - "os/exec" - "strings" - - "github.com/AikidoSec/safechain-internals/internal/platform" - "github.com/AikidoSec/safechain-internals/internal/utils" -) - -type managedBlockFormat struct { - startMarker string - endMarker string -} - -const aikidoCertMarker = "AIKIDO_CERT=" - -const ( - pipCertEnvVar = "PIP_CERT" - requestsCABundleEnvVar = "REQUESTS_CA_BUNDLE" - sslCertFileEnvVar = "SSL_CERT_FILE" -) - -func buildManagedBlock(body string, format managedBlockFormat, newline string) string { - return format.startMarker + newline + body + newline + format.endMarker + newline -} - -func detectNewline(content string) string { - if strings.Contains(content, "\r\n") { - return "\r\n" - } - return "\n" -} - -func hasTrailingNewline(content string) bool { - return strings.HasSuffix(content, "\n") || strings.HasSuffix(content, "\r\n") -} - -func writeManagedBlock(path string, body string, perm os.FileMode, format managedBlockFormat) error { - content := "" - if data, err := os.ReadFile(path); err == nil { - content = string(data) - } else if !os.IsNotExist(err) { - return fmt.Errorf("failed to read %s: %w", path, err) - } - - newline := detectNewline(content) - - stripped, _, err := utils.RemoveMarkedBlock(content, format.startMarker, format.endMarker) - if err != nil { - return fmt.Errorf("failed to remove existing managed block in %s: %w", path, err) - } - - if stripped != "" && !hasTrailingNewline(stripped) { - stripped += newline - } - - body = strings.ReplaceAll(body, "\r\n", "\n") - if newline != "\n" { - body = strings.ReplaceAll(body, "\n", newline) - } - - return os.WriteFile(path, []byte(stripped+buildManagedBlock(body, format, newline)), perm) -} - -// extractMarkedCertValue scans output for a line starting with aikidoCertMarker -// and returns the value after it. This tolerates arbitrary text before or after -// the marker line, which interactive shells may produce. -func extractMarkedCertValue(output string) string { - for line := range strings.SplitSeq(output, "\n") { - if strings.HasPrefix(line, aikidoCertMarker) { - return strings.TrimSpace(line[len(aikidoCertMarker):]) - } - } - return "" -} - -func findSystemPipCABundle(ctx context.Context) string { - for _, pythonBin := range []string{"python3", "python"} { - pythonPath, err := exec.LookPath(pythonBin) - if err != nil { - continue - } - out, err := platform.RunAsCurrentUserWithPathEnv(ctx, pythonPath, "-c", "import certifi; print(certifi.where())") - if err != nil { - continue - } - path := strings.TrimSpace(out) - if path == "" { - continue - } - if _, err := os.Stat(path); err == nil { - return path - } - } - - // certifi is not installed — fall back to the path reported by Python's - // stdlib ssl module (available in every Python installation). - for _, pythonBin := range []string{"python3", "python"} { - pythonPath, err := exec.LookPath(pythonBin) - if err != nil { - continue - } - out, err := platform.RunAsCurrentUserWithPathEnv(ctx, pythonPath, "-c", - "import ssl; p = ssl.get_default_verify_paths(); print(p.cafile or p.openssl_cafile or '')") - if err != nil { - continue - } - path := strings.TrimSpace(out) - if path == "" { - continue - } - if _, err := os.Stat(path); err == nil { - return path - } - } - - return "" -} - -func extractMarkedPipCertSetting(output string) pipCertSetting { - return parsePipCertSettingString(extractMarkedCertValue(output)) -} - -func parsePipCertSettingString(value string) pipCertSetting { - trimmed := strings.TrimSpace(value) - if trimmed == "" { - return pipCertSetting{} - } - - envVar, path, ok := strings.Cut(trimmed, ":") - if !ok { - return pipCertSetting{ - EnvVar: pipCertEnvVar, - Path: trimmed, - } - } - - switch envVar { - case pipCertEnvVar, requestsCABundleEnvVar, sslCertFileEnvVar: - return pipCertSetting{ - EnvVar: envVar, - Path: strings.TrimSpace(path), - } - default: - return pipCertSetting{ - EnvVar: pipCertEnvVar, - Path: trimmed, - } - } -} diff --git a/internal/dockerca/dockerca.go b/internal/certconfig/docker/docker.go similarity index 91% rename from internal/dockerca/dockerca.go rename to internal/certconfig/docker/docker.go index 07e9f812..19e197ec 100644 --- a/internal/dockerca/dockerca.go +++ b/internal/certconfig/docker/docker.go @@ -1,4 +1,4 @@ -// Package dockerca installs the Safechain proxy CA certificate into running +// Package docker installs the Safechain proxy CA certificate into running // Docker containers. The following Linux distributions are supported: // // - Debian family (Debian, Ubuntu, Linux Mint, Pop!_OS, Kali): via @@ -10,7 +10,7 @@ // // Unsupported distributions (e.g. Arch, openSUSE, distroless/scratch images) // are detected and skipped with a warning log. -package dockerca +package docker import ( "bufio" @@ -36,7 +36,29 @@ const ( installMethodRHEL installMethod = "rhel" ) -func InstallCAOnRunningContainers(ctx context.Context) error { +// Configurator implements the certconfig Configurator interface for Docker containers. +type Configurator struct{} + +func New() *Configurator { + return &Configurator{} +} + +func (c *Configurator) Name() string { + return "docker" +} + +func (c *Configurator) Install(ctx context.Context) error { + return InstallDockerCA(ctx) +} + +// Uninstall is a no-op: containers are ephemeral, and removing CAs from +// running containers at teardown would be fragile and of little value. +func (c *Configurator) Uninstall(_ context.Context) error { + log.Println("Docker CA: skipping uninstall (containers are ephemeral)") + return nil +} + +func InstallDockerCA(ctx context.Context) error { dockerBinary, err := findDockerBinary() if err != nil { log.Println("Docker CA: docker binary not found, skipping reconcile") diff --git a/internal/dockerca/dockerca_darwin.go b/internal/certconfig/docker/docker_darwin.go similarity index 99% rename from internal/dockerca/dockerca_darwin.go rename to internal/certconfig/docker/docker_darwin.go index 6de8c6f8..29d51e3e 100644 --- a/internal/dockerca/dockerca_darwin.go +++ b/internal/certconfig/docker/docker_darwin.go @@ -1,6 +1,6 @@ //go:build darwin -package dockerca +package docker import ( "bufio" diff --git a/internal/dockerca/dockerca_test.go b/internal/certconfig/docker/docker_test.go similarity index 99% rename from internal/dockerca/dockerca_test.go rename to internal/certconfig/docker/docker_test.go index 52f54698..c796ef00 100644 --- a/internal/dockerca/dockerca_test.go +++ b/internal/certconfig/docker/docker_test.go @@ -1,4 +1,4 @@ -package dockerca +package docker import ( "strings" diff --git a/internal/dockerca/dockerca_windows.go b/internal/certconfig/docker/docker_windows.go similarity index 99% rename from internal/dockerca/dockerca_windows.go rename to internal/certconfig/docker/docker_windows.go index f1136d7f..11cfae3c 100644 --- a/internal/dockerca/dockerca_windows.go +++ b/internal/certconfig/docker/docker_windows.go @@ -1,6 +1,6 @@ //go:build windows -package dockerca +package docker import ( "context" diff --git a/internal/dockerca/osrelease.go b/internal/certconfig/docker/osrelease.go similarity index 98% rename from internal/dockerca/osrelease.go rename to internal/certconfig/docker/osrelease.go index bdf30e4f..8938a3a0 100644 --- a/internal/dockerca/osrelease.go +++ b/internal/certconfig/docker/osrelease.go @@ -1,4 +1,4 @@ -package dockerca +package docker import "strings" diff --git a/internal/certconfig/firefox.go b/internal/certconfig/firefox.go deleted file mode 100644 index 688d1e6c..00000000 --- a/internal/certconfig/firefox.go +++ /dev/null @@ -1,96 +0,0 @@ -package certconfig - -import ( - "context" - "os" - "path/filepath" - "runtime" - - "github.com/AikidoSec/safechain-internals/internal/platform" - "github.com/AikidoSec/safechain-internals/internal/utils" -) - -var firefoxManagedBlockFormat = managedBlockFormat{ - startMarker: "// aikido-endpoint-cert-config-start", - endMarker: "// aikido-endpoint-cert-config-end", -} - -type firefoxConfigurator struct{} - -func newFirefoxConfigurator() Configurator { - return &firefoxConfigurator{} -} - -func (c *firefoxConfigurator) Name() string { - return "firefox" -} - -func (c *firefoxConfigurator) Install(_ context.Context) error { - for _, profile := range firefoxProfiles() { - path := filepath.Join(profile, "user.js") - body := `user_pref("security.enterprise_roots.enabled", true);` - if err := writeManagedBlock(path, body, 0o644, firefoxManagedBlockFormat); err != nil { - return err - } - } - return nil -} - -func (c *firefoxConfigurator) Uninstall(_ context.Context) error { - for _, profile := range firefoxProfiles() { - path := filepath.Join(profile, "user.js") - if err := utils.RemoveManagedBlock(path, 0o644, firefoxManagedBlockFormat.startMarker, firefoxManagedBlockFormat.endMarker); err != nil { - return err - } - } - return nil -} - -func firefoxProfiles() []string { - profilesRoot := firefoxProfilesRoot() - if profilesRoot == "" { - return nil - } - - if _, err := os.Stat(profilesRoot); err != nil { - return nil - } - - entries, err := os.ReadDir(profilesRoot) - if err != nil { - return nil - } - - profiles := make([]string, 0, len(entries)) - for _, entry := range entries { - if !entry.IsDir() { - continue - } - profilePath := filepath.Join(profilesRoot, entry.Name()) - if isFirefoxProfileDir(profilePath) { - profiles = append(profiles, profilePath) - } - } - return profiles -} - -func isFirefoxProfileDir(profilePath string) bool { - for _, fileName := range []string{"prefs.js", "user.js"} { - if _, err := os.Stat(filepath.Join(profilePath, fileName)); err == nil { - return true - } - } - return false -} - -func firefoxProfilesRoot() string { - homeDir := platform.GetConfig().HomeDir - switch runtime.GOOS { - case "darwin": - return filepath.Join(homeDir, "Library", "Application Support", "Firefox", "Profiles") - case "windows": - return filepath.Join(homeDir, "AppData", "Roaming", "Mozilla", "Firefox", "Profiles") - default: - return "" - } -} diff --git a/internal/certconfig/firefox/firefox.go b/internal/certconfig/firefox/firefox.go new file mode 100644 index 00000000..0677d7b3 --- /dev/null +++ b/internal/certconfig/firefox/firefox.go @@ -0,0 +1,97 @@ +package firefox + +import ( + "context" + "os" + "path/filepath" + "runtime" + + "github.com/AikidoSec/safechain-internals/internal/certconfig/shared" + "github.com/AikidoSec/safechain-internals/internal/platform" + "github.com/AikidoSec/safechain-internals/internal/utils" +) + +var managedBlockFormat = shared.ManagedBlockFormat{ + StartMarker: "// aikido-endpoint-cert-config-start", + EndMarker: "// aikido-endpoint-cert-config-end", +} + +type Configurator struct{} + +func New() *Configurator { + return &Configurator{} +} + +func (c *Configurator) Name() string { + return "firefox" +} + +func (c *Configurator) Install(_ context.Context) error { + for _, profile := range profiles() { + path := filepath.Join(profile, "user.js") + body := `user_pref("security.enterprise_roots.enabled", true);` + if err := shared.WriteManagedBlock(path, body, 0o644, managedBlockFormat); err != nil { + return err + } + } + return nil +} + +func (c *Configurator) Uninstall(_ context.Context) error { + for _, profile := range profiles() { + path := filepath.Join(profile, "user.js") + if err := utils.RemoveManagedBlock(path, 0o644, managedBlockFormat.StartMarker, managedBlockFormat.EndMarker); err != nil { + return err + } + } + return nil +} + +func profiles() []string { + root := profilesRoot() + if root == "" { + return nil + } + + if _, err := os.Stat(root); err != nil { + return nil + } + + entries, err := os.ReadDir(root) + if err != nil { + return nil + } + + result := make([]string, 0, len(entries)) + for _, entry := range entries { + if !entry.IsDir() { + continue + } + profilePath := filepath.Join(root, entry.Name()) + if isProfileDir(profilePath) { + result = append(result, profilePath) + } + } + return result +} + +func isProfileDir(profilePath string) bool { + for _, fileName := range []string{"prefs.js", "user.js"} { + if _, err := os.Stat(filepath.Join(profilePath, fileName)); err == nil { + return true + } + } + return false +} + +func profilesRoot() string { + homeDir := platform.GetConfig().HomeDir + switch runtime.GOOS { + case "darwin": + return filepath.Join(homeDir, "Library", "Application Support", "Firefox", "Profiles") + case "windows": + return filepath.Join(homeDir, "AppData", "Roaming", "Mozilla", "Firefox", "Profiles") + default: + return "" + } +} diff --git a/internal/certconfig/firefox_test.go b/internal/certconfig/firefox/firefox_test.go similarity index 80% rename from internal/certconfig/firefox_test.go rename to internal/certconfig/firefox/firefox_test.go index 13366126..5e687c31 100644 --- a/internal/certconfig/firefox_test.go +++ b/internal/certconfig/firefox/firefox_test.go @@ -1,4 +1,4 @@ -package certconfig +package firefox import ( "os" @@ -8,18 +8,18 @@ import ( "github.com/AikidoSec/safechain-internals/internal/platform" ) -func TestIsFirefoxProfileDir(t *testing.T) { +func TestIsProfileDir(t *testing.T) { dir := t.TempDir() if err := os.WriteFile(filepath.Join(dir, "prefs.js"), []byte("// prefs"), 0o644); err != nil { t.Fatal(err) } - if !isFirefoxProfileDir(dir) { + if !isProfileDir(dir) { t.Fatal("expected prefs.js to mark directory as Firefox profile") } } -func TestFirefoxProfilesFiltersDirectories(t *testing.T) { +func TestProfilesFiltersDirectories(t *testing.T) { cfg := platform.GetConfig() originalHome := cfg.HomeDir t.Cleanup(func() { @@ -29,9 +29,9 @@ func TestFirefoxProfilesFiltersDirectories(t *testing.T) { home := t.TempDir() cfg.HomeDir = home - root := firefoxProfilesRoot() + root := profilesRoot() if root == "" { - t.Skip("unsupported OS for firefoxProfilesRoot") + t.Skip("unsupported OS for profilesRoot") } valid := filepath.Join(root, "abcd1234.default-release") @@ -47,7 +47,7 @@ func TestFirefoxProfilesFiltersDirectories(t *testing.T) { t.Fatal(err) } - got := firefoxProfiles() + got := profiles() if len(got) != 1 { t.Fatalf("expected 1 Firefox profile, got %d (%v)", len(got), got) } diff --git a/internal/certconfig/node.go b/internal/certconfig/node/node.go similarity index 56% rename from internal/certconfig/node.go rename to internal/certconfig/node/node.go index 5dfcee15..c2eadcd2 100644 --- a/internal/certconfig/node.go +++ b/internal/certconfig/node/node.go @@ -1,4 +1,4 @@ -package certconfig +package node import ( "context" @@ -7,44 +7,45 @@ import ( "path/filepath" "strings" + "github.com/AikidoSec/safechain-internals/internal/certconfig/shared" "github.com/AikidoSec/safechain-internals/internal/platform" ) -type nodeTrustConfigurator interface { +type trustConfigurator interface { Install(context.Context) error Uninstall(context.Context) error } -type nodeConfigurator struct { - trust nodeTrustConfigurator +type Configurator struct { + trust trustConfigurator } -func newNodeConfigurator() Configurator { - return &nodeConfigurator{ - trust: newNodeTrustConfigurator(combinedCaBundlePath()), +func New() *Configurator { + return &Configurator{ + trust: newTrustConfigurator(shared.CombinedCaBundlePath()), } } -func (c *nodeConfigurator) Name() string { +func (c *Configurator) Name() string { return "node" } -// originalNodeExtraCACertsPath is where we persist the user's pre-existing +// originalExtraCACertsPath is where we persist the user's pre-existing // NODE_EXTRA_CA_CERTS value so it can be preserved across reinstalls and // restored on uninstall. -func originalNodeExtraCACertsPath() string { +func originalExtraCACertsPath() string { return filepath.Join(platform.GetRunDir(), "endpoint-protection-node-original-extra-ca-certs.txt") } -// ensureOriginalNodeExtraCACerts returns the user's pre-existing NODE_EXTRA_CA_CERTS +// ensureOriginalExtraCACerts returns the user's pre-existing NODE_EXTRA_CA_CERTS // value, saving it to disk on first install. On reinstall the saved value is // returned directly — avoiding a live shell lookup that would return our own // combined bundle path instead of the user's original. -func ensureOriginalNodeExtraCACerts(ctx context.Context) (string, error) { - return ensureOriginalNodeExtraCACertsAt(ctx, originalNodeExtraCACertsPath(), runNodeExtraCACertsLookup) +func ensureOriginalExtraCACerts(ctx context.Context) (string, error) { + return ensureOriginalExtraCACertsAt(ctx, originalExtraCACertsPath(), runExtraCACertsLookup) } -func ensureOriginalNodeExtraCACertsAt( +func ensureOriginalExtraCACertsAt( ctx context.Context, savedPath string, lookup func(context.Context) (string, error), @@ -65,21 +66,21 @@ func ensureOriginalNodeExtraCACertsAt( return original, nil } -func (c *nodeConfigurator) Install(ctx context.Context) error { - original, err := ensureOriginalNodeExtraCACerts(ctx) +func (c *Configurator) Install(ctx context.Context) error { + original, err := ensureOriginalExtraCACerts(ctx) if err != nil { return err } - if _, err := ensureCombinedCABundle(original); err != nil { + if _, err := shared.EnsureCombinedCABundle(original); err != nil { return err } return c.trust.Install(ctx) } -func (c *nodeConfigurator) Uninstall(ctx context.Context) error { +func (c *Configurator) Uninstall(ctx context.Context) error { if err := c.trust.Uninstall(ctx); err != nil { return err } - _ = os.Remove(originalNodeExtraCACertsPath()) - return removeCombinedCABundle() + _ = os.Remove(originalExtraCACertsPath()) + return shared.RemoveCombinedCABundle() } diff --git a/internal/certconfig/node/node_darwin.go b/internal/certconfig/node/node_darwin.go new file mode 100644 index 00000000..5eb78af5 --- /dev/null +++ b/internal/certconfig/node/node_darwin.go @@ -0,0 +1,51 @@ +//go:build darwin + +package node + +import ( + "context" + "fmt" + "path/filepath" + + "github.com/AikidoSec/safechain-internals/internal/certconfig/shared" + "github.com/AikidoSec/safechain-internals/internal/platform" +) + +type darwinTrustConfigurator struct { + targets []shared.DarwinShellTarget +} + +var shellManagedBlockFormat = shared.ManagedBlockFormat{ + StartMarker: "# aikido-endpoint-cert-config-start", + EndMarker: "# aikido-endpoint-cert-config-end", +} + +func newTrustConfigurator(bundlePath string) trustConfigurator { + return &darwinTrustConfigurator{ + targets: darwinShellTargets(bundlePath), + } +} + +func (c *darwinTrustConfigurator) Install(_ context.Context) error { + return shared.InstallDarwinShellTargets(c.targets, shellManagedBlockFormat) +} + +func (c *darwinTrustConfigurator) Uninstall(_ context.Context) error { + return shared.UninstallDarwinShellTargets(c.targets, shellManagedBlockFormat) +} + +func darwinShellTargets(bundlePath string) []shared.DarwinShellTarget { + homeDir := platform.GetConfig().HomeDir + comment := "# Allow Node.js tooling to trust the SafeChain MITM CA while preserving public roots." + posix := comment + "\n" + fmt.Sprintf("export NODE_EXTRA_CA_CERTS=%q", bundlePath) + fish := comment + "\n" + fmt.Sprintf("set -gx NODE_EXTRA_CA_CERTS %q", bundlePath) + + return []shared.DarwinShellTarget{ + {Path: filepath.Join(homeDir, ".zshrc"), Body: posix}, + {Path: filepath.Join(homeDir, ".zprofile"), Body: posix}, + {Path: filepath.Join(homeDir, ".bash_profile"), Body: posix}, + {Path: filepath.Join(homeDir, ".bashrc"), Body: posix}, + {Path: filepath.Join(homeDir, ".profile"), Body: posix}, + {Path: filepath.Join(homeDir, ".config", "fish", "config.fish"), Body: fish, CreateIfMissing: true}, + } +} diff --git a/internal/certconfig/node/node_env_darwin.go b/internal/certconfig/node/node_env_darwin.go new file mode 100644 index 00000000..79d50ffa --- /dev/null +++ b/internal/certconfig/node/node_env_darwin.go @@ -0,0 +1,44 @@ +//go:build darwin + +package node + +import ( + "context" + "os/exec" + + "github.com/AikidoSec/safechain-internals/internal/certconfig/shared" + "github.com/AikidoSec/safechain-internals/internal/platform" +) + +// extraCACertsShellLookups tries login and interactive startup files separately +// because they are sourced in different modes: +// - Login non-interactive (-lc): ~/.zshenv, ~/.zprofile (zsh) or ~/.bash_profile (bash) +// - Interactive non-login (-ic): ~/.zshrc (zsh) or ~/.bashrc (bash) +// - fish --login: ~/.config/fish/config.fish +// +// Each command wraps the value in a unique marker so interactive startup noise +// (prompts, oh-my-zsh banners, etc.) does not contaminate the result. +var extraCACertsShellLookups = []shared.ShellLookup{ + {Name: "zsh", Args: []string{"-lc", `printf 'AIKIDO_CERT=%s\n' "${NODE_EXTRA_CA_CERTS:-}"`}}, + {Name: "zsh", Args: []string{"-ic", `printf 'AIKIDO_CERT=%s\n' "${NODE_EXTRA_CA_CERTS:-}"`}}, + {Name: "bash", Args: []string{"-lc", `printf 'AIKIDO_CERT=%s\n' "${NODE_EXTRA_CA_CERTS:-}"`}}, + {Name: "bash", Args: []string{"-ic", `printf 'AIKIDO_CERT=%s\n' "${NODE_EXTRA_CA_CERTS:-}"`}}, + // set -q guards against fish 3.x warnings on unset variable access. + {Name: "fish", Args: []string{"--login", "-c", "set -q NODE_EXTRA_CA_CERTS; and printf 'AIKIDO_CERT=%s\n' $NODE_EXTRA_CA_CERTS"}}, +} + +func runExtraCACertsLookup(ctx context.Context) (string, error) { + for _, lookup := range extraCACertsShellLookups { + shellPath, err := exec.LookPath(lookup.Name) + if err != nil { + continue // shell not installed + } + out, err := platform.RunAsCurrentUserWithPathEnv(ctx, shellPath, lookup.Args...) + if err == nil { + if value := shared.ExtractMarkedCertValue(out); value != "" { + return value, nil + } + } + } + return "", nil +} diff --git a/internal/certconfig/node_env_darwin_test.go b/internal/certconfig/node/node_env_darwin_test.go similarity index 71% rename from internal/certconfig/node_env_darwin_test.go rename to internal/certconfig/node/node_env_darwin_test.go index f0df4de5..d20b7121 100644 --- a/internal/certconfig/node_env_darwin_test.go +++ b/internal/certconfig/node/node_env_darwin_test.go @@ -1,6 +1,6 @@ //go:build darwin -package certconfig +package node import ( "context" @@ -10,32 +10,8 @@ import ( "testing" ) -func TestExtractMarkedCertValue(t *testing.T) { - tests := []struct { - name string - output string - want string - }{ - {"clean output", "AIKIDO_CERT=/path/to/ca.pem\n", "/path/to/ca.pem"}, - {"empty value", "AIKIDO_CERT=\n", ""}, - {"marker absent", "some random output\n", ""}, - {"marker buried in noise", "Welcome to zsh!\nAIKIDO_CERT=/corp/ca.pem\nsome trailing line", "/corp/ca.pem"}, - {"interactive startup noise before", "compinit output\n[oh-my-zsh]\nAIKIDO_CERT=/ca.pem\n", "/ca.pem"}, - {"whitespace trimmed", "AIKIDO_CERT= /ca.pem \n", "/ca.pem"}, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got := extractMarkedCertValue(tt.output) - if got != tt.want { - t.Fatalf("got %q, want %q", got, tt.want) - } - }) - } -} - // TestNodeExtraCACertsShellLookups verifies that each shell command in -// nodeCACertsShellLookups correctly reads NODE_EXTRA_CA_CERTS from the +// extraCACertsShellLookups correctly reads NODE_EXTRA_CA_CERTS from the // shell's startup file (both login and interactive variants), and returns // empty when nothing is set. Tests are skipped if the shell is not installed. func TestNodeExtraCACertsShellLookups(t *testing.T) { @@ -97,7 +73,7 @@ func TestNodeExtraCACertsShellLookups(t *testing.T) { t.Fatal(err) } - got, err := runNodeExtraCACertsLookup(context.Background()) + got, err := runExtraCACertsLookup(context.Background()) if err != nil { t.Fatalf("unexpected error: %v", err) } @@ -111,7 +87,7 @@ func TestNodeExtraCACertsShellLookups(t *testing.T) { t.Fatal(err) } - got, err := runNodeExtraCACertsLookup(context.Background()) + got, err := runExtraCACertsLookup(context.Background()) if err != nil { t.Fatalf("unexpected error: %v", err) } diff --git a/internal/certconfig/node_env_windows.go b/internal/certconfig/node/node_env_windows.go similarity index 75% rename from internal/certconfig/node_env_windows.go rename to internal/certconfig/node/node_env_windows.go index d7537955..05368e41 100644 --- a/internal/certconfig/node_env_windows.go +++ b/internal/certconfig/node/node_env_windows.go @@ -1,6 +1,6 @@ //go:build windows -package certconfig +package node import ( "context" @@ -8,7 +8,7 @@ import ( "github.com/AikidoSec/safechain-internals/internal/platform" ) -func runNodeExtraCACertsLookup(ctx context.Context) (string, error) { +func runExtraCACertsLookup(ctx context.Context) (string, error) { return platform.RunAsCurrentUser(ctx, "powershell", []string{ "-NoProfile", "-NonInteractive", diff --git a/internal/certconfig/node_test.go b/internal/certconfig/node/node_test.go similarity index 70% rename from internal/certconfig/node_test.go rename to internal/certconfig/node/node_test.go index 4d72c0d6..74dc2f87 100644 --- a/internal/certconfig/node_test.go +++ b/internal/certconfig/node/node_test.go @@ -1,4 +1,4 @@ -package certconfig +package node import ( "context" @@ -8,13 +8,13 @@ import ( "testing" ) -func TestEnsureOriginalNodeExtraCACertsFirstInstall(t *testing.T) { +func TestEnsureOriginalExtraCACertsFirstInstall(t *testing.T) { savedPath := filepath.Join(t.TempDir(), "original.txt") lookup := func(_ context.Context) (string, error) { return "/corporate/ca.pem", nil } - got, err := ensureOriginalNodeExtraCACertsAt(context.Background(), savedPath, lookup) + got, err := ensureOriginalExtraCACertsAt(context.Background(), savedPath, lookup) if err != nil { t.Fatalf("unexpected error: %v", err) } @@ -31,11 +31,11 @@ func TestEnsureOriginalNodeExtraCACertsFirstInstall(t *testing.T) { } } -func TestEnsureOriginalNodeExtraCACertsFirstInstallNothingSet(t *testing.T) { +func TestEnsureOriginalExtraCACertsFirstInstallNothingSet(t *testing.T) { savedPath := filepath.Join(t.TempDir(), "original.txt") lookup := func(_ context.Context) (string, error) { return "", nil } - got, err := ensureOriginalNodeExtraCACertsAt(context.Background(), savedPath, lookup) + got, err := ensureOriginalExtraCACertsAt(context.Background(), savedPath, lookup) if err != nil { t.Fatalf("unexpected error: %v", err) } @@ -49,7 +49,7 @@ func TestEnsureOriginalNodeExtraCACertsFirstInstallNothingSet(t *testing.T) { } } -func TestEnsureOriginalNodeExtraCACertsReinstallSkipsLookup(t *testing.T) { +func TestEnsureOriginalExtraCACertsReinstallSkipsLookup(t *testing.T) { savedPath := filepath.Join(t.TempDir(), "original.txt") if err := os.WriteFile(savedPath, []byte("/saved/ca.pem"), 0o600); err != nil { t.Fatal(err) @@ -61,7 +61,7 @@ func TestEnsureOriginalNodeExtraCACertsReinstallSkipsLookup(t *testing.T) { return "/new-value/ca.pem", nil } - got, err := ensureOriginalNodeExtraCACertsAt(context.Background(), savedPath, lookup) + got, err := ensureOriginalExtraCACertsAt(context.Background(), savedPath, lookup) if err != nil { t.Fatalf("unexpected error: %v", err) } @@ -73,13 +73,13 @@ func TestEnsureOriginalNodeExtraCACertsReinstallSkipsLookup(t *testing.T) { } } -func TestEnsureOriginalNodeExtraCACertsTrimsSavedWhitespace(t *testing.T) { +func TestEnsureOriginalExtraCACertsTrimsSavedWhitespace(t *testing.T) { savedPath := filepath.Join(t.TempDir(), "original.txt") if err := os.WriteFile(savedPath, []byte(" /trimmed/ca.pem\n"), 0o600); err != nil { t.Fatal(err) } - got, err := ensureOriginalNodeExtraCACertsAt(context.Background(), savedPath, nil) + got, err := ensureOriginalExtraCACertsAt(context.Background(), savedPath, nil) if err != nil { t.Fatalf("unexpected error: %v", err) } @@ -88,11 +88,11 @@ func TestEnsureOriginalNodeExtraCACertsTrimsSavedWhitespace(t *testing.T) { } } -func TestEnsureOriginalNodeExtraCACertsTrimsLookupWhitespace(t *testing.T) { +func TestEnsureOriginalExtraCACertsTrimsLookupWhitespace(t *testing.T) { savedPath := filepath.Join(t.TempDir(), "original.txt") lookup := func(_ context.Context) (string, error) { return " /padded/ca.pem\n", nil } - got, err := ensureOriginalNodeExtraCACertsAt(context.Background(), savedPath, lookup) + got, err := ensureOriginalExtraCACertsAt(context.Background(), savedPath, lookup) if err != nil { t.Fatalf("unexpected error: %v", err) } @@ -101,13 +101,13 @@ func TestEnsureOriginalNodeExtraCACertsTrimsLookupWhitespace(t *testing.T) { } } -func TestEnsureOriginalNodeExtraCACertsLookupError(t *testing.T) { +func TestEnsureOriginalExtraCACertsLookupError(t *testing.T) { savedPath := filepath.Join(t.TempDir(), "original.txt") lookup := func(_ context.Context) (string, error) { return "", errors.New("shell not found") } - _, err := ensureOriginalNodeExtraCACertsAt(context.Background(), savedPath, lookup) + _, err := ensureOriginalExtraCACertsAt(context.Background(), savedPath, lookup) if err == nil { t.Fatal("expected error when lookup fails") } diff --git a/internal/certconfig/node/node_windows.go b/internal/certconfig/node/node_windows.go new file mode 100644 index 00000000..5446748e --- /dev/null +++ b/internal/certconfig/node/node_windows.go @@ -0,0 +1,51 @@ +//go:build windows + +package node + +import ( + "context" + "fmt" + "os" + "strings" + + "github.com/AikidoSec/safechain-internals/internal/certconfig/shared" +) + +type windowsTrustConfigurator struct { + bundlePath string +} + +func newTrustConfigurator(bundlePath string) trustConfigurator { + return &windowsTrustConfigurator{ + bundlePath: bundlePath, + } +} + +func (c *windowsTrustConfigurator) Install(ctx context.Context) error { + return shared.RunPowerShellAsCurrentUser( + ctx, + fmt.Sprintf( + "[Environment]::SetEnvironmentVariable('NODE_EXTRA_CA_CERTS', '%s', 'User')", + shared.EscapePowerShellSingleQuoted(c.bundlePath), + ), + ) +} + +func (c *windowsTrustConfigurator) Uninstall(ctx context.Context) error { + original := "" + if data, err := os.ReadFile(originalExtraCACertsPath()); err == nil { + original = strings.TrimSpace(string(data)) + } + // Saved file deletion is handled by Configurator.Uninstall. + + var script string + if original != "" { + script = fmt.Sprintf( + "[Environment]::SetEnvironmentVariable('NODE_EXTRA_CA_CERTS', '%s', 'User')", + shared.EscapePowerShellSingleQuoted(original), + ) + } else { + script = "[Environment]::SetEnvironmentVariable('NODE_EXTRA_CA_CERTS', $null, 'User')" + } + return shared.RunPowerShellAsCurrentUser(ctx, script) +} diff --git a/internal/certconfig/node_darwin.go b/internal/certconfig/node_darwin.go deleted file mode 100644 index 1fb677be..00000000 --- a/internal/certconfig/node_darwin.go +++ /dev/null @@ -1,79 +0,0 @@ -//go:build darwin - -package certconfig - -import ( - "context" - "fmt" - "os" - "path/filepath" - - "github.com/AikidoSec/safechain-internals/internal/platform" - "github.com/AikidoSec/safechain-internals/internal/utils" -) - -type darwinNodeTrustConfigurator struct { - targets []darwinShellTarget -} - -type darwinShellTarget struct { - path string - body string - createIfMissing bool -} - -func newNodeTrustConfigurator(bundlePath string) nodeTrustConfigurator { - return &darwinNodeTrustConfigurator{ - targets: darwinShellTargets(bundlePath), - } -} - -func (c *darwinNodeTrustConfigurator) Install(_ context.Context) error { - for _, target := range c.targets { - if _, err := os.Stat(target.path); err != nil { - if !os.IsNotExist(err) { - return fmt.Errorf("failed to stat %s: %w", target.path, err) - } - if !target.createIfMissing { - continue - } - if err := os.MkdirAll(filepath.Dir(target.path), 0o755); err != nil { - return fmt.Errorf("failed to create config dir for %s: %w", target.path, err) - } - } - if err := writeManagedBlock(target.path, target.body, 0o644, shellManagedBlockFormat); err != nil { - return err - } - } - return nil -} - -func (c *darwinNodeTrustConfigurator) Uninstall(_ context.Context) error { - for _, target := range c.targets { - if err := utils.RemoveManagedBlock(target.path, 0o644, shellManagedBlockFormat.startMarker, shellManagedBlockFormat.endMarker); err != nil { - return err - } - } - return nil -} - -func darwinShellTargets(bundlePath string) []darwinShellTarget { - homeDir := platform.GetConfig().HomeDir - comment := "# Allow Node.js tooling to trust the SafeChain MITM CA while preserving public roots." - posix := comment + "\n" + fmt.Sprintf("export NODE_EXTRA_CA_CERTS=%q", bundlePath) - fish := comment + "\n" + fmt.Sprintf("set -gx NODE_EXTRA_CA_CERTS %q", bundlePath) - - return []darwinShellTarget{ - {path: filepath.Join(homeDir, ".zshrc"), body: posix}, - {path: filepath.Join(homeDir, ".zprofile"), body: posix}, - {path: filepath.Join(homeDir, ".bash_profile"), body: posix}, - {path: filepath.Join(homeDir, ".bashrc"), body: posix}, - {path: filepath.Join(homeDir, ".profile"), body: posix}, - {path: filepath.Join(homeDir, ".config", "fish", "config.fish"), body: fish, createIfMissing: true}, - } -} - -var shellManagedBlockFormat = managedBlockFormat{ - startMarker: "# aikido-endpoint-cert-config-start", - endMarker: "# aikido-endpoint-cert-config-end", -} diff --git a/internal/certconfig/node_env_darwin.go b/internal/certconfig/node_env_darwin.go deleted file mode 100644 index 52b2af37..00000000 --- a/internal/certconfig/node_env_darwin.go +++ /dev/null @@ -1,49 +0,0 @@ -//go:build darwin - -package certconfig - -import ( - "context" - "os/exec" - - "github.com/AikidoSec/safechain-internals/internal/platform" -) - -// shellLookup describes how to query NODE_EXTRA_CA_CERTS from a specific shell. -type shellLookup struct { - name string - args []string -} - -// nodeCACertsShellLookups tries login and interactive startup files separately -// because they are sourced in different modes: -// - Login non-interactive (-lc): ~/.zshenv, ~/.zprofile (zsh) or ~/.bash_profile (bash) -// - Interactive non-login (-ic): ~/.zshrc (zsh) or ~/.bashrc (bash) -// - fish --login: ~/.config/fish/config.fish -// -// Each command wraps the value in a unique marker so interactive startup noise -// (prompts, oh-my-zsh banners, etc.) does not contaminate the result. -var nodeCACertsShellLookups = []shellLookup{ - {"zsh", []string{"-lc", `printf 'AIKIDO_CERT=%s\n' "${NODE_EXTRA_CA_CERTS:-}"`}}, - {"zsh", []string{"-ic", `printf 'AIKIDO_CERT=%s\n' "${NODE_EXTRA_CA_CERTS:-}"`}}, - {"bash", []string{"-lc", `printf 'AIKIDO_CERT=%s\n' "${NODE_EXTRA_CA_CERTS:-}"`}}, - {"bash", []string{"-ic", `printf 'AIKIDO_CERT=%s\n' "${NODE_EXTRA_CA_CERTS:-}"`}}, - // set -q guards against fish 3.x warnings on unset variable access. - {"fish", []string{"--login", "-c", "set -q NODE_EXTRA_CA_CERTS; and printf 'AIKIDO_CERT=%s\n' $NODE_EXTRA_CA_CERTS"}}, -} - -func runNodeExtraCACertsLookup(ctx context.Context) (string, error) { - for _, lookup := range nodeCACertsShellLookups { - shellPath, err := exec.LookPath(lookup.name) - if err != nil { - continue // shell not installed - } - out, err := platform.RunAsCurrentUserWithPathEnv(ctx, shellPath, lookup.args...) - if err == nil { - if value := extractMarkedCertValue(out); value != "" { - return value, nil - } - } - } - return "", nil -} diff --git a/internal/certconfig/node_windows.go b/internal/certconfig/node_windows.go deleted file mode 100644 index c301020a..00000000 --- a/internal/certconfig/node_windows.go +++ /dev/null @@ -1,65 +0,0 @@ -//go:build windows - -package certconfig - -import ( - "context" - "fmt" - "os" - "strings" - - "github.com/AikidoSec/safechain-internals/internal/platform" -) - -type windowsNodeTrustConfigurator struct { - bundlePath string -} - -func newNodeTrustConfigurator(bundlePath string) nodeTrustConfigurator { - return &windowsNodeTrustConfigurator{ - bundlePath: bundlePath, - } -} - -func (c *windowsNodeTrustConfigurator) Install(ctx context.Context) error { - return runPowerShellAsCurrentUser( - ctx, - fmt.Sprintf( - "[Environment]::SetEnvironmentVariable('NODE_EXTRA_CA_CERTS', '%s', 'User')", - escapePowerShellSingleQuoted(c.bundlePath), - ), - ) -} - -func (c *windowsNodeTrustConfigurator) Uninstall(ctx context.Context) error { - original := "" - if data, err := os.ReadFile(originalNodeExtraCACertsPath()); err == nil { - original = strings.TrimSpace(string(data)) - } - // Saved file deletion is handled by nodeConfigurator.Uninstall. - - var script string - if original != "" { - script = fmt.Sprintf( - "[Environment]::SetEnvironmentVariable('NODE_EXTRA_CA_CERTS', '%s', 'User')", - escapePowerShellSingleQuoted(original), - ) - } else { - script = "[Environment]::SetEnvironmentVariable('NODE_EXTRA_CA_CERTS', $null, 'User')" - } - return runPowerShellAsCurrentUser(ctx, script) -} - -func runPowerShellAsCurrentUser(ctx context.Context, script string) error { - _, err := platform.RunAsCurrentUser(ctx, "powershell", []string{ - "-NoProfile", - "-NonInteractive", - "-Command", - script, - }) - return err -} - -func escapePowerShellSingleQuoted(value string) string { - return strings.ReplaceAll(value, "'", "''") -} diff --git a/internal/certconfig/pip.go b/internal/certconfig/pip.go deleted file mode 100644 index a07ab576..00000000 --- a/internal/certconfig/pip.go +++ /dev/null @@ -1,138 +0,0 @@ -package certconfig - -import ( - "context" - "encoding/json" - "fmt" - "os" - "path/filepath" - "strings" - - "github.com/AikidoSec/safechain-internals/internal/platform" - "github.com/AikidoSec/safechain-internals/internal/utils" -) - -type pipTrustConfigurator interface { - Install(context.Context) error - Uninstall(context.Context) error -} - -type pipConfigurator struct { - trust pipTrustConfigurator -} - -type pipCertSetting struct { - EnvVar string `json:"env_var"` - Path string `json:"path"` -} - -func newPipConfigurator() Configurator { - return &pipConfigurator{ - trust: newPipTrustConfigurator(pipCombinedCaBundlePath()), - } -} - -func (c *pipConfigurator) Name() string { - return "pip" -} - -func originalPipCertPath() string { - return filepath.Join(platform.GetRunDir(), "endpoint-protection-pip-original-cert-path.txt") -} - -func ensureOriginalPipCert(ctx context.Context) (pipCertSetting, error) { - return ensureOriginalPipCertAt(ctx, originalPipCertPath(), runPipCertLookup) -} - -func ensureOriginalPipCertAt( - ctx context.Context, - savedPath string, - lookup func(context.Context) (pipCertSetting, error), -) (pipCertSetting, error) { - if data, err := os.ReadFile(savedPath); err == nil { - return parseSavedPipCertSetting(data) - } - - original, err := lookup(ctx) - if err != nil { - return pipCertSetting{}, fmt.Errorf("read existing pip certificate configuration: %w", err) - } - original.Path = strings.TrimSpace(original.Path) - - data, err := json.Marshal(original) - if err != nil { - return pipCertSetting{}, fmt.Errorf("marshal existing pip certificate configuration: %w", err) - } - if err := os.WriteFile(savedPath, data, 0o600); err != nil { - return pipCertSetting{}, fmt.Errorf("persist existing pip certificate configuration: %w", err) - } - return original, nil -} - -func (c *pipConfigurator) Install(ctx context.Context) error { - original, err := ensureOriginalPipCert(ctx) - if err != nil { - return err - } - baseCACertBundle, err := resolvePipBaseCACertBundle(ctx, original.Path) - if err != nil { - return err - } - if _, err := ensurePipCombinedCABundle(baseCACertBundle); err != nil { - return err - } - return c.trust.Install(ctx) -} - -func (c *pipConfigurator) Uninstall(ctx context.Context) error { - if err := c.trust.Uninstall(ctx); err != nil { - return err - } - _ = os.Remove(originalPipCertPath()) - return removePipCombinedCABundle() -} - -func parseSavedPipCertSetting(data []byte) (pipCertSetting, error) { - var setting pipCertSetting - if err := json.Unmarshal(data, &setting); err == nil { - setting.Path = strings.TrimSpace(setting.Path) - return setting, nil - } - - // Backward compatibility with the earlier plain-text format. - return pipCertSetting{ - EnvVar: pipCertEnvVar, - Path: strings.TrimSpace(string(data)), - }, nil -} - -func resolvePipBaseCACertBundle(ctx context.Context, original string) (string, error) { - return resolvePipBaseCACertBundleAt(ctx, original, findSystemPipCABundle) -} - -func resolvePipBaseCACertBundleAt( - ctx context.Context, - original string, - findCertifi func(context.Context) string, -) (string, error) { - if original != "" { - return validatePipBaseCABundle(original) - } - - if certifi := findCertifi(ctx); certifi != "" { - return validatePipBaseCABundle(certifi) - } - - return "", fmt.Errorf("no usable pip CA bundle found; refusing to set PIP_CERT without a trusted base bundle") -} - -func validatePipBaseCABundle(path string) (string, error) { - expanded := utils.ExpandHomePath(strings.TrimSpace(path), platform.GetConfig().HomeDir) - if expanded == "" { - return "", fmt.Errorf("pip CA bundle path is empty") - } - if _, err := readAndValidatePEMBundle(expanded); err != nil { - return "", fmt.Errorf("pip CA bundle %s is invalid: %w", expanded, err) - } - return expanded, nil -} diff --git a/internal/certconfig/pip/pip.go b/internal/certconfig/pip/pip.go new file mode 100644 index 00000000..1db16e2b --- /dev/null +++ b/internal/certconfig/pip/pip.go @@ -0,0 +1,139 @@ +package pip + +import ( + "context" + "encoding/json" + "fmt" + "os" + "path/filepath" + "strings" + + "github.com/AikidoSec/safechain-internals/internal/certconfig/shared" + "github.com/AikidoSec/safechain-internals/internal/platform" + "github.com/AikidoSec/safechain-internals/internal/utils" +) + +type trustConfigurator interface { + Install(context.Context) error + Uninstall(context.Context) error +} + +type Configurator struct { + trust trustConfigurator +} + +type CertSetting struct { + EnvVar string `json:"env_var"` + Path string `json:"path"` +} + +func New() *Configurator { + return &Configurator{ + trust: newTrustConfigurator(shared.PipCombinedCaBundlePath()), + } +} + +func (c *Configurator) Name() string { + return "pip" +} + +func originalCertPath() string { + return filepath.Join(platform.GetRunDir(), "endpoint-protection-pip-original-cert-path.txt") +} + +func ensureOriginalCert(ctx context.Context) (CertSetting, error) { + return ensureOriginalCertAt(ctx, originalCertPath(), runCertLookup) +} + +func ensureOriginalCertAt( + ctx context.Context, + savedPath string, + lookup func(context.Context) (CertSetting, error), +) (CertSetting, error) { + if data, err := os.ReadFile(savedPath); err == nil { + return parseSavedCertSetting(data) + } + + original, err := lookup(ctx) + if err != nil { + return CertSetting{}, fmt.Errorf("read existing pip certificate configuration: %w", err) + } + original.Path = strings.TrimSpace(original.Path) + + data, err := json.Marshal(original) + if err != nil { + return CertSetting{}, fmt.Errorf("marshal existing pip certificate configuration: %w", err) + } + if err := os.WriteFile(savedPath, data, 0o600); err != nil { + return CertSetting{}, fmt.Errorf("persist existing pip certificate configuration: %w", err) + } + return original, nil +} + +func (c *Configurator) Install(ctx context.Context) error { + original, err := ensureOriginalCert(ctx) + if err != nil { + return err + } + baseCACertBundle, err := resolveBaseCACertBundle(ctx, original.Path) + if err != nil { + return err + } + if _, err := shared.EnsurePipCombinedCABundle(baseCACertBundle); err != nil { + return err + } + return c.trust.Install(ctx) +} + +func (c *Configurator) Uninstall(ctx context.Context) error { + if err := c.trust.Uninstall(ctx); err != nil { + return err + } + _ = os.Remove(originalCertPath()) + return shared.RemovePipCombinedCABundle() +} + +func parseSavedCertSetting(data []byte) (CertSetting, error) { + var setting CertSetting + if err := json.Unmarshal(data, &setting); err == nil { + setting.Path = strings.TrimSpace(setting.Path) + return setting, nil + } + + // Backward compatibility with the earlier plain-text format. + return CertSetting{ + EnvVar: CertEnvVar, + Path: strings.TrimSpace(string(data)), + }, nil +} + +func resolveBaseCACertBundle(ctx context.Context, original string) (string, error) { + return resolveBaseCACertBundleAt(ctx, original, findSystemCABundle) +} + +func resolveBaseCACertBundleAt( + ctx context.Context, + original string, + findCertifi func(context.Context) string, +) (string, error) { + if original != "" { + return validateBaseCABundle(original) + } + + if certifi := findCertifi(ctx); certifi != "" { + return validateBaseCABundle(certifi) + } + + return "", fmt.Errorf("no usable pip CA bundle found; refusing to set PIP_CERT without a trusted base bundle") +} + +func validateBaseCABundle(path string) (string, error) { + expanded := utils.ExpandHomePath(strings.TrimSpace(path), platform.GetConfig().HomeDir) + if expanded == "" { + return "", fmt.Errorf("pip CA bundle path is empty") + } + if _, err := shared.ReadAndValidatePEMBundle(expanded); err != nil { + return "", fmt.Errorf("pip CA bundle %s is invalid: %w", expanded, err) + } + return expanded, nil +} diff --git a/internal/certconfig/pip/pip_darwin.go b/internal/certconfig/pip/pip_darwin.go new file mode 100644 index 00000000..6a20a30e --- /dev/null +++ b/internal/certconfig/pip/pip_darwin.go @@ -0,0 +1,48 @@ +//go:build darwin + +package pip + +import ( + "context" + "fmt" + "path/filepath" + + "github.com/AikidoSec/safechain-internals/internal/certconfig/shared" + "github.com/AikidoSec/safechain-internals/internal/platform" +) + +var shellManagedBlockFormat = shared.ManagedBlockFormat{ + StartMarker: "# aikido-endpoint-pip-cert-config-start", + EndMarker: "# aikido-endpoint-pip-cert-config-end", +} + +type darwinTrustConfigurator struct { + targets []shared.DarwinShellTarget +} + +func newTrustConfigurator(bundlePath string) trustConfigurator { + comment := "# Allow pip to trust the SafeChain MITM CA while preserving user-provided roots." + posix := comment + "\n" + fmt.Sprintf("export PIP_CERT=%q", bundlePath) + fish := comment + "\n" + fmt.Sprintf("set -gx PIP_CERT %q", bundlePath) + + homeDir := platform.GetConfig().HomeDir + + return &darwinTrustConfigurator{ + targets: []shared.DarwinShellTarget{ + {Path: filepath.Join(homeDir, ".zshrc"), Body: posix}, + {Path: filepath.Join(homeDir, ".zprofile"), Body: posix}, + {Path: filepath.Join(homeDir, ".bash_profile"), Body: posix}, + {Path: filepath.Join(homeDir, ".bashrc"), Body: posix}, + {Path: filepath.Join(homeDir, ".profile"), Body: posix}, + {Path: filepath.Join(homeDir, ".config", "fish", "config.fish"), Body: fish, CreateIfMissing: true}, + }, + } +} + +func (c *darwinTrustConfigurator) Install(_ context.Context) error { + return shared.InstallDarwinShellTargets(c.targets, shellManagedBlockFormat) +} + +func (c *darwinTrustConfigurator) Uninstall(_ context.Context) error { + return shared.UninstallDarwinShellTargets(c.targets, shellManagedBlockFormat) +} diff --git a/internal/certconfig/pip/pip_env_darwin.go b/internal/certconfig/pip/pip_env_darwin.go new file mode 100644 index 00000000..e7b547b9 --- /dev/null +++ b/internal/certconfig/pip/pip_env_darwin.go @@ -0,0 +1,43 @@ +//go:build darwin + +package pip + +import ( + "context" + "os/exec" + + "github.com/AikidoSec/safechain-internals/internal/certconfig/shared" + "github.com/AikidoSec/safechain-internals/internal/platform" +) + +// certShellLookups tries login and interactive startup files separately +// because they are sourced in different modes: +// - Login non-interactive (-lc): ~/.zprofile or ~/.bash_profile +// - Interactive non-login (-ic): ~/.zshrc or ~/.bashrc +// - fish --login: ~/.config/fish/config.fish +// +// Each command wraps the result in a unique marker so shell startup noise does +// not contaminate the discovered Python CA bundle override. +var certShellLookups = []shared.ShellLookup{ + {Name: "zsh", Args: []string{"-lc", `if [ -n "${PIP_CERT:-}" ]; then printf 'AIKIDO_CERT=PIP_CERT:%s\n' "$PIP_CERT"; elif [ -n "${REQUESTS_CA_BUNDLE:-}" ]; then printf 'AIKIDO_CERT=REQUESTS_CA_BUNDLE:%s\n' "$REQUESTS_CA_BUNDLE"; elif [ -n "${SSL_CERT_FILE:-}" ]; then printf 'AIKIDO_CERT=SSL_CERT_FILE:%s\n' "$SSL_CERT_FILE"; fi`}}, + {Name: "zsh", Args: []string{"-ic", `if [ -n "${PIP_CERT:-}" ]; then printf 'AIKIDO_CERT=PIP_CERT:%s\n' "$PIP_CERT"; elif [ -n "${REQUESTS_CA_BUNDLE:-}" ]; then printf 'AIKIDO_CERT=REQUESTS_CA_BUNDLE:%s\n' "$REQUESTS_CA_BUNDLE"; elif [ -n "${SSL_CERT_FILE:-}" ]; then printf 'AIKIDO_CERT=SSL_CERT_FILE:%s\n' "$SSL_CERT_FILE"; fi`}}, + {Name: "bash", Args: []string{"-lc", `if [ -n "${PIP_CERT:-}" ]; then printf 'AIKIDO_CERT=PIP_CERT:%s\n' "$PIP_CERT"; elif [ -n "${REQUESTS_CA_BUNDLE:-}" ]; then printf 'AIKIDO_CERT=REQUESTS_CA_BUNDLE:%s\n' "$REQUESTS_CA_BUNDLE"; elif [ -n "${SSL_CERT_FILE:-}" ]; then printf 'AIKIDO_CERT=SSL_CERT_FILE:%s\n' "$SSL_CERT_FILE"; fi`}}, + {Name: "bash", Args: []string{"-ic", `if [ -n "${PIP_CERT:-}" ]; then printf 'AIKIDO_CERT=PIP_CERT:%s\n' "$PIP_CERT"; elif [ -n "${REQUESTS_CA_BUNDLE:-}" ]; then printf 'AIKIDO_CERT=REQUESTS_CA_BUNDLE:%s\n' "$REQUESTS_CA_BUNDLE"; elif [ -n "${SSL_CERT_FILE:-}" ]; then printf 'AIKIDO_CERT=SSL_CERT_FILE:%s\n' "$SSL_CERT_FILE"; fi`}}, + {Name: "fish", Args: []string{"--login", "-c", "if set -q PIP_CERT; printf 'AIKIDO_CERT=PIP_CERT:%s\n' $PIP_CERT; else if set -q REQUESTS_CA_BUNDLE; printf 'AIKIDO_CERT=REQUESTS_CA_BUNDLE:%s\n' $REQUESTS_CA_BUNDLE; else if set -q SSL_CERT_FILE; printf 'AIKIDO_CERT=SSL_CERT_FILE:%s\n' $SSL_CERT_FILE; end; end; end"}}, +} + +func runCertLookup(ctx context.Context) (CertSetting, error) { + for _, lookup := range certShellLookups { + shellPath, err := exec.LookPath(lookup.Name) + if err != nil { + continue + } + out, err := platform.RunAsCurrentUserWithPathEnv(ctx, shellPath, lookup.Args...) + if err == nil { + if setting := extractMarkedCertSetting(out); setting.Path != "" { + return setting, nil + } + } + } + return CertSetting{}, nil +} diff --git a/internal/certconfig/pip_env_windows.go b/internal/certconfig/pip/pip_env_windows.go similarity index 76% rename from internal/certconfig/pip_env_windows.go rename to internal/certconfig/pip/pip_env_windows.go index 4e762125..39a97119 100644 --- a/internal/certconfig/pip_env_windows.go +++ b/internal/certconfig/pip/pip_env_windows.go @@ -1,6 +1,6 @@ //go:build windows -package certconfig +package pip import ( "context" @@ -8,7 +8,7 @@ import ( "github.com/AikidoSec/safechain-internals/internal/platform" ) -func runPipCertLookup(ctx context.Context) (pipCertSetting, error) { +func runCertLookup(ctx context.Context) (CertSetting, error) { out, err := platform.RunAsCurrentUser(ctx, "powershell", []string{ "-NoProfile", "-NonInteractive", @@ -16,7 +16,7 @@ func runPipCertLookup(ctx context.Context) (pipCertSetting, error) { `if ($env:PIP_CERT) { [Console]::Write('PIP_CERT:' + $env:PIP_CERT) } elseif ($env:REQUESTS_CA_BUNDLE) { [Console]::Write('REQUESTS_CA_BUNDLE:' + $env:REQUESTS_CA_BUNDLE) } elseif ($env:SSL_CERT_FILE) { [Console]::Write('SSL_CERT_FILE:' + $env:SSL_CERT_FILE) }`, }) if err != nil { - return pipCertSetting{}, err + return CertSetting{}, err } - return parsePipCertSettingString(out), nil + return parseCertSettingString(out), nil } diff --git a/internal/certconfig/pip/pip_test.go b/internal/certconfig/pip/pip_test.go new file mode 100644 index 00000000..e2fd634f --- /dev/null +++ b/internal/certconfig/pip/pip_test.go @@ -0,0 +1,180 @@ +package pip + +import ( + "context" + "crypto/rand" + "crypto/rsa" + "crypto/x509" + "crypto/x509/pkix" + "encoding/pem" + "errors" + "math/big" + "os" + "path/filepath" + "testing" + "time" +) + +func TestEnsureOriginalCertFirstInstall(t *testing.T) { + savedPath := filepath.Join(t.TempDir(), "original.txt") + lookup := func(_ context.Context) (CertSetting, error) { + return CertSetting{EnvVar: CertEnvVar, Path: "/corporate/pip-ca.pem"}, nil + } + + got, err := ensureOriginalCertAt(context.Background(), savedPath, lookup) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if got.EnvVar != CertEnvVar || got.Path != "/corporate/pip-ca.pem" { + t.Fatalf("got %+v, want env=%q path=%q", got, CertEnvVar, "/corporate/pip-ca.pem") + } + + data, err := os.ReadFile(savedPath) + if err != nil { + t.Fatalf("saved file not written: %v", err) + } + parsed, err := parseSavedCertSetting(data) + if err != nil { + t.Fatalf("failed to parse saved state: %v", err) + } + if parsed.EnvVar != CertEnvVar || parsed.Path != "/corporate/pip-ca.pem" { + t.Fatalf("saved file contains %+v, want env=%q path=%q", parsed, CertEnvVar, "/corporate/pip-ca.pem") + } +} + +func TestEnsureOriginalCertFirstInstallNothingSet(t *testing.T) { + savedPath := filepath.Join(t.TempDir(), "original.txt") + lookup := func(_ context.Context) (CertSetting, error) { return CertSetting{}, nil } + + got, err := ensureOriginalCertAt(context.Background(), savedPath, lookup) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if got != (CertSetting{}) { + t.Fatalf("got %+v, want empty setting", got) + } + if _, err := os.Stat(savedPath); err != nil { + t.Fatalf("saved file not written for empty value: %v", err) + } +} + +func TestEnsureOriginalCertReinstallSkipsLookup(t *testing.T) { + savedPath := filepath.Join(t.TempDir(), "original.txt") + if err := os.WriteFile(savedPath, []byte("/saved/pip-ca.pem"), 0o600); err != nil { + t.Fatal(err) + } + + lookupCalled := false + lookup := func(_ context.Context) (CertSetting, error) { + lookupCalled = true + return CertSetting{EnvVar: CertEnvVar, Path: "/new-value/pip-ca.pem"}, nil + } + + got, err := ensureOriginalCertAt(context.Background(), savedPath, lookup) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if got.EnvVar != CertEnvVar || got.Path != "/saved/pip-ca.pem" { + t.Fatalf("got %+v, want env=%q path=%q", got, CertEnvVar, "/saved/pip-ca.pem") + } + if lookupCalled { + t.Fatal("lookup should not be called when saved file exists") + } +} + +func TestEnsureOriginalCertLookupError(t *testing.T) { + savedPath := filepath.Join(t.TempDir(), "original.txt") + lookup := func(_ context.Context) (CertSetting, error) { + return CertSetting{}, errors.New("shell not found") + } + + _, err := ensureOriginalCertAt(context.Background(), savedPath, lookup) + if err == nil { + t.Fatal("expected error when lookup fails") + } +} + +func TestParseSavedCertSettingLegacyFormat(t *testing.T) { + got, err := parseSavedCertSetting([]byte("/legacy/pip-ca.pem\n")) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if got.EnvVar != CertEnvVar || got.Path != "/legacy/pip-ca.pem" { + t.Fatalf("got %+v, want env=%q path=%q", got, CertEnvVar, "/legacy/pip-ca.pem") + } +} + +func TestResolveBaseCACertBundleUsesOriginalWhenPresent(t *testing.T) { + path := filepath.Join(t.TempDir(), "base.pem") + if err := os.WriteFile(path, []byte(mustCreateTestCertificatePEM(t, "pip-base")), 0o644); err != nil { + t.Fatal(err) + } + + got, err := resolveBaseCACertBundleAt(context.Background(), path, func(context.Context) string { + t.Fatal("certifi lookup should not run when original bundle exists") + return "" + }) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if got != path { + t.Fatalf("got %q, want %q", got, path) + } +} + +func TestResolveBaseCACertBundleFallsBackToCertifi(t *testing.T) { + path := filepath.Join(t.TempDir(), "certifi.pem") + if err := os.WriteFile(path, []byte(mustCreateTestCertificatePEM(t, "certifi-base")), 0o644); err != nil { + t.Fatal(err) + } + + got, err := resolveBaseCACertBundleAt(context.Background(), "", func(context.Context) string { + return path + }) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if got != path { + t.Fatalf("got %q, want %q", got, path) + } +} + +func TestResolveBaseCACertBundleFailsClosedWithoutBase(t *testing.T) { + _, err := resolveBaseCACertBundleAt(context.Background(), "", func(context.Context) string { + return "" + }) + if err == nil { + t.Fatal("expected error when no base bundle is available") + } +} + +func mustCreateTestCertificatePEM(t *testing.T, commonName string) string { + t.Helper() + + key, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + t.Fatalf("GenerateKey failed: %v", err) + } + + template := &x509.Certificate{ + SerialNumber: big.NewInt(1), + Subject: pkix.Name{ + CommonName: commonName, + }, + NotBefore: time.Now().Add(-time.Hour), + NotAfter: time.Now().Add(time.Hour), + KeyUsage: x509.KeyUsageCertSign | x509.KeyUsageDigitalSignature, + BasicConstraintsValid: true, + IsCA: true, + } + + der, err := x509.CreateCertificate(rand.Reader, template, template, &key.PublicKey, key) + if err != nil { + t.Fatalf("CreateCertificate failed: %v", err) + } + + return string(pem.EncodeToMemory(&pem.Block{ + Type: "CERTIFICATE", + Bytes: der, + })) +} diff --git a/internal/certconfig/pip/pip_utils.go b/internal/certconfig/pip/pip_utils.go new file mode 100644 index 00000000..0bc93159 --- /dev/null +++ b/internal/certconfig/pip/pip_utils.go @@ -0,0 +1,92 @@ +package pip + +import ( + "context" + "os" + "os/exec" + "strings" + + "github.com/AikidoSec/safechain-internals/internal/certconfig/shared" + "github.com/AikidoSec/safechain-internals/internal/platform" +) + +const ( + CertEnvVar = "PIP_CERT" + RequestsCABundleEnvVar = "REQUESTS_CA_BUNDLE" + SSLCertFileEnvVar = "SSL_CERT_FILE" +) + +func findSystemCABundle(ctx context.Context) string { + for _, pythonBin := range []string{"python3", "python"} { + pythonPath, err := exec.LookPath(pythonBin) + if err != nil { + continue + } + out, err := platform.RunAsCurrentUserWithPathEnv(ctx, pythonPath, "-c", "import certifi; print(certifi.where())") + if err != nil { + continue + } + path := strings.TrimSpace(out) + if path == "" { + continue + } + if _, err := os.Stat(path); err == nil { + return path + } + } + + // certifi is not installed — fall back to the path reported by Python's + // stdlib ssl module (available in every Python installation). + for _, pythonBin := range []string{"python3", "python"} { + pythonPath, err := exec.LookPath(pythonBin) + if err != nil { + continue + } + out, err := platform.RunAsCurrentUserWithPathEnv(ctx, pythonPath, "-c", + "import ssl; p = ssl.get_default_verify_paths(); print(p.cafile or p.openssl_cafile or '')") + if err != nil { + continue + } + path := strings.TrimSpace(out) + if path == "" { + continue + } + if _, err := os.Stat(path); err == nil { + return path + } + } + + return "" +} + +func extractMarkedCertSetting(output string) CertSetting { + return parseCertSettingString(shared.ExtractMarkedCertValue(output)) +} + +func parseCertSettingString(value string) CertSetting { + trimmed := strings.TrimSpace(value) + if trimmed == "" { + return CertSetting{} + } + + envVar, path, ok := strings.Cut(trimmed, ":") + if !ok { + return CertSetting{ + EnvVar: CertEnvVar, + Path: trimmed, + } + } + + switch envVar { + case CertEnvVar, RequestsCABundleEnvVar, SSLCertFileEnvVar: + return CertSetting{ + EnvVar: envVar, + Path: strings.TrimSpace(path), + } + default: + return CertSetting{ + EnvVar: CertEnvVar, + Path: trimmed, + } + } +} diff --git a/internal/certconfig/pip/pip_windows.go b/internal/certconfig/pip/pip_windows.go new file mode 100644 index 00000000..8f8a10f7 --- /dev/null +++ b/internal/certconfig/pip/pip_windows.go @@ -0,0 +1,70 @@ +//go:build windows + +package pip + +import ( + "context" + "fmt" + "os" + + "github.com/AikidoSec/safechain-internals/internal/certconfig/shared" +) + +type windowsTrustConfigurator struct { + bundlePath string +} + +func newTrustConfigurator(bundlePath string) trustConfigurator { + return &windowsTrustConfigurator{ + bundlePath: bundlePath, + } +} + +func (c *windowsTrustConfigurator) Install(ctx context.Context) error { + return shared.RunPowerShellAsCurrentUser( + ctx, + fmt.Sprintf( + "[Environment]::SetEnvironmentVariable('PIP_CERT', '%s', 'User')", + shared.EscapePowerShellSingleQuoted(c.bundlePath), + ), + ) +} + +func (c *windowsTrustConfigurator) Uninstall(ctx context.Context) error { + original := CertSetting{} + if data, err := os.ReadFile(originalCertPath()); err == nil { + parsed, parseErr := parseSavedCertSetting(data) + if parseErr == nil { + original = parsed + } + } + + script := restoreWindowsEnvScript(original) + return shared.RunPowerShellAsCurrentUser(ctx, script) +} + +func restoreWindowsEnvScript(original CertSetting) string { + path := shared.EscapePowerShellSingleQuoted(original.Path) + + switch original.EnvVar { + case RequestsCABundleEnvVar: + return fmt.Sprintf( + "[Environment]::SetEnvironmentVariable('PIP_CERT', $null, 'User'); [Environment]::SetEnvironmentVariable('REQUESTS_CA_BUNDLE', '%s', 'User')", + path, + ) + case SSLCertFileEnvVar: + return fmt.Sprintf( + "[Environment]::SetEnvironmentVariable('PIP_CERT', $null, 'User'); [Environment]::SetEnvironmentVariable('SSL_CERT_FILE', '%s', 'User')", + path, + ) + case CertEnvVar: + if original.Path != "" { + return fmt.Sprintf( + "[Environment]::SetEnvironmentVariable('PIP_CERT', '%s', 'User')", + path, + ) + } + } + + return "[Environment]::SetEnvironmentVariable('PIP_CERT', $null, 'User')" +} diff --git a/internal/certconfig/pip_windows_test.go b/internal/certconfig/pip/pip_windows_test.go similarity index 76% rename from internal/certconfig/pip_windows_test.go rename to internal/certconfig/pip/pip_windows_test.go index 8c583580..64dd1c53 100644 --- a/internal/certconfig/pip_windows_test.go +++ b/internal/certconfig/pip/pip_windows_test.go @@ -1,15 +1,15 @@ //go:build windows -package certconfig +package pip import ( "strings" "testing" ) -func TestRestoreWindowsPipEnvScriptRestoresRequestsBundle(t *testing.T) { - script := restoreWindowsPipEnvScript(pipCertSetting{ - EnvVar: requestsCABundleEnvVar, +func TestRestoreWindowsEnvScriptRestoresRequestsBundle(t *testing.T) { + script := restoreWindowsEnvScript(CertSetting{ + EnvVar: RequestsCABundleEnvVar, Path: `C:\corp\bundle.pem`, }) diff --git a/internal/certconfig/pip_darwin.go b/internal/certconfig/pip_darwin.go deleted file mode 100644 index 8bb6bc77..00000000 --- a/internal/certconfig/pip_darwin.go +++ /dev/null @@ -1,70 +0,0 @@ -//go:build darwin - -package certconfig - -import ( - "context" - "fmt" - "os" - "path/filepath" - - "github.com/AikidoSec/safechain-internals/internal/platform" - "github.com/AikidoSec/safechain-internals/internal/utils" -) - -var pipShellManagedBlockFormat = managedBlockFormat{ - startMarker: "# aikido-endpoint-pip-cert-config-start", - endMarker: "# aikido-endpoint-pip-cert-config-end", -} - -type darwinPipTrustConfigurator struct { - targets []darwinShellTarget -} - -func newPipTrustConfigurator(bundlePath string) pipTrustConfigurator { - comment := "# Allow pip to trust the SafeChain MITM CA while preserving user-provided roots." - posix := comment + "\n" + fmt.Sprintf("export PIP_CERT=%q", bundlePath) - fish := comment + "\n" + fmt.Sprintf("set -gx PIP_CERT %q", bundlePath) - - homeDir := platform.GetConfig().HomeDir - - return &darwinPipTrustConfigurator{ - targets: []darwinShellTarget{ - {path: filepath.Join(homeDir, ".zshrc"), body: posix}, - {path: filepath.Join(homeDir, ".zprofile"), body: posix}, - {path: filepath.Join(homeDir, ".bash_profile"), body: posix}, - {path: filepath.Join(homeDir, ".bashrc"), body: posix}, - {path: filepath.Join(homeDir, ".profile"), body: posix}, - {path: filepath.Join(homeDir, ".config", "fish", "config.fish"), body: fish, createIfMissing: true}, - }, - } -} - -func (c *darwinPipTrustConfigurator) Install(_ context.Context) error { - for _, target := range c.targets { - if _, err := os.Stat(target.path); err != nil { - if !os.IsNotExist(err) { - return fmt.Errorf("failed to stat %s: %w", target.path, err) - } - if !target.createIfMissing { - continue - } - if err := os.MkdirAll(filepath.Dir(target.path), 0o755); err != nil { - return fmt.Errorf("failed to create config dir for %s: %w", target.path, err) - } - } - if err := writeManagedBlock(target.path, target.body, 0o644, pipShellManagedBlockFormat); err != nil { - return err - } - } - return nil -} - -func (c *darwinPipTrustConfigurator) Uninstall(_ context.Context) error { - for _, target := range c.targets { - if err := utils.RemoveManagedBlock(target.path, 0o644, pipShellManagedBlockFormat.startMarker, pipShellManagedBlockFormat.endMarker); err != nil { - return err - } - } - return nil -} diff --git a/internal/certconfig/pip_env_darwin.go b/internal/certconfig/pip_env_darwin.go deleted file mode 100644 index 1e195861..00000000 --- a/internal/certconfig/pip_env_darwin.go +++ /dev/null @@ -1,42 +0,0 @@ -//go:build darwin - -package certconfig - -import ( - "context" - "os/exec" - - "github.com/AikidoSec/safechain-internals/internal/platform" -) - -// pipCertShellLookups tries login and interactive startup files separately -// because they are sourced in different modes: -// - Login non-interactive (-lc): ~/.zprofile or ~/.bash_profile -// - Interactive non-login (-ic): ~/.zshrc or ~/.bashrc -// - fish --login: ~/.config/fish/config.fish -// -// Each command wraps the result in a unique marker so shell startup noise does -// not contaminate the discovered Python CA bundle override. -var pipCertShellLookups = []shellLookup{ - {"zsh", []string{"-lc", `if [ -n "${PIP_CERT:-}" ]; then printf 'AIKIDO_CERT=PIP_CERT:%s\n' "$PIP_CERT"; elif [ -n "${REQUESTS_CA_BUNDLE:-}" ]; then printf 'AIKIDO_CERT=REQUESTS_CA_BUNDLE:%s\n' "$REQUESTS_CA_BUNDLE"; elif [ -n "${SSL_CERT_FILE:-}" ]; then printf 'AIKIDO_CERT=SSL_CERT_FILE:%s\n' "$SSL_CERT_FILE"; fi`}}, - {"zsh", []string{"-ic", `if [ -n "${PIP_CERT:-}" ]; then printf 'AIKIDO_CERT=PIP_CERT:%s\n' "$PIP_CERT"; elif [ -n "${REQUESTS_CA_BUNDLE:-}" ]; then printf 'AIKIDO_CERT=REQUESTS_CA_BUNDLE:%s\n' "$REQUESTS_CA_BUNDLE"; elif [ -n "${SSL_CERT_FILE:-}" ]; then printf 'AIKIDO_CERT=SSL_CERT_FILE:%s\n' "$SSL_CERT_FILE"; fi`}}, - {"bash", []string{"-lc", `if [ -n "${PIP_CERT:-}" ]; then printf 'AIKIDO_CERT=PIP_CERT:%s\n' "$PIP_CERT"; elif [ -n "${REQUESTS_CA_BUNDLE:-}" ]; then printf 'AIKIDO_CERT=REQUESTS_CA_BUNDLE:%s\n' "$REQUESTS_CA_BUNDLE"; elif [ -n "${SSL_CERT_FILE:-}" ]; then printf 'AIKIDO_CERT=SSL_CERT_FILE:%s\n' "$SSL_CERT_FILE"; fi`}}, - {"bash", []string{"-ic", `if [ -n "${PIP_CERT:-}" ]; then printf 'AIKIDO_CERT=PIP_CERT:%s\n' "$PIP_CERT"; elif [ -n "${REQUESTS_CA_BUNDLE:-}" ]; then printf 'AIKIDO_CERT=REQUESTS_CA_BUNDLE:%s\n' "$REQUESTS_CA_BUNDLE"; elif [ -n "${SSL_CERT_FILE:-}" ]; then printf 'AIKIDO_CERT=SSL_CERT_FILE:%s\n' "$SSL_CERT_FILE"; fi`}}, - {"fish", []string{"--login", "-c", "if set -q PIP_CERT; printf 'AIKIDO_CERT=PIP_CERT:%s\n' $PIP_CERT; else if set -q REQUESTS_CA_BUNDLE; printf 'AIKIDO_CERT=REQUESTS_CA_BUNDLE:%s\n' $REQUESTS_CA_BUNDLE; else if set -q SSL_CERT_FILE; printf 'AIKIDO_CERT=SSL_CERT_FILE:%s\n' $SSL_CERT_FILE; end; end; end"}}, -} - -func runPipCertLookup(ctx context.Context) (pipCertSetting, error) { - for _, lookup := range pipCertShellLookups { - shellPath, err := exec.LookPath(lookup.name) - if err != nil { - continue - } - out, err := platform.RunAsCurrentUserWithPathEnv(ctx, shellPath, lookup.args...) - if err == nil { - if setting := extractMarkedPipCertSetting(out); setting.Path != "" { - return setting, nil - } - } - } - return pipCertSetting{}, nil -} diff --git a/internal/certconfig/pip_test.go b/internal/certconfig/pip_test.go deleted file mode 100644 index c748a59c..00000000 --- a/internal/certconfig/pip_test.go +++ /dev/null @@ -1,142 +0,0 @@ -package certconfig - -import ( - "context" - "errors" - "os" - "path/filepath" - "testing" -) - -func TestEnsureOriginalPipCertFirstInstall(t *testing.T) { - savedPath := filepath.Join(t.TempDir(), "original.txt") - lookup := func(_ context.Context) (pipCertSetting, error) { - return pipCertSetting{EnvVar: pipCertEnvVar, Path: "/corporate/pip-ca.pem"}, nil - } - - got, err := ensureOriginalPipCertAt(context.Background(), savedPath, lookup) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - if got.EnvVar != pipCertEnvVar || got.Path != "/corporate/pip-ca.pem" { - t.Fatalf("got %+v, want env=%q path=%q", got, pipCertEnvVar, "/corporate/pip-ca.pem") - } - - data, err := os.ReadFile(savedPath) - if err != nil { - t.Fatalf("saved file not written: %v", err) - } - parsed, err := parseSavedPipCertSetting(data) - if err != nil { - t.Fatalf("failed to parse saved state: %v", err) - } - if parsed.EnvVar != pipCertEnvVar || parsed.Path != "/corporate/pip-ca.pem" { - t.Fatalf("saved file contains %+v, want env=%q path=%q", parsed, pipCertEnvVar, "/corporate/pip-ca.pem") - } -} - -func TestEnsureOriginalPipCertFirstInstallNothingSet(t *testing.T) { - savedPath := filepath.Join(t.TempDir(), "original.txt") - lookup := func(_ context.Context) (pipCertSetting, error) { return pipCertSetting{}, nil } - - got, err := ensureOriginalPipCertAt(context.Background(), savedPath, lookup) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - if got != (pipCertSetting{}) { - t.Fatalf("got %+v, want empty setting", got) - } - if _, err := os.Stat(savedPath); err != nil { - t.Fatalf("saved file not written for empty value: %v", err) - } -} - -func TestEnsureOriginalPipCertReinstallSkipsLookup(t *testing.T) { - savedPath := filepath.Join(t.TempDir(), "original.txt") - if err := os.WriteFile(savedPath, []byte("/saved/pip-ca.pem"), 0o600); err != nil { - t.Fatal(err) - } - - lookupCalled := false - lookup := func(_ context.Context) (pipCertSetting, error) { - lookupCalled = true - return pipCertSetting{EnvVar: pipCertEnvVar, Path: "/new-value/pip-ca.pem"}, nil - } - - got, err := ensureOriginalPipCertAt(context.Background(), savedPath, lookup) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - if got.EnvVar != pipCertEnvVar || got.Path != "/saved/pip-ca.pem" { - t.Fatalf("got %+v, want env=%q path=%q", got, pipCertEnvVar, "/saved/pip-ca.pem") - } - if lookupCalled { - t.Fatal("lookup should not be called when saved file exists") - } -} - -func TestEnsureOriginalPipCertLookupError(t *testing.T) { - savedPath := filepath.Join(t.TempDir(), "original.txt") - lookup := func(_ context.Context) (pipCertSetting, error) { - return pipCertSetting{}, errors.New("shell not found") - } - - _, err := ensureOriginalPipCertAt(context.Background(), savedPath, lookup) - if err == nil { - t.Fatal("expected error when lookup fails") - } -} - -func TestParseSavedPipCertSettingLegacyFormat(t *testing.T) { - got, err := parseSavedPipCertSetting([]byte("/legacy/pip-ca.pem\n")) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - if got.EnvVar != pipCertEnvVar || got.Path != "/legacy/pip-ca.pem" { - t.Fatalf("got %+v, want env=%q path=%q", got, pipCertEnvVar, "/legacy/pip-ca.pem") - } -} - -func TestResolvePipBaseCACertBundleUsesOriginalWhenPresent(t *testing.T) { - path := filepath.Join(t.TempDir(), "base.pem") - if err := os.WriteFile(path, []byte(mustCreateTestCertificatePEM(t, "pip-base")), 0o644); err != nil { - t.Fatal(err) - } - - got, err := resolvePipBaseCACertBundleAt(context.Background(), path, func(context.Context) string { - t.Fatal("certifi lookup should not run when original bundle exists") - return "" - }) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - if got != path { - t.Fatalf("got %q, want %q", got, path) - } -} - -func TestResolvePipBaseCACertBundleFallsBackToCertifi(t *testing.T) { - path := filepath.Join(t.TempDir(), "certifi.pem") - if err := os.WriteFile(path, []byte(mustCreateTestCertificatePEM(t, "certifi-base")), 0o644); err != nil { - t.Fatal(err) - } - - got, err := resolvePipBaseCACertBundleAt(context.Background(), "", func(context.Context) string { - return path - }) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - if got != path { - t.Fatalf("got %q, want %q", got, path) - } -} - -func TestResolvePipBaseCACertBundleFailsClosedWithoutBase(t *testing.T) { - _, err := resolvePipBaseCACertBundleAt(context.Background(), "", func(context.Context) string { - return "" - }) - if err == nil { - t.Fatal("expected error when no base bundle is available") - } -} diff --git a/internal/certconfig/pip_windows.go b/internal/certconfig/pip_windows.go deleted file mode 100644 index 929812f3..00000000 --- a/internal/certconfig/pip_windows.go +++ /dev/null @@ -1,68 +0,0 @@ -//go:build windows - -package certconfig - -import ( - "context" - "fmt" - "os" -) - -type windowsPipTrustConfigurator struct { - bundlePath string -} - -func newPipTrustConfigurator(bundlePath string) pipTrustConfigurator { - return &windowsPipTrustConfigurator{ - bundlePath: bundlePath, - } -} - -func (c *windowsPipTrustConfigurator) Install(ctx context.Context) error { - return runPowerShellAsCurrentUser( - ctx, - fmt.Sprintf( - "[Environment]::SetEnvironmentVariable('PIP_CERT', '%s', 'User')", - escapePowerShellSingleQuoted(c.bundlePath), - ), - ) -} - -func (c *windowsPipTrustConfigurator) Uninstall(ctx context.Context) error { - original := pipCertSetting{} - if data, err := os.ReadFile(originalPipCertPath()); err == nil { - parsed, parseErr := parseSavedPipCertSetting(data) - if parseErr == nil { - original = parsed - } - } - - script := restoreWindowsPipEnvScript(original) - return runPowerShellAsCurrentUser(ctx, script) -} - -func restoreWindowsPipEnvScript(original pipCertSetting) string { - path := escapePowerShellSingleQuoted(original.Path) - - switch original.EnvVar { - case requestsCABundleEnvVar: - return fmt.Sprintf( - "[Environment]::SetEnvironmentVariable('PIP_CERT', $null, 'User'); [Environment]::SetEnvironmentVariable('REQUESTS_CA_BUNDLE', '%s', 'User')", - path, - ) - case sslCertFileEnvVar: - return fmt.Sprintf( - "[Environment]::SetEnvironmentVariable('PIP_CERT', $null, 'User'); [Environment]::SetEnvironmentVariable('SSL_CERT_FILE', '%s', 'User')", - path, - ) - case pipCertEnvVar: - if original.Path != "" { - return fmt.Sprintf( - "[Environment]::SetEnvironmentVariable('PIP_CERT', '%s', 'User')", - path, - ) - } - } - - return "[Environment]::SetEnvironmentVariable('PIP_CERT', $null, 'User')" -} diff --git a/internal/certconfig/certbundle.go b/internal/certconfig/shared/certbundle.go similarity index 80% rename from internal/certconfig/certbundle.go rename to internal/certconfig/shared/certbundle.go index 8fa195ef..7953561c 100644 --- a/internal/certconfig/certbundle.go +++ b/internal/certconfig/shared/certbundle.go @@ -1,4 +1,4 @@ -package certconfig +package shared import ( "crypto/x509" @@ -18,32 +18,32 @@ const ( pipCombinedBundleName = "endpoint-protection-pip-combined-ca.pem" ) -func combinedCaBundlePath() string { +func CombinedCaBundlePath() string { return filepath.Join(platform.GetRunDir(), nodeCombinedBundleName) } -func pipCombinedCaBundlePath() string { +func PipCombinedCaBundlePath() string { return filepath.Join(platform.GetRunDir(), pipCombinedBundleName) } -// ensureCombinedCABundle writes the combined CA bundle containing the SafeChain CA +// EnsureCombinedCABundle writes the combined CA bundle containing the SafeChain CA // and, if non-empty, the user's pre-existing originalCACertsPath. The SafeChain CA // is mandatory — the call fails if it can't be read. The original is silently skipped // on error (missing file, invalid PEM, etc.). -func ensureCombinedCABundle(originalCACertsPath string) (string, error) { - return ensureCombinedCABundleAt(combinedCaBundlePath(), originalCACertsPath) +func EnsureCombinedCABundle(originalCACertsPath string) (string, error) { + return EnsureCombinedCABundleAt(CombinedCaBundlePath(), originalCACertsPath) } -// ensurePipCombinedCABundle builds the pip CA bundle. +// EnsurePipCombinedCABundle builds the pip CA bundle. // // Unlike NODE_EXTRA_CA_CERTS (which appends), PIP_CERT replaces pip's bundle // entirely. baseCACertsPath must already point to a validated PEM bundle that // pip should continue trusting after the SafeChain CA is added. -func ensurePipCombinedCABundle(baseCACertsPath string) (string, error) { - bundlePath := pipCombinedCaBundlePath() +func EnsurePipCombinedCABundle(baseCACertsPath string) (string, error) { + bundlePath := PipCombinedCaBundlePath() safeChainCACertPath := proxy.GetCaCertPath() - safeChainPayload, err := readAndValidatePEMBundle(safeChainCACertPath) + safeChainPayload, err := ReadAndValidatePEMBundle(safeChainCACertPath) if err != nil { return "", fmt.Errorf("failed to read SafeChain CA: %w", err) } @@ -55,7 +55,7 @@ func ensurePipCombinedCABundle(baseCACertsPath string) (string, error) { return "", fmt.Errorf("pip CA bundle path is empty") } if expanded != safeChainCACertPath && expanded != bundlePath { - payload, err := readAndValidatePEMBundle(expanded) + payload, err := ReadAndValidatePEMBundle(expanded) if err != nil { return "", fmt.Errorf("failed to read pip base CA bundle: %w", err) } @@ -68,9 +68,9 @@ func ensurePipCombinedCABundle(baseCACertsPath string) (string, error) { return bundlePath, nil } -func ensureCombinedCABundleAt(bundlePath string, originalCACertsPath string) (string, error) { +func EnsureCombinedCABundleAt(bundlePath string, originalCACertsPath string) (string, error) { safeChainCACertPath := proxy.GetCaCertPath() - safeChainPayload, err := readAndValidatePEMBundle(safeChainCACertPath) + safeChainPayload, err := ReadAndValidatePEMBundle(safeChainCACertPath) if err != nil { return "", fmt.Errorf("failed to read SafeChain CA: %w", err) } @@ -80,7 +80,7 @@ func ensureCombinedCABundleAt(bundlePath string, originalCACertsPath string) (st if originalCACertsPath != "" { expanded := utils.ExpandHomePath(strings.TrimSpace(originalCACertsPath), platform.GetConfig().HomeDir) if expanded != "" && expanded != safeChainCACertPath && expanded != bundlePath { - if payload, err := readAndValidatePEMBundle(expanded); err == nil { + if payload, err := ReadAndValidatePEMBundle(expanded); err == nil { parts = append(parts, payload) } } @@ -92,15 +92,15 @@ func ensureCombinedCABundleAt(bundlePath string, originalCACertsPath string) (st return bundlePath, nil } -func removeCombinedCABundle() error { - return removeCombinedCABundleAt(combinedCaBundlePath()) +func RemoveCombinedCABundle() error { + return RemoveCombinedCABundleAt(CombinedCaBundlePath()) } -func removePipCombinedCABundle() error { - return removeCombinedCABundleAt(pipCombinedCaBundlePath()) +func RemovePipCombinedCABundle() error { + return RemoveCombinedCABundleAt(PipCombinedCaBundlePath()) } -func removeCombinedCABundleAt(path string) error { +func RemoveCombinedCABundleAt(path string) error { err := os.Remove(path) if err != nil && !os.IsNotExist(err) { return fmt.Errorf("failed to remove combined CA bundle: %w", err) @@ -108,7 +108,7 @@ func removeCombinedCABundleAt(path string) error { return nil } -func readAndValidatePEMBundle(path string) (string, error) { +func ReadAndValidatePEMBundle(path string) (string, error) { info, err := os.Lstat(path) if err != nil { return "", err diff --git a/internal/certconfig/certbundle_test.go b/internal/certconfig/shared/certbundle_test.go similarity index 79% rename from internal/certconfig/certbundle_test.go rename to internal/certconfig/shared/certbundle_test.go index 779506a1..83437b2c 100644 --- a/internal/certconfig/certbundle_test.go +++ b/internal/certconfig/shared/certbundle_test.go @@ -1,4 +1,4 @@ -package certconfig +package shared import ( "crypto/rand" @@ -16,14 +16,14 @@ import ( func TestReadAndValidatePEMBundleValidCertificate(t *testing.T) { path := filepath.Join(t.TempDir(), "bundle.pem") - pemData := mustCreateTestCertificatePEM(t, "test-cert") + pemData := MustCreateTestCertificatePEM(t, "test-cert") if err := os.WriteFile(path, []byte(pemData), 0o644); err != nil { t.Fatal(err) } - got, err := readAndValidatePEMBundle(path) + got, err := ReadAndValidatePEMBundle(path) if err != nil { - t.Fatalf("readAndValidatePEMBundle failed: %v", err) + t.Fatalf("ReadAndValidatePEMBundle failed: %v", err) } if !strings.Contains(got, "BEGIN CERTIFICATE") { @@ -37,7 +37,7 @@ func TestReadAndValidatePEMBundleRejectsNonPEMContent(t *testing.T) { t.Fatal(err) } - _, err := readAndValidatePEMBundle(path) + _, err := ReadAndValidatePEMBundle(path) if err == nil { t.Fatal("expected error for non-PEM content") } @@ -48,7 +48,7 @@ func TestReadAndValidatePEMBundleRejectsSymlink(t *testing.T) { target := filepath.Join(dir, "target.pem") link := filepath.Join(dir, "link.pem") - pemData := mustCreateTestCertificatePEM(t, "symlink-test") + pemData := MustCreateTestCertificatePEM(t, "symlink-test") if err := os.WriteFile(target, []byte(pemData), 0o644); err != nil { t.Fatal(err) } @@ -56,13 +56,15 @@ func TestReadAndValidatePEMBundleRejectsSymlink(t *testing.T) { t.Skipf("symlink creation not supported: %v", err) } - _, err := readAndValidatePEMBundle(link) + _, err := ReadAndValidatePEMBundle(link) if err == nil { t.Fatal("expected error for symlinked bundle") } } -func mustCreateTestCertificatePEM(t *testing.T, commonName string) string { +// MustCreateTestCertificatePEM creates a self-signed test certificate PEM. +// Exported for use by sub-package tests. +func MustCreateTestCertificatePEM(t *testing.T, commonName string) string { t.Helper() key, err := rsa.GenerateKey(rand.Reader, 2048) diff --git a/internal/certconfig/shared/managed_block.go b/internal/certconfig/shared/managed_block.go new file mode 100644 index 00000000..ba1f56f8 --- /dev/null +++ b/internal/certconfig/shared/managed_block.go @@ -0,0 +1,56 @@ +package shared + +import ( + "fmt" + "os" + "strings" + + "github.com/AikidoSec/safechain-internals/internal/utils" +) + +type ManagedBlockFormat struct { + StartMarker string + EndMarker string +} + +func BuildManagedBlock(body string, format ManagedBlockFormat, newline string) string { + return format.StartMarker + newline + body + newline + format.EndMarker + newline +} + +func DetectNewline(content string) string { + if strings.Contains(content, "\r\n") { + return "\r\n" + } + return "\n" +} + +func HasTrailingNewline(content string) bool { + return strings.HasSuffix(content, "\n") || strings.HasSuffix(content, "\r\n") +} + +func WriteManagedBlock(path string, body string, perm os.FileMode, format ManagedBlockFormat) error { + content := "" + if data, err := os.ReadFile(path); err == nil { + content = string(data) + } else if !os.IsNotExist(err) { + return fmt.Errorf("failed to read %s: %w", path, err) + } + + newline := DetectNewline(content) + + stripped, _, err := utils.RemoveMarkedBlock(content, format.StartMarker, format.EndMarker) + if err != nil { + return fmt.Errorf("failed to remove existing managed block in %s: %w", path, err) + } + + if stripped != "" && !HasTrailingNewline(stripped) { + stripped += newline + } + + body = strings.ReplaceAll(body, "\r\n", "\n") + if newline != "\n" { + body = strings.ReplaceAll(body, "\n", newline) + } + + return os.WriteFile(path, []byte(stripped+BuildManagedBlock(body, format, newline)), perm) +} diff --git a/internal/certconfig/common_test.go b/internal/certconfig/shared/managed_block_test.go similarity index 78% rename from internal/certconfig/common_test.go rename to internal/certconfig/shared/managed_block_test.go index f3d3033a..935220be 100644 --- a/internal/certconfig/common_test.go +++ b/internal/certconfig/shared/managed_block_test.go @@ -1,4 +1,4 @@ -package certconfig +package shared import ( "os" @@ -13,9 +13,9 @@ func TestWriteManagedBlockReplacesExistingBlock(t *testing.T) { t.Helper() path := filepath.Join(t.TempDir(), "config.txt") - format := managedBlockFormat{ - startMarker: "# start", - endMarker: "# end", + format := ManagedBlockFormat{ + StartMarker: "# start", + EndMarker: "# end", } initial := strings.Join([]string{ @@ -30,8 +30,8 @@ func TestWriteManagedBlockReplacesExistingBlock(t *testing.T) { t.Fatal(err) } - if err := writeManagedBlock(path, "new", 0o644, format); err != nil { - t.Fatalf("writeManagedBlock failed: %v", err) + if err := WriteManagedBlock(path, "new", 0o644, format); err != nil { + t.Fatalf("WriteManagedBlock failed: %v", err) } gotBytes, err := os.ReadFile(path) @@ -53,9 +53,9 @@ func TestWriteManagedBlockReplacesExistingBlock(t *testing.T) { func TestWriteManagedBlockPreservesCRLF(t *testing.T) { path := filepath.Join(t.TempDir(), "config.txt") - format := managedBlockFormat{ - startMarker: "# start", - endMarker: "# end", + format := ManagedBlockFormat{ + StartMarker: "# start", + EndMarker: "# end", } initial := "before\r\nafter\r\n" @@ -63,8 +63,8 @@ func TestWriteManagedBlockPreservesCRLF(t *testing.T) { t.Fatal(err) } - if err := writeManagedBlock(path, "line1\nline2", 0o644, format); err != nil { - t.Fatalf("writeManagedBlock failed: %v", err) + if err := WriteManagedBlock(path, "line1\nline2", 0o644, format); err != nil { + t.Fatalf("WriteManagedBlock failed: %v", err) } gotBytes, err := os.ReadFile(path) @@ -80,9 +80,9 @@ func TestWriteManagedBlockPreservesCRLF(t *testing.T) { func TestRemoveManagedBlockRemovesOnlyManagedSection(t *testing.T) { path := filepath.Join(t.TempDir(), "config.txt") - format := managedBlockFormat{ - startMarker: "# start", - endMarker: "# end", + format := ManagedBlockFormat{ + StartMarker: "# start", + EndMarker: "# end", } initial := strings.Join([]string{ @@ -97,7 +97,7 @@ func TestRemoveManagedBlockRemovesOnlyManagedSection(t *testing.T) { t.Fatal(err) } - if err := utils.RemoveManagedBlock(path, 0o644, format.startMarker, format.endMarker); err != nil { + if err := utils.RemoveManagedBlock(path, 0o644, format.StartMarker, format.EndMarker); err != nil { t.Fatalf("RemoveManagedBlock failed: %v", err) } diff --git a/internal/certconfig/shared/marker.go b/internal/certconfig/shared/marker.go new file mode 100644 index 00000000..e6b1f52c --- /dev/null +++ b/internal/certconfig/shared/marker.go @@ -0,0 +1,17 @@ +package shared + +import "strings" + +const AikidoCertMarker = "AIKIDO_CERT=" + +// ExtractMarkedCertValue scans output for a line starting with AikidoCertMarker +// and returns the value after it. This tolerates arbitrary text before or after +// the marker line, which interactive shells may produce. +func ExtractMarkedCertValue(output string) string { + for line := range strings.SplitSeq(output, "\n") { + if strings.HasPrefix(line, AikidoCertMarker) { + return strings.TrimSpace(line[len(AikidoCertMarker):]) + } + } + return "" +} diff --git a/internal/certconfig/shared/marker_test.go b/internal/certconfig/shared/marker_test.go new file mode 100644 index 00000000..3f4b0534 --- /dev/null +++ b/internal/certconfig/shared/marker_test.go @@ -0,0 +1,27 @@ +package shared + +import "testing" + +func TestExtractMarkedCertValue(t *testing.T) { + tests := []struct { + name string + output string + want string + }{ + {"clean output", "AIKIDO_CERT=/path/to/ca.pem\n", "/path/to/ca.pem"}, + {"empty value", "AIKIDO_CERT=\n", ""}, + {"marker absent", "some random output\n", ""}, + {"marker buried in noise", "Welcome to zsh!\nAIKIDO_CERT=/corp/ca.pem\nsome trailing line", "/corp/ca.pem"}, + {"interactive startup noise before", "compinit output\n[oh-my-zsh]\nAIKIDO_CERT=/ca.pem\n", "/ca.pem"}, + {"whitespace trimmed", "AIKIDO_CERT= /ca.pem \n", "/ca.pem"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := ExtractMarkedCertValue(tt.output) + if got != tt.want { + t.Fatalf("got %q, want %q", got, tt.want) + } + }) + } +} diff --git a/internal/certconfig/shared/shell_darwin.go b/internal/certconfig/shared/shell_darwin.go new file mode 100644 index 00000000..a4b7b13b --- /dev/null +++ b/internal/certconfig/shared/shell_darwin.go @@ -0,0 +1,55 @@ +//go:build darwin + +package shared + +import ( + "fmt" + "os" + "path/filepath" + + "github.com/AikidoSec/safechain-internals/internal/utils" +) + +// ShellLookup describes how to query an environment variable from a specific shell. +type ShellLookup struct { + Name string + Args []string +} + +// DarwinShellTarget represents a shell startup file to inject a managed block into. +type DarwinShellTarget struct { + Path string + Body string + CreateIfMissing bool +} + +// InstallDarwinShellTargets writes managed blocks to the given shell targets. +func InstallDarwinShellTargets(targets []DarwinShellTarget, format ManagedBlockFormat) error { + for _, target := range targets { + if _, err := os.Stat(target.Path); err != nil { + if !os.IsNotExist(err) { + return fmt.Errorf("failed to stat %s: %w", target.Path, err) + } + if !target.CreateIfMissing { + continue + } + if err := os.MkdirAll(filepath.Dir(target.Path), 0o755); err != nil { + return fmt.Errorf("failed to create config dir for %s: %w", target.Path, err) + } + } + if err := WriteManagedBlock(target.Path, target.Body, 0o644, format); err != nil { + return err + } + } + return nil +} + +// UninstallDarwinShellTargets removes managed blocks from the given shell targets. +func UninstallDarwinShellTargets(targets []DarwinShellTarget, format ManagedBlockFormat) error { + for _, target := range targets { + if err := utils.RemoveManagedBlock(target.Path, 0o644, format.StartMarker, format.EndMarker); err != nil { + return err + } + } + return nil +} diff --git a/internal/certconfig/shared/shell_windows.go b/internal/certconfig/shared/shell_windows.go new file mode 100644 index 00000000..6a053bde --- /dev/null +++ b/internal/certconfig/shared/shell_windows.go @@ -0,0 +1,24 @@ +//go:build windows + +package shared + +import ( + "context" + "strings" + + "github.com/AikidoSec/safechain-internals/internal/platform" +) + +func RunPowerShellAsCurrentUser(ctx context.Context, script string) error { + _, err := platform.RunAsCurrentUser(ctx, "powershell", []string{ + "-NoProfile", + "-NonInteractive", + "-Command", + script, + }) + return err +} + +func EscapePowerShellSingleQuoted(value string) string { + return strings.ReplaceAll(value, "'", "''") +} diff --git a/internal/daemon/daemon.go b/internal/daemon/daemon.go index d97668ef..d2ea74e7 100644 --- a/internal/daemon/daemon.go +++ b/internal/daemon/daemon.go @@ -14,7 +14,7 @@ import ( "github.com/AikidoSec/safechain-internals/internal/config" "github.com/AikidoSec/safechain-internals/internal/constants" "github.com/AikidoSec/safechain-internals/internal/device" - "github.com/AikidoSec/safechain-internals/internal/dockerca" + "github.com/AikidoSec/safechain-internals/internal/certconfig/docker" "github.com/AikidoSec/safechain-internals/internal/ingress" "github.com/AikidoSec/safechain-internals/internal/platform" "github.com/AikidoSec/safechain-internals/internal/proxy" @@ -296,7 +296,7 @@ func (d *Daemon) runDockerCALoop(ctx context.Context) { cycleErr = runDockerCACycle(quietCtx) } else { // Docker was running but went offline. Probe the daemon. - cycleErr = dockerca.ProbeDockerDaemon(quietCtx) + cycleErr = docker.ProbeDockerDaemon(quietCtx) } if ctx.Err() != nil { return @@ -320,10 +320,10 @@ func (d *Daemon) runDockerCALoop(ctx context.Context) { } func runDockerCACycle(ctx context.Context) error { - if err := dockerca.InstallCAOnRunningContainers(ctx); err != nil { + if err := docker.InstallDockerCA(ctx); err != nil { return err } - return dockerca.WatchContainerStarts(ctx) + return docker.WatchContainerStarts(ctx) } func (d *Daemon) printDaemonStatus() { diff --git a/internal/setup/steps/03_configure_certificate_trust/step.go b/internal/setup/steps/03_configure_certificate_trust/step.go index cb622bb4..9b828f9c 100644 --- a/internal/setup/steps/03_configure_certificate_trust/step.go +++ b/internal/setup/steps/03_configure_certificate_trust/step.go @@ -19,7 +19,7 @@ func (s *Step) InstallName() string { } func (s *Step) InstallDescription() string { - return "Configures ecosystem-specific trust settings for npm, pip, and Firefox" + return "Configures ecosystem-specific trust settings for npm, pip, Firefox, and Docker" } func (s *Step) UninstallName() string {