From b5507eef57120bd12b9495927772c4b119e5bc6f Mon Sep 17 00:00:00 2001 From: Hugo DELVAL Date: Wed, 20 Dec 2017 15:10:14 +0100 Subject: [PATCH 1/3] Support connection over Unix socket --- README.md | 4 +--- api.go | 50 +++++++++++++++++++++++++++++++++---------------- api_test.go | 8 +++++--- example_test.go | 6 +----- sa_test.go | 16 ++++++++-------- spamc.go | 12 +----------- 6 files changed, 50 insertions(+), 46 deletions(-) diff --git a/README.md b/README.md index 38d808f..175ead3 100644 --- a/README.md +++ b/README.md @@ -11,9 +11,7 @@ Basic example: ```go // Connect -c := New("127.0.0.1:783", &net.Dialer{ - Timeout: 20 * time.Second, -}) +c := New(TCPDialer("127.0.0.1:783")) ctx := context.Background() msg := strings.NewReader("Subject: Hello\r\n\r\nHey there!\r\n") diff --git a/api.go b/api.go index 689ac3b..dd4ec23 100644 --- a/api.go +++ b/api.go @@ -33,10 +33,8 @@ type Error struct { func (e Error) Error() string { return e.msg } -// Dialer to connect to spamd; usually a net.Dialer instance. -type Dialer interface { - DialContext(ctx context.Context, network, address string) (net.Conn, error) -} +// Dialer to connect to spamd; cf standard functions like `TCPDialer` or `UnixDialer`. +type Dialer func(context.Context) (net.Conn, error) // Header for requests and responses. type Header map[string]string @@ -104,24 +102,44 @@ func (h Header) normalizeKey(k string) string { } } -// New created a new Client instance. +// New creates a new Client instance. // -// The addr should be as "host:port"; as dialer most people will want to use -// net.Dialer: +// Examples: // -// New("127.0.0.1:783", &net.Dialer{Timeout: 20 * time.Second}) +// New(TCPDialer("127.0.0.1:783")) +// New(UnixDialer("/path/to/sock.unix")) +// New(BuildGenericDialer("127.0.0.1:783", "tcp", time.Second)) // -// If the passed dialer is nil then this will be used as a default. -func New(addr string, d Dialer) *Client { - if d == nil { - d = &net.Dialer{Timeout: 20 * time.Second} - } - return &Client{ - addr: addr, - dialer: d, +func New(d Dialer) *Client { + return &Client{dialer: d} +} + +// BuildGenericDialer is a generic method to build a Dialer (cf `New` method) +func BuildGenericDialer(addr string, proto string, timeout time.Duration) Dialer { + dialer := net.Dialer{Timeout: timeout} + return func(ctx context.Context) (net.Conn, error) { + conn, err := dialer.DialContext(ctx, proto, addr) + if err != nil { + return conn, err + } + err = conn.SetDeadline(time.Now().Add(timeout)) + if err != nil { + return conn, err + } + return conn, nil } } +// TCPDialer creates a TCP Dialer with a 20 seconds timeout (cf `New` method) +func TCPDialer(addr string) Dialer { + return BuildGenericDialer(addr, "tcp", 20*time.Second) +} + +// UnixDialer creates a Unix Dialer with a 20 seconds timeout (cf `New` method) +func UnixDialer(addr string) Dialer { + return BuildGenericDialer(addr, "unix", 20*time.Second) +} + // Ping returns a confirmation that spamd is alive. func (c *Client) Ping(ctx context.Context) error { read, err := c.send(ctx, cmdPing, strings.NewReader(""), nil) diff --git a/api_test.go b/api_test.go index 909fad1..477bd12 100644 --- a/api_test.go +++ b/api_test.go @@ -371,9 +371,11 @@ func (d *testDialer) DialContext(ctx context.Context, network, address string) ( } func newClient(resp string) *Client { - d := &testDialer{conn: fakeconn.New()} - d.conn.ReadFrom.WriteString(resp) - return New("", d) + conn := fakeconn.New() + conn.ReadFrom.WriteString(resp) + return New(func(ctx context.Context) (net.Conn, error) { + return conn, nil + }) } func TestHeader(t *testing.T) { diff --git a/example_test.go b/example_test.go index 3dfafd3..239e1ea 100644 --- a/example_test.go +++ b/example_test.go @@ -4,16 +4,12 @@ import ( "context" "fmt" "log" - "net" "strings" - "time" ) func Example() { // Connect - c := New("127.0.0.1:783", &net.Dialer{ - Timeout: 20 * time.Second, - }) + c := New(TCPDialer("127.0.0.1:783")) ctx := context.Background() msg := strings.NewReader("Subject: Hello\r\n\r\nHey there!\r\n") diff --git a/sa_test.go b/sa_test.go index 35d0c6d..3c65265 100644 --- a/sa_test.go +++ b/sa_test.go @@ -13,7 +13,7 @@ import ( var addr = os.Getenv("SPAMC_SA_ADDRESS") func TestSAPing(t *testing.T) { - client := New(addr, nil) + client := New(TCPDialer(addr)) err := client.Ping(context.Background()) if err != nil { t.Fatal(err) @@ -21,7 +21,7 @@ func TestSAPing(t *testing.T) { } func TestSACheck(t *testing.T) { - client := New(addr, nil) + client := New(TCPDialer(addr)) r, err := client.Check(context.Background(), strings.NewReader("\r\nPenis viagra\r\n"), nil) if err != nil { t.Fatal(err) @@ -35,7 +35,7 @@ func TestSACheck(t *testing.T) { } func TestSASymbols(t *testing.T) { - client := New(addr, nil) + client := New(TCPDialer(addr)) r, err := client.Symbols(context.Background(), strings.NewReader(""+ "Date: now\r\n"+ "From: invalid\r\n"+ @@ -55,7 +55,7 @@ func TestSASymbols(t *testing.T) { } func TestSAReport(t *testing.T) { - client := New(addr, nil) + client := New(TCPDialer(addr)) r, err := client.Report(context.Background(), strings.NewReader(""+ "Date: now\r\n"+ "From: a@example.com\r\n"+ @@ -76,7 +76,7 @@ func TestSAReport(t *testing.T) { } func TestSAProcess(t *testing.T) { - client := New(addr, nil) + client := New(TCPDialer(addr)) r, err := client.Process(context.Background(), strings.NewReader(""+ "Date: now\r\n"+ "From: a@example.com\r\n"+ @@ -110,7 +110,7 @@ func TestSAProcess(t *testing.T) { } func TestSAHeaders(t *testing.T) { - client := New(addr, nil) + client := New(TCPDialer(addr)) r, err := client.Headers(context.Background(), strings.NewReader(""+ "Date: now\r\n"+ "From: a@example.com\r\n"+ @@ -144,7 +144,7 @@ func TestSAHeaders(t *testing.T) { } func TestSATell(t *testing.T) { - client := New(addr, nil) + client := New(TCPDialer(addr)) message := strings.NewReader("Subject: Hello, world!\r\n\r\nTest message.\r\n") r, err := client.Tell(context.Background(), message, Header{}. Set("Message-class", "spam"). @@ -166,7 +166,7 @@ func TestSATell(t *testing.T) { // Make sure SA works when we send the message without trailing newline. func TestSANoTrailingNewline(t *testing.T) { - client := New(addr, nil) + client := New(TCPDialer(addr)) r, err := client.Check(context.Background(), strings.NewReader("woot"), nil) if err != nil { diff --git a/spamc.go b/spamc.go index ea6b07f..109ce92 100644 --- a/spamc.go +++ b/spamc.go @@ -12,7 +12,6 @@ import ( "regexp" "strconv" "strings" - "time" "github.com/pkg/errors" "github.com/teamwork/utils/mathutil" @@ -161,7 +160,7 @@ func sizeFromReader(r io.Reader) (int64, error) { } func (c *Client) dial(ctx context.Context) (net.Conn, error) { - conn, err := c.dialer.DialContext(ctx, "tcp", c.addr) + conn, err := c.dialer(ctx) if err != nil { if conn != nil { conn.Close() // nolint: errcheck @@ -169,15 +168,6 @@ func (c *Client) dial(ctx context.Context) (net.Conn, error) { return nil, errors.Wrap(err, "could not connect to spamd") } - // Set connection timeout - if ndial, ok := c.dialer.(*net.Dialer); ok { - err = conn.SetDeadline(time.Now().Add(ndial.Timeout)) - if err != nil { - conn.Close() // nolint: errcheck - return nil, errors.Wrap(err, "connection to spamd timed out") - } - } - return conn, nil } From e6d0584b1f24d0fcd93390736d6e480039914f36 Mon Sep 17 00:00:00 2001 From: Martin Tournoij Date: Fri, 22 Dec 2017 00:13:33 +0000 Subject: [PATCH 2/3] Some small changes - Some small documentation fixes. - Unexport the BuildGenericDialer() function for now. I'm not sure how useful it is to have this exported. - The Client.dial() function is rather redundant now, so remove that. --- api.go | 28 +++++----------------------- api_test.go | 8 -------- spamc.go | 31 ++++++++++++++++++------------- 3 files changed, 23 insertions(+), 44 deletions(-) diff --git a/api.go b/api.go index dd4ec23..b6bfe6f 100644 --- a/api.go +++ b/api.go @@ -33,7 +33,7 @@ type Error struct { func (e Error) Error() string { return e.msg } -// Dialer to connect to spamd; cf standard functions like `TCPDialer` or `UnixDialer`. +// Dialer to connect to spamd. type Dialer func(context.Context) (net.Conn, error) // Header for requests and responses. @@ -108,36 +108,18 @@ func (h Header) normalizeKey(k string) string { // // New(TCPDialer("127.0.0.1:783")) // New(UnixDialer("/path/to/sock.unix")) -// New(BuildGenericDialer("127.0.0.1:783", "tcp", time.Second)) -// func New(d Dialer) *Client { return &Client{dialer: d} } -// BuildGenericDialer is a generic method to build a Dialer (cf `New` method) -func BuildGenericDialer(addr string, proto string, timeout time.Duration) Dialer { - dialer := net.Dialer{Timeout: timeout} - return func(ctx context.Context) (net.Conn, error) { - conn, err := dialer.DialContext(ctx, proto, addr) - if err != nil { - return conn, err - } - err = conn.SetDeadline(time.Now().Add(timeout)) - if err != nil { - return conn, err - } - return conn, nil - } -} - -// TCPDialer creates a TCP Dialer with a 20 seconds timeout (cf `New` method) +// TCPDialer creates a TCP Dialer with a 20 second timeout. func TCPDialer(addr string) Dialer { - return BuildGenericDialer(addr, "tcp", 20*time.Second) + return buildDialer(addr, "tcp", 20*time.Second) } -// UnixDialer creates a Unix Dialer with a 20 seconds timeout (cf `New` method) +// UnixDialer creates a Unix Dialer with a 20 second timeout. func UnixDialer(addr string) Dialer { - return BuildGenericDialer(addr, "unix", 20*time.Second) + return buildDialer(addr, "unix", 20*time.Second) } // Ping returns a confirmation that spamd is alive. diff --git a/api_test.go b/api_test.go index 477bd12..8d4b22d 100644 --- a/api_test.go +++ b/api_test.go @@ -362,14 +362,6 @@ func TestTell(t *testing.T) { } } -type testDialer struct { - conn fakeconn.Conn -} - -func (d *testDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) { - return d.conn, nil -} - func newClient(resp string) *Client { conn := fakeconn.New() conn.ReadFrom.WriteString(resp) diff --git a/spamc.go b/spamc.go index 109ce92..7cad3fe 100644 --- a/spamc.go +++ b/spamc.go @@ -12,6 +12,7 @@ import ( "regexp" "strconv" "strings" + "time" "github.com/pkg/errors" "github.com/teamwork/utils/mathutil" @@ -55,6 +56,19 @@ var errorMessages = map[int]string{ 79: "Read timeout", // EX_TIMEOUT } +// buildDialer is a helper to build a Dialer. +func buildDialer(addr string, proto string, timeout time.Duration) Dialer { + dialer := net.Dialer{Timeout: timeout} + return func(ctx context.Context) (net.Conn, error) { + conn, err := dialer.DialContext(ctx, proto, addr) + if err != nil { + return conn, err + } + err = conn.SetDeadline(time.Now().Add(timeout)) + return conn, err + } +} + // send a command to spamd. func (c *Client) send( ctx context.Context, @@ -63,8 +77,11 @@ func (c *Client) send( headers Header, ) (io.ReadCloser, error) { - conn, err := c.dial(ctx) + conn, err := c.dialer(ctx) if err != nil { + if conn != nil { + conn.Close() // nolint: errcheck + } return nil, errors.Wrapf(err, "could not dial to %v", c.addr) } @@ -159,18 +176,6 @@ func sizeFromReader(r io.Reader) (int64, error) { } -func (c *Client) dial(ctx context.Context) (net.Conn, error) { - conn, err := c.dialer(ctx) - if err != nil { - if conn != nil { - conn.Close() // nolint: errcheck - } - return nil, errors.Wrap(err, "could not connect to spamd") - } - - return conn, nil -} - // The spamd protocol is a HTTP-esque protocol; a response's first line is the // response code: // From d5d275f24f3e7591bac540a89e8cef107ce5b6b4 Mon Sep 17 00:00:00 2001 From: Martin Tournoij Date: Fri, 22 Dec 2017 00:30:48 +0000 Subject: [PATCH 3/3] Add TLSDialer Hm, need to think about the API for a bit here. Can't say I particularly like this. --- api.go | 16 ++++++++++++++++ sa_test.go | 1 + 2 files changed, 17 insertions(+) diff --git a/api.go b/api.go index b6bfe6f..ffb5682 100644 --- a/api.go +++ b/api.go @@ -3,6 +3,7 @@ package spamc import ( "bufio" "context" + "crypto/tls" "fmt" "io" "net" @@ -122,6 +123,21 @@ func UnixDialer(addr string) Dialer { return buildDialer(addr, "unix", 20*time.Second) } +// TLSDialer creates a TLS Dialer. +func TLSDialer(addr string, c *tls.Config) Dialer { + timeout := 20 * time.Second + dialer := net.Dialer{Timeout: timeout} + return func(ctx context.Context) (net.Conn, error) { + conn, err := tls.DialWithDialer(&dialer, "tcp", addr, c) + if err != nil { + return conn, err + } + err = conn.SetDeadline(time.Now().Add(timeout)) + return conn, err + } + +} + // Ping returns a confirmation that spamd is alive. func (c *Client) Ping(ctx context.Context) error { read, err := c.send(ctx, cmdPing, strings.NewReader(""), nil) diff --git a/sa_test.go b/sa_test.go index 3c65265..7c8edc0 100644 --- a/sa_test.go +++ b/sa_test.go @@ -4,6 +4,7 @@ package spamc import ( "context" + "crypto/tls" "io/ioutil" "os" "strings"