From d359cacfb75484126b2f3eecd6cf023776854e47 Mon Sep 17 00:00:00 2001 From: Reinier Criel Date: Thu, 9 Apr 2026 13:20:17 -0700 Subject: [PATCH] Check if PEM bundle is valid after downloading --- internal/certconfig/certbundle.go | 62 +---------------- internal/certconfig/certbundle_test.go | 50 -------------- internal/proxy/ca.go | 20 ++++-- internal/utils/pem.go | 74 ++++++++++++++++++++ internal/utils/pem_test.go | 94 ++++++++++++++++++++++++++ 5 files changed, 185 insertions(+), 115 deletions(-) create mode 100644 internal/utils/pem.go create mode 100644 internal/utils/pem_test.go diff --git a/internal/certconfig/certbundle.go b/internal/certconfig/certbundle.go index 8fa195ef..0c38eea8 100644 --- a/internal/certconfig/certbundle.go +++ b/internal/certconfig/certbundle.go @@ -1,8 +1,6 @@ package certconfig import ( - "crypto/x509" - "encoding/pem" "fmt" "os" "path/filepath" @@ -109,63 +107,5 @@ func removeCombinedCABundleAt(path string) error { } func readAndValidatePEMBundle(path string) (string, error) { - info, err := os.Lstat(path) - if err != nil { - return "", err - } - if info.Mode()&os.ModeSymlink != 0 { - return "", fmt.Errorf("refusing to read symlinked certificate bundle %s", path) - } - if !info.Mode().IsRegular() { - return "", fmt.Errorf("refusing to read non-regular certificate bundle %s", path) - } - - data, err := os.ReadFile(path) - if err != nil { - return "", err - } - - normalized := strings.TrimSpace(strings.ReplaceAll(string(data), "\r\n", "\n")) - if normalized == "" { - return "", fmt.Errorf("certificate bundle %s is empty", path) - } - - var ( - rest = []byte(normalized) - blocks []string - certCount int - ) - - for len(rest) > 0 { - block, remaining := pem.Decode(rest) - if block == nil { - if strings.TrimSpace(string(rest)) != "" { - return "", fmt.Errorf("certificate bundle %s contains non-PEM content", path) - } - break - } - - if block.Type != "CERTIFICATE" { - return "", fmt.Errorf("certificate bundle %s contains unsupported PEM block type %q", path, block.Type) - } - - // Include the certificate regardless of whether Go's strict x509 parser - // accepts it — legacy root CAs (e.g. negative serial numbers) are valid - // for OpenSSL/pip but rejected by Go. We still parse to catch genuinely - // malformed DER; those are skipped rather than failing the whole bundle. - if _, err := x509.ParseCertificate(block.Bytes); err != nil { - rest = remaining - continue - } - - blocks = append(blocks, strings.TrimSpace(string(pem.EncodeToMemory(block)))) - certCount++ - rest = remaining - } - - if certCount == 0 { - return "", fmt.Errorf("certificate bundle %s does not contain any valid certificates", path) - } - - return strings.Join(blocks, "\n"), nil + return utils.ReadAndValidatePEMBundle(path) } diff --git a/internal/certconfig/certbundle_test.go b/internal/certconfig/certbundle_test.go index 779506a1..7d9c4f06 100644 --- a/internal/certconfig/certbundle_test.go +++ b/internal/certconfig/certbundle_test.go @@ -7,60 +7,10 @@ import ( "crypto/x509/pkix" "encoding/pem" "math/big" - "os" - "path/filepath" - "strings" "testing" "time" ) -func TestReadAndValidatePEMBundleValidCertificate(t *testing.T) { - path := filepath.Join(t.TempDir(), "bundle.pem") - pemData := mustCreateTestCertificatePEM(t, "test-cert") - if err := os.WriteFile(path, []byte(pemData), 0o644); err != nil { - t.Fatal(err) - } - - got, err := readAndValidatePEMBundle(path) - if err != nil { - t.Fatalf("readAndValidatePEMBundle failed: %v", err) - } - - if !strings.Contains(got, "BEGIN CERTIFICATE") { - t.Fatalf("expected certificate PEM in output, got %q", got) - } -} - -func TestReadAndValidatePEMBundleRejectsNonPEMContent(t *testing.T) { - path := filepath.Join(t.TempDir(), "bundle.pem") - if err := os.WriteFile(path, []byte("not a certificate"), 0o644); err != nil { - t.Fatal(err) - } - - _, err := readAndValidatePEMBundle(path) - if err == nil { - t.Fatal("expected error for non-PEM content") - } -} - -func TestReadAndValidatePEMBundleRejectsSymlink(t *testing.T) { - dir := t.TempDir() - target := filepath.Join(dir, "target.pem") - link := filepath.Join(dir, "link.pem") - - pemData := mustCreateTestCertificatePEM(t, "symlink-test") - if err := os.WriteFile(target, []byte(pemData), 0o644); err != nil { - t.Fatal(err) - } - if err := os.Symlink(target, link); err != nil { - t.Skipf("symlink creation not supported: %v", err) - } - - _, err := readAndValidatePEMBundle(link) - if err == nil { - t.Fatal("expected error for symlinked bundle") - } -} func mustCreateTestCertificatePEM(t *testing.T, commonName string) string { t.Helper() diff --git a/internal/proxy/ca.go b/internal/proxy/ca.go index 811e0de5..a57e5dc4 100644 --- a/internal/proxy/ca.go +++ b/internal/proxy/ca.go @@ -45,20 +45,32 @@ func DownloadCACertFromL7Proxy() error { return fmt.Errorf("failed to get meta url: %v", err) } - if err := utils.DownloadBinary(context.Background(), metaUrl+"/ca", GetCaCertPath()); err != nil { + certPath := GetCaCertPath() + if err := utils.DownloadBinary(context.Background(), metaUrl+"/ca", certPath); err != nil { return fmt.Errorf("failed to download ca cert: %v", err) } - log.Println("Downloaded CA cert from proxy:", GetCaCertPath()) + if _, err := utils.ReadAndValidatePEMBundle(certPath); err != nil { + os.Remove(certPath) + return fmt.Errorf("downloaded ca cert is invalid: %w", err) + } + + log.Println("Downloaded CA cert from proxy:", certPath) return nil } func DownloadCACertFromL4Proxy(ctx context.Context) error { - if err := utils.DownloadBinary(ctx, l4HijackCAURL, GetCaCertPath()); err != nil { + certPath := GetCaCertPath() + if err := utils.DownloadBinary(ctx, l4HijackCAURL, certPath); err != nil { return fmt.Errorf("failed to download CA cert from L4 proxy: %v", err) } - log.Println("Downloaded CA cert from L4 proxy:", GetCaCertPath()) + if _, err := utils.ReadAndValidatePEMBundle(certPath); err != nil { + os.Remove(certPath) + return fmt.Errorf("downloaded L4 CA cert is invalid: %w", err) + } + + log.Println("Downloaded CA cert from L4 proxy:", certPath) return nil } diff --git a/internal/utils/pem.go b/internal/utils/pem.go new file mode 100644 index 00000000..83fd5f6b --- /dev/null +++ b/internal/utils/pem.go @@ -0,0 +1,74 @@ +package utils + +import ( + "crypto/x509" + "encoding/pem" + "fmt" + "os" + "strings" +) + +// ReadAndValidatePEMBundle reads a PEM certificate bundle from path, validates +// that it contains at least one well-formed CERTIFICATE block, and returns the +// normalised PEM content. Symlinks and non-regular files are rejected. +func ReadAndValidatePEMBundle(path string) (string, error) { + info, err := os.Lstat(path) + if err != nil { + return "", err + } + if info.Mode()&os.ModeSymlink != 0 { + return "", fmt.Errorf("refusing to read symlinked certificate bundle %s", path) + } + if !info.Mode().IsRegular() { + return "", fmt.Errorf("refusing to read non-regular certificate bundle %s", path) + } + + data, err := os.ReadFile(path) + if err != nil { + return "", err + } + + normalized := strings.TrimSpace(strings.ReplaceAll(string(data), "\r\n", "\n")) + if normalized == "" { + return "", fmt.Errorf("certificate bundle %s is empty", path) + } + + var ( + rest = []byte(normalized) + blocks []string + certCount int + ) + + for len(rest) > 0 { + block, remaining := pem.Decode(rest) + if block == nil { + if strings.TrimSpace(string(rest)) != "" { + return "", fmt.Errorf("certificate bundle %s contains non-PEM content", path) + } + break + } + + if block.Type != "CERTIFICATE" { + return "", fmt.Errorf("certificate bundle %s contains unsupported PEM block type %q", path, block.Type) + } + + // Include the certificate regardless of whether Go's strict x509 parser + // accepts it — legacy root CAs (e.g. negative serial numbers) are valid + // for OpenSSL/pip but rejected by Go. We still parse to catch genuinely + // malformed DER; those are skipped rather than failing the whole bundle. + if _, err := x509.ParseCertificate(block.Bytes); err != nil { + rest = remaining + continue + } + + blocks = append(blocks, strings.TrimSpace(string(pem.EncodeToMemory(block)))) + certCount++ + rest = remaining + } + + if certCount == 0 { + return "", fmt.Errorf("certificate bundle %s does not contain any valid certificates", path) + } + + return strings.Join(blocks, "\n"), nil +} diff --git a/internal/utils/pem_test.go b/internal/utils/pem_test.go new file mode 100644 index 00000000..48033a2f --- /dev/null +++ b/internal/utils/pem_test.go @@ -0,0 +1,94 @@ +package utils + +import ( + "crypto/rand" + "crypto/rsa" + "crypto/x509" + "crypto/x509/pkix" + "encoding/pem" + "math/big" + "os" + "path/filepath" + "strings" + "testing" + "time" +) + +func TestReadAndValidatePEMBundleValidCertificate(t *testing.T) { + path := filepath.Join(t.TempDir(), "bundle.pem") + pemData := mustCreateTestCertificatePEM(t, "test-cert") + if err := os.WriteFile(path, []byte(pemData), 0o644); err != nil { + t.Fatal(err) + } + + got, err := ReadAndValidatePEMBundle(path) + if err != nil { + t.Fatalf("ReadAndValidatePEMBundle failed: %v", err) + } + + if !strings.Contains(got, "BEGIN CERTIFICATE") { + t.Fatalf("expected certificate PEM in output, got %q", got) + } +} + +func TestReadAndValidatePEMBundleRejectsNonPEMContent(t *testing.T) { + path := filepath.Join(t.TempDir(), "bundle.pem") + if err := os.WriteFile(path, []byte("not a certificate"), 0o644); err != nil { + t.Fatal(err) + } + + _, err := ReadAndValidatePEMBundle(path) + if err == nil { + t.Fatal("expected error for non-PEM content") + } +} + +func TestReadAndValidatePEMBundleRejectsSymlink(t *testing.T) { + dir := t.TempDir() + target := filepath.Join(dir, "target.pem") + link := filepath.Join(dir, "link.pem") + + pemData := mustCreateTestCertificatePEM(t, "symlink-test") + if err := os.WriteFile(target, []byte(pemData), 0o644); err != nil { + t.Fatal(err) + } + if err := os.Symlink(target, link); err != nil { + t.Skipf("symlink creation not supported: %v", err) + } + + _, err := ReadAndValidatePEMBundle(link) + if err == nil { + t.Fatal("expected error for symlinked bundle") + } +} + +func mustCreateTestCertificatePEM(t *testing.T, commonName string) string { + t.Helper() + + key, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + t.Fatalf("GenerateKey failed: %v", err) + } + + template := &x509.Certificate{ + SerialNumber: big.NewInt(1), + Subject: pkix.Name{ + CommonName: commonName, + }, + NotBefore: time.Now().Add(-time.Hour), + NotAfter: time.Now().Add(time.Hour), + KeyUsage: x509.KeyUsageCertSign | x509.KeyUsageDigitalSignature, + BasicConstraintsValid: true, + IsCA: true, + } + + der, err := x509.CreateCertificate(rand.Reader, template, template, &key.PublicKey, key) + if err != nil { + t.Fatalf("CreateCertificate failed: %v", err) + } + + return string(pem.EncodeToMemory(&pem.Block{ + Type: "CERTIFICATE", + Bytes: der, + })) +}