Skip to content

Commit 920e809

Browse files
authored
Merge pull request #83 from netlify/nats-basic-auth
[nats] Implement more auth methods
2 parents 697102f + 0ef5271 commit 920e809

File tree

3 files changed

+123
-69
lines changed

3 files changed

+123
-69
lines changed

messaging/nats.go

Lines changed: 51 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
package messaging
22

33
import (
4+
"crypto/tls"
45
"fmt"
56
"io/ioutil"
7+
"strings"
68
"time"
79

810
"github.com/nats-io/go-nats"
@@ -24,11 +26,16 @@ func ConfigureNatsConnection(config *nconf.NatsConfig, log logrus.FieldLogger) (
2426
if log == nil {
2527
log = silent
2628
}
29+
2730
if config == nil {
2831
log.Debug("Skipping nats connection because there is no config")
2932
return nil, nil
3033
}
3134

35+
if !config.TLS.Enabled {
36+
log.Warn("Connection to NATS servers is not using TLS")
37+
}
38+
3239
if err := config.LoadServerNames(); err != nil {
3340
return nil, errors.Wrap(err, "Failed to discover new servers")
3441
}
@@ -43,14 +50,29 @@ func ConfigureNatsConnection(config *nconf.NatsConfig, log logrus.FieldLogger) (
4350
}
4451

4552
func ConnectToNats(config *nconf.NatsConfig, opts ...nats.Option) (*nats.Conn, error) {
46-
if config.TLS != nil {
47-
tlsConfig, err := config.TLS.TLSConfig()
48-
if err != nil {
49-
return nil, errors.Wrap(err, "Failed to configure TLS")
50-
}
51-
if tlsConfig != nil {
52-
opts = append(opts, nats.Secure(tlsConfig))
53+
tlsConfig, err := config.TLS.TLSConfig()
54+
if err != nil {
55+
return nil, errors.Wrap(err, "Failed to configure TLS")
56+
}
57+
58+
// If TLS is enabled, make the connection to NATS secure
59+
if config.TLS.Enabled {
60+
opts = append(opts, NatsRootCAs(tlsConfig))
61+
}
62+
63+
switch strings.ToLower(config.Auth.Method) {
64+
case nconf.NatsAuthMethodUser:
65+
opts = append(opts, nats.UserInfo(config.Auth.User, config.Auth.Password))
66+
case nconf.NatsAuthMethodToken:
67+
opts = append(opts, nats.Token(config.Auth.Token))
68+
case nconf.NatsAuthMethodTLS:
69+
// if using TLS auth, make sure the client certificate is loaded
70+
if tlsConfig == nil || len(tlsConfig.Certificates) == 0 {
71+
return nil, fmt.Errorf("TLS auth method is configured but no certificate was loaded")
5372
}
73+
opts = append(opts, nats.Secure(tlsConfig))
74+
default:
75+
return nil, fmt.Errorf("Invalid auth method: '%s'", config.Auth.Method)
5476
}
5577

5678
return nats.Connect(config.ServerString(), opts...)
@@ -109,3 +131,25 @@ func ErrorHandler(log logrus.FieldLogger) nats.Option {
109131
}
110132
return nats.ErrorHandler(handler)
111133
}
134+
135+
// NatsRootCAs is a NATS helper option to provide the RootCAs pool from a tls.Config struct. If Secure is
136+
// not already set this will set it as well.
137+
func NatsRootCAs(tlsConf *tls.Config) nats.Option {
138+
return func(o *nats.Options) error {
139+
if tlsConf.RootCAs == nil {
140+
return fmt.Errorf("nats: the RootCAs pool from the given tls.Config is nil")
141+
}
142+
143+
if o.TLSConfig == nil {
144+
o.TLSConfig = &tls.Config{
145+
MinVersion: tls.VersionTLS12,
146+
InsecureSkipVerify: tlsConf.InsecureSkipVerify,
147+
}
148+
}
149+
150+
o.TLSConfig.RootCAs = tlsConf.RootCAs
151+
o.Secure = true
152+
153+
return nil
154+
}
155+
}

nconf/nats.go

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,28 @@ import (
88
"github.com/sirupsen/logrus"
99
)
1010

11+
const (
12+
NatsAuthMethodUser = "user"
13+
NatsAuthMethodToken = "token"
14+
NatsAuthMethodTLS = "tls"
15+
)
16+
17+
type NatsAuth struct {
18+
Method string `mapstructure:"method"`
19+
User string `mapstructure:"user"`
20+
Password string `mapstructure:"password"`
21+
Token string `mapstructure:"token"`
22+
}
23+
1124
type NatsConfig struct {
1225
TLS *TLSConfig `mapstructure:"tls_conf"`
13-
DiscoveryName string `split_words:"true" mapstructure:"discovery_name"`
26+
DiscoveryName string `mapstructure:"discovery_name" split_words:"true"`
1427
Servers []string `mapstructure:"servers"`
28+
Auth NatsAuth `mapstructure:"auth"`
1529

1630
// for streaming
17-
ClusterID string `mapstructure:"cluster_id" envconfig:"cluster_id"`
18-
ClientID string `mapstructure:"client_id" envconfig:"client_id"`
31+
ClusterID string `mapstructure:"cluster_id" split_words:"true"`
32+
ClientID string `mapstructure:"client_id" split_words:"true"`
1933
StartPos string `mapstructure:"start_pos" split_words:"true"`
2034
}
2135

@@ -48,6 +62,10 @@ func (config *NatsConfig) Fields() logrus.Fields {
4862
"servers": strings.Join(config.Servers, ","),
4963
}
5064

65+
if config.Auth.Method != "" {
66+
f["auth_method"] = config.Auth.Method
67+
}
68+
5169
if config.TLS != nil {
5270
f["ca_files"] = strings.Join(config.TLS.CAFiles, ",")
5371
f["key_file"] = config.TLS.KeyFile

nconf/tls.go

Lines changed: 51 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@ import (
55
"crypto/x509"
66
"fmt"
77
"io/ioutil"
8+
9+
"github.com/pkg/errors"
810
)
911

1012
type TLSConfig struct {
@@ -21,88 +23,78 @@ type TLSConfig struct {
2123
}
2224

2325
func (cfg TLSConfig) TLSConfig() (*tls.Config, error) {
24-
var tlsconf *tls.Config
2526
var err error
26-
if cfg.Cert != "" && cfg.Key != "" {
27-
tlsconf, err = LoadFromValues(cfg.Cert, cfg.Key, cfg.CA)
28-
} else if cfg.CertFile != "" && cfg.KeyFile != "" {
29-
tlsconf, err = LoadFromFiles(cfg.CertFile, cfg.KeyFile, cfg.CAFiles)
30-
}
3127

32-
if err != nil {
33-
return nil, err
28+
tlsConf := &tls.Config{
29+
MinVersion: tls.VersionTLS12,
30+
InsecureSkipVerify: cfg.Insecure,
3431
}
3532

36-
if tlsconf != nil {
37-
tlsconf.InsecureSkipVerify = cfg.Insecure
33+
// Load CA
34+
if cfg.CA != "" {
35+
tlsConf.RootCAs, err = LoadCAFromValue(cfg.CA)
36+
} else if len(cfg.CAFiles) > 0 {
37+
tlsConf.RootCAs, err = LoadCAFromFiles(cfg.CAFiles)
38+
} else {
39+
tlsConf.RootCAs, err = x509.SystemCertPool()
3840
}
3941

40-
return tlsconf, nil
41-
}
42+
if err != nil {
43+
return nil, errors.Wrap(err, "Error setting up Root CA pool")
44+
}
4245

43-
func LoadFromValues(certPEM, keyPEM, ca string) (*tls.Config, error) {
44-
var pool *x509.CertPool
45-
// If no CA cert if provided, use system pool
46-
if ca == "" {
47-
p, err := x509.SystemCertPool()
48-
if err != nil {
49-
return nil, err
50-
}
51-
pool = p
52-
} else {
53-
pool = x509.NewCertPool()
54-
if !pool.AppendCertsFromPEM([]byte(ca)) {
55-
return nil, fmt.Errorf("Failed to add CA cert")
56-
}
46+
// Load Certs if any
47+
var cert tls.Certificate
48+
if cfg.Cert != "" && cfg.Key != "" {
49+
cert, err = LoadCertFromValues(cfg.Cert, cfg.Key)
50+
tlsConf.Certificates = append(tlsConf.Certificates, cert)
51+
} else if cfg.CertFile != "" && cfg.KeyFile != "" {
52+
cert, err = LoadCertFromFiles(cfg.CertFile, cfg.KeyFile)
53+
tlsConf.Certificates = append(tlsConf.Certificates, cert)
5754
}
5855

59-
cert, err := tls.X509KeyPair([]byte(certPEM), []byte(keyPEM))
6056
if err != nil {
61-
return nil, err
57+
return nil, errors.Wrap(err, "Error loading certificate KeyPair")
6258
}
6359

64-
tlsConfig := &tls.Config{
65-
RootCAs: pool,
66-
Certificates: []tls.Certificate{cert},
67-
MinVersion: tls.VersionTLS12,
60+
// Backwards compatibility: if TLS is not explicitly enabled, return nil if no certificate was provided
61+
// Old code disabled TLS by not providing a certificate, which returned nil when calling TLSConfig()
62+
if !cfg.Enabled && len(tlsConf.Certificates) == 0 {
63+
return nil, nil
6864
}
6965

70-
return tlsConfig, nil
66+
return tlsConf, nil
67+
}
68+
69+
func LoadCertFromValues(certPEM, keyPEM string) (tls.Certificate, error) {
70+
return tls.X509KeyPair([]byte(certPEM), []byte(keyPEM))
7171
}
7272

73-
func LoadFromFiles(certFile, keyFile string, cafiles []string) (*tls.Config, error) {
74-
var pool *x509.CertPool
75-
if len(cafiles) == 0 {
76-
p, err := x509.SystemCertPool()
73+
func LoadCertFromFiles(certFile, keyFile string) (tls.Certificate, error) {
74+
return tls.LoadX509KeyPair(certFile, keyFile)
75+
}
76+
77+
func LoadCAFromFiles(cafiles []string) (*x509.CertPool, error) {
78+
pool := x509.NewCertPool()
79+
80+
for _, caFile := range cafiles {
81+
caData, err := ioutil.ReadFile(caFile)
7782
if err != nil {
7883
return nil, err
7984
}
80-
pool = p
81-
} else {
82-
pool = x509.NewCertPool()
83-
84-
for _, caFile := range cafiles {
85-
caData, err := ioutil.ReadFile(caFile)
86-
if err != nil {
87-
return nil, err
88-
}
8985

90-
if !pool.AppendCertsFromPEM(caData) {
91-
return nil, fmt.Errorf("Failed to add CA cert at %s", caFile)
92-
}
86+
if !pool.AppendCertsFromPEM(caData) {
87+
return nil, fmt.Errorf("Failed to add CA cert at %s", caFile)
9388
}
9489
}
9590

96-
cert, err := tls.LoadX509KeyPair(certFile, keyFile)
97-
if err != nil {
98-
return nil, err
99-
}
91+
return pool, nil
92+
}
10093

101-
tlsConfig := &tls.Config{
102-
RootCAs: pool,
103-
Certificates: []tls.Certificate{cert},
104-
MinVersion: tls.VersionTLS12,
94+
func LoadCAFromValue(ca string) (*x509.CertPool, error) {
95+
pool := x509.NewCertPool()
96+
if !pool.AppendCertsFromPEM([]byte(ca)) {
97+
return nil, fmt.Errorf("Failed to add CA cert")
10598
}
106-
107-
return tlsConfig, nil
99+
return pool, nil
108100
}

0 commit comments

Comments
 (0)