Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 22 additions & 8 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -1406,7 +1406,7 @@ func (f *File) readAt(b []byte, off int64) (n int, err error) {
case len(b)-n == 0:
return n, nil
case len(b)-n <= maxReadSize:
bs, isEOF, err := f.readAtChunk(len(b)-n, int64(n)+off)
bs, isEOF, rr, err := f.readAtChunk(len(b)-n, int64(n)+off)
if err != nil {
if err, ok := err.(*ResponseError); ok && erref.NtStatus(err.Code) == erref.STATUS_END_OF_FILE && n != 0 {
return n, nil
Expand All @@ -1415,12 +1415,13 @@ func (f *File) readAt(b []byte, off int64) (n int, err error) {
}

n += copy(b[n:], bs)
rr.freeRecvBuf()

if isEOF {
return n, nil
}
default:
bs, isEOF, err := f.readAtChunk(maxReadSize, int64(n)+off)
bs, isEOF, rr, err := f.readAtChunk(maxReadSize, int64(n)+off)
if err != nil {
if err, ok := err.(*ResponseError); ok && erref.NtStatus(err.Code) == erref.STATUS_END_OF_FILE && n != 0 {
return n, nil
Expand All @@ -1429,6 +1430,7 @@ func (f *File) readAt(b []byte, off int64) (n int, err error) {
}

n += copy(b[n:], bs)
rr.freeRecvBuf()

if isEOF {
return n, nil
Expand All @@ -1437,15 +1439,15 @@ func (f *File) readAt(b []byte, off int64) (n int, err error) {
}
}

func (f *File) readAtChunk(n int, off int64) (bs []byte, isEOF bool, err error) {
func (f *File) readAtChunk(n int, off int64) (bs []byte, isEOF bool, rr *requestResponse, err error) {
creditCharge, m, err := f.fs.loanCredit(n)
defer func() {
if err != nil {
f.fs.chargeCredit(creditCharge)
}
}()
if err != nil {
return nil, false, err
return nil, false, nil, err
}

req := &smb2.ReadRequest{
Expand All @@ -1463,19 +1465,31 @@ func (f *File) readAtChunk(n int, off int64) (bs []byte, isEOF bool, err error)

req.CreditCharge = creditCharge

res, err := f.sendRecv(smb2.SMB2_READ, req)
rr, err = f.fs.send(req, f.fs.ctx)
if err != nil {
return nil, false, err
return nil, false, nil, err
}

pkt, err := f.fs.recv(rr)
if err != nil {
return nil, false, nil, err
}

res, err := accept(smb2.SMB2_READ, pkt)
if err != nil {
rr.freeRecvBuf()
return nil, false, nil, err
}

r := smb2.ReadResponseDecoder(res)
if r.IsInvalid() {
return nil, false, &InvalidResponseError{"broken read response format"}
rr.freeRecvBuf()
return nil, false, nil, &InvalidResponseError{"broken read response format"}
}

bs = r.Data()

return bs, len(bs) < m, nil
return bs, len(bs) < m, rr, nil
}

func (f *File) Readdir(n int) (fi []os.FileInfo, err error) {
Expand Down
145 changes: 112 additions & 33 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@ import (
"github.com/cloudsoda/go-smb2/internal/smb2"
)

// length of tag used to verify the integrity of the encrypted data
const AES_AUTH_TAG_LEN = 16

// Negotiator contains options for func (*Dialer) Dial.
type Negotiator struct {
RequireMessageSigning bool // enforce signing?
Expand Down Expand Up @@ -98,7 +101,7 @@ func (n *Negotiator) negotiate(t transport, a *account, ctx context.Context) (*c
}

go conn.runSender()
go conn.runReciever()
go conn.runReceiver()

retry:
req, err := n.makeRequest()
Expand Down Expand Up @@ -227,6 +230,12 @@ retry:
return conn, nil
}

// recvBuf wraps a pooled receive buffer. Stored as *recvBuf in sync.Pool
// so the pointer fits directly in the interface without boxing allocations.
type recvBuf struct {
b []byte
}

type requestResponse struct {
msgId uint64
asyncId uint64
Expand All @@ -235,6 +244,17 @@ type requestResponse struct {
ctx context.Context
recv chan []byte
err error
rb *recvBuf // pooled receive buffer wrapper; return via freeRecvBuf
bufPool *sync.Pool // pool to return rb to
}

// freeRecvBuf returns the pooled receive buffer, if any. Safe to call
// multiple times or when rb is nil.
func (rr *requestResponse) freeRecvBuf() {
if rr.bufPool != nil && rr.rb != nil {
rr.bufPool.Put(rr.rb)
rr.rb = nil
}
}

type outstandingRequests struct {
Expand Down Expand Up @@ -301,7 +321,8 @@ type conn struct {
wdone chan struct{}
write chan []byte
werr chan error
encodeBuf []byte // retained request encoding buffer; reused under conn.m
encodeBuf []byte // retained request encoding buffer; reused under conn.m
recvPool sync.Pool // reusable receive buffers

m sync.Mutex

Expand Down Expand Up @@ -448,7 +469,7 @@ func (conn *conn) makeRequestResponse(req smb2.Packet, tc *treeConn, ctx context
if s != nil {
if _, ok := req.(*smb2.SessionSetupRequest); !ok {
if s.sessionFlags&smb2.SMB2_SESSION_FLAG_ENCRYPT_DATA != 0 || (tc != nil && tc.shareFlags&smb2.SMB2_SHAREFLAG_ENCRYPT_DATA != 0) {
needed := 52 + len(pkt) + 16
needed := 52 + len(pkt) + AES_AUTH_TAG_LEN
if cap(s.encryptBuf) < needed {
s.encryptBuf = make([]byte, needed)
}
Expand Down Expand Up @@ -505,7 +526,7 @@ func (conn *conn) runSender() {
}
}

func (conn *conn) runReciever() {
func (conn *conn) runReceiver() {
var err error

for {
Expand All @@ -516,10 +537,16 @@ func (conn *conn) runReciever() {
goto exit
}

pkt := make([]byte, n)
rb, ok := conn.recvPool.Get().(*recvBuf)
if !ok || cap(rb.b) < n {
rb = &recvBuf{b: make([]byte, n)}
}
pkt := rb.b[:n]

_, e = conn.t.Read(pkt)
if e != nil {
conn.freePoolBuf(rb)

err = &TransportError{e}

goto exit
Expand All @@ -530,23 +557,37 @@ func (conn *conn) runReciever() {
var isEncrypted bool

if hasSession {
pkt, e, isEncrypted = conn.tryDecrypt(pkt)
var pRb *recvBuf
pkt, pRb, e, isEncrypted = conn.tryDecrypt(pkt)
if e != nil {
conn.freePoolBuf(rb)

logger.Println("skip:", e)

continue
}

if isEncrypted {
// Decrypt produced a new plaintext buffer; the
// original ciphertext buffer can be reused now.
conn.freePoolBuf(rb)
rb = pRb
}

p := smb2.PacketCodec(pkt)
if s := conn.session; s != nil {
if s.sessionId != p.SessionId() {
conn.freePoolBuf(rb)

logger.Println("skip:", &InvalidResponseError{"unknown session id"})

continue
}

if tc, ok := s.treeConnTables[p.TreeId()]; ok {
if tc.treeId != p.TreeId() {
conn.freePoolBuf(rb)

logger.Println("skip:", &InvalidResponseError{"unknown tree id"})

continue
Expand All @@ -555,31 +596,43 @@ func (conn *conn) runReciever() {
}
}

var next []byte

for {
p := smb2.PacketCodec(pkt)

if off := p.NextCommand(); off != 0 {
pkt, next = pkt[:off], pkt[off:]
} else {
next = nil
}

p := smb2.PacketCodec(pkt)
if p.NextCommand() == 0 {
// Single response: transfer the pooled buffer to the caller.
if hasSession {
e = conn.tryVerify(pkt, isEncrypted)
}

e = conn.tryHandle(pkt, e)
if e != nil {
if e = conn.tryHandle(pkt, e, rb); e != nil {
logger.Println("skip:", e)
}
} else {
// Compound response: sub-responses share the underlying
// buffer, so we cannot transfer ownership to any one caller.
// The buffer is intentionally not returned to the pool; it
// will be GC'd once all consumers finish with their pkt slices.

var next []byte
for {
if off := p.NextCommand(); off != 0 {
pkt, next = pkt[:off], pkt[off:]
} else {
next = nil
}

if next == nil {
break
}
if hasSession {
e = conn.tryVerify(pkt, isEncrypted)
}
if e = conn.tryHandle(pkt, e, nil); e != nil {
logger.Println("skip:", e)
}

pkt = next
if next == nil {
break
}

pkt = next
p = smb2.PacketCodec(pkt)
}
}
}

Expand Down Expand Up @@ -683,31 +736,40 @@ func acceptError(status uint32, res []byte) error {
return &ResponseError{Code: status, data: [][]byte{eData}}
}

func (conn *conn) tryDecrypt(pkt []byte) ([]byte, error, bool) {
func (conn *conn) tryDecrypt(pkt []byte) ([]byte, *recvBuf, error, bool) {
p := smb2.PacketCodec(pkt)
if p.IsInvalid() {
t := smb2.TransformCodec(pkt)
if t.IsInvalid() {
return nil, &InvalidResponseError{"broken packet header format"}, false
return nil, nil, &InvalidResponseError{"broken packet header format"}, false
}

if t.Flags() != smb2.Encrypted {
return nil, &InvalidResponseError{"encrypted flag is not on"}, false
return nil, nil, &InvalidResponseError{"encrypted flag is not on"}, false
}

if conn.session == nil || conn.session.sessionId != t.SessionId() {
return nil, &InvalidResponseError{"unknown session id returned"}, false
return nil, nil, &InvalidResponseError{"unknown session id returned"}, false
}

// Get a pooled buffer for the decrypt work-buffer (ciphertext + tag).
cLen := len(t.EncryptedData()) + AES_AUTH_TAG_LEN
pRb, ok := conn.recvPool.Get().(*recvBuf)
if !ok || cap(pRb.b) < cLen {
pRb = &recvBuf{b: make([]byte, cLen)}
}
c := pRb.b[:cLen]

pkt, err := conn.session.decrypt(pkt)
pkt, err := conn.session.decrypt(pkt, c)
if err != nil {
return nil, &InvalidResponseError{err.Error()}, false
conn.freePoolBuf(pRb)
return nil, nil, &InvalidResponseError{err.Error()}, false
}

return pkt, nil, true
return pkt, pRb, nil, true
}

return pkt, nil, false
return pkt, nil, nil, false
}

func (conn *conn) tryVerify(pkt []byte, isEncrypted bool) error {
Expand Down Expand Up @@ -740,28 +802,45 @@ func (conn *conn) tryVerify(pkt []byte, isEncrypted bool) error {
return nil
}

func (conn *conn) tryHandle(pkt []byte, e error) error {
func (conn *conn) tryHandle(pkt []byte, e error, rb *recvBuf) error {
p := smb2.PacketCodec(pkt)

msgId := p.MessageId()

rr, ok := conn.outstandingRequests.pop(msgId)
switch {
case !ok:
conn.freePoolBuf(rb)
return &InvalidResponseError{"unknown message id returned"}
case e != nil:
rr.err = e
conn.freePoolBuf(rb)

close(rr.recv)
case erref.NtStatus(p.Status()) == erref.STATUS_PENDING:
rr.asyncId = p.AsyncId()
conn.account.charge(p.CreditResponse(), rr.creditRequest)
conn.outstandingRequests.set(msgId, rr)
conn.freePoolBuf(rb)
default:
conn.account.charge(p.CreditResponse(), rr.creditRequest)

// Transfer ownership of the pooled receive buffer to the
// requestResponse so the caller can return it via freeRecvBuf
// after it has finished reading the response packet. (the
// error cases in this switch statement all free the buffer
// immediately)
rr.rb = rb
rr.bufPool = &conn.recvPool

rr.recv <- pkt
}

return nil
}

func (conn *conn) freePoolBuf(rb *recvBuf) {
if rb != nil {
conn.recvPool.Put(rb)
}
}
Loading
Loading