Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 103 additions & 1 deletion client.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@ package smb2

import (
"context"
"errors"
"fmt"
"io"
"io/fs"
"math/rand"
"net"
"os"
Expand Down Expand Up @@ -916,6 +918,26 @@ func (fs *Share) ReadDir(dirname string) ([]os.FileInfo, error) {
return fis, nil
}

// ReadDirPlus returns all directory entries enriched with security descriptors,
// sorted by name. For directories with many entries, prefer opening the
// directory and calling File.ReaddirPlus incrementally.
func (fs *Share) ReadDirPlus(dirname string, securityInfo SecurityInformationRequestFlags) ([]DirEntryPlus, error) {
f, err := fs.Open(dirname)
if err != nil {
return nil, err
}
defer f.Close()

entries, err := f.ReaddirPlus(-1, securityInfo)
if err != nil {
return nil, err
}

sort.Slice(entries, func(i, j int) bool { return entries[i].Name() < entries[j].Name() })

return entries, nil
}

const (
intSize = 32 << (^uint(0) >> 63) // 32 or 64
maxInt = 1<<(intSize-1) - 1
Expand Down Expand Up @@ -1465,7 +1487,7 @@ func (f *File) readAtChunk(n int, off int64) (bs []byte, isEOF bool, rr *request

req.CreditCharge = creditCharge

rr, err = f.fs.send(req, f.fs.ctx)
rr, err = f.fs.send(f.fs.ctx, req)
if err != nil {
return nil, false, nil, err
}
Expand Down Expand Up @@ -1537,6 +1559,71 @@ func (f *File) Readdir(n int) (fi []os.FileInfo, err error) {
return fi, nil
}

// ReaddirPlus reads directory entries enriched with security descriptors.
// Like Readdir(n), it is incremental — each call returns the next n entries
// with their security descriptors. Returns io.EOF after the last entry.
//
// Security queries use SMB2 compound requests (CREATE+QUERY_INFO+CLOSE batched
// into single round-trips), sub-batching internally to respect credit limits.
func (f *File) ReaddirPlus(n int, securityInfo SecurityInformationRequestFlags) ([]DirEntryPlus, error) {
fi, err := f.Readdir(n)
if len(fi) == 0 {
return nil, err
}

// Build relative paths for the compound security query.
// File names from Readdir are bare names; prefix with the directory path.
dir := f.name
paths := make([]string, len(fi))
for i, info := range fi {
if dir == "" {
paths[i] = info.Name()
} else {
paths[i] = dir + `\` + info.Name()
}
}

tc := f.fs.treeConn
secResults, secErr := tc.compoundSecurityInfoBatch(
paths, uint32(securityInfo), f.mapping, f.fs.ctx,
)
if secErr != nil {
// If the compound batch itself failed, return entries without security info
// and propagate the batch error on the first entry.
entries := make([]DirEntryPlus, len(fi))
for i, info := range fi {
entries[i] = DirEntryPlus{FileInfo: info, Err: secErr}
}
return entries, errors.Join(err, secErr)
}

entries := make([]DirEntryPlus, 0, len(fi))
for i, info := range fi {
var sd *sddl.SecurityDescriptor
var err error
if secResults[i].err != nil {
Comment thread
arashpayan marked this conversation as resolved.
if isFileDeleted(secResults[i].err) {
continue // file deleted between Readdir and security query
}
err = secResults[i].err
} else if secResults[i].data != nil {
var parseErr error
sd, parseErr = sddl.FromBinary(secResults[i].data)
if parseErr != nil {
sd = nil // belt-and-suspenders assignment
err = fmt.Errorf("parsing security descriptor for %s: %w", info.Name(), parseErr)
}
}
entries = append(entries, DirEntryPlus{
FileInfo: info,
SecurityDescriptor: sd,
Err: err,
})
}

return entries, err
}

func (f *File) Readdirnames(n int) (names []string, err error) {
fi, err := f.Readdir(n)
if err != nil {
Expand Down Expand Up @@ -2418,3 +2505,18 @@ func (fs *FileStat) IsDir() bool {
func (fs *FileStat) Sys() any {
return fs
}

// DirEntryPlus extends os.FileInfo with security metadata obtained via
// compound requests. It implements fs.DirEntry.
type DirEntryPlus struct {
os.FileInfo
SecurityDescriptor *sddl.SecurityDescriptor

Err error // non-nil if the security query failed for this entry
}

var _ fs.DirEntry = &DirEntryPlus{}

func (d *DirEntryPlus) Type() os.FileMode { return d.FileInfo.Mode().Type() }
func (d *DirEntryPlus) Info() (os.FileInfo, error) { return d.FileInfo, nil }
func (d *DirEntryPlus) Name() string { return d.FileInfo.Name() }
206 changes: 201 additions & 5 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -317,12 +317,13 @@ type conn struct {

account *account

rdone chan struct{}
wdone chan struct{}
write chan []byte
werr chan error
encodeBuf []byte // retained request encoding buffer; reused under conn.m
rdone chan struct{}
wdone chan struct{}
write chan []byte
werr chan error
recvPool sync.Pool // reusable receive buffers
encodeBuf []byte // retained request encoding buffer; reused under conn.m
compoundSizes []int // retained sizes buffer for compound requests; reused under conn.m

m sync.Mutex

Expand Down Expand Up @@ -427,6 +428,201 @@ func (conn *conn) sendWith(ctx context.Context, req smb2.Packet, tc *treeConn) (
return rr, nil
}

// compoundEntry describes a single request within a compound.
type compoundEntry struct {
req smb2.Packet
tc *treeConn
related bool // set SMB2_FLAGS_RELATED_OPERATIONS on this request
}

// sendCompound serializes multiple SMB2 requests into a single transport frame
// and sends them as a compound request. Related entries share a file handle via
// the sentinel FileId. Returns one requestResponse per entry for receiving
// individual responses.
//
// Caller must have already set CreditCharge on each entry's header (via loanCredit).
func (conn *conn) sendCompound(ctx context.Context, entries []compoundEntry) ([]*requestResponse, error) {
conn.m.Lock()
defer conn.m.Unlock()

if conn.err != nil {
return nil, conn.err
}

select {
case <-ctx.Done():
return nil, ctx.Err()
default:
}

n := len(entries)

// Reuse retained sizes buffer.
if cap(conn.compoundSizes) < n {
conn.compoundSizes = make([]int, n)
}
sizes := conn.compoundSizes[:n]

// Phase 1: Set header fields and compute sizes.
var totalCreditCharge uint16

for i, entry := range entries {
hdr := entry.req.Header()

msgId := conn.sequenceWindow
creditCharge := hdr.CreditCharge
conn.sequenceWindow += uint64(creditCharge)
totalCreditCharge += creditCharge

hdr.MessageId = msgId

// Only the last request asks for new credits.
if i < n-1 {
hdr.CreditRequestResponse = 0
} else {
if hdr.CreditRequestResponse == 0 {
hdr.CreditRequestResponse = totalCreditCharge
}
hdr.CreditRequestResponse += conn.account.opening()
}

if entry.related {
hdr.Flags |= smb2.SMB2_FLAGS_RELATED_OPERATIONS
}

s := conn.session
if s != nil {
hdr.SessionId = s.sessionId
if entry.tc != nil {
hdr.TreeId = entry.tc.treeId
}
}

sizes[i] = entry.req.Size()
}

// Compute total buffer size with 8-byte alignment between requests.
totalSize := 0
for i, sz := range sizes {
if i < n-1 {
totalSize += (sz + 7) &^ 7
} else {
totalSize += sz
}
}

// Phase 2: Encode into conn.encodeBuf (retained compound buffer).
if cap(conn.encodeBuf) < totalSize {
conn.encodeBuf = make([]byte, totalSize)
}
compound := conn.encodeBuf[:totalSize]
clear(compound)

offset := 0
for i, entry := range entries {
pkt := compound[offset : offset+sizes[i]]
entry.req.Encode(pkt)

if i < n-1 {
aligned := (sizes[i] + 7) &^ 7
smb2.PacketCodec(pkt).SetNextCommand(uint32(aligned))
offset += aligned
} else {
offset += sizes[i]
}
}

// Phase 3: Sign or encrypt.
wirePkt := compound

s := conn.session
if s != nil {
encrypt := s.sessionFlags&smb2.SMB2_SESSION_FLAG_ENCRYPT_DATA != 0
if !encrypt {
for _, entry := range entries {
if entry.tc != nil && entry.tc.shareFlags&smb2.SMB2_SHAREFLAG_ENCRYPT_DATA != 0 {
encrypt = true
break
}
}
}

if encrypt {
// Encrypt the entire compound as one unit using s.encryptBuf.
needed := 52 + len(compound) + 16
if cap(s.encryptBuf) < needed {
s.encryptBuf = make([]byte, needed)
}
clear(s.encryptBuf[:needed])
var err error
wirePkt, err = s.encrypt(compound, s.encryptBuf[:needed])
if err != nil {
return nil, &InternalError{err.Error()}
}
} else if s.sessionFlags&(smb2.SMB2_SESSION_FLAG_IS_GUEST|smb2.SMB2_SESSION_FLAG_IS_NULL) == 0 {
// Sign each packet individually in-place.
// Per MS-SMB2 3.3.5.2.4, the server uses the NextCommand value as the
// message length for signature verification (8-byte aligned size), so
// non-last entries must be signed over the aligned length including padding.
off := 0
for i := range entries {
signLen := sizes[i]
if i < n-1 {
signLen = (sizes[i] + 7) &^ 7
}
pkt := compound[off : off+signLen]
s.sign(pkt)
off += signLen
}
}
}

// Phase 4: Register all requestResponses and send.
rrs := make([]*requestResponse, n)
off := 0
for i := range entries {
p := smb2.PacketCodec(compound[off : off+sizes[i]])
rrs[i] = &requestResponse{
msgId: p.MessageId(),
creditRequest: p.CreditRequest(),
ctx: ctx,
recv: make(chan []byte, 1),
}
conn.outstandingRequests.set(rrs[i].msgId, rrs[i])

if i < n-1 {
off += (sizes[i] + 7) &^ 7
} else {
off += sizes[i]
}
}

select {
case conn.write <- wirePkt:
select {
case err := <-conn.werr:
if err != nil {
for _, rr := range rrs {
conn.outstandingRequests.pop(rr.msgId)
}
return nil, &TransportError{err}
}
case <-ctx.Done():
for _, rr := range rrs {
conn.outstandingRequests.pop(rr.msgId)
}
return nil, ctx.Err()
}
case <-ctx.Done():
for _, rr := range rrs {
conn.outstandingRequests.pop(rr.msgId)
}
return nil, ctx.Err()
}

return rrs, nil
}

func (conn *conn) makeRequestResponse(ctx context.Context, req smb2.Packet, tc *treeConn) (rr *requestResponse, err error) {
hdr := req.Header()

Expand Down
10 changes: 10 additions & 0 deletions errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package smb2
import (
"errors"
"fmt"
"os"

"github.com/cloudsoda/go-smb2/internal/erref"
)
Expand Down Expand Up @@ -50,3 +51,12 @@ type ResponseError struct {
func (err *ResponseError) Error() string {
return fmt.Sprintf("response error: %v", erref.NtStatus(err.Code))
}

// isFileDeleted reports whether err indicates a file was deleted or is pending
// deletion. This is used to silently skip directory entries that vanish between
// Readdir and a subsequent compound security query.
func isFileDeleted(err error) bool {
var re *ResponseError
return errors.Is(err, os.ErrNotExist) ||
(errors.As(err, &re) && re.Code == uint32(erref.STATUS_DELETE_PENDING))
}
Loading
Loading