diff --git a/client.go b/client.go index 0028ddc..e555a6d 100644 --- a/client.go +++ b/client.go @@ -7,6 +7,7 @@ import ( "net" "regexp" "strings" + "sync" "time" ) @@ -57,6 +58,7 @@ type Client struct { notify chan Notification disconnect chan struct{} res []string + mutex sync.Mutex Server *ServerMethods } @@ -234,7 +236,15 @@ func (c *Client) ExecCmd(cmd *Cmd) ([]string, error) { return nil, ErrNotConnected } - c.work <- cmd.String() + c.mutex.Lock() + defer c.mutex.Unlock() + + select { + case c.work <- cmd.String(): + // continue + case <-time.After(c.timeout): + return nil, ErrTimeout + } select { case err := <-c.err: diff --git a/client_test.go b/client_test.go index cdc028d..a02d826 100644 --- a/client_test.go +++ b/client_test.go @@ -230,3 +230,41 @@ func TestClientBadHeader(t *testing.T) { // Should never get here assert.NoError(t, c.Close()) } + +func TestConcurrency(t *testing.T) { + s := newServer(t) + if s == nil { + return + } + defer func() { + assert.NoError(t, s.Close()) + }() + + c, err := NewClient(s.Addr, Timeout(time.Millisecond*100)) + if !assert.NoError(t, err) { + return + } + + const iterations = 10 + errors := make(chan error) + + go func() { + defer close(errors) + + for i := 0; i <= iterations; i++ { + if _, err2 := c.Server.GroupList(); err2 != nil { + errors <- err2 + } + } + }() + + for i := 0; i <= iterations; i++ { + _, err = c.Server.GroupList() + assert.NoError(t, err) + } + + // receive errors from go-routine and wait for completion + for err := range errors { + assert.NoError(t, err) + } +}