Skip to content

Commit

Permalink
Merge pull request #15 from liu-song/stream
Browse files Browse the repository at this point in the history
perf: uses RWMutex
  • Loading branch information
GuangmingLuo authored Apr 20, 2023
2 parents 98b9116 + cdfecc4 commit 145e239
Show file tree
Hide file tree
Showing 5 changed files with 39 additions and 37 deletions.
5 changes: 1 addition & 4 deletions protocol_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -327,10 +327,7 @@ func handleStreamClose(s *Session, hdr header, buf []byte) (int, bool, error) {
id := binary.BigEndian.Uint32(buf[:4])
s.logger.debugf("receive peer stream[%d] goaway.", id)

s.streamLock.Lock()
stream := s.streams[id]
s.streamLock.Unlock()

stream := s.getStreamById(id)
if stream == nil {
s.logger.warnf("missing stream: %d", id)
return headerSize + idLen, false, nil
Expand Down
21 changes: 13 additions & 8 deletions session.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ type Session struct {
eventConn eventConn
// streams maps a stream id to a stream protected by streamLock.
streams map[uint32]*Stream
streamLock sync.Mutex
streamLock sync.RWMutex

// acceptCh is used to pass ready streams to the client
acceptCh chan *Stream
Expand Down Expand Up @@ -234,11 +234,11 @@ func (s *Session) CloseChan() <-chan struct{} {
return s.shutdownCh
}

// ActiveStreams returns the number of currently open streams
func (s *Session) ActiveStreams() int {
s.streamLock.Lock()
// GetActiveStreamCount returns the number of currently open streams
func (s *Session) GetActiveStreamCount() int {
s.streamLock.RLock()
num := len(s.streams)
s.streamLock.Unlock()
s.streamLock.RUnlock()
return num
}

Expand Down Expand Up @@ -578,6 +578,13 @@ func (s *Session) getStream(id uint32, state streamState) (stream *Stream) {

}

func (s *Session) getStreamById(id uint32) *Stream {
s.streamLock.RLock()
stream := s.streams[id]
s.streamLock.RUnlock()
return stream
}

// handleStreamMessage handles either a data frame
func (s *Session) handleStreamMessage(stream *Stream, wrapper bufferSliceWrapper, state streamState) error {

Expand Down Expand Up @@ -702,9 +709,7 @@ func (s *Session) hotRestart(epoch uint64, event eventType) error {

// GetMetrics return the session's metrics for monitoring
func (s *Session) GetMetrics() (PerformanceMetrics, StabilityMetrics, ShareMemoryMetrics) {
s.streamLock.Lock()
activeStreamCount := uint64(len(s.streams))
s.streamLock.Unlock()
activeStreamCount := uint64(s.GetActiveStreamCount())

//session will close shutdownCH when it was stopping. and monitorLoop will wake up and flush metrics.
//if queueManager do unmap at the moment, there will be panic.
Expand Down
14 changes: 7 additions & 7 deletions session_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ type SessionManager struct {
epoch uint64
randID uint64
state sessionSateType
sync.Mutex
sync.RWMutex
}

// SessionManagerConfig is the configuration of SessionManager
Expand Down Expand Up @@ -189,25 +189,25 @@ func (sm *SessionManager) background() {
go func(id int) {
defer sm.wg.Done()
for {
sm.Lock()
sm.RLock()
if sm.state == hotRestartState {
sm.Unlock()
sm.RUnlock()
time.Sleep(500 * time.Millisecond)
continue
}
pool := sm.pools[id]
sm.Unlock()
sm.RUnlock()

select {
case <-pool.Session().CloseChan():
sm.Lock()
sm.RLock()
epochID := sm.epoch
randID := sm.randID
if sm.state == hotRestartState || epochID != pool.Session().epochID || randID != pool.Session().randID {
sm.Unlock()
sm.RUnlock()
break
}
sm.Unlock()
sm.RUnlock()

pool.close()
for {
Expand Down
24 changes: 12 additions & 12 deletions session_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -237,8 +237,8 @@ func TestSendData_Small(t *testing.T) {
t.Logf("err: %v", err)
}

if server.ActiveStreams() != 1 {
t.Fatal("num of streams is ", server.ActiveStreams())
if server.GetActiveStreamCount() != 1 {
t.Fatal("num of streams is ", server.GetActiveStreamCount())
}

size := 0
Expand All @@ -265,7 +265,7 @@ func TestSendData_Small(t *testing.T) {
t.Logf("err: %v", err)
}

if client.ActiveStreams() != 1 {
if client.GetActiveStreamCount() != 1 {
t.Logf("bad")
}

Expand Down Expand Up @@ -296,10 +296,10 @@ func TestSendData_Small(t *testing.T) {
panic("timeout")
}

if client.ActiveStreams() != 0 {
t.Fatalf("bad, streams:%d", client.ActiveStreams())
if client.GetActiveStreamCount() != 0 {
t.Fatalf("bad, streams:%d", client.GetActiveStreamCount())
}
if server.ActiveStreams() != 0 {
if server.GetActiveStreamCount() != 0 {
t.Fatalf("bad")
}
}
Expand Down Expand Up @@ -449,8 +449,8 @@ func TestSendData_Small_Memfd(t *testing.T) {
t.Logf("err: %v", err)
}

if server.ActiveStreams() != 1 {
t.Fatal("num of streams is ", server.ActiveStreams())
if server.GetActiveStreamCount() != 1 {
t.Fatal("num of streams is ", server.GetActiveStreamCount())
}

size := 0
Expand All @@ -477,7 +477,7 @@ func TestSendData_Small_Memfd(t *testing.T) {
t.Logf("err: %v", err)
}

if client.ActiveStreams() != 1 {
if client.GetActiveStreamCount() != 1 {
t.Logf("bad")
}

Expand Down Expand Up @@ -508,10 +508,10 @@ func TestSendData_Small_Memfd(t *testing.T) {
panic("timeout")
}

if client.ActiveStreams() != 0 {
t.Fatalf("bad, streams:%d", client.ActiveStreams())
if client.GetActiveStreamCount() != 0 {
t.Fatalf("bad, streams:%d", client.GetActiveStreamCount())
}
if server.ActiveStreams() != 0 {
if server.GetActiveStreamCount() != 0 {
t.Fatalf("bad")
}
}
12 changes: 6 additions & 6 deletions stream_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ func newClientServerWithNoCheck(conf *Config) (client *Session, server *Session)
return client, server
}

//Close 94.4%
// Close 94.4%
func TestStream_Close(t *testing.T) {
c := testConf()
c.QueueCap = 1
Expand All @@ -72,10 +72,10 @@ func TestStream_Close(t *testing.T) {
<-hadForceCloseNotifyCh
_, err = s.BufferReader().ReadByte()
assert.Equal(t, ErrEndOfStream, err)
assert.Equal(t, 1, server.ActiveStreams())
assert.Equal(t, 1, server.GetActiveStreamCount())
assert.Equal(t, uint32(streamHalfClosed), s.state)
s.Close()
assert.Equal(t, 0, server.ActiveStreams())
assert.Equal(t, 0, server.GetActiveStreamCount())
assert.Equal(t, uint32(streamClosed), s.state)
close(doneCh)
}()
Expand Down Expand Up @@ -226,7 +226,7 @@ func TestStream_RandomPackageSize(t *testing.T) {

}

//TODO:halfClose 75.0%
// TODO:halfClose 75.0%
func TestStream_HalfClose(t *testing.T) {
conf := testConf()
conf.ShareMemoryBufferCap = 1 << 20
Expand Down Expand Up @@ -404,7 +404,7 @@ func TestStream_SendQueueFullTimeout(t *testing.T) {
}
}

//reset 92.9%
// reset 92.9%
func TestStream_Reset(t *testing.T) {
conf := testConf()
client, server := testClientServerConfig(conf)
Expand Down Expand Up @@ -509,7 +509,7 @@ func TestStream_fillDataToReadBuffer(t *testing.T) {

}

//TODO: SwapBufferForReuse 66.7%
// TODO: SwapBufferForReuse 66.7%
func TestStream_SwapBufferForReuse(t *testing.T) {

}
Expand Down

0 comments on commit 145e239

Please sign in to comment.