Skip to content

Commit 035af7b

Browse files
committed
pass metadata along with stream creation command
1 parent f4f6ca3 commit 035af7b

File tree

3 files changed

+33
-8
lines changed

3 files changed

+33
-8
lines changed

session.go

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ func newSession(config *Config, conn io.ReadWriteCloser, client bool) *Session {
8080
}
8181

8282
// OpenStream is used to create a new stream
83-
func (s *Session) OpenStream() (*Stream, error) {
83+
func (s *Session) OpenStream(metadata ...byte) (*Stream, error) {
8484
if s.IsClosed() {
8585
return nil, errors.New(errBrokenPipe)
8686
}
@@ -101,9 +101,11 @@ func (s *Session) OpenStream() (*Stream, error) {
101101
}
102102
s.nextStreamIDLock.Unlock()
103103

104-
stream := newStream(sid, s.config.MaxFrameSize, s)
104+
stream := newStream(sid, metadata, s.config.MaxFrameSize, s)
105105

106-
if _, err := s.writeFrame(newFrame(cmdSYN, sid)); err != nil {
106+
frame := newFrame(cmdSYN, sid)
107+
frame.data = metadata
108+
if _, err := s.writeFrame(frame); err != nil {
107109
return nil, errors.Wrap(err, "writeFrame")
108110
}
109111

@@ -247,7 +249,7 @@ func (s *Session) recvLoop() {
247249
case cmdSYN:
248250
s.streamLock.Lock()
249251
if _, ok := s.streams[f.sid]; !ok {
250-
stream := newStream(f.sid, s.config.MaxFrameSize, s)
252+
stream := newStream(f.sid, f.data, s.config.MaxFrameSize, s)
251253
s.streams[f.sid] = stream
252254
select {
253255
case s.chAccepts <- stream:

session_test.go

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package smux
22

33
import (
4+
"bytes"
45
crand "crypto/rand"
56
"encoding/binary"
67
"fmt"
@@ -16,7 +17,7 @@ import (
1617
// setupServer starts new server listening on a random localhost port and
1718
// returns address of the server, function to stop the server, new client
1819
// connection to this server or an error.
19-
func setupServer(tb testing.TB) (addr string, stopfunc func(), client net.Conn, err error) {
20+
func setupServer(tb testing.TB, metadata ...byte) (addr string, stopfunc func(), client net.Conn, err error) {
2021
ln, err := net.Listen("tcp", "localhost:0")
2122
if err != nil {
2223
return "", nil, nil, err
@@ -27,7 +28,7 @@ func setupServer(tb testing.TB) (addr string, stopfunc func(), client net.Conn,
2728
tb.Error(err)
2829
return
2930
}
30-
go handleConnection(conn)
31+
go handleConnection(tb, conn, metadata...)
3132
}()
3233
addr = ln.Addr().String()
3334
conn, err := net.Dial("tcp", addr)
@@ -38,10 +39,13 @@ func setupServer(tb testing.TB) (addr string, stopfunc func(), client net.Conn,
3839
return ln.Addr().String(), func() { ln.Close() }, conn, nil
3940
}
4041

41-
func handleConnection(conn net.Conn) {
42+
func handleConnection(tb testing.TB, conn net.Conn, metadata ...byte) {
4243
session, _ := Server(conn, nil)
4344
for {
4445
if stream, err := session.AcceptStream(); err == nil {
46+
if !bytes.Equal(metadata, stream.Metadata()) {
47+
tb.Fatal("metadata mimatch")
48+
}
4549
go func(s io.ReadWriteCloser) {
4650
buf := make([]byte, 65536)
4751
for {
@@ -58,6 +62,18 @@ func handleConnection(conn net.Conn) {
5862
}
5963
}
6064

65+
func TestMetadata(t *testing.T) {
66+
metadata := []byte("hello, world")
67+
_, stop, cli, err := setupServer(t, metadata...)
68+
if err != nil {
69+
t.Fatal(err)
70+
}
71+
defer stop()
72+
session, _ := Client(cli, nil)
73+
session.OpenStream(metadata...)
74+
session.Close()
75+
}
76+
6177
func TestEcho(t *testing.T) {
6278
_, stop, cli, err := setupServer(t)
6379
if err != nil {

stream.go

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ import (
1414
// Stream implements net.Conn
1515
type Stream struct {
1616
id uint32
17+
metadata []byte
1718
rstflag int32
1819
sess *Session
1920
buffer bytes.Buffer
@@ -27,9 +28,10 @@ type Stream struct {
2728
}
2829

2930
// newStream initiates a Stream struct
30-
func newStream(id uint32, frameSize int, sess *Session) *Stream {
31+
func newStream(id uint32, metadata []byte, frameSize int, sess *Session) *Stream {
3132
s := new(Stream)
3233
s.id = id
34+
s.metadata = metadata
3335
s.chReadEvent = make(chan struct{}, 1)
3436
s.frameSize = frameSize
3537
s.sess = sess
@@ -42,6 +44,11 @@ func (s *Stream) ID() uint32 {
4244
return s.id
4345
}
4446

47+
// Metadata returns stream metadata which was provided when opening stream.
48+
func (s *Stream) Metadata() []byte {
49+
return s.metadata
50+
}
51+
4552
// Read implements net.Conn
4653
func (s *Stream) Read(b []byte) (n int, err error) {
4754
if len(b) == 0 {

0 commit comments

Comments
 (0)