Skip to content

Commit f4f6ca3

Browse files
authored
Merge pull request #43 from jannson/fix-keepalive-block
fix: writeFrame may block forever in keepalive function
2 parents 6cf098d + d48c641 commit f4f6ca3

File tree

2 files changed

+143
-3
lines changed

2 files changed

+143
-3
lines changed

session.go

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -289,7 +289,7 @@ func (s *Session) keepalive() {
289289
for {
290290
select {
291291
case <-tickerPing.C:
292-
s.writeFrame(newFrame(cmdNOP, 0))
292+
s.writeFrameInternal(newFrame(cmdNOP, 0), tickerPing.C)
293293
s.notifyBucket() // force a signal to the recvLoop
294294
case <-tickerTimeout.C:
295295
if !atomic.CompareAndSwapInt32(&s.dataReady, 1, 0) {
@@ -335,6 +335,11 @@ func (s *Session) sendLoop() {
335335
// writeFrame writes the frame to the underlying connection
336336
// and returns the number of bytes written if successful
337337
func (s *Session) writeFrame(f Frame) (n int, err error) {
338+
return s.writeFrameInternal(f, nil)
339+
}
340+
341+
// internal writeFrame version to support deadline used in keepalive
342+
func (s *Session) writeFrameInternal(f Frame, deadline <-chan time.Time) (int, error) {
338343
req := writeRequest{
339344
frame: f,
340345
result: make(chan writeResult, 1),
@@ -343,8 +348,16 @@ func (s *Session) writeFrame(f Frame) (n int, err error) {
343348
case <-s.die:
344349
return 0, errors.New(errBrokenPipe)
345350
case s.writes <- req:
351+
case <-deadline:
352+
return 0, errTimeout
346353
}
347354

348-
result := <-req.result
349-
return result.n, result.err
355+
select {
356+
case result := <-req.result:
357+
return result.n, result.err
358+
case <-deadline:
359+
return 0, errTimeout
360+
case <-s.die:
361+
return 0, errors.New(errBrokenPipe)
362+
}
350363
}

session_test.go

Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -304,6 +304,44 @@ func TestKeepAliveTimeout(t *testing.T) {
304304
}
305305
}
306306

307+
type blockWriteConn struct {
308+
net.Conn
309+
}
310+
311+
func (c *blockWriteConn) Write(b []byte) (n int, err error) {
312+
forever := time.Hour * 24
313+
time.Sleep(forever)
314+
return c.Conn.Write(b)
315+
}
316+
317+
func TestKeepAliveBlockWriteTimeout(t *testing.T) {
318+
ln, err := net.Listen("tcp", "localhost:0")
319+
if err != nil {
320+
t.Fatal(err)
321+
}
322+
defer ln.Close()
323+
go func() {
324+
ln.Accept()
325+
}()
326+
327+
cli, err := net.Dial("tcp", ln.Addr().String())
328+
if err != nil {
329+
t.Fatal(err)
330+
}
331+
defer cli.Close()
332+
//when writeFrame block, keepalive in old version never timeout
333+
blockWriteCli := &blockWriteConn{cli}
334+
335+
config := DefaultConfig()
336+
config.KeepAliveInterval = time.Second
337+
config.KeepAliveTimeout = 2 * time.Second
338+
session, _ := Client(blockWriteCli, config)
339+
time.Sleep(3 * time.Second)
340+
if !session.IsClosed() {
341+
t.Fatal("keepalive-timeout failed")
342+
}
343+
}
344+
307345
func TestServerEcho(t *testing.T) {
308346
ln, err := net.Listen("tcp", "localhost:0")
309347
if err != nil {
@@ -541,6 +579,95 @@ func TestRandomFrame(t *testing.T) {
541579

542580
session.conn.Write(buf)
543581
cli.Close()
582+
583+
// writeFrame after die
584+
cli, err = net.Dial("tcp", addr)
585+
if err != nil {
586+
t.Fatal(err)
587+
}
588+
session, _ = Client(cli, nil)
589+
//close first
590+
session.Close()
591+
for i := 0; i < 100; i++ {
592+
f := newFrame(byte(rand.Uint32()), rand.Uint32())
593+
session.writeFrame(f)
594+
}
595+
}
596+
597+
func TestWriteFrameInternal(t *testing.T) {
598+
addr, stop, cli, err := setupServer(t)
599+
if err != nil {
600+
t.Fatal(err)
601+
}
602+
defer stop()
603+
// pure random
604+
session, _ := Client(cli, nil)
605+
for i := 0; i < 100; i++ {
606+
rnd := make([]byte, rand.Uint32()%1024)
607+
io.ReadFull(crand.Reader, rnd)
608+
session.conn.Write(rnd)
609+
}
610+
cli.Close()
611+
612+
// writeFrame after die
613+
cli, err = net.Dial("tcp", addr)
614+
if err != nil {
615+
t.Fatal(err)
616+
}
617+
session, _ = Client(cli, nil)
618+
//close first
619+
session.Close()
620+
for i := 0; i < 100; i++ {
621+
f := newFrame(byte(rand.Uint32()), rand.Uint32())
622+
session.writeFrameInternal(f, time.After(session.config.KeepAliveTimeout))
623+
}
624+
625+
// random cmds
626+
cli, err = net.Dial("tcp", addr)
627+
if err != nil {
628+
t.Fatal(err)
629+
}
630+
allcmds := []byte{cmdSYN, cmdFIN, cmdPSH, cmdNOP}
631+
session, _ = Client(cli, nil)
632+
for i := 0; i < 100; i++ {
633+
f := newFrame(allcmds[rand.Int()%len(allcmds)], rand.Uint32())
634+
session.writeFrameInternal(f, time.After(session.config.KeepAliveTimeout))
635+
}
636+
//deadline occur
637+
{
638+
c := make(chan time.Time)
639+
close(c)
640+
f := newFrame(allcmds[rand.Int()%len(allcmds)], rand.Uint32())
641+
_, err := session.writeFrameInternal(f, c)
642+
if err != errTimeout {
643+
t.Fatal("write frame with deadline failed", err)
644+
}
645+
}
646+
cli.Close()
647+
648+
{
649+
cli, err = net.Dial("tcp", addr)
650+
if err != nil {
651+
t.Fatal(err)
652+
}
653+
config := DefaultConfig()
654+
config.KeepAliveInterval = time.Second
655+
config.KeepAliveTimeout = 2 * time.Second
656+
session, _ = Client(&blockWriteConn{cli}, config)
657+
f := newFrame(byte(rand.Uint32()), rand.Uint32())
658+
c := make(chan time.Time)
659+
go func() {
660+
//die first, deadline second, better for coverage
661+
time.Sleep(time.Second)
662+
session.Close()
663+
time.Sleep(time.Second)
664+
close(c)
665+
}()
666+
_, err = session.writeFrameInternal(f, c)
667+
if err.Error() != errBrokenPipe {
668+
t.Fatal("write frame with deadline failed", err)
669+
}
670+
}
544671
}
545672

546673
func TestReadDeadline(t *testing.T) {

0 commit comments

Comments
 (0)