diff --git a/client.go b/client.go index d45b5c0..27ba0fe 100644 --- a/client.go +++ b/client.go @@ -1406,7 +1406,7 @@ func (f *File) readAt(b []byte, off int64) (n int, err error) { case len(b)-n == 0: return n, nil case len(b)-n <= maxReadSize: - bs, isEOF, err := f.readAtChunk(len(b)-n, int64(n)+off) + bs, isEOF, rr, err := f.readAtChunk(len(b)-n, int64(n)+off) if err != nil { if err, ok := err.(*ResponseError); ok && erref.NtStatus(err.Code) == erref.STATUS_END_OF_FILE && n != 0 { return n, nil @@ -1415,12 +1415,13 @@ func (f *File) readAt(b []byte, off int64) (n int, err error) { } n += copy(b[n:], bs) + rr.freeRecvBuf() if isEOF { return n, nil } default: - bs, isEOF, err := f.readAtChunk(maxReadSize, int64(n)+off) + bs, isEOF, rr, err := f.readAtChunk(maxReadSize, int64(n)+off) if err != nil { if err, ok := err.(*ResponseError); ok && erref.NtStatus(err.Code) == erref.STATUS_END_OF_FILE && n != 0 { return n, nil @@ -1429,6 +1430,7 @@ func (f *File) readAt(b []byte, off int64) (n int, err error) { } n += copy(b[n:], bs) + rr.freeRecvBuf() if isEOF { return n, nil @@ -1437,7 +1439,7 @@ func (f *File) readAt(b []byte, off int64) (n int, err error) { } } -func (f *File) readAtChunk(n int, off int64) (bs []byte, isEOF bool, err error) { +func (f *File) readAtChunk(n int, off int64) (bs []byte, isEOF bool, rr *requestResponse, err error) { creditCharge, m, err := f.fs.loanCredit(n) defer func() { if err != nil { @@ -1445,7 +1447,7 @@ func (f *File) readAtChunk(n int, off int64) (bs []byte, isEOF bool, err error) } }() if err != nil { - return nil, false, err + return nil, false, nil, err } req := &smb2.ReadRequest{ @@ -1463,19 +1465,31 @@ func (f *File) readAtChunk(n int, off int64) (bs []byte, isEOF bool, err error) req.CreditCharge = creditCharge - res, err := f.sendRecv(smb2.SMB2_READ, req) + rr, err = f.fs.send(req, f.fs.ctx) if err != nil { - return nil, false, err + return nil, false, nil, err + } + + pkt, err := f.fs.recv(rr) + if err != nil { + return nil, false, nil, err + } + + res, err := accept(smb2.SMB2_READ, pkt) + if err != nil { + rr.freeRecvBuf() + return nil, false, nil, err } r := smb2.ReadResponseDecoder(res) if r.IsInvalid() { - return nil, false, &InvalidResponseError{"broken read response format"} + rr.freeRecvBuf() + return nil, false, nil, &InvalidResponseError{"broken read response format"} } bs = r.Data() - return bs, len(bs) < m, nil + return bs, len(bs) < m, rr, nil } func (f *File) Readdir(n int) (fi []os.FileInfo, err error) { diff --git a/conn.go b/conn.go index 3fe06d8..a8a5422 100644 --- a/conn.go +++ b/conn.go @@ -13,6 +13,9 @@ import ( "github.com/cloudsoda/go-smb2/internal/smb2" ) +// length of tag used to verify the integrity of the encrypted data +const AES_AUTH_TAG_LEN = 16 + // Negotiator contains options for func (*Dialer) Dial. type Negotiator struct { RequireMessageSigning bool // enforce signing? @@ -98,7 +101,7 @@ func (n *Negotiator) negotiate(t transport, a *account, ctx context.Context) (*c } go conn.runSender() - go conn.runReciever() + go conn.runReceiver() retry: req, err := n.makeRequest() @@ -227,6 +230,12 @@ retry: return conn, nil } +// recvBuf wraps a pooled receive buffer. Stored as *recvBuf in sync.Pool +// so the pointer fits directly in the interface without boxing allocations. +type recvBuf struct { + b []byte +} + type requestResponse struct { msgId uint64 asyncId uint64 @@ -235,6 +244,17 @@ type requestResponse struct { ctx context.Context recv chan []byte err error + rb *recvBuf // pooled receive buffer wrapper; return via freeRecvBuf + bufPool *sync.Pool // pool to return rb to +} + +// freeRecvBuf returns the pooled receive buffer, if any. Safe to call +// multiple times or when rb is nil. +func (rr *requestResponse) freeRecvBuf() { + if rr.bufPool != nil && rr.rb != nil { + rr.bufPool.Put(rr.rb) + rr.rb = nil + } } type outstandingRequests struct { @@ -301,7 +321,8 @@ type conn struct { wdone chan struct{} write chan []byte werr chan error - encodeBuf []byte // retained request encoding buffer; reused under conn.m + encodeBuf []byte // retained request encoding buffer; reused under conn.m + recvPool sync.Pool // reusable receive buffers m sync.Mutex @@ -448,7 +469,7 @@ func (conn *conn) makeRequestResponse(req smb2.Packet, tc *treeConn, ctx context if s != nil { if _, ok := req.(*smb2.SessionSetupRequest); !ok { if s.sessionFlags&smb2.SMB2_SESSION_FLAG_ENCRYPT_DATA != 0 || (tc != nil && tc.shareFlags&smb2.SMB2_SHAREFLAG_ENCRYPT_DATA != 0) { - needed := 52 + len(pkt) + 16 + needed := 52 + len(pkt) + AES_AUTH_TAG_LEN if cap(s.encryptBuf) < needed { s.encryptBuf = make([]byte, needed) } @@ -505,7 +526,7 @@ func (conn *conn) runSender() { } } -func (conn *conn) runReciever() { +func (conn *conn) runReceiver() { var err error for { @@ -516,10 +537,16 @@ func (conn *conn) runReciever() { goto exit } - pkt := make([]byte, n) + rb, ok := conn.recvPool.Get().(*recvBuf) + if !ok || cap(rb.b) < n { + rb = &recvBuf{b: make([]byte, n)} + } + pkt := rb.b[:n] _, e = conn.t.Read(pkt) if e != nil { + conn.freePoolBuf(rb) + err = &TransportError{e} goto exit @@ -530,16 +557,28 @@ func (conn *conn) runReciever() { var isEncrypted bool if hasSession { - pkt, e, isEncrypted = conn.tryDecrypt(pkt) + var pRb *recvBuf + pkt, pRb, e, isEncrypted = conn.tryDecrypt(pkt) if e != nil { + conn.freePoolBuf(rb) + logger.Println("skip:", e) continue } + if isEncrypted { + // Decrypt produced a new plaintext buffer; the + // original ciphertext buffer can be reused now. + conn.freePoolBuf(rb) + rb = pRb + } + p := smb2.PacketCodec(pkt) if s := conn.session; s != nil { if s.sessionId != p.SessionId() { + conn.freePoolBuf(rb) + logger.Println("skip:", &InvalidResponseError{"unknown session id"}) continue @@ -547,6 +586,8 @@ func (conn *conn) runReciever() { if tc, ok := s.treeConnTables[p.TreeId()]; ok { if tc.treeId != p.TreeId() { + conn.freePoolBuf(rb) + logger.Println("skip:", &InvalidResponseError{"unknown tree id"}) continue @@ -555,31 +596,43 @@ func (conn *conn) runReciever() { } } - var next []byte - - for { - p := smb2.PacketCodec(pkt) - - if off := p.NextCommand(); off != 0 { - pkt, next = pkt[:off], pkt[off:] - } else { - next = nil - } - + p := smb2.PacketCodec(pkt) + if p.NextCommand() == 0 { + // Single response: transfer the pooled buffer to the caller. if hasSession { e = conn.tryVerify(pkt, isEncrypted) } - - e = conn.tryHandle(pkt, e) - if e != nil { + if e = conn.tryHandle(pkt, e, rb); e != nil { logger.Println("skip:", e) } + } else { + // Compound response: sub-responses share the underlying + // buffer, so we cannot transfer ownership to any one caller. + // The buffer is intentionally not returned to the pool; it + // will be GC'd once all consumers finish with their pkt slices. + + var next []byte + for { + if off := p.NextCommand(); off != 0 { + pkt, next = pkt[:off], pkt[off:] + } else { + next = nil + } - if next == nil { - break - } + if hasSession { + e = conn.tryVerify(pkt, isEncrypted) + } + if e = conn.tryHandle(pkt, e, nil); e != nil { + logger.Println("skip:", e) + } - pkt = next + if next == nil { + break + } + + pkt = next + p = smb2.PacketCodec(pkt) + } } } @@ -683,31 +736,40 @@ func acceptError(status uint32, res []byte) error { return &ResponseError{Code: status, data: [][]byte{eData}} } -func (conn *conn) tryDecrypt(pkt []byte) ([]byte, error, bool) { +func (conn *conn) tryDecrypt(pkt []byte) ([]byte, *recvBuf, error, bool) { p := smb2.PacketCodec(pkt) if p.IsInvalid() { t := smb2.TransformCodec(pkt) if t.IsInvalid() { - return nil, &InvalidResponseError{"broken packet header format"}, false + return nil, nil, &InvalidResponseError{"broken packet header format"}, false } if t.Flags() != smb2.Encrypted { - return nil, &InvalidResponseError{"encrypted flag is not on"}, false + return nil, nil, &InvalidResponseError{"encrypted flag is not on"}, false } if conn.session == nil || conn.session.sessionId != t.SessionId() { - return nil, &InvalidResponseError{"unknown session id returned"}, false + return nil, nil, &InvalidResponseError{"unknown session id returned"}, false + } + + // Get a pooled buffer for the decrypt work-buffer (ciphertext + tag). + cLen := len(t.EncryptedData()) + AES_AUTH_TAG_LEN + pRb, ok := conn.recvPool.Get().(*recvBuf) + if !ok || cap(pRb.b) < cLen { + pRb = &recvBuf{b: make([]byte, cLen)} } + c := pRb.b[:cLen] - pkt, err := conn.session.decrypt(pkt) + pkt, err := conn.session.decrypt(pkt, c) if err != nil { - return nil, &InvalidResponseError{err.Error()}, false + conn.freePoolBuf(pRb) + return nil, nil, &InvalidResponseError{err.Error()}, false } - return pkt, nil, true + return pkt, pRb, nil, true } - return pkt, nil, false + return pkt, nil, nil, false } func (conn *conn) tryVerify(pkt []byte, isEncrypted bool) error { @@ -740,7 +802,7 @@ func (conn *conn) tryVerify(pkt []byte, isEncrypted bool) error { return nil } -func (conn *conn) tryHandle(pkt []byte, e error) error { +func (conn *conn) tryHandle(pkt []byte, e error, rb *recvBuf) error { p := smb2.PacketCodec(pkt) msgId := p.MessageId() @@ -748,20 +810,37 @@ func (conn *conn) tryHandle(pkt []byte, e error) error { rr, ok := conn.outstandingRequests.pop(msgId) switch { case !ok: + conn.freePoolBuf(rb) return &InvalidResponseError{"unknown message id returned"} case e != nil: rr.err = e + conn.freePoolBuf(rb) close(rr.recv) case erref.NtStatus(p.Status()) == erref.STATUS_PENDING: rr.asyncId = p.AsyncId() conn.account.charge(p.CreditResponse(), rr.creditRequest) conn.outstandingRequests.set(msgId, rr) + conn.freePoolBuf(rb) default: conn.account.charge(p.CreditResponse(), rr.creditRequest) + // Transfer ownership of the pooled receive buffer to the + // requestResponse so the caller can return it via freeRecvBuf + // after it has finished reading the response packet. (the + // error cases in this switch statement all free the buffer + // immediately) + rr.rb = rb + rr.bufPool = &conn.recvPool + rr.recv <- pkt } return nil } + +func (conn *conn) freePoolBuf(rb *recvBuf) { + if rb != nil { + conn.recvPool.Put(rb) + } +} diff --git a/conn_bench_test.go b/conn_bench_test.go index 2f637a6..35a3609 100644 --- a/conn_bench_test.go +++ b/conn_bench_test.go @@ -32,7 +32,7 @@ func newBenchConn(netConn net.Conn) (*conn, func()) { capabilities: smb2.SMB2_GLOBAL_CAP_LARGE_MTU, } go c.runSender() - go c.runReciever() + go c.runReceiver() cleanup := func() { c.rdone <- struct{}{} @@ -67,6 +67,9 @@ func fakeServer(t transport, responseData []byte) { } respBuf := make([]byte, resp.Size()) resp.Encode(respBuf) + // Fix DataOffset: Encode writes 16 (offset within response body), + // but the decoder expects offset from packet start (64 + 16 = 80). + respBuf[64+2] = 80 reqBuf := make([]byte, bufSize) @@ -81,9 +84,10 @@ func fakeServer(t transport, responseData []byte) { p := smb2.PacketCodec(reqBuf[:n]) - // Patch MessageId and CreditResponse into the template. + // Patch MessageId, Command, and CreditResponse into the template. rp := smb2.PacketCodec(respBuf) rp.SetMessageId(p.MessageId()) + rp.SetCommand(p.Command()) rp.SetCreditResponse(p.CreditRequest()) if _, err := t.Write(respBuf); err != nil { @@ -105,6 +109,9 @@ func fakeServerEncrypted(t transport, responseData []byte, dec, enc cipher.AEAD, } plainResp := make([]byte, resp.Size()) resp.Encode(plainResp) + // Fix DataOffset: Encode writes 16 (offset within response body), + // but the decoder expects offset from packet start (64 + 16 = 80). + plainResp[64+2] = 80 reqBuf := make([]byte, bufSize+52+16) // room for transform header + payload + tag decBuf := make([]byte, 0, bufSize+16) // decrypt work buffer @@ -154,6 +161,123 @@ func fakeServerEncrypted(t transport, responseData []byte, dec, enc cipher.AEAD, } } +// newBenchFile constructs a File wired through the production +// Share → treeConn → session → conn chain, so benchmarks can +// exercise readAt and other production code paths. The caller +// must set up c.session before calling this. +func newBenchFile(c *conn) *File { + tc := &treeConn{ + session: c.session, + } + + fs := &Share{ + treeConn: tc, + ctx: context.Background(), + } + + return &File{ + fs: fs, + fd: &smb2.FileId{}, + } +} + +func BenchmarkReadAt(b *testing.B) { + sizes := []struct { + name string + n int + }{ + {"1KB", 1 << 10}, + {"64KB", 1 << 16}, + {"1MB", 1 << 20}, + } + + for _, sz := range sizes { + b.Run("Plain/"+sz.name, func(b *testing.B) { + clientConn, serverConn := net.Pipe() + c, cleanup := newBenchConn(clientConn) + defer cleanup() + + c.session = &session{ + conn: c, + treeConnTables: make(map[uint32]*treeConn), + sessionFlags: smb2.SMB2_SESSION_FLAG_IS_GUEST, + } + + responseData := make([]byte, sz.n) + go fakeServer(direct(serverConn), responseData) + + f := newBenchFile(c) + buf := make([]byte, sz.n) + + b.SetBytes(int64(sz.n)) + b.ReportAllocs() + b.ResetTimer() + + for b.Loop() { + n, err := f.readAt(buf, 0) + if err != nil { + b.Fatal(err) + } + if n != sz.n { + b.Fatalf("short read: %d != %d", n, sz.n) + } + } + }) + } + + for _, sz := range sizes { + b.Run("Encrypted/"+sz.name, func(b *testing.B) { + clientConn, serverConn := net.Pipe() + c, cleanup := newBenchConn(clientConn) + defer cleanup() + + keyC2S := make([]byte, 16) + keyS2C := make([]byte, 16) + if _, err := rand.Read(keyC2S); err != nil { + panic(err) + } + if _, err := rand.Read(keyS2C); err != nil { + panic(err) + } + + c.session = &session{ + conn: c, + treeConnTables: make(map[uint32]*treeConn), + sessionFlags: smb2.SMB2_SESSION_FLAG_ENCRYPT_DATA, + sessionId: 0xdeadbeef, + encrypter: newGCM(keyC2S), + decrypter: newGCM(keyS2C), + } + c.enableSession() + + responseData := make([]byte, sz.n) + go fakeServerEncrypted( + direct(serverConn), responseData, + newGCM(keyC2S), + newGCM(keyS2C), + 0xdeadbeef, + ) + + f := newBenchFile(c) + buf := make([]byte, sz.n) + + b.SetBytes(int64(sz.n)) + b.ReportAllocs() + b.ResetTimer() + + for b.Loop() { + n, err := f.readAt(buf, 0) + if err != nil { + b.Fatal(err) + } + if n != sz.n { + b.Fatalf("short read: %d != %d", n, sz.n) + } + } + }) + } +} + func BenchmarkRoundTrip(b *testing.B) { sizes := []struct { name string @@ -180,7 +304,7 @@ func BenchmarkRoundTrip(b *testing.B) { b.ReportAllocs() b.ResetTimer() - for i := 0; i < b.N; i++ { + for b.Loop() { req := &smb2.ReadRequest{ Length: uint32(sz.n), Offset: 0, @@ -195,6 +319,7 @@ func BenchmarkRoundTrip(b *testing.B) { if _, err := c.recv(rr); err != nil { b.Fatal(err) } + rr.freeRecvBuf() } }) } @@ -242,7 +367,7 @@ func BenchmarkRoundTrip(b *testing.B) { b.ReportAllocs() b.ResetTimer() - for i := 0; i < b.N; i++ { + for b.Loop() { req := &smb2.ReadRequest{ Length: uint32(sz.n), Offset: 0, @@ -257,6 +382,7 @@ func BenchmarkRoundTrip(b *testing.B) { if _, err := c.recv(rr); err != nil { b.Fatal(err) } + rr.freeRecvBuf() } }) } diff --git a/session.go b/session.go index 23a8852..36e03d6 100644 --- a/session.go +++ b/session.go @@ -392,10 +392,15 @@ func (s *session) encrypt(pkt, c []byte) ([]byte, error) { return c, nil } -func (s *session) decrypt(pkt []byte) ([]byte, error) { +// decrypt decrypts an SMB3 transform packet. c must be at least +// len(EncryptedData)+len(Signature) bytes; decrypt copies the +// ciphertext and tag into c and decrypts in-place. +func (s *session) decrypt(pkt, c []byte) ([]byte, error) { t := smb2.TransformCodec(pkt) - c := append(t.EncryptedData(), t.Signature()...) + c = c[:0] + c = append(c, t.EncryptedData()...) + c = append(c, t.Signature()...) return s.decrypter.Open( c[:0],