Skip to content

Commit bc4fce0

Browse files
committed
Improve Close handshake behaviour
- For JS we ensure we indicate which size initiated the close first from our POV - For normal Go, concurrent closes block until the first one succeeds instead of returning early
1 parent 62ea6c1 commit bc4fce0

File tree

4 files changed

+57
-24
lines changed

4 files changed

+57
-24
lines changed

conn.go

Lines changed: 28 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -851,6 +851,13 @@ func (c *Conn) realWriteFrame(ctx context.Context, h header, p []byte) (n int, e
851851
// complete.
852852
func (c *Conn) Close(code StatusCode, reason string) error {
853853
err := c.exportedClose(code, reason, true)
854+
var ec errClosing
855+
if errors.As(err, &ec) {
856+
<-c.closed
857+
// We wait until the connection closes.
858+
// We use writeClose and not exportedClose to avoid a second failed to marshal close frame error.
859+
err = c.writeClose(nil, ec.ce, true)
860+
}
854861
if err != nil {
855862
return fmt.Errorf("failed to close websocket connection: %w", err)
856863
}
@@ -878,15 +885,31 @@ func (c *Conn) exportedClose(code StatusCode, reason string, handshake bool) err
878885
return c.writeClose(p, fmt.Errorf("sent close: %w", ce), handshake)
879886
}
880887

888+
type errClosing struct {
889+
ce error
890+
}
891+
892+
func (e errClosing) Error() string {
893+
return "already closing connection"
894+
}
895+
881896
func (c *Conn) writeClose(p []byte, ce error, handshake bool) error {
882-
select {
883-
case <-c.closed:
884-
return fmt.Errorf("tried to close with %v but connection already closed: %w", ce, c.closeErr)
885-
default:
897+
if c.isClosed() {
898+
return fmt.Errorf("tried to close with %q but connection already closed: %w", ce, c.closeErr)
886899
}
887900

888901
if !c.closing.CAS(0, 1) {
889-
return fmt.Errorf("another goroutine is closing")
902+
// Normally, we would want to wait until the connection is closed,
903+
// at least for when a user calls into Close, so we handle that case in
904+
// the exported Close function.
905+
//
906+
// But for internal library usage, we always want to return early, e.g.
907+
// if we are performing a close handshake and the peer sends their close frame,
908+
// we do not want to block here waiting for c.closed to close because it won't,
909+
// at least not until we return since the gorouine that will close it is this one.
910+
return errClosing{
911+
ce: ce,
912+
}
890913
}
891914

892915
// No matter what happens next, close error should be set.

conn_common.go

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -234,3 +234,12 @@ func (v *atomicInt64) Increment(delta int64) int64 {
234234
func (v *atomicInt64) CAS(old, new int64) (swapped bool) {
235235
return atomic.CompareAndSwapInt64(&v.v, old, new)
236236
}
237+
238+
func (c *Conn) isClosed() bool {
239+
select {
240+
case <-c.closed:
241+
return true
242+
default:
243+
return false
244+
}
245+
}

conn_test.go

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -602,7 +602,11 @@ func TestConn(t *testing.T) {
602602
{
603603
name: "largeControlFrame",
604604
server: func(ctx context.Context, c *websocket.Conn) error {
605-
_, err := c.WriteFrame(ctx, true, websocket.OpClose, []byte(strings.Repeat("x", 4096)))
605+
err := c.WriteHeader(ctx, websocket.Header{
606+
Fin: true,
607+
OpCode: websocket.OpClose,
608+
PayloadLength: 4096,
609+
})
606610
if err != nil {
607611
return err
608612
}

websocket_js.go

Lines changed: 15 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ type Conn struct {
2323
// read limit for a message in bytes.
2424
msgReadLimit *atomicInt64
2525

26-
closeMu sync.Mutex
26+
closingMu sync.Mutex
2727
isReadClosed *atomicInt64
2828
closeOnce sync.Once
2929
closed chan struct{}
@@ -43,6 +43,9 @@ func (c *Conn) close(err error, wasClean bool) {
4343
c.closeOnce.Do(func() {
4444
runtime.SetFinalizer(c, nil)
4545

46+
if !wasClean {
47+
err = fmt.Errorf("unclean connection close: %w", err)
48+
}
4649
c.setCloseErr(err)
4750
c.closeWasClean = wasClean
4851
close(c.closed)
@@ -59,14 +62,11 @@ func (c *Conn) init() {
5962
c.isReadClosed = &atomicInt64{}
6063

6164
c.releaseOnClose = c.ws.OnClose(func(e wsjs.CloseEvent) {
62-
var err error = CloseError{
65+
err := CloseError{
6366
Code: StatusCode(e.Code),
6467
Reason: e.Reason,
6568
}
66-
if !e.WasClean {
67-
err = fmt.Errorf("connection close was not clean: %w", err)
68-
}
69-
c.close(err, e.WasClean)
69+
c.close(fmt.Errorf("received close: %w", err), e.WasClean)
7070

7171
c.releaseOnClose()
7272
c.releaseOnMessage()
@@ -182,15 +182,6 @@ func (c *Conn) write(ctx context.Context, typ MessageType, p []byte) error {
182182
}
183183
}
184184

185-
func (c *Conn) isClosed() bool {
186-
select {
187-
case <-c.closed:
188-
return true
189-
default:
190-
return false
191-
}
192-
}
193-
194185
// Close closes the websocket with the given code and reason.
195186
// It will wait until the peer responds with a close frame
196187
// or the connection is closed.
@@ -204,13 +195,19 @@ func (c *Conn) Close(code StatusCode, reason string) error {
204195
}
205196

206197
func (c *Conn) exportedClose(code StatusCode, reason string) error {
207-
c.closeMu.Lock()
208-
defer c.closeMu.Unlock()
198+
c.closingMu.Lock()
199+
defer c.closingMu.Unlock()
200+
201+
ce := fmt.Errorf("sent close: %w", CloseError{
202+
Code: code,
203+
Reason: reason,
204+
})
209205

210206
if c.isClosed() {
211-
return fmt.Errorf("already closed: %w", c.closeErr)
207+
return fmt.Errorf("tried to close with %q but connection already closed: %w", ce, c.closeErr)
212208
}
213209

210+
c.setCloseErr(ce)
214211
err := c.ws.Close(int(code), reason)
215212
if err != nil {
216213
return err

0 commit comments

Comments
 (0)