Skip to content

Commit 4462583

Browse files
authored
Merge pull request #366 from SiaFoundation/chris/cancel-read
Add final Read when closing QUIC stream + reproduction test
2 parents 80e72f3 + fb4ddc9 commit 4462583

File tree

5 files changed

+152
-34
lines changed

5 files changed

+152
-34
lines changed
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
---
2+
default: patch
3+
---
4+
5+
# Fix QUIC streams not getting closed properly.

rhp/v4/quic/quic.go

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import (
55
"crypto/tls"
66
"errors"
77
"fmt"
8+
"io"
89
"net"
910
"net/http"
1011
"time"
@@ -25,6 +26,10 @@ const (
2526

2627
// TLSNextProtoRHP4 is the ALPN identifier for the Quic RHP4 protocol.
2728
TLSNextProtoRHP4 = "sia/rhp4"
29+
30+
// maxIncomingStreams is the maximum number of incoming streams allowed per
31+
// QUIC connection by the server.
32+
maxIncomingStreams = 100000
2833
)
2934

3035
type (
@@ -60,6 +65,16 @@ func WithTLSConfig(fn func(*tls.Config)) ClientOption {
6065
}
6166
}
6267

68+
func (s *stream) Close() error {
69+
err := s.Stream.Close()
70+
_, errCopy := io.CopyN(io.Discard, s, 4096)
71+
if !errors.Is(errCopy, io.EOF) {
72+
// fall back to forcefully canceling the read if we couldn't reach EOF
73+
s.CancelRead(1)
74+
}
75+
return err
76+
}
77+
6378
// LocalAddr implements net.Conn
6479
func (s *stream) LocalAddr() net.Addr {
6580
return s.localAddr
@@ -192,7 +207,7 @@ func Listen(conn net.PacketConn, certs CertManager) (*quic.Listener, error) {
192207
EnableDatagrams: true,
193208
KeepAlivePeriod: 30 * time.Second,
194209
MaxIdleTimeout: 30 * time.Minute,
195-
MaxIncomingStreams: 1000,
210+
MaxIncomingStreams: maxIncomingStreams,
196211
})
197212
}
198213

rhp/v4/quic/quic_test.go

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
package quic
2+
3+
import (
4+
"context"
5+
"crypto/tls"
6+
"net"
7+
"sync"
8+
"testing"
9+
10+
"github.com/quic-go/quic-go"
11+
"go.sia.tech/core/types"
12+
"go.sia.tech/coreutils/rhp/v4"
13+
"go.sia.tech/coreutils/testutil/certs"
14+
)
15+
16+
// setupTestPair sets up a QUIC server and client for testing.
17+
func setupTestPair(tb testing.TB) (*quic.Listener, rhp.TransportClient) {
18+
tb.Helper()
19+
20+
udpAddr, err := net.ResolveUDPAddr("udp", "localhost:0")
21+
if err != nil {
22+
tb.Fatal(err)
23+
}
24+
25+
conn, err := net.ListenUDP("udp", udpAddr)
26+
if err != nil {
27+
tb.Fatal(err)
28+
}
29+
tb.Cleanup(func() { conn.Close() })
30+
31+
l, err := Listen(conn, &certs.EphemeralCertManager{})
32+
if err != nil {
33+
tb.Fatal(err)
34+
}
35+
tb.Cleanup(func() { l.Close() })
36+
37+
client, err := Dial(context.Background(), conn.LocalAddr().String(), types.PublicKey{}, WithTLSConfig(func(tc *tls.Config) {
38+
tc.InsecureSkipVerify = true
39+
}))
40+
if err != nil {
41+
tb.Fatal(err)
42+
}
43+
tb.Cleanup(func() { client.Close() })
44+
return l, client
45+
}
46+
47+
// TestStreamLimit is a regression test that ensures we can perform more than
48+
// maxIncomingStreams RPCs on a single QUIC connection. This should be the case
49+
// if both sides close their streams.
50+
// Before fixing the issue, this test would hang for the full 10 minute test
51+
// timeout.
52+
func TestStreamLimit(t *testing.T) {
53+
server, client := setupTestPair(t)
54+
55+
// the server just accepts streams and cancels them
56+
go func() {
57+
for {
58+
conn, err := server.Accept(context.Background())
59+
if err != nil {
60+
return
61+
}
62+
63+
transport := &transport{conn}
64+
65+
for {
66+
stream, err := transport.AcceptStream()
67+
if err != nil {
68+
return
69+
}
70+
stream.Close()
71+
}
72+
}
73+
}()
74+
75+
// open the maximum number of streams + 1 which should neither block forever
76+
// nor return an error
77+
var wg sync.WaitGroup
78+
for range maxIncomingStreams + 1 {
79+
wg.Add(1)
80+
go func() {
81+
defer wg.Done()
82+
83+
stream, err := client.DialStream(context.Background())
84+
if err != nil {
85+
t.Error(err)
86+
return
87+
}
88+
stream.Close()
89+
}()
90+
}
91+
wg.Wait()
92+
}

testutil/certs/certs.go

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
package certs
2+
3+
import (
4+
"crypto/rsa"
5+
"crypto/tls"
6+
"crypto/x509"
7+
"encoding/pem"
8+
"fmt"
9+
"math/big"
10+
11+
"lukechampine.com/frand"
12+
)
13+
14+
// An EphemeralCertManager is an in-memory minimal rhp4.CertManager for testing.
15+
// Calls to GetCertificate will return a new self-signed certificate each time.
16+
type EphemeralCertManager struct{}
17+
18+
// GetCertificate returns a new self-signed certificate each time it is called.
19+
func (ec *EphemeralCertManager) GetCertificate(*tls.ClientHelloInfo) (*tls.Certificate, error) {
20+
key, err := rsa.GenerateKey(frand.Reader, 2048)
21+
if err != nil {
22+
return nil, fmt.Errorf("failed to generate key: %w", err)
23+
}
24+
template := x509.Certificate{SerialNumber: big.NewInt(1)}
25+
certDER, err := x509.CreateCertificate(frand.Reader, &template, &template, &key.PublicKey, key)
26+
if err != nil {
27+
return nil, fmt.Errorf("failed to create cert: %w", err)
28+
}
29+
keyPEM := pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(key)})
30+
certPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: certDER})
31+
32+
cert, err := tls.X509KeyPair(certPEM, keyPEM)
33+
if err != nil {
34+
return nil, fmt.Errorf("failed to create tls cert: %w", err)
35+
}
36+
return &cert, nil
37+
}

testutil/host.go

Lines changed: 2 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,7 @@
11
package testutil
22

33
import (
4-
"crypto/rsa"
5-
"crypto/tls"
6-
"crypto/x509"
7-
"encoding/pem"
84
"errors"
9-
"fmt"
10-
"math/big"
115
"net"
126
"sync"
137
"testing"
@@ -19,35 +13,10 @@ import (
1913
rhp4 "go.sia.tech/coreutils/rhp/v4"
2014
"go.sia.tech/coreutils/rhp/v4/quic"
2115
"go.sia.tech/coreutils/rhp/v4/siamux"
16+
"go.sia.tech/coreutils/testutil/certs"
2217
"go.uber.org/zap"
23-
"lukechampine.com/frand"
2418
)
2519

26-
// An EphemeralCertManager is an in-memory minimal rhp4.CertManager for testing.
27-
// Calls to GetCertificate will return a new self-signed certificate each time.
28-
type EphemeralCertManager struct{}
29-
30-
// GetCertificate returns a new self-signed certificate each time it is called.
31-
func (ec *EphemeralCertManager) GetCertificate(*tls.ClientHelloInfo) (*tls.Certificate, error) {
32-
key, err := rsa.GenerateKey(frand.Reader, 2048)
33-
if err != nil {
34-
return nil, fmt.Errorf("failed to generate key: %w", err)
35-
}
36-
template := x509.Certificate{SerialNumber: big.NewInt(1)}
37-
certDER, err := x509.CreateCertificate(frand.Reader, &template, &template, &key.PublicKey, key)
38-
if err != nil {
39-
return nil, fmt.Errorf("failed to create cert: %w", err)
40-
}
41-
keyPEM := pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(key)})
42-
certPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: certDER})
43-
44-
cert, err := tls.X509KeyPair(certPEM, keyPEM)
45-
if err != nil {
46-
return nil, fmt.Errorf("failed to create tls cert: %w", err)
47-
}
48-
return &cert, nil
49-
}
50-
5120
// An EphemeralSectorStore is an in-memory minimal rhp4.SectorStore for testing.
5221
type EphemeralSectorStore struct {
5322
mu sync.Mutex
@@ -438,7 +407,7 @@ func ServeQUIC(tb testing.TB, s *rhp4.Server, log *zap.Logger) string {
438407
}
439408
tb.Cleanup(func() { conn.Close() })
440409

441-
l, err := quic.Listen(conn, &EphemeralCertManager{})
410+
l, err := quic.Listen(conn, &certs.EphemeralCertManager{})
442411
if err != nil {
443412
tb.Fatal(err)
444413
}

0 commit comments

Comments
 (0)