From d186c9d650cb5a6b71c1015c1a9738395c42801f Mon Sep 17 00:00:00 2001 From: Byron Rakitzis Date: Tue, 17 Mar 2026 03:46:58 -0700 Subject: [PATCH 1/2] gofmt --- conn.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/conn.go b/conn.go index d785206..fd2c3f3 100644 --- a/conn.go +++ b/conn.go @@ -321,9 +321,9 @@ type conn struct { wdone chan struct{} write chan []byte werr chan error - recvPool sync.Pool // reusable receive buffers - encodeBuf []byte // retained request encoding buffer; reused under conn.m - compoundSizes []int // retained sizes buffer for compound requests; reused under conn.m + recvPool sync.Pool // reusable receive buffers + encodeBuf []byte // retained request encoding buffer; reused under conn.m + compoundSizes []int // retained sizes buffer for compound requests; reused under conn.m m sync.Mutex From b5293152dca5dc55be58379c180c18108fa35be1 Mon Sep 17 00:00:00 2001 From: Byron Rakitzis Date: Tue, 17 Mar 2026 04:09:38 -0700 Subject: [PATCH 2/2] Replace conn.session with atomic.Pointer The session variable in conn was accessed after _useSession was polled via an atomic.Load(), but this obscured the safety of loads and stores of conn.session. This change introcuces an atomic.Pointer for clarity, and also guarantees that if conn.session is accessed without first verifying useSession that conn will not be accessed unsafely. This change provides some code clarity, but should not affect the correctness of the code. --- conn.go | 42 +++++++++++++++++------------------------- conn_bench_test.go | 19 +++++++++---------- session.go | 4 ++-- 3 files changed, 28 insertions(+), 37 deletions(-) diff --git a/conn.go b/conn.go index fd2c3f3..100186d 100644 --- a/conn.go +++ b/conn.go @@ -302,7 +302,7 @@ func (r *outstandingRequests) shutdown(err error) { type conn struct { t transport - session *session + session atomic.Pointer[session] outstandingRequests *outstandingRequests sequenceWindow uint64 dialect uint16 @@ -333,15 +333,7 @@ type conn struct { // serverGuid [16]byte // clientGuid [16]byte - _useSession int32 // receiver use session? -} - -func (conn *conn) useSession() bool { - return atomic.LoadInt32(&conn._useSession) != 0 -} - -func (conn *conn) enableSession() { - atomic.StoreInt32(&conn._useSession, 1) + useSession atomic.Bool } //lint:ignore U1000 appears to be legacy, unsure, so leaving for now @@ -490,8 +482,7 @@ func (conn *conn) sendCompound(ctx context.Context, entries []compoundEntry) ([] hdr.Flags |= smb2.SMB2_FLAGS_RELATED_OPERATIONS } - s := conn.session - if s != nil { + if s := conn.session.Load(); s != nil { hdr.SessionId = s.sessionId if entry.tc != nil { hdr.TreeId = entry.tc.treeId @@ -535,8 +526,7 @@ func (conn *conn) sendCompound(ctx context.Context, entries []compoundEntry) ([] // Phase 3: Sign or encrypt. wirePkt := compound - s := conn.session - if s != nil { + if s := conn.session.Load(); s != nil { encrypt := s.sessionFlags&smb2.SMB2_SESSION_FLAG_ENCRYPT_DATA != 0 if !encrypt { for _, entry := range entries { @@ -643,11 +633,10 @@ func (conn *conn) makeRequestResponse(ctx context.Context, req smb2.Packet, tc * hdr.MessageId = msgId - s := conn.session + s := conn.session.Load() if s != nil { hdr.SessionId = s.sessionId - if tc != nil { hdr.TreeId = tc.treeId } @@ -748,7 +737,7 @@ func (conn *conn) runReceiver() { goto exit } - hasSession := conn.useSession() + hasSession := conn.useSession.Load() var isEncrypted bool @@ -771,7 +760,7 @@ func (conn *conn) runReceiver() { } p := smb2.PacketCodec(pkt) - if s := conn.session; s != nil { + if s := conn.session.Load(); s != nil { if s.sessionId != p.SessionId() { conn.freePoolBuf(rb) @@ -944,7 +933,9 @@ func (conn *conn) tryDecrypt(pkt []byte) ([]byte, *recvBuf, error, bool) { return nil, nil, &InvalidResponseError{"encrypted flag is not on"}, false } - if conn.session == nil || conn.session.sessionId != t.SessionId() { + s := conn.session.Load() + + if s == nil || s.sessionId != t.SessionId() { return nil, nil, &InvalidResponseError{"unknown session id returned"}, false } @@ -956,7 +947,7 @@ func (conn *conn) tryDecrypt(pkt []byte) ([]byte, *recvBuf, error, bool) { } c := pRb.b[:cLen] - pkt, err := conn.session.decrypt(pkt, c) + pkt, err := s.decrypt(pkt, c) if err != nil { conn.freePoolBuf(pRb) return nil, nil, &InvalidResponseError{err.Error()}, false @@ -972,21 +963,22 @@ func (conn *conn) tryVerify(pkt []byte, isEncrypted bool) error { p := smb2.PacketCodec(pkt) msgId := p.MessageId() + s := conn.session.Load() if msgId != 0xFFFFFFFFFFFFFFFF { if p.Flags()&smb2.SMB2_FLAGS_SIGNED != 0 { - if conn.session == nil || conn.session.sessionId != p.SessionId() { + if s == nil || s.sessionId != p.SessionId() { return &InvalidResponseError{"unknown session id returned"} } else { - if !conn.session.verify(pkt) { + if !s.verify(pkt) { return &InvalidResponseError{"unverified packet returned"} } } } else { if conn.requireSigning && !isEncrypted { - if conn.session != nil { - if conn.session.sessionFlags&(smb2.SMB2_SESSION_FLAG_IS_GUEST|smb2.SMB2_SESSION_FLAG_IS_NULL) == 0 { - if conn.session.sessionId == p.SessionId() { + if s != nil { + if s.sessionFlags&(smb2.SMB2_SESSION_FLAG_IS_GUEST|smb2.SMB2_SESSION_FLAG_IS_NULL) == 0 { + if s.sessionId == p.SessionId() { return &InvalidResponseError{"signing required"} } } diff --git a/conn_bench_test.go b/conn_bench_test.go index a4ab084..9cafd57 100644 --- a/conn_bench_test.go +++ b/conn_bench_test.go @@ -167,7 +167,7 @@ func fakeServerEncrypted(t transport, responseData []byte, dec, enc cipher.AEAD, // must set up c.session before calling this. func newBenchFile(c *conn) *File { tc := &treeConn{ - session: c.session, + session: c.session.Load(), } fs := &Share{ @@ -197,11 +197,11 @@ func BenchmarkReadAt(b *testing.B) { c, cleanup := newBenchConn(clientConn) defer cleanup() - c.session = &session{ + c.session.Store(&session{ conn: c, treeConnTables: make(map[uint32]*treeConn), sessionFlags: smb2.SMB2_SESSION_FLAG_IS_GUEST, - } + }) responseData := make([]byte, sz.n) go fakeServer(direct(serverConn), responseData) @@ -240,15 +240,15 @@ func BenchmarkReadAt(b *testing.B) { panic(err) } - c.session = &session{ + c.session.Store(&session{ conn: c, treeConnTables: make(map[uint32]*treeConn), sessionFlags: smb2.SMB2_SESSION_FLAG_ENCRYPT_DATA, sessionId: 0xdeadbeef, encrypter: newGCM(keyC2S), decrypter: newGCM(keyS2C), - } - c.enableSession() + }) + c.useSession.Store(true) responseData := make([]byte, sz.n) go fakeServerEncrypted( @@ -341,16 +341,15 @@ func BenchmarkRoundTrip(b *testing.B) { panic(err) } - s := &session{ + c.session.Store(&session{ conn: c, treeConnTables: make(map[uint32]*treeConn), sessionFlags: smb2.SMB2_SESSION_FLAG_ENCRYPT_DATA, sessionId: 0xdeadbeef, encrypter: newGCM(keyC2S), decrypter: newGCM(keyS2C), - } - c.session = s - c.enableSession() + }) + c.useSession.Store(true) responseData := make([]byte, sz.n) go fakeServerEncrypted( diff --git a/session.go b/session.go index 745ef6d..8ec8beb 100644 --- a/session.go +++ b/session.go @@ -115,7 +115,7 @@ func sessionSetup(ctx context.Context, conn *conn, i Initiator) (*session, error // We set session before sending packet just for setting hdr.SessionId. // But, we should not permit access from receiver until the session information is completed. - conn.session = s + conn.session.Store(s) if status == erref.STATUS_MORE_PROCESSING_REQUIRED { req.SecurityBuffer = outputToken @@ -255,7 +255,7 @@ func sessionSetup(ctx context.Context, conn *conn, i Initiator) (*session, error s.sessionFlags = r.SessionFlags() // now, allow access from receiver - s.enableSession() + s.useSession.Store(true) return s, nil }