diff --git a/README.md b/README.md index 38d808f..175ead3 100644 --- a/README.md +++ b/README.md @@ -11,9 +11,7 @@ Basic example: ```go // Connect -c := New("127.0.0.1:783", &net.Dialer{ - Timeout: 20 * time.Second, -}) +c := New(TCPDialer("127.0.0.1:783")) ctx := context.Background() msg := strings.NewReader("Subject: Hello\r\n\r\nHey there!\r\n") diff --git a/api.go b/api.go index 689ac3b..ffb5682 100644 --- a/api.go +++ b/api.go @@ -3,6 +3,7 @@ package spamc import ( "bufio" "context" + "crypto/tls" "fmt" "io" "net" @@ -33,10 +34,8 @@ type Error struct { func (e Error) Error() string { return e.msg } -// Dialer to connect to spamd; usually a net.Dialer instance. -type Dialer interface { - DialContext(ctx context.Context, network, address string) (net.Conn, error) -} +// Dialer to connect to spamd. +type Dialer func(context.Context) (net.Conn, error) // Header for requests and responses. type Header map[string]string @@ -104,22 +103,39 @@ func (h Header) normalizeKey(k string) string { } } -// New created a new Client instance. -// -// The addr should be as "host:port"; as dialer most people will want to use -// net.Dialer: +// New creates a new Client instance. // -// New("127.0.0.1:783", &net.Dialer{Timeout: 20 * time.Second}) +// Examples: // -// If the passed dialer is nil then this will be used as a default. -func New(addr string, d Dialer) *Client { - if d == nil { - d = &net.Dialer{Timeout: 20 * time.Second} - } - return &Client{ - addr: addr, - dialer: d, +// New(TCPDialer("127.0.0.1:783")) +// New(UnixDialer("/path/to/sock.unix")) +func New(d Dialer) *Client { + return &Client{dialer: d} +} + +// TCPDialer creates a TCP Dialer with a 20 second timeout. +func TCPDialer(addr string) Dialer { + return buildDialer(addr, "tcp", 20*time.Second) +} + +// UnixDialer creates a Unix Dialer with a 20 second timeout. +func UnixDialer(addr string) Dialer { + return buildDialer(addr, "unix", 20*time.Second) +} + +// TLSDialer creates a TLS Dialer. +func TLSDialer(addr string, c *tls.Config) Dialer { + timeout := 20 * time.Second + dialer := net.Dialer{Timeout: timeout} + return func(ctx context.Context) (net.Conn, error) { + conn, err := tls.DialWithDialer(&dialer, "tcp", addr, c) + if err != nil { + return conn, err + } + err = conn.SetDeadline(time.Now().Add(timeout)) + return conn, err } + } // Ping returns a confirmation that spamd is alive. diff --git a/api_test.go b/api_test.go index 909fad1..8d4b22d 100644 --- a/api_test.go +++ b/api_test.go @@ -362,18 +362,12 @@ func TestTell(t *testing.T) { } } -type testDialer struct { - conn fakeconn.Conn -} - -func (d *testDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) { - return d.conn, nil -} - func newClient(resp string) *Client { - d := &testDialer{conn: fakeconn.New()} - d.conn.ReadFrom.WriteString(resp) - return New("", d) + conn := fakeconn.New() + conn.ReadFrom.WriteString(resp) + return New(func(ctx context.Context) (net.Conn, error) { + return conn, nil + }) } func TestHeader(t *testing.T) { diff --git a/example_test.go b/example_test.go index 3dfafd3..239e1ea 100644 --- a/example_test.go +++ b/example_test.go @@ -4,16 +4,12 @@ import ( "context" "fmt" "log" - "net" "strings" - "time" ) func Example() { // Connect - c := New("127.0.0.1:783", &net.Dialer{ - Timeout: 20 * time.Second, - }) + c := New(TCPDialer("127.0.0.1:783")) ctx := context.Background() msg := strings.NewReader("Subject: Hello\r\n\r\nHey there!\r\n") diff --git a/sa_test.go b/sa_test.go index 35d0c6d..7c8edc0 100644 --- a/sa_test.go +++ b/sa_test.go @@ -4,6 +4,7 @@ package spamc import ( "context" + "crypto/tls" "io/ioutil" "os" "strings" @@ -13,7 +14,7 @@ import ( var addr = os.Getenv("SPAMC_SA_ADDRESS") func TestSAPing(t *testing.T) { - client := New(addr, nil) + client := New(TCPDialer(addr)) err := client.Ping(context.Background()) if err != nil { t.Fatal(err) @@ -21,7 +22,7 @@ func TestSAPing(t *testing.T) { } func TestSACheck(t *testing.T) { - client := New(addr, nil) + client := New(TCPDialer(addr)) r, err := client.Check(context.Background(), strings.NewReader("\r\nPenis viagra\r\n"), nil) if err != nil { t.Fatal(err) @@ -35,7 +36,7 @@ func TestSACheck(t *testing.T) { } func TestSASymbols(t *testing.T) { - client := New(addr, nil) + client := New(TCPDialer(addr)) r, err := client.Symbols(context.Background(), strings.NewReader(""+ "Date: now\r\n"+ "From: invalid\r\n"+ @@ -55,7 +56,7 @@ func TestSASymbols(t *testing.T) { } func TestSAReport(t *testing.T) { - client := New(addr, nil) + client := New(TCPDialer(addr)) r, err := client.Report(context.Background(), strings.NewReader(""+ "Date: now\r\n"+ "From: a@example.com\r\n"+ @@ -76,7 +77,7 @@ func TestSAReport(t *testing.T) { } func TestSAProcess(t *testing.T) { - client := New(addr, nil) + client := New(TCPDialer(addr)) r, err := client.Process(context.Background(), strings.NewReader(""+ "Date: now\r\n"+ "From: a@example.com\r\n"+ @@ -110,7 +111,7 @@ func TestSAProcess(t *testing.T) { } func TestSAHeaders(t *testing.T) { - client := New(addr, nil) + client := New(TCPDialer(addr)) r, err := client.Headers(context.Background(), strings.NewReader(""+ "Date: now\r\n"+ "From: a@example.com\r\n"+ @@ -144,7 +145,7 @@ func TestSAHeaders(t *testing.T) { } func TestSATell(t *testing.T) { - client := New(addr, nil) + client := New(TCPDialer(addr)) message := strings.NewReader("Subject: Hello, world!\r\n\r\nTest message.\r\n") r, err := client.Tell(context.Background(), message, Header{}. Set("Message-class", "spam"). @@ -166,7 +167,7 @@ func TestSATell(t *testing.T) { // Make sure SA works when we send the message without trailing newline. func TestSANoTrailingNewline(t *testing.T) { - client := New(addr, nil) + client := New(TCPDialer(addr)) r, err := client.Check(context.Background(), strings.NewReader("woot"), nil) if err != nil { diff --git a/spamc.go b/spamc.go index ea6b07f..7cad3fe 100644 --- a/spamc.go +++ b/spamc.go @@ -56,6 +56,19 @@ var errorMessages = map[int]string{ 79: "Read timeout", // EX_TIMEOUT } +// buildDialer is a helper to build a Dialer. +func buildDialer(addr string, proto string, timeout time.Duration) Dialer { + dialer := net.Dialer{Timeout: timeout} + return func(ctx context.Context) (net.Conn, error) { + conn, err := dialer.DialContext(ctx, proto, addr) + if err != nil { + return conn, err + } + err = conn.SetDeadline(time.Now().Add(timeout)) + return conn, err + } +} + // send a command to spamd. func (c *Client) send( ctx context.Context, @@ -64,8 +77,11 @@ func (c *Client) send( headers Header, ) (io.ReadCloser, error) { - conn, err := c.dial(ctx) + conn, err := c.dialer(ctx) if err != nil { + if conn != nil { + conn.Close() // nolint: errcheck + } return nil, errors.Wrapf(err, "could not dial to %v", c.addr) } @@ -160,27 +176,6 @@ func sizeFromReader(r io.Reader) (int64, error) { } -func (c *Client) dial(ctx context.Context) (net.Conn, error) { - conn, err := c.dialer.DialContext(ctx, "tcp", c.addr) - if err != nil { - if conn != nil { - conn.Close() // nolint: errcheck - } - return nil, errors.Wrap(err, "could not connect to spamd") - } - - // Set connection timeout - if ndial, ok := c.dialer.(*net.Dialer); ok { - err = conn.SetDeadline(time.Now().Add(ndial.Timeout)) - if err != nil { - conn.Close() // nolint: errcheck - return nil, errors.Wrap(err, "connection to spamd timed out") - } - } - - return conn, nil -} - // The spamd protocol is a HTTP-esque protocol; a response's first line is the // response code: //