Skip to content

Support connection over Unix socket #18

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,7 @@ Basic example:

```go
// Connect
c := New("127.0.0.1:783", &net.Dialer{
Timeout: 20 * time.Second,
})
c := New(TCPDialer("127.0.0.1:783"))
ctx := context.Background()

msg := strings.NewReader("Subject: Hello\r\n\r\nHey there!\r\n")
Expand Down
50 changes: 33 additions & 17 deletions api.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package spamc
import (
"bufio"
"context"
"crypto/tls"
"fmt"
"io"
"net"
Expand Down Expand Up @@ -33,10 +34,8 @@ type Error struct {

func (e Error) Error() string { return e.msg }

// Dialer to connect to spamd; usually a net.Dialer instance.
type Dialer interface {
DialContext(ctx context.Context, network, address string) (net.Conn, error)
}
// Dialer to connect to spamd.
type Dialer func(context.Context) (net.Conn, error)

// Header for requests and responses.
type Header map[string]string
Expand Down Expand Up @@ -104,22 +103,39 @@ func (h Header) normalizeKey(k string) string {
}
}

// New created a new Client instance.
//
// The addr should be as "host:port"; as dialer most people will want to use
// net.Dialer:
// New creates a new Client instance.
//
// New("127.0.0.1:783", &net.Dialer{Timeout: 20 * time.Second})
// Examples:
//
// If the passed dialer is nil then this will be used as a default.
func New(addr string, d Dialer) *Client {
if d == nil {
d = &net.Dialer{Timeout: 20 * time.Second}
}
return &Client{
addr: addr,
dialer: d,
// New(TCPDialer("127.0.0.1:783"))
// New(UnixDialer("/path/to/sock.unix"))
func New(d Dialer) *Client {
return &Client{dialer: d}
}

// TCPDialer creates a TCP Dialer with a 20 second timeout.
func TCPDialer(addr string) Dialer {
return buildDialer(addr, "tcp", 20*time.Second)
}

// UnixDialer creates a Unix Dialer with a 20 second timeout.
func UnixDialer(addr string) Dialer {
return buildDialer(addr, "unix", 20*time.Second)
}

// TLSDialer creates a TLS Dialer.
func TLSDialer(addr string, c *tls.Config) Dialer {
timeout := 20 * time.Second
dialer := net.Dialer{Timeout: timeout}
return func(ctx context.Context) (net.Conn, error) {
conn, err := tls.DialWithDialer(&dialer, "tcp", addr, c)
if err != nil {
return conn, err
}
err = conn.SetDeadline(time.Now().Add(timeout))
return conn, err
}

}

// Ping returns a confirmation that spamd is alive.
Expand Down
16 changes: 5 additions & 11 deletions api_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -362,18 +362,12 @@ func TestTell(t *testing.T) {
}
}

type testDialer struct {
conn fakeconn.Conn
}

func (d *testDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
return d.conn, nil
}

func newClient(resp string) *Client {
d := &testDialer{conn: fakeconn.New()}
d.conn.ReadFrom.WriteString(resp)
return New("", d)
conn := fakeconn.New()
conn.ReadFrom.WriteString(resp)
return New(func(ctx context.Context) (net.Conn, error) {
return conn, nil
})
}

func TestHeader(t *testing.T) {
Expand Down
6 changes: 1 addition & 5 deletions example_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,12 @@ import (
"context"
"fmt"
"log"
"net"
"strings"
"time"
)

func Example() {
// Connect
c := New("127.0.0.1:783", &net.Dialer{
Timeout: 20 * time.Second,
})
c := New(TCPDialer("127.0.0.1:783"))
ctx := context.Background()

msg := strings.NewReader("Subject: Hello\r\n\r\nHey there!\r\n")
Expand Down
17 changes: 9 additions & 8 deletions sa_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ package spamc

import (
"context"
"crypto/tls"
"io/ioutil"
"os"
"strings"
Expand All @@ -13,15 +14,15 @@ import (
var addr = os.Getenv("SPAMC_SA_ADDRESS")

func TestSAPing(t *testing.T) {
client := New(addr, nil)
client := New(TCPDialer(addr))
err := client.Ping(context.Background())
if err != nil {
t.Fatal(err)
}
}

func TestSACheck(t *testing.T) {
client := New(addr, nil)
client := New(TCPDialer(addr))
r, err := client.Check(context.Background(), strings.NewReader("\r\nPenis viagra\r\n"), nil)
if err != nil {
t.Fatal(err)
Expand All @@ -35,7 +36,7 @@ func TestSACheck(t *testing.T) {
}

func TestSASymbols(t *testing.T) {
client := New(addr, nil)
client := New(TCPDialer(addr))
r, err := client.Symbols(context.Background(), strings.NewReader(""+
"Date: now\r\n"+
"From: invalid\r\n"+
Expand All @@ -55,7 +56,7 @@ func TestSASymbols(t *testing.T) {
}

func TestSAReport(t *testing.T) {
client := New(addr, nil)
client := New(TCPDialer(addr))
r, err := client.Report(context.Background(), strings.NewReader(""+
"Date: now\r\n"+
"From: [email protected]\r\n"+
Expand All @@ -76,7 +77,7 @@ func TestSAReport(t *testing.T) {
}

func TestSAProcess(t *testing.T) {
client := New(addr, nil)
client := New(TCPDialer(addr))
r, err := client.Process(context.Background(), strings.NewReader(""+
"Date: now\r\n"+
"From: [email protected]\r\n"+
Expand Down Expand Up @@ -110,7 +111,7 @@ func TestSAProcess(t *testing.T) {
}

func TestSAHeaders(t *testing.T) {
client := New(addr, nil)
client := New(TCPDialer(addr))
r, err := client.Headers(context.Background(), strings.NewReader(""+
"Date: now\r\n"+
"From: [email protected]\r\n"+
Expand Down Expand Up @@ -144,7 +145,7 @@ func TestSAHeaders(t *testing.T) {
}

func TestSATell(t *testing.T) {
client := New(addr, nil)
client := New(TCPDialer(addr))
message := strings.NewReader("Subject: Hello, world!\r\n\r\nTest message.\r\n")
r, err := client.Tell(context.Background(), message, Header{}.
Set("Message-class", "spam").
Expand All @@ -166,7 +167,7 @@ func TestSATell(t *testing.T) {

// Make sure SA works when we send the message without trailing newline.
func TestSANoTrailingNewline(t *testing.T) {
client := New(addr, nil)
client := New(TCPDialer(addr))

r, err := client.Check(context.Background(), strings.NewReader("woot"), nil)
if err != nil {
Expand Down
39 changes: 17 additions & 22 deletions spamc.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,19 @@ var errorMessages = map[int]string{
79: "Read timeout", // EX_TIMEOUT
}

// buildDialer is a helper to build a Dialer.
func buildDialer(addr string, proto string, timeout time.Duration) Dialer {
dialer := net.Dialer{Timeout: timeout}
return func(ctx context.Context) (net.Conn, error) {
conn, err := dialer.DialContext(ctx, proto, addr)
if err != nil {
return conn, err
}
err = conn.SetDeadline(time.Now().Add(timeout))
return conn, err
}
}

// send a command to spamd.
func (c *Client) send(
ctx context.Context,
Expand All @@ -64,8 +77,11 @@ func (c *Client) send(
headers Header,
) (io.ReadCloser, error) {

conn, err := c.dial(ctx)
conn, err := c.dialer(ctx)
if err != nil {
if conn != nil {
conn.Close() // nolint: errcheck
}
return nil, errors.Wrapf(err, "could not dial to %v", c.addr)
}

Expand Down Expand Up @@ -160,27 +176,6 @@ func sizeFromReader(r io.Reader) (int64, error) {

}

func (c *Client) dial(ctx context.Context) (net.Conn, error) {
conn, err := c.dialer.DialContext(ctx, "tcp", c.addr)
if err != nil {
if conn != nil {
conn.Close() // nolint: errcheck
}
return nil, errors.Wrap(err, "could not connect to spamd")
}

// Set connection timeout
if ndial, ok := c.dialer.(*net.Dialer); ok {
err = conn.SetDeadline(time.Now().Add(ndial.Timeout))
if err != nil {
conn.Close() // nolint: errcheck
return nil, errors.Wrap(err, "connection to spamd timed out")
}
}

return conn, nil
}

// The spamd protocol is a HTTP-esque protocol; a response's first line is the
// response code:
//
Expand Down