diff --git a/conn.go b/conn.go index d785206..100186d 100644 --- a/conn.go +++ b/conn.go @@ -302,7 +302,7 @@ func (r *outstandingRequests) shutdown(err error) { type conn struct { t transport - session *session + session atomic.Pointer[session] outstandingRequests *outstandingRequests sequenceWindow uint64 dialect uint16 @@ -321,9 +321,9 @@ type conn struct { wdone chan struct{} write chan []byte werr chan error - recvPool sync.Pool // reusable receive buffers - encodeBuf []byte // retained request encoding buffer; reused under conn.m - compoundSizes []int // retained sizes buffer for compound requests; reused under conn.m + recvPool sync.Pool // reusable receive buffers + encodeBuf []byte // retained request encoding buffer; reused under conn.m + compoundSizes []int // retained sizes buffer for compound requests; reused under conn.m m sync.Mutex @@ -333,15 +333,7 @@ type conn struct { // serverGuid [16]byte // clientGuid [16]byte - _useSession int32 // receiver use session? -} - -func (conn *conn) useSession() bool { - return atomic.LoadInt32(&conn._useSession) != 0 -} - -func (conn *conn) enableSession() { - atomic.StoreInt32(&conn._useSession, 1) + useSession atomic.Bool } //lint:ignore U1000 appears to be legacy, unsure, so leaving for now @@ -490,8 +482,7 @@ func (conn *conn) sendCompound(ctx context.Context, entries []compoundEntry) ([] hdr.Flags |= smb2.SMB2_FLAGS_RELATED_OPERATIONS } - s := conn.session - if s != nil { + if s := conn.session.Load(); s != nil { hdr.SessionId = s.sessionId if entry.tc != nil { hdr.TreeId = entry.tc.treeId @@ -535,8 +526,7 @@ func (conn *conn) sendCompound(ctx context.Context, entries []compoundEntry) ([] // Phase 3: Sign or encrypt. wirePkt := compound - s := conn.session - if s != nil { + if s := conn.session.Load(); s != nil { encrypt := s.sessionFlags&smb2.SMB2_SESSION_FLAG_ENCRYPT_DATA != 0 if !encrypt { for _, entry := range entries { @@ -643,11 +633,10 @@ func (conn *conn) makeRequestResponse(ctx context.Context, req smb2.Packet, tc * hdr.MessageId = msgId - s := conn.session + s := conn.session.Load() if s != nil { hdr.SessionId = s.sessionId - if tc != nil { hdr.TreeId = tc.treeId } @@ -748,7 +737,7 @@ func (conn *conn) runReceiver() { goto exit } - hasSession := conn.useSession() + hasSession := conn.useSession.Load() var isEncrypted bool @@ -771,7 +760,7 @@ func (conn *conn) runReceiver() { } p := smb2.PacketCodec(pkt) - if s := conn.session; s != nil { + if s := conn.session.Load(); s != nil { if s.sessionId != p.SessionId() { conn.freePoolBuf(rb) @@ -944,7 +933,9 @@ func (conn *conn) tryDecrypt(pkt []byte) ([]byte, *recvBuf, error, bool) { return nil, nil, &InvalidResponseError{"encrypted flag is not on"}, false } - if conn.session == nil || conn.session.sessionId != t.SessionId() { + s := conn.session.Load() + + if s == nil || s.sessionId != t.SessionId() { return nil, nil, &InvalidResponseError{"unknown session id returned"}, false } @@ -956,7 +947,7 @@ func (conn *conn) tryDecrypt(pkt []byte) ([]byte, *recvBuf, error, bool) { } c := pRb.b[:cLen] - pkt, err := conn.session.decrypt(pkt, c) + pkt, err := s.decrypt(pkt, c) if err != nil { conn.freePoolBuf(pRb) return nil, nil, &InvalidResponseError{err.Error()}, false @@ -972,21 +963,22 @@ func (conn *conn) tryVerify(pkt []byte, isEncrypted bool) error { p := smb2.PacketCodec(pkt) msgId := p.MessageId() + s := conn.session.Load() if msgId != 0xFFFFFFFFFFFFFFFF { if p.Flags()&smb2.SMB2_FLAGS_SIGNED != 0 { - if conn.session == nil || conn.session.sessionId != p.SessionId() { + if s == nil || s.sessionId != p.SessionId() { return &InvalidResponseError{"unknown session id returned"} } else { - if !conn.session.verify(pkt) { + if !s.verify(pkt) { return &InvalidResponseError{"unverified packet returned"} } } } else { if conn.requireSigning && !isEncrypted { - if conn.session != nil { - if conn.session.sessionFlags&(smb2.SMB2_SESSION_FLAG_IS_GUEST|smb2.SMB2_SESSION_FLAG_IS_NULL) == 0 { - if conn.session.sessionId == p.SessionId() { + if s != nil { + if s.sessionFlags&(smb2.SMB2_SESSION_FLAG_IS_GUEST|smb2.SMB2_SESSION_FLAG_IS_NULL) == 0 { + if s.sessionId == p.SessionId() { return &InvalidResponseError{"signing required"} } } diff --git a/conn_bench_test.go b/conn_bench_test.go index a4ab084..9cafd57 100644 --- a/conn_bench_test.go +++ b/conn_bench_test.go @@ -167,7 +167,7 @@ func fakeServerEncrypted(t transport, responseData []byte, dec, enc cipher.AEAD, // must set up c.session before calling this. func newBenchFile(c *conn) *File { tc := &treeConn{ - session: c.session, + session: c.session.Load(), } fs := &Share{ @@ -197,11 +197,11 @@ func BenchmarkReadAt(b *testing.B) { c, cleanup := newBenchConn(clientConn) defer cleanup() - c.session = &session{ + c.session.Store(&session{ conn: c, treeConnTables: make(map[uint32]*treeConn), sessionFlags: smb2.SMB2_SESSION_FLAG_IS_GUEST, - } + }) responseData := make([]byte, sz.n) go fakeServer(direct(serverConn), responseData) @@ -240,15 +240,15 @@ func BenchmarkReadAt(b *testing.B) { panic(err) } - c.session = &session{ + c.session.Store(&session{ conn: c, treeConnTables: make(map[uint32]*treeConn), sessionFlags: smb2.SMB2_SESSION_FLAG_ENCRYPT_DATA, sessionId: 0xdeadbeef, encrypter: newGCM(keyC2S), decrypter: newGCM(keyS2C), - } - c.enableSession() + }) + c.useSession.Store(true) responseData := make([]byte, sz.n) go fakeServerEncrypted( @@ -341,16 +341,15 @@ func BenchmarkRoundTrip(b *testing.B) { panic(err) } - s := &session{ + c.session.Store(&session{ conn: c, treeConnTables: make(map[uint32]*treeConn), sessionFlags: smb2.SMB2_SESSION_FLAG_ENCRYPT_DATA, sessionId: 0xdeadbeef, encrypter: newGCM(keyC2S), decrypter: newGCM(keyS2C), - } - c.session = s - c.enableSession() + }) + c.useSession.Store(true) responseData := make([]byte, sz.n) go fakeServerEncrypted( diff --git a/session.go b/session.go index 745ef6d..8ec8beb 100644 --- a/session.go +++ b/session.go @@ -115,7 +115,7 @@ func sessionSetup(ctx context.Context, conn *conn, i Initiator) (*session, error // We set session before sending packet just for setting hdr.SessionId. // But, we should not permit access from receiver until the session information is completed. - conn.session = s + conn.session.Store(s) if status == erref.STATUS_MORE_PROCESSING_REQUIRED { req.SecurityBuffer = outputToken @@ -255,7 +255,7 @@ func sessionSetup(ctx context.Context, conn *conn, i Initiator) (*session, error s.sessionFlags = r.SessionFlags() // now, allow access from receiver - s.enableSession() + s.useSession.Store(true) return s, nil }