Skip to content

Commit 9d759bc

Browse files
Add TLS Hot Reload. Fixes #94
This is a very simple implementation. It optionally polls the disk for changes to the key / cert files, and attempts to reload them if it detects the modification time or the size has changed.
1 parent eee73d3 commit 9d759bc

File tree

5 files changed

+277
-6
lines changed

5 files changed

+277
-6
lines changed

README.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ Flags:
4141
--listen string listen address (default ":8000")
4242
--log filename write HTTP requests in the combined log format to the specified filename (use "-" for logging to stdout)
4343
--max-size int the maximum size of the repository in bytes
44-
--no-auth disable .htpasswd authentication
44+
--no-auth disable authentication
4545
--no-verify-upload do not verify the integrity of uploaded data. DO NOT enable unless the rest-server runs on a very low-power device
4646
--path string data directory (default "/tmp/restic")
4747
--private-repos users can only access their private repo
@@ -51,6 +51,8 @@ Flags:
5151
--tls turn on TLS support
5252
--tls-cert string TLS certificate path
5353
--tls-key string TLS key path
54+
--tls-load-dyn dynamically reload TLS key and cert file from disk if they change
55+
--tls-load-dyn-poll duration poll at most once per interval when tls-load-dyn is enabled (default 1m0s)
5456
--tls-min-ver string TLS min version, one of (1.2|1.3) (default "1.2")
5557
-v, --version version for rest-server
5658
```

cmd/rest-server/dynamicchecker.go

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
package main
2+
3+
import (
4+
"context"
5+
"crypto/tls"
6+
"io/fs"
7+
"log"
8+
"os"
9+
"sync/atomic"
10+
"time"
11+
)
12+
13+
type dynamicChecker struct {
14+
certificate atomic.Pointer[tls.Certificate]
15+
keyFile, certFile string
16+
keyFileInfo, certFileInfo fs.FileInfo
17+
}
18+
19+
// newDynamicChecker creates a struct that holds the data we need to do
20+
// dynamic certificate reloads from disk. If it cannot load the files
21+
// or they are invalid, an error is returned. Following a successful
22+
// instantiation, the getCertificate method will always return a valid
23+
// certificate, and it wil poll for changes every threshold.
24+
func newDynamicChecker(certFile, keyFile string (*dynamicChecker, error) {
25+
keyFileInfo, err := os.Stat(keyFile)
26+
if err != nil {
27+
return nil, err
28+
}
29+
certFileInfo, err := os.Stat(certFile)
30+
if err != nil {
31+
return nil, err
32+
}
33+
crt, err := tls.LoadX509KeyPair(certFile, keyFile)
34+
if err != nil {
35+
return nil, err
36+
}
37+
dc := &dynamicChecker{
38+
keyFile: keyFile,
39+
certFile: certFile,
40+
keyFileInfo: keyFileInfo,
41+
certFileInfo: certFileInfo,
42+
}
43+
dc.certificate.Store(&crt)
44+
return dc, nil
45+
}
46+
47+
// getCertificate - always returns a valid tls.Certificate and nil error.
48+
func (dc *dynamicChecker) getCertificate(_ *tls.ClientHelloInfo) (*tls.Certificate, error) {
49+
return dc.certificate.Load(), nil
50+
}
51+
52+
// poll runs in a goroutine and periodically polls the key and cert for
53+
// updates.
54+
func (dc *dynamicChecker) poll(ctx context.Context, interval time.Duration) {
55+
go func() {
56+
t := time.NewTimer(interval)
57+
defer t.Stop() // go >= 1.23 means we don't have to check the return
58+
for {
59+
select {
60+
case <-ctx.Done():
61+
return
62+
case <-t.C:
63+
keyFileInfo, err := os.Stat(dc.keyFile)
64+
if err != nil {
65+
log.Printf("could not stat keyFile %s: %s, using certificate tls", dc.keyFile, err)
66+
break // select
67+
}
68+
certFileInfo, err := os.Stat(dc.certFile)
69+
if err != nil {
70+
log.Printf("could not stat certFile %s: %s, using certificate tls", dc.certFile, err)
71+
break // select
72+
}
73+
if !keyFileInfo.ModTime().Equal(dc.keyFileInfo.ModTime()) ||
74+
keyFileInfo.Size() != dc.keyFileInfo.Size() ||
75+
!certFileInfo.ModTime().Equal(dc.certFileInfo.ModTime()) ||
76+
certFileInfo.Size() != dc.certFileInfo.Size() {
77+
// they changed on disk, reload
78+
crt, err := tls.LoadX509KeyPair(dc.certFile, dc.keyFile)
79+
if err != nil {
80+
log.Printf("could not load cert and key files: %s", err)
81+
break // select
82+
}
83+
dc.certificate.Store(&crt)
84+
dc.certFileInfo = certFileInfo
85+
dc.keyFileInfo = keyFileInfo
86+
log.Printf("reloaded certificate from disk as it was modified")
87+
}
88+
} // end select
89+
t.Reset(interval)
90+
} // end for
91+
}()
92+
}
Lines changed: 150 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,150 @@
1+
package main
2+
3+
import (
4+
"context"
5+
"crypto/ecdsa"
6+
"crypto/elliptic"
7+
"crypto/rand"
8+
"crypto/x509"
9+
"crypto/x509/pkix"
10+
"encoding/pem"
11+
"math/big"
12+
"os"
13+
"testing"
14+
"time"
15+
)
16+
17+
func TestDynamicReload(t *testing.T) {
18+
cert, key, err := generateCertFiles()
19+
if err != nil {
20+
t.Fatal(err)
21+
}
22+
t.Logf("created %s and %s files", cert, key)
23+
t.Cleanup(func() {
24+
_ = os.Remove(cert)
25+
_ = os.Remove(key)
26+
})
27+
err = generateSelfSigned(cert, key)
28+
if err != nil {
29+
t.Fatal(err)
30+
}
31+
dc, err := newDynamicChecker(cert, key, time.Second)
32+
if err != nil {
33+
t.Fatal(err)
34+
}
35+
ctx, cancel := context.WithCancel(context.Background())
36+
defer cancel()
37+
dc.poll(ctx)
38+
crt1Raw, err := dc.getCertificate(nil)
39+
if err != nil {
40+
t.Fatal(err)
41+
}
42+
crt1, err := x509.ParseCertificate(crt1Raw.Certificate[0])
43+
if err != nil {
44+
t.Fatal(err)
45+
}
46+
err = generateSelfSigned(cert, key)
47+
if err != nil {
48+
t.Fatal(err)
49+
}
50+
time.Sleep(time.Second * 2)
51+
crt2Raw, err := dc.getCertificate(nil)
52+
if err != nil {
53+
t.Fatal(err)
54+
}
55+
crt2, err := x509.ParseCertificate(crt2Raw.Certificate[0])
56+
if err != nil {
57+
t.Fatal(err)
58+
}
59+
if crt1.SerialNumber.Cmp(crt2.SerialNumber) == 0 {
60+
t.Fatal("expected certificate to be different")
61+
}
62+
t.Logf("cert 1 serial: %s cert 2 serial: %s", crt1.SerialNumber, crt2.SerialNumber)
63+
// force a certificate
64+
_ = os.Remove(cert)
65+
time.Sleep(time.Second * 2)
66+
crt3Raw, err := dc.getCertificate(nil)
67+
if err != nil {
68+
t.Fatal(err)
69+
}
70+
crt3, err := x509.ParseCertificate(crt3Raw.Certificate[0])
71+
if err != nil {
72+
t.Fatal(err)
73+
}
74+
if crt2.SerialNumber.Cmp(crt3.SerialNumber) != 0 {
75+
t.Fatal("expected certificate to be certificate")
76+
}
77+
}
78+
79+
func generateCertFiles() (cert, key string, err error) {
80+
certFile, err := os.CreateTemp("", "cert")
81+
if err != nil {
82+
return "", "", err
83+
}
84+
cert = certFile.Name()
85+
_ = certFile.Close()
86+
keyFile, err := os.CreateTemp("", "key")
87+
if err != nil {
88+
return "", "", err
89+
}
90+
key = keyFile.Name()
91+
_ = keyFile.Close()
92+
return cert, key, nil
93+
}
94+
95+
var serial = int64(9000)
96+
97+
func NextSerial() *big.Int {
98+
serial++
99+
return big.NewInt(serial)
100+
}
101+
102+
func generateSelfSigned(certFile, keyFile string) error {
103+
pk, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
104+
if err != nil {
105+
return err
106+
}
107+
template := &x509.Certificate{
108+
SerialNumber: NextSerial(),
109+
Subject: pkix.Name{
110+
Organization: []string{"Widgets Inc"},
111+
},
112+
NotBefore: time.Now(),
113+
NotAfter: time.Now().Add(time.Hour * 24),
114+
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
115+
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
116+
BasicConstraintsValid: true,
117+
DNSNames: []string{"localhost"},
118+
}
119+
certDer, err := x509.CreateCertificate(rand.Reader, template, template, pk.Public(), pk)
120+
if err != nil {
121+
return err
122+
}
123+
keyDer, err := x509.MarshalECPrivateKey(pk)
124+
if err != nil {
125+
return err
126+
}
127+
keyFh, err := os.OpenFile(keyFile, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600)
128+
if err != nil {
129+
return err
130+
}
131+
defer func() {
132+
_ = keyFh.Close()
133+
}()
134+
certFh, err := os.OpenFile(certFile, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600)
135+
if err != nil {
136+
return err
137+
}
138+
defer func() {
139+
_ = certFh.Close()
140+
}()
141+
err = pem.Encode(certFh, &pem.Block{Type: "CERTIFICATE", Bytes: certDer})
142+
if err != nil {
143+
return err
144+
}
145+
err = pem.Encode(keyFh, &pem.Block{Type: "EC PRIVATE KEY", Bytes: keyDer})
146+
if err != nil {
147+
return err
148+
}
149+
return nil
150+
}

cmd/rest-server/main.go

Lines changed: 29 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ import (
1515
"runtime/pprof"
1616
"sync"
1717
"syscall"
18+
"time"
1819

1920
restserver "github.com/restic/rest-server"
2021
"github.com/spf13/cobra"
@@ -46,9 +47,10 @@ func newRestServerApp() *restServerApp {
4647
Version: fmt.Sprintf("rest-server %s compiled with %v on %v/%v\n", version, runtime.Version(), runtime.GOOS, runtime.GOARCH),
4748
},
4849
Server: restserver.Server{
49-
Path: filepath.Join(os.TempDir(), "restic"),
50-
Listen: ":8000",
51-
TLSMinVer: "1.2",
50+
Path: filepath.Join(os.TempDir(), "restic"),
51+
Listen: ":8000",
52+
TLSMinVer: "1.2",
53+
TLSReloadTime: time.Minute,
5254
},
5355
}
5456
rv.CmdRoot.RunE = rv.runRoot
@@ -63,6 +65,9 @@ func newRestServerApp() *restServerApp {
6365
flags.BoolVar(&rv.Server.TLS, "tls", rv.Server.TLS, "turn on TLS support")
6466
flags.StringVar(&rv.Server.TLSCert, "tls-cert", rv.Server.TLSCert, "TLS certificate path")
6567
flags.StringVar(&rv.Server.TLSKey, "tls-key", rv.Server.TLSKey, "TLS key path")
68+
flags.BoolVar(&rv.Server.TLSDynamicReload, "tls-load-dyn", rv.Server.TLSDynamicReload, "dynamically reload TLS key and cert file from disk if they change")
69+
flags.DurationVar(&rv.Server.TLSReloadTime, "tls-load-dyn-poll", rv.Server.TLSReloadTime, "poll at most once per interval when tls-load-dyn is enabled")
70+
6671
flags.StringVar(&rv.Server.TLSMinVer, "tls-min-ver", rv.Server.TLSMinVer, "TLS min version, one of (1.2|1.3)")
6772
flags.BoolVar(&rv.Server.NoAuth, "no-auth", rv.Server.NoAuth, "disable authentication")
6873
flags.StringVar(&rv.Server.HtpasswdPath, "htpasswd-file", rv.Server.HtpasswdPath, "location of .htpasswd file (default: \"<data directory>/.htpasswd)\"")
@@ -198,19 +203,38 @@ func (app *restServerApp) runRoot(_ *cobra.Command, _ []string) error {
198203
default:
199204
return fmt.Errorf("Unsupported TLS min version: %s. Allowed versions are 1.2 or 1.3", app.Server.TLSMinVer)
200205
}
201-
202206
srv := &http.Server{
203207
Handler: handler,
204208
TLSConfig: tlscfg,
205209
}
206210

211+
if enabledTLS {
212+
if app.Server.TLSDynamicReload {
213+
dc, err := newDynamicChecker(publicKey, privateKey, app.Server.TLSReloadTime)
214+
if err != nil {
215+
return fmt.Errorf("unable to load key pair: %w", err)
216+
}
217+
dc.poll(app.CmdRoot.Context())
218+
tlscfg.GetCertificate = dc.getCertificate
219+
} else {
220+
crt, err := tls.LoadX509KeyPair(publicKey, privateKey)
221+
if err != nil {
222+
return fmt.Errorf("unable to load key pair: %w", err)
223+
}
224+
tlscfg.Certificates = []tls.Certificate{crt}
225+
}
226+
}
227+
207228
// run server in background
208229
go func() {
209230
if !enabledTLS {
210231
err = srv.Serve(listener)
211232
} else {
212233
log.Printf("TLS enabled, private key %s, pubkey %v", privateKey, publicKey)
213-
err = srv.ServeTLS(listener, publicKey, privateKey)
234+
if app.Server.TLSDynamicReload {
235+
log.Printf("TLS dynamic reloading enabled, will poll up to once every %s for changes", app.Server.TLSReloadTime)
236+
}
237+
err = srv.ServeTLS(listener, "", "")
214238
}
215239
if err != nil && !errors.Is(err, http.ErrServerClosed) {
216240
log.Fatalf("listen and serve returned err: %v", err)

handlers.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import (
88
"path/filepath"
99
"strings"
1010
"sync"
11+
"time"
1112

1213
"github.com/restic/rest-server/quota"
1314
"github.com/restic/rest-server/repo"
@@ -24,6 +25,8 @@ type Server struct {
2425
TLSCert string
2526
TLSMinVer string
2627
TLS bool
28+
TLSDynamicReload bool
29+
TLSReloadTime time.Duration
2730
NoAuth bool
2831
ProxyAuthUsername string
2932
AppendOnly bool

0 commit comments

Comments
 (0)