From 3444ebc2155d2e4600733ffb1f07e075e85fe23b Mon Sep 17 00:00:00 2001 From: Nathan Johnson Date: Mon, 14 Apr 2025 18:25:27 -0500 Subject: [PATCH] Add TLS Hot Reload. Fixes #94 This is a very simple implementation. It optionally polls the disk for changes to the key / cert files, and attempts to reload them if it detects the modification time or the size has changed. --- README.md | 4 +- cmd/rest-server/dynamicchecker.go | 92 +++++++++++++++ cmd/rest-server/dynamicchecker_test.go | 150 +++++++++++++++++++++++++ cmd/rest-server/main.go | 34 +++++- handlers.go | 3 + 5 files changed, 277 insertions(+), 6 deletions(-) create mode 100644 cmd/rest-server/dynamicchecker.go create mode 100644 cmd/rest-server/dynamicchecker_test.go diff --git a/README.md b/README.md index 1d6245a..34ba755 100644 --- a/README.md +++ b/README.md @@ -41,7 +41,7 @@ Flags: --listen string listen address (default ":8000") --log filename write HTTP requests in the combined log format to the specified filename (use "-" for logging to stdout) --max-size int the maximum size of the repository in bytes - --no-auth disable .htpasswd authentication + --no-auth disable authentication --no-verify-upload do not verify the integrity of uploaded data. DO NOT enable unless the rest-server runs on a very low-power device --path string data directory (default "/tmp/restic") --private-repos users can only access their private repo @@ -51,6 +51,8 @@ Flags: --tls turn on TLS support --tls-cert string TLS certificate path --tls-key string TLS key path + --tls-load-dyn dynamically reload TLS key and cert file from disk if they change + --tls-load-dyn-poll duration poll at most once per interval when tls-load-dyn is enabled (default 1m0s) --tls-min-ver string TLS min version, one of (1.2|1.3) (default "1.2") -v, --version version for rest-server ``` diff --git a/cmd/rest-server/dynamicchecker.go b/cmd/rest-server/dynamicchecker.go new file mode 100644 index 0000000..b70fceb --- /dev/null +++ b/cmd/rest-server/dynamicchecker.go @@ -0,0 +1,92 @@ +package main + +import ( + "context" + "crypto/tls" + "io/fs" + "log" + "os" + "sync/atomic" + "time" +) + +type dynamicChecker struct { + certificate atomic.Pointer[tls.Certificate] + keyFile, certFile string + keyFileInfo, certFileInfo fs.FileInfo +} + +// newDynamicChecker creates a struct that holds the data we need to do +// dynamic certificate reloads from disk. If it cannot load the files +// or they are invalid, an error is returned. Following a successful +// instantiation, the getCertificate method will always return a valid +// certificate, and we should call the poll method to check for changes. +func newDynamicChecker(certFile, keyFile string) (*dynamicChecker, error) { + keyFileInfo, err := os.Stat(keyFile) + if err != nil { + return nil, err + } + certFileInfo, err := os.Stat(certFile) + if err != nil { + return nil, err + } + crt, err := tls.LoadX509KeyPair(certFile, keyFile) + if err != nil { + return nil, err + } + dc := &dynamicChecker{ + keyFile: keyFile, + certFile: certFile, + keyFileInfo: keyFileInfo, + certFileInfo: certFileInfo, + } + dc.certificate.Store(&crt) + return dc, nil +} + +// getCertificate - always returns a valid tls.Certificate and nil error. +func (dc *dynamicChecker) getCertificate(_ *tls.ClientHelloInfo) (*tls.Certificate, error) { + return dc.certificate.Load(), nil +} + +// poll runs in a goroutine and periodically polls the key and cert for +// updates. +func (dc *dynamicChecker) poll(ctx context.Context, interval time.Duration) { + go func() { + t := time.NewTimer(interval) + defer t.Stop() // go >= 1.23 means we don't have to check the return + for { + select { + case <-ctx.Done(): + return + case <-t.C: + keyFileInfo, err := os.Stat(dc.keyFile) + if err != nil { + log.Printf("could not stat keyFile %s, using previous cert: %s", dc.keyFile, err) + break // select + } + certFileInfo, err := os.Stat(dc.certFile) + if err != nil { + log.Printf("could not stat certFile %s, using previous cert: %s", dc.certFile, err) + break // select + } + if !keyFileInfo.ModTime().Equal(dc.keyFileInfo.ModTime()) || + keyFileInfo.Size() != dc.keyFileInfo.Size() || + !certFileInfo.ModTime().Equal(dc.certFileInfo.ModTime()) || + certFileInfo.Size() != dc.certFileInfo.Size() { + // they changed on disk, reload + crt, err := tls.LoadX509KeyPair(dc.certFile, dc.keyFile) + if err != nil { + log.Printf("could not load cert and key files, using previous cert: %s", err) + break // select + } + dc.certificate.Store(&crt) + dc.certFileInfo = certFileInfo + dc.keyFileInfo = keyFileInfo + log.Printf("successfully reloaded certificate from disk") + } + } // end select + t.Reset(interval) + } // end for + }() +} diff --git a/cmd/rest-server/dynamicchecker_test.go b/cmd/rest-server/dynamicchecker_test.go new file mode 100644 index 0000000..25eebe6 --- /dev/null +++ b/cmd/rest-server/dynamicchecker_test.go @@ -0,0 +1,150 @@ +package main + +import ( + "context" + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "crypto/x509" + "crypto/x509/pkix" + "encoding/pem" + "math/big" + "os" + "testing" + "time" +) + +func TestDynamicReload(t *testing.T) { + cert, key, err := generateCertFiles() + if err != nil { + t.Fatal(err) + } + t.Logf("created %s and %s files", cert, key) + t.Cleanup(func() { + _ = os.Remove(cert) + _ = os.Remove(key) + }) + err = generateSelfSigned(cert, key) + if err != nil { + t.Fatal(err) + } + dc, err := newDynamicChecker(cert, key) + if err != nil { + t.Fatal(err) + } + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + dc.poll(ctx, time.Second) + crt1Raw, err := dc.getCertificate(nil) + if err != nil { + t.Fatal(err) + } + crt1, err := x509.ParseCertificate(crt1Raw.Certificate[0]) + if err != nil { + t.Fatal(err) + } + err = generateSelfSigned(cert, key) + if err != nil { + t.Fatal(err) + } + time.Sleep(time.Second * 2) + crt2Raw, err := dc.getCertificate(nil) + if err != nil { + t.Fatal(err) + } + crt2, err := x509.ParseCertificate(crt2Raw.Certificate[0]) + if err != nil { + t.Fatal(err) + } + if crt1.SerialNumber.Cmp(crt2.SerialNumber) == 0 { + t.Fatal("expected certificate to be different") + } + t.Logf("cert 1 serial: %s cert 2 serial: %s", crt1.SerialNumber, crt2.SerialNumber) + // force a certificate + _ = os.Remove(cert) + time.Sleep(time.Second * 2) + crt3Raw, err := dc.getCertificate(nil) + if err != nil { + t.Fatal(err) + } + crt3, err := x509.ParseCertificate(crt3Raw.Certificate[0]) + if err != nil { + t.Fatal(err) + } + if crt2.SerialNumber.Cmp(crt3.SerialNumber) != 0 { + t.Fatal("expected certificate to be certificate") + } +} + +func generateCertFiles() (cert, key string, err error) { + certFile, err := os.CreateTemp("", "cert") + if err != nil { + return "", "", err + } + cert = certFile.Name() + _ = certFile.Close() + keyFile, err := os.CreateTemp("", "key") + if err != nil { + return "", "", err + } + key = keyFile.Name() + _ = keyFile.Close() + return cert, key, nil +} + +var serial = int64(9000) + +func NextSerial() *big.Int { + serial++ + return big.NewInt(serial) +} + +func generateSelfSigned(certFile, keyFile string) error { + pk, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + return err + } + template := &x509.Certificate{ + SerialNumber: NextSerial(), + Subject: pkix.Name{ + Organization: []string{"Widgets Inc"}, + }, + NotBefore: time.Now(), + NotAfter: time.Now().Add(time.Hour * 24), + KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, + BasicConstraintsValid: true, + DNSNames: []string{"localhost"}, + } + certDer, err := x509.CreateCertificate(rand.Reader, template, template, pk.Public(), pk) + if err != nil { + return err + } + keyDer, err := x509.MarshalECPrivateKey(pk) + if err != nil { + return err + } + keyFh, err := os.OpenFile(keyFile, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600) + if err != nil { + return err + } + defer func() { + _ = keyFh.Close() + }() + certFh, err := os.OpenFile(certFile, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600) + if err != nil { + return err + } + defer func() { + _ = certFh.Close() + }() + err = pem.Encode(certFh, &pem.Block{Type: "CERTIFICATE", Bytes: certDer}) + if err != nil { + return err + } + err = pem.Encode(keyFh, &pem.Block{Type: "EC PRIVATE KEY", Bytes: keyDer}) + if err != nil { + return err + } + return nil +} diff --git a/cmd/rest-server/main.go b/cmd/rest-server/main.go index dfa594d..3b67420 100644 --- a/cmd/rest-server/main.go +++ b/cmd/rest-server/main.go @@ -15,6 +15,7 @@ import ( "runtime/pprof" "sync" "syscall" + "time" restserver "github.com/restic/rest-server" "github.com/spf13/cobra" @@ -46,9 +47,10 @@ func newRestServerApp() *restServerApp { Version: fmt.Sprintf("rest-server %s compiled with %v on %v/%v\n", version, runtime.Version(), runtime.GOOS, runtime.GOARCH), }, Server: restserver.Server{ - Path: filepath.Join(os.TempDir(), "restic"), - Listen: ":8000", - TLSMinVer: "1.2", + Path: filepath.Join(os.TempDir(), "restic"), + Listen: ":8000", + TLSMinVer: "1.2", + TLSReloadTime: time.Minute, }, } rv.CmdRoot.RunE = rv.runRoot @@ -63,6 +65,9 @@ func newRestServerApp() *restServerApp { flags.BoolVar(&rv.Server.TLS, "tls", rv.Server.TLS, "turn on TLS support") flags.StringVar(&rv.Server.TLSCert, "tls-cert", rv.Server.TLSCert, "TLS certificate path") flags.StringVar(&rv.Server.TLSKey, "tls-key", rv.Server.TLSKey, "TLS key path") + flags.BoolVar(&rv.Server.TLSDynamicReload, "tls-load-dyn", rv.Server.TLSDynamicReload, "dynamically reload TLS key and cert file from disk if they change") + flags.DurationVar(&rv.Server.TLSReloadTime, "tls-load-dyn-poll", rv.Server.TLSReloadTime, "poll at most once per interval when tls-load-dyn is enabled") + flags.StringVar(&rv.Server.TLSMinVer, "tls-min-ver", rv.Server.TLSMinVer, "TLS min version, one of (1.2|1.3)") flags.BoolVar(&rv.Server.NoAuth, "no-auth", rv.Server.NoAuth, "disable authentication") flags.StringVar(&rv.Server.HtpasswdPath, "htpasswd-file", rv.Server.HtpasswdPath, "location of .htpasswd file (default: \"/.htpasswd)\"") @@ -198,19 +203,38 @@ func (app *restServerApp) runRoot(_ *cobra.Command, _ []string) error { default: return fmt.Errorf("Unsupported TLS min version: %s. Allowed versions are 1.2 or 1.3", app.Server.TLSMinVer) } - srv := &http.Server{ Handler: handler, TLSConfig: tlscfg, } + if enabledTLS { + if app.Server.TLSDynamicReload { + dc, err := newDynamicChecker(publicKey, privateKey) + if err != nil { + return fmt.Errorf("unable to load key pair: %w", err) + } + dc.poll(app.CmdRoot.Context(), app.Server.TLSReloadTime) + tlscfg.GetCertificate = dc.getCertificate + } else { + crt, err := tls.LoadX509KeyPair(publicKey, privateKey) + if err != nil { + return fmt.Errorf("unable to load key pair: %w", err) + } + tlscfg.Certificates = []tls.Certificate{crt} + } + } + // run server in background go func() { if !enabledTLS { err = srv.Serve(listener) } else { log.Printf("TLS enabled, private key %s, pubkey %v", privateKey, publicKey) - err = srv.ServeTLS(listener, publicKey, privateKey) + if app.Server.TLSDynamicReload { + log.Printf("TLS dynamic reloading enabled, will poll up to once every %s for changes", app.Server.TLSReloadTime) + } + err = srv.ServeTLS(listener, "", "") } if err != nil && !errors.Is(err, http.ErrServerClosed) { log.Fatalf("listen and serve returned err: %v", err) diff --git a/handlers.go b/handlers.go index 5938edd..03ac54e 100644 --- a/handlers.go +++ b/handlers.go @@ -8,6 +8,7 @@ import ( "path/filepath" "strings" "sync" + "time" "github.com/restic/rest-server/quota" "github.com/restic/rest-server/repo" @@ -24,6 +25,8 @@ type Server struct { TLSCert string TLSMinVer string TLS bool + TLSDynamicReload bool + TLSReloadTime time.Duration NoAuth bool ProxyAuthUsername string AppendOnly bool