diff --git a/protocol_manager.go b/protocol_manager.go index 68c0407..e980b0b 100644 --- a/protocol_manager.go +++ b/protocol_manager.go @@ -327,10 +327,7 @@ func handleStreamClose(s *Session, hdr header, buf []byte) (int, bool, error) { id := binary.BigEndian.Uint32(buf[:4]) s.logger.debugf("receive peer stream[%d] goaway.", id) - s.streamLock.Lock() - stream := s.streams[id] - s.streamLock.Unlock() - + stream := s.getStreamById(id) if stream == nil { s.logger.warnf("missing stream: %d", id) return headerSize + idLen, false, nil diff --git a/session.go b/session.go index ef12772..c057e47 100644 --- a/session.go +++ b/session.go @@ -52,7 +52,7 @@ type Session struct { eventConn eventConn // streams maps a stream id to a stream protected by streamLock. streams map[uint32]*Stream - streamLock sync.Mutex + streamLock sync.RWMutex // acceptCh is used to pass ready streams to the client acceptCh chan *Stream @@ -234,11 +234,11 @@ func (s *Session) CloseChan() <-chan struct{} { return s.shutdownCh } -// ActiveStreams returns the number of currently open streams -func (s *Session) ActiveStreams() int { - s.streamLock.Lock() +// GetActiveStreamCount returns the number of currently open streams +func (s *Session) GetActiveStreamCount() int { + s.streamLock.RLock() num := len(s.streams) - s.streamLock.Unlock() + s.streamLock.RUnlock() return num } @@ -578,6 +578,13 @@ func (s *Session) getStream(id uint32, state streamState) (stream *Stream) { } +func (s *Session) getStreamById(id uint32) *Stream { + s.streamLock.RLock() + stream := s.streams[id] + s.streamLock.RUnlock() + return stream +} + // handleStreamMessage handles either a data frame func (s *Session) handleStreamMessage(stream *Stream, wrapper bufferSliceWrapper, state streamState) error { @@ -702,9 +709,7 @@ func (s *Session) hotRestart(epoch uint64, event eventType) error { // GetMetrics return the session's metrics for monitoring func (s *Session) GetMetrics() (PerformanceMetrics, StabilityMetrics, ShareMemoryMetrics) { - s.streamLock.Lock() - activeStreamCount := uint64(len(s.streams)) - s.streamLock.Unlock() + activeStreamCount := uint64(s.GetActiveStreamCount()) //session will close shutdownCH when it was stopping. and monitorLoop will wake up and flush metrics. //if queueManager do unmap at the moment, there will be panic. diff --git a/session_manager.go b/session_manager.go index c0df8f1..7464688 100644 --- a/session_manager.go +++ b/session_manager.go @@ -60,7 +60,7 @@ type SessionManager struct { epoch uint64 randID uint64 state sessionSateType - sync.Mutex + sync.RWMutex } // SessionManagerConfig is the configuration of SessionManager @@ -189,25 +189,25 @@ func (sm *SessionManager) background() { go func(id int) { defer sm.wg.Done() for { - sm.Lock() + sm.RLock() if sm.state == hotRestartState { - sm.Unlock() + sm.RUnlock() time.Sleep(500 * time.Millisecond) continue } pool := sm.pools[id] - sm.Unlock() + sm.RUnlock() select { case <-pool.Session().CloseChan(): - sm.Lock() + sm.RLock() epochID := sm.epoch randID := sm.randID if sm.state == hotRestartState || epochID != pool.Session().epochID || randID != pool.Session().randID { - sm.Unlock() + sm.RUnlock() break } - sm.Unlock() + sm.RUnlock() pool.close() for { diff --git a/session_test.go b/session_test.go index 434be6e..36588ae 100644 --- a/session_test.go +++ b/session_test.go @@ -237,8 +237,8 @@ func TestSendData_Small(t *testing.T) { t.Logf("err: %v", err) } - if server.ActiveStreams() != 1 { - t.Fatal("num of streams is ", server.ActiveStreams()) + if server.GetActiveStreamCount() != 1 { + t.Fatal("num of streams is ", server.GetActiveStreamCount()) } size := 0 @@ -265,7 +265,7 @@ func TestSendData_Small(t *testing.T) { t.Logf("err: %v", err) } - if client.ActiveStreams() != 1 { + if client.GetActiveStreamCount() != 1 { t.Logf("bad") } @@ -296,10 +296,10 @@ func TestSendData_Small(t *testing.T) { panic("timeout") } - if client.ActiveStreams() != 0 { - t.Fatalf("bad, streams:%d", client.ActiveStreams()) + if client.GetActiveStreamCount() != 0 { + t.Fatalf("bad, streams:%d", client.GetActiveStreamCount()) } - if server.ActiveStreams() != 0 { + if server.GetActiveStreamCount() != 0 { t.Fatalf("bad") } } @@ -449,8 +449,8 @@ func TestSendData_Small_Memfd(t *testing.T) { t.Logf("err: %v", err) } - if server.ActiveStreams() != 1 { - t.Fatal("num of streams is ", server.ActiveStreams()) + if server.GetActiveStreamCount() != 1 { + t.Fatal("num of streams is ", server.GetActiveStreamCount()) } size := 0 @@ -477,7 +477,7 @@ func TestSendData_Small_Memfd(t *testing.T) { t.Logf("err: %v", err) } - if client.ActiveStreams() != 1 { + if client.GetActiveStreamCount() != 1 { t.Logf("bad") } @@ -508,10 +508,10 @@ func TestSendData_Small_Memfd(t *testing.T) { panic("timeout") } - if client.ActiveStreams() != 0 { - t.Fatalf("bad, streams:%d", client.ActiveStreams()) + if client.GetActiveStreamCount() != 0 { + t.Fatalf("bad, streams:%d", client.GetActiveStreamCount()) } - if server.ActiveStreams() != 0 { + if server.GetActiveStreamCount() != 0 { t.Fatalf("bad") } } diff --git a/stream_test.go b/stream_test.go index 9c2f4bc..42b2a39 100644 --- a/stream_test.go +++ b/stream_test.go @@ -51,7 +51,7 @@ func newClientServerWithNoCheck(conf *Config) (client *Session, server *Session) return client, server } -//Close 94.4% +// Close 94.4% func TestStream_Close(t *testing.T) { c := testConf() c.QueueCap = 1 @@ -72,10 +72,10 @@ func TestStream_Close(t *testing.T) { <-hadForceCloseNotifyCh _, err = s.BufferReader().ReadByte() assert.Equal(t, ErrEndOfStream, err) - assert.Equal(t, 1, server.ActiveStreams()) + assert.Equal(t, 1, server.GetActiveStreamCount()) assert.Equal(t, uint32(streamHalfClosed), s.state) s.Close() - assert.Equal(t, 0, server.ActiveStreams()) + assert.Equal(t, 0, server.GetActiveStreamCount()) assert.Equal(t, uint32(streamClosed), s.state) close(doneCh) }() @@ -226,7 +226,7 @@ func TestStream_RandomPackageSize(t *testing.T) { } -//TODO:halfClose 75.0% +// TODO:halfClose 75.0% func TestStream_HalfClose(t *testing.T) { conf := testConf() conf.ShareMemoryBufferCap = 1 << 20 @@ -404,7 +404,7 @@ func TestStream_SendQueueFullTimeout(t *testing.T) { } } -//reset 92.9% +// reset 92.9% func TestStream_Reset(t *testing.T) { conf := testConf() client, server := testClientServerConfig(conf) @@ -509,7 +509,7 @@ func TestStream_fillDataToReadBuffer(t *testing.T) { } -//TODO: SwapBufferForReuse 66.7% +// TODO: SwapBufferForReuse 66.7% func TestStream_SwapBufferForReuse(t *testing.T) { }