Skip to content

Add TLS Hot Reload. Fixes #94 #342

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ Flags:
--listen string listen address (default ":8000")
--log filename write HTTP requests in the combined log format to the specified filename (use "-" for logging to stdout)
--max-size int the maximum size of the repository in bytes
--no-auth disable .htpasswd authentication
--no-auth disable authentication
--no-verify-upload do not verify the integrity of uploaded data. DO NOT enable unless the rest-server runs on a very low-power device
--path string data directory (default "/tmp/restic")
--private-repos users can only access their private repo
Expand All @@ -51,6 +51,8 @@ Flags:
--tls turn on TLS support
--tls-cert string TLS certificate path
--tls-key string TLS key path
--tls-load-dyn dynamically reload TLS key and cert file from disk if they change
--tls-load-dyn-poll duration poll at most once per interval when tls-load-dyn is enabled (default 1m0s)
--tls-min-ver string TLS min version, one of (1.2|1.3) (default "1.2")
-v, --version version for rest-server
```
Expand Down
92 changes: 92 additions & 0 deletions cmd/rest-server/dynamicchecker.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
package main

import (
"context"
"crypto/tls"
"io/fs"
"log"
"os"
"sync/atomic"
"time"
)

type dynamicChecker struct {
certificate atomic.Pointer[tls.Certificate]
keyFile, certFile string
keyFileInfo, certFileInfo fs.FileInfo
}

// newDynamicChecker creates a struct that holds the data we need to do
// dynamic certificate reloads from disk. If it cannot load the files
// or they are invalid, an error is returned. Following a successful
// instantiation, the getCertificate method will always return a valid
// certificate, and we should call the poll method to check for changes.
func newDynamicChecker(certFile, keyFile string) (*dynamicChecker, error) {
keyFileInfo, err := os.Stat(keyFile)
if err != nil {
return nil, err
}
certFileInfo, err := os.Stat(certFile)
if err != nil {
return nil, err
}
crt, err := tls.LoadX509KeyPair(certFile, keyFile)
if err != nil {
return nil, err
}
dc := &dynamicChecker{
keyFile: keyFile,
certFile: certFile,
keyFileInfo: keyFileInfo,
certFileInfo: certFileInfo,
}
dc.certificate.Store(&crt)
return dc, nil
}

// getCertificate - always returns a valid tls.Certificate and nil error.
func (dc *dynamicChecker) getCertificate(_ *tls.ClientHelloInfo) (*tls.Certificate, error) {
return dc.certificate.Load(), nil
}

// poll runs in a goroutine and periodically polls the key and cert for
// updates.
func (dc *dynamicChecker) poll(ctx context.Context, interval time.Duration) {
go func() {
t := time.NewTimer(interval)
defer t.Stop() // go >= 1.23 means we don't have to check the return
for {
select {
case <-ctx.Done():
return
case <-t.C:
keyFileInfo, err := os.Stat(dc.keyFile)
if err != nil {
log.Printf("could not stat keyFile %s, using previous cert: %s", dc.keyFile, err)
break // select
}
certFileInfo, err := os.Stat(dc.certFile)
if err != nil {
log.Printf("could not stat certFile %s, using previous cert: %s", dc.certFile, err)
break // select
}
if !keyFileInfo.ModTime().Equal(dc.keyFileInfo.ModTime()) ||
keyFileInfo.Size() != dc.keyFileInfo.Size() ||
!certFileInfo.ModTime().Equal(dc.certFileInfo.ModTime()) ||
certFileInfo.Size() != dc.certFileInfo.Size() {
// they changed on disk, reload
crt, err := tls.LoadX509KeyPair(dc.certFile, dc.keyFile)
if err != nil {
log.Printf("could not load cert and key files, using previous cert: %s", err)
break // select
}
dc.certificate.Store(&crt)
dc.certFileInfo = certFileInfo
dc.keyFileInfo = keyFileInfo
log.Printf("successfully reloaded certificate from disk")
}
} // end select
t.Reset(interval)
} // end for
}()
}
150 changes: 150 additions & 0 deletions cmd/rest-server/dynamicchecker_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
package main

import (
"context"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/x509"
"crypto/x509/pkix"
"encoding/pem"
"math/big"
"os"
"testing"
"time"
)

func TestDynamicReload(t *testing.T) {
cert, key, err := generateCertFiles()
if err != nil {
t.Fatal(err)
}
t.Logf("created %s and %s files", cert, key)
t.Cleanup(func() {
_ = os.Remove(cert)
_ = os.Remove(key)
})
err = generateSelfSigned(cert, key)
if err != nil {
t.Fatal(err)
}
dc, err := newDynamicChecker(cert, key)
if err != nil {
t.Fatal(err)
}
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
dc.poll(ctx, time.Second)
crt1Raw, err := dc.getCertificate(nil)
if err != nil {
t.Fatal(err)
}
crt1, err := x509.ParseCertificate(crt1Raw.Certificate[0])
if err != nil {
t.Fatal(err)
}
err = generateSelfSigned(cert, key)
if err != nil {
t.Fatal(err)
}
time.Sleep(time.Second * 2)
crt2Raw, err := dc.getCertificate(nil)
if err != nil {
t.Fatal(err)
}
crt2, err := x509.ParseCertificate(crt2Raw.Certificate[0])
if err != nil {
t.Fatal(err)
}
if crt1.SerialNumber.Cmp(crt2.SerialNumber) == 0 {
t.Fatal("expected certificate to be different")
}
t.Logf("cert 1 serial: %s cert 2 serial: %s", crt1.SerialNumber, crt2.SerialNumber)
// force a certificate
_ = os.Remove(cert)
time.Sleep(time.Second * 2)
crt3Raw, err := dc.getCertificate(nil)
if err != nil {
t.Fatal(err)
}
crt3, err := x509.ParseCertificate(crt3Raw.Certificate[0])
if err != nil {
t.Fatal(err)
}
if crt2.SerialNumber.Cmp(crt3.SerialNumber) != 0 {
t.Fatal("expected certificate to be certificate")
}
}

func generateCertFiles() (cert, key string, err error) {
certFile, err := os.CreateTemp("", "cert")
if err != nil {
return "", "", err
}
cert = certFile.Name()
_ = certFile.Close()
keyFile, err := os.CreateTemp("", "key")
if err != nil {
return "", "", err
}
key = keyFile.Name()
_ = keyFile.Close()
return cert, key, nil
}

var serial = int64(9000)

func NextSerial() *big.Int {
serial++
return big.NewInt(serial)
}

func generateSelfSigned(certFile, keyFile string) error {
pk, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil {
return err
}
template := &x509.Certificate{
SerialNumber: NextSerial(),
Subject: pkix.Name{
Organization: []string{"Widgets Inc"},
},
NotBefore: time.Now(),
NotAfter: time.Now().Add(time.Hour * 24),
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
BasicConstraintsValid: true,
DNSNames: []string{"localhost"},
}
certDer, err := x509.CreateCertificate(rand.Reader, template, template, pk.Public(), pk)
if err != nil {
return err
}
keyDer, err := x509.MarshalECPrivateKey(pk)
if err != nil {
return err
}
keyFh, err := os.OpenFile(keyFile, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600)
if err != nil {
return err
}
defer func() {
_ = keyFh.Close()
}()
certFh, err := os.OpenFile(certFile, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600)
if err != nil {
return err
}
defer func() {
_ = certFh.Close()
}()
err = pem.Encode(certFh, &pem.Block{Type: "CERTIFICATE", Bytes: certDer})
if err != nil {
return err
}
err = pem.Encode(keyFh, &pem.Block{Type: "EC PRIVATE KEY", Bytes: keyDer})
if err != nil {
return err
}
return nil
}
34 changes: 29 additions & 5 deletions cmd/rest-server/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import (
"runtime/pprof"
"sync"
"syscall"
"time"

restserver "github.com/restic/rest-server"
"github.com/spf13/cobra"
Expand Down Expand Up @@ -46,9 +47,10 @@ func newRestServerApp() *restServerApp {
Version: fmt.Sprintf("rest-server %s compiled with %v on %v/%v\n", version, runtime.Version(), runtime.GOOS, runtime.GOARCH),
},
Server: restserver.Server{
Path: filepath.Join(os.TempDir(), "restic"),
Listen: ":8000",
TLSMinVer: "1.2",
Path: filepath.Join(os.TempDir(), "restic"),
Listen: ":8000",
TLSMinVer: "1.2",
TLSReloadTime: time.Minute,
},
}
rv.CmdRoot.RunE = rv.runRoot
Expand All @@ -63,6 +65,9 @@ func newRestServerApp() *restServerApp {
flags.BoolVar(&rv.Server.TLS, "tls", rv.Server.TLS, "turn on TLS support")
flags.StringVar(&rv.Server.TLSCert, "tls-cert", rv.Server.TLSCert, "TLS certificate path")
flags.StringVar(&rv.Server.TLSKey, "tls-key", rv.Server.TLSKey, "TLS key path")
flags.BoolVar(&rv.Server.TLSDynamicReload, "tls-load-dyn", rv.Server.TLSDynamicReload, "dynamically reload TLS key and cert file from disk if they change")
flags.DurationVar(&rv.Server.TLSReloadTime, "tls-load-dyn-poll", rv.Server.TLSReloadTime, "poll at most once per interval when tls-load-dyn is enabled")

flags.StringVar(&rv.Server.TLSMinVer, "tls-min-ver", rv.Server.TLSMinVer, "TLS min version, one of (1.2|1.3)")
flags.BoolVar(&rv.Server.NoAuth, "no-auth", rv.Server.NoAuth, "disable authentication")
flags.StringVar(&rv.Server.HtpasswdPath, "htpasswd-file", rv.Server.HtpasswdPath, "location of .htpasswd file (default: \"<data directory>/.htpasswd)\"")
Expand Down Expand Up @@ -198,19 +203,38 @@ func (app *restServerApp) runRoot(_ *cobra.Command, _ []string) error {
default:
return fmt.Errorf("Unsupported TLS min version: %s. Allowed versions are 1.2 or 1.3", app.Server.TLSMinVer)
}

srv := &http.Server{
Handler: handler,
TLSConfig: tlscfg,
}

if enabledTLS {
if app.Server.TLSDynamicReload {
dc, err := newDynamicChecker(publicKey, privateKey)
if err != nil {
return fmt.Errorf("unable to load key pair: %w", err)
}
dc.poll(app.CmdRoot.Context(), app.Server.TLSReloadTime)
tlscfg.GetCertificate = dc.getCertificate
} else {
crt, err := tls.LoadX509KeyPair(publicKey, privateKey)
if err != nil {
return fmt.Errorf("unable to load key pair: %w", err)
}
tlscfg.Certificates = []tls.Certificate{crt}
}
}

// run server in background
go func() {
if !enabledTLS {
err = srv.Serve(listener)
} else {
log.Printf("TLS enabled, private key %s, pubkey %v", privateKey, publicKey)
err = srv.ServeTLS(listener, publicKey, privateKey)
if app.Server.TLSDynamicReload {
log.Printf("TLS dynamic reloading enabled, will poll up to once every %s for changes", app.Server.TLSReloadTime)
}
err = srv.ServeTLS(listener, "", "")
}
if err != nil && !errors.Is(err, http.ErrServerClosed) {
log.Fatalf("listen and serve returned err: %v", err)
Expand Down
3 changes: 3 additions & 0 deletions handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"path/filepath"
"strings"
"sync"
"time"

"github.com/restic/rest-server/quota"
"github.com/restic/rest-server/repo"
Expand All @@ -24,6 +25,8 @@ type Server struct {
TLSCert string
TLSMinVer string
TLS bool
TLSDynamicReload bool
TLSReloadTime time.Duration
NoAuth bool
ProxyAuthUsername string
AppendOnly bool
Expand Down