diff --git a/client.go b/client.go index 2968ada..5712c83 100644 --- a/client.go +++ b/client.go @@ -2,8 +2,10 @@ package smb2 import ( "context" + "errors" "fmt" "io" + "io/fs" "math/rand" "net" "os" @@ -916,6 +918,26 @@ func (fs *Share) ReadDir(dirname string) ([]os.FileInfo, error) { return fis, nil } +// ReadDirPlus returns all directory entries enriched with security descriptors, +// sorted by name. For directories with many entries, prefer opening the +// directory and calling File.ReaddirPlus incrementally. +func (fs *Share) ReadDirPlus(dirname string, securityInfo SecurityInformationRequestFlags) ([]DirEntryPlus, error) { + f, err := fs.Open(dirname) + if err != nil { + return nil, err + } + defer f.Close() + + entries, err := f.ReaddirPlus(-1, securityInfo) + if err != nil { + return nil, err + } + + sort.Slice(entries, func(i, j int) bool { return entries[i].Name() < entries[j].Name() }) + + return entries, nil +} + const ( intSize = 32 << (^uint(0) >> 63) // 32 or 64 maxInt = 1<<(intSize-1) - 1 @@ -1465,7 +1487,7 @@ func (f *File) readAtChunk(n int, off int64) (bs []byte, isEOF bool, rr *request req.CreditCharge = creditCharge - rr, err = f.fs.send(req, f.fs.ctx) + rr, err = f.fs.send(f.fs.ctx, req) if err != nil { return nil, false, nil, err } @@ -1537,6 +1559,71 @@ func (f *File) Readdir(n int) (fi []os.FileInfo, err error) { return fi, nil } +// ReaddirPlus reads directory entries enriched with security descriptors. +// Like Readdir(n), it is incremental — each call returns the next n entries +// with their security descriptors. Returns io.EOF after the last entry. +// +// Security queries use SMB2 compound requests (CREATE+QUERY_INFO+CLOSE batched +// into single round-trips), sub-batching internally to respect credit limits. +func (f *File) ReaddirPlus(n int, securityInfo SecurityInformationRequestFlags) ([]DirEntryPlus, error) { + fi, err := f.Readdir(n) + if len(fi) == 0 { + return nil, err + } + + // Build relative paths for the compound security query. + // File names from Readdir are bare names; prefix with the directory path. + dir := f.name + paths := make([]string, len(fi)) + for i, info := range fi { + if dir == "" { + paths[i] = info.Name() + } else { + paths[i] = dir + `\` + info.Name() + } + } + + tc := f.fs.treeConn + secResults, secErr := tc.compoundSecurityInfoBatch( + paths, uint32(securityInfo), f.mapping, f.fs.ctx, + ) + if secErr != nil { + // If the compound batch itself failed, return entries without security info + // and propagate the batch error on the first entry. + entries := make([]DirEntryPlus, len(fi)) + for i, info := range fi { + entries[i] = DirEntryPlus{FileInfo: info, Err: secErr} + } + return entries, errors.Join(err, secErr) + } + + entries := make([]DirEntryPlus, 0, len(fi)) + for i, info := range fi { + var sd *sddl.SecurityDescriptor + var err error + if secResults[i].err != nil { + if isFileDeleted(secResults[i].err) { + continue // file deleted between Readdir and security query + } + err = secResults[i].err + } else if secResults[i].data != nil { + var parseErr error + sd, parseErr = sddl.FromBinary(secResults[i].data) + if parseErr != nil { + sd = nil // belt-and-suspenders assignment + err = fmt.Errorf("parsing security descriptor for %s: %w", info.Name(), parseErr) + } + } + entries = append(entries, DirEntryPlus{ + FileInfo: info, + SecurityDescriptor: sd, + Err: err, + }) + } + + return entries, err +} + func (f *File) Readdirnames(n int) (names []string, err error) { fi, err := f.Readdir(n) if err != nil { @@ -2418,3 +2505,18 @@ func (fs *FileStat) IsDir() bool { func (fs *FileStat) Sys() any { return fs } + +// DirEntryPlus extends os.FileInfo with security metadata obtained via +// compound requests. It implements fs.DirEntry. +type DirEntryPlus struct { + os.FileInfo + SecurityDescriptor *sddl.SecurityDescriptor + + Err error // non-nil if the security query failed for this entry +} + +var _ fs.DirEntry = &DirEntryPlus{} + +func (d *DirEntryPlus) Type() os.FileMode { return d.FileInfo.Mode().Type() } +func (d *DirEntryPlus) Info() (os.FileInfo, error) { return d.FileInfo, nil } +func (d *DirEntryPlus) Name() string { return d.FileInfo.Name() } diff --git a/conn.go b/conn.go index 3b1186b..d785206 100644 --- a/conn.go +++ b/conn.go @@ -317,12 +317,13 @@ type conn struct { account *account - rdone chan struct{} - wdone chan struct{} - write chan []byte - werr chan error - encodeBuf []byte // retained request encoding buffer; reused under conn.m + rdone chan struct{} + wdone chan struct{} + write chan []byte + werr chan error recvPool sync.Pool // reusable receive buffers + encodeBuf []byte // retained request encoding buffer; reused under conn.m + compoundSizes []int // retained sizes buffer for compound requests; reused under conn.m m sync.Mutex @@ -427,6 +428,201 @@ func (conn *conn) sendWith(ctx context.Context, req smb2.Packet, tc *treeConn) ( return rr, nil } +// compoundEntry describes a single request within a compound. +type compoundEntry struct { + req smb2.Packet + tc *treeConn + related bool // set SMB2_FLAGS_RELATED_OPERATIONS on this request +} + +// sendCompound serializes multiple SMB2 requests into a single transport frame +// and sends them as a compound request. Related entries share a file handle via +// the sentinel FileId. Returns one requestResponse per entry for receiving +// individual responses. +// +// Caller must have already set CreditCharge on each entry's header (via loanCredit). +func (conn *conn) sendCompound(ctx context.Context, entries []compoundEntry) ([]*requestResponse, error) { + conn.m.Lock() + defer conn.m.Unlock() + + if conn.err != nil { + return nil, conn.err + } + + select { + case <-ctx.Done(): + return nil, ctx.Err() + default: + } + + n := len(entries) + + // Reuse retained sizes buffer. + if cap(conn.compoundSizes) < n { + conn.compoundSizes = make([]int, n) + } + sizes := conn.compoundSizes[:n] + + // Phase 1: Set header fields and compute sizes. + var totalCreditCharge uint16 + + for i, entry := range entries { + hdr := entry.req.Header() + + msgId := conn.sequenceWindow + creditCharge := hdr.CreditCharge + conn.sequenceWindow += uint64(creditCharge) + totalCreditCharge += creditCharge + + hdr.MessageId = msgId + + // Only the last request asks for new credits. + if i < n-1 { + hdr.CreditRequestResponse = 0 + } else { + if hdr.CreditRequestResponse == 0 { + hdr.CreditRequestResponse = totalCreditCharge + } + hdr.CreditRequestResponse += conn.account.opening() + } + + if entry.related { + hdr.Flags |= smb2.SMB2_FLAGS_RELATED_OPERATIONS + } + + s := conn.session + if s != nil { + hdr.SessionId = s.sessionId + if entry.tc != nil { + hdr.TreeId = entry.tc.treeId + } + } + + sizes[i] = entry.req.Size() + } + + // Compute total buffer size with 8-byte alignment between requests. + totalSize := 0 + for i, sz := range sizes { + if i < n-1 { + totalSize += (sz + 7) &^ 7 + } else { + totalSize += sz + } + } + + // Phase 2: Encode into conn.encodeBuf (retained compound buffer). + if cap(conn.encodeBuf) < totalSize { + conn.encodeBuf = make([]byte, totalSize) + } + compound := conn.encodeBuf[:totalSize] + clear(compound) + + offset := 0 + for i, entry := range entries { + pkt := compound[offset : offset+sizes[i]] + entry.req.Encode(pkt) + + if i < n-1 { + aligned := (sizes[i] + 7) &^ 7 + smb2.PacketCodec(pkt).SetNextCommand(uint32(aligned)) + offset += aligned + } else { + offset += sizes[i] + } + } + + // Phase 3: Sign or encrypt. + wirePkt := compound + + s := conn.session + if s != nil { + encrypt := s.sessionFlags&smb2.SMB2_SESSION_FLAG_ENCRYPT_DATA != 0 + if !encrypt { + for _, entry := range entries { + if entry.tc != nil && entry.tc.shareFlags&smb2.SMB2_SHAREFLAG_ENCRYPT_DATA != 0 { + encrypt = true + break + } + } + } + + if encrypt { + // Encrypt the entire compound as one unit using s.encryptBuf. + needed := 52 + len(compound) + 16 + if cap(s.encryptBuf) < needed { + s.encryptBuf = make([]byte, needed) + } + clear(s.encryptBuf[:needed]) + var err error + wirePkt, err = s.encrypt(compound, s.encryptBuf[:needed]) + if err != nil { + return nil, &InternalError{err.Error()} + } + } else if s.sessionFlags&(smb2.SMB2_SESSION_FLAG_IS_GUEST|smb2.SMB2_SESSION_FLAG_IS_NULL) == 0 { + // Sign each packet individually in-place. + // Per MS-SMB2 3.3.5.2.4, the server uses the NextCommand value as the + // message length for signature verification (8-byte aligned size), so + // non-last entries must be signed over the aligned length including padding. + off := 0 + for i := range entries { + signLen := sizes[i] + if i < n-1 { + signLen = (sizes[i] + 7) &^ 7 + } + pkt := compound[off : off+signLen] + s.sign(pkt) + off += signLen + } + } + } + + // Phase 4: Register all requestResponses and send. + rrs := make([]*requestResponse, n) + off := 0 + for i := range entries { + p := smb2.PacketCodec(compound[off : off+sizes[i]]) + rrs[i] = &requestResponse{ + msgId: p.MessageId(), + creditRequest: p.CreditRequest(), + ctx: ctx, + recv: make(chan []byte, 1), + } + conn.outstandingRequests.set(rrs[i].msgId, rrs[i]) + + if i < n-1 { + off += (sizes[i] + 7) &^ 7 + } else { + off += sizes[i] + } + } + + select { + case conn.write <- wirePkt: + select { + case err := <-conn.werr: + if err != nil { + for _, rr := range rrs { + conn.outstandingRequests.pop(rr.msgId) + } + return nil, &TransportError{err} + } + case <-ctx.Done(): + for _, rr := range rrs { + conn.outstandingRequests.pop(rr.msgId) + } + return nil, ctx.Err() + } + case <-ctx.Done(): + for _, rr := range rrs { + conn.outstandingRequests.pop(rr.msgId) + } + return nil, ctx.Err() + } + + return rrs, nil +} + func (conn *conn) makeRequestResponse(ctx context.Context, req smb2.Packet, tc *treeConn) (rr *requestResponse, err error) { hdr := req.Header() diff --git a/errors.go b/errors.go index a6cf7c4..c2ca2a2 100644 --- a/errors.go +++ b/errors.go @@ -3,6 +3,7 @@ package smb2 import ( "errors" "fmt" + "os" "github.com/cloudsoda/go-smb2/internal/erref" ) @@ -50,3 +51,12 @@ type ResponseError struct { func (err *ResponseError) Error() string { return fmt.Sprintf("response error: %v", erref.NtStatus(err.Code)) } + +// isFileDeleted reports whether err indicates a file was deleted or is pending +// deletion. This is used to silently skip directory entries that vanish between +// Readdir and a subsequent compound security query. +func isFileDeleted(err error) bool { + var re *ResponseError + return errors.Is(err, os.ErrNotExist) || + (errors.As(err, &re) && re.Code == uint32(erref.STATUS_DELETE_PENDING)) +} diff --git a/errors_test.go b/errors_test.go new file mode 100644 index 0000000..3bd934e --- /dev/null +++ b/errors_test.go @@ -0,0 +1,31 @@ +package smb2 + +import ( + "fmt" + "os" + "testing" + + "github.com/cloudsoda/go-smb2/internal/erref" +) + +func TestIsFileDeleted(t *testing.T) { + tests := []struct { + name string + err error + want bool + }{ + {"ErrNotExist", os.ErrNotExist, true}, + {"wrapped ErrNotExist", fmt.Errorf("open foo: %w", os.ErrNotExist), true}, + {"STATUS_OBJECT_NAME_NOT_FOUND", os.ErrNotExist, true}, // accept maps this to ErrNotExist + {"STATUS_DELETE_PENDING", &ResponseError{Code: uint32(erref.STATUS_DELETE_PENDING)}, true}, + {"other ResponseError", &ResponseError{Code: uint32(erref.STATUS_ACCESS_DENIED)}, false}, + {"unrelated error", fmt.Errorf("something else"), false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := isFileDeleted(tt.err); got != tt.want { + t.Errorf("isFileDeleted(%v) = %v, want %v", tt.err, got, tt.want) + } + }) + } +} diff --git a/smb2_test.go b/smb2_test.go index 019cf0d..604dd7b 100644 --- a/smb2_test.go +++ b/smb2_test.go @@ -942,3 +942,152 @@ func TestSecurityDescriptor(t *testing.T) { t.Error("unexpected nil SD") } } + +func TestReaddirPlus(t *testing.T) { + if fs == nil { + t.Skip() + } + testDir := fmt.Sprintf("testDir-%d-TestReaddirPlus", os.Getpid()) + err := fs.Mkdir(testDir, 0755) + if err != nil { + t.Fatal(err) + } + defer func() { + _ = fs.RemoveAll(testDir) + }() + + // Create enough files to exercise multi-batch incremental reads. + fileNames := []string{ + "alpha.txt", "bravo.txt", "charlie.txt", "delta.txt", + "echo.txt", "foxtrot.txt", "golf.txt", "hotel.txt", + "india.txt", "juliet.txt", + } + for _, name := range fileNames { + f, err := fs.Create(testDir + `\` + name) + if err != nil { + t.Fatal(err) + } + f.Close() + } + + flags := smb2.OwnerSecurityInformation | smb2.GroupSecurityInformation | smb2.DACLSecurityInformation + + t.Run("File.ReaddirPlus_all", func(t *testing.T) { + d, err := fs.Open(testDir) + if err != nil { + t.Fatal(err) + } + defer d.Close() + + entries, err := d.ReaddirPlus(-1, flags) + if err != nil { + t.Fatal(err) + } + + if len(entries) != len(fileNames) { + t.Fatalf("expected %d entries, got %d", len(fileNames), len(entries)) + } + + for _, e := range entries { + if e.Err != nil { + t.Errorf("entry %s: unexpected error: %v", e.Name(), e.Err) + continue + } + if e.SecurityDescriptor == nil { + t.Errorf("entry %s: expected non-nil SecurityDescriptor", e.Name()) + } + if e.IsDir() { + t.Errorf("entry %s: expected file, got directory", e.Name()) + } + } + + // Verify EOF on subsequent call. + entries2, err := d.ReaddirPlus(1, flags) + require.Equal(t, io.EOF, err) + require.Empty(t, entries2) + }) + + t.Run("File.ReaddirPlus_incremental", func(t *testing.T) { + d, err := fs.Open(testDir) + if err != nil { + t.Fatal(err) + } + defer d.Close() + + var all []smb2.DirEntryPlus + + // Read 3 at a time to exercise multiple batches. + const batchSize = 3 + for { + entries, err := d.ReaddirPlus(batchSize, flags) + all = append(all, entries...) + if err == io.EOF { + break + } + if err != nil { + t.Fatal(err) + } + if len(entries) == 0 || len(entries) > batchSize { + t.Fatalf("expected 1-%d entries, got %d", batchSize, len(entries)) + } + } + + if len(all) != len(fileNames) { + t.Fatalf("expected %d total entries, got %d", len(fileNames), len(all)) + } + + for _, e := range all { + if e.Err != nil { + t.Errorf("entry %s: unexpected error: %v", e.Name(), e.Err) + } + if e.SecurityDescriptor == nil { + t.Errorf("entry %s: expected non-nil SecurityDescriptor", e.Name()) + } + } + }) + + t.Run("Share.ReadDirPlus", func(t *testing.T) { + entries, err := fs.ReadDirPlus(testDir, flags) + if err != nil { + t.Fatal(err) + } + + if len(entries) != len(fileNames) { + t.Fatalf("expected %d entries, got %d", len(fileNames), len(entries)) + } + + // Share.ReadDirPlus sorts by name. + sortedNames := make([]string, len(fileNames)) + copy(sortedNames, fileNames) + sort.Strings(sortedNames) + + for i, name := range sortedNames { + if entries[i].Name() != name { + t.Errorf("entry %d: expected name %q, got %q", i, name, entries[i].Name()) + } + if entries[i].Err != nil { + t.Errorf("entry %s: unexpected error: %v", name, entries[i].Err) + continue + } + if entries[i].SecurityDescriptor == nil { + t.Errorf("entry %s: expected non-nil SecurityDescriptor", name) + } + } + }) + + t.Run("empty_directory", func(t *testing.T) { + emptyDir := testDir + `\empty` + err := fs.Mkdir(emptyDir, 0755) + if err != nil { + t.Fatal(err) + } + + entries, err := fs.ReadDirPlus(emptyDir, flags) + if err != nil { + t.Fatal(err) + } + if len(entries) != 0 { + t.Errorf("expected 0 entries, got %d", len(entries)) + } + }) +} diff --git a/tree_conn.go b/tree_conn.go index b2223ba..903c8c6 100644 --- a/tree_conn.go +++ b/tree_conn.go @@ -8,6 +8,183 @@ import ( "github.com/cloudsoda/go-smb2/internal/utf16le" ) +// sentinelFileId is used for related compound operations. +// The server interprets it as "use the handle from the preceding CREATE". +var sentinelFileId = &smb2.FileId{ + Persistent: [8]byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}, + Volatile: [8]byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}, +} + +// compoundSecurityResult holds the result of one file's security query. +type compoundSecurityResult struct { + data []byte // raw security descriptor bytes; nil on error + err error +} + +// compoundSecurityInfoBatch queries security descriptors for multiple files +// using compound CREATE+QUERY_INFO+CLOSE requests. It sub-batches internally +// based on available credits and returns one result per path (in order). +func (tc *treeConn) compoundSecurityInfoBatch( + paths []string, + securityInfo uint32, + mapping utf16le.MapChars, + ctx context.Context, +) ([]compoundSecurityResult, error) { + results := make([]compoundSecurityResult, len(paths)) + + // Determine access rights. + access := uint32(smb2.READ_CONTROL) + if securityInfo&smb2.SACL_SECUIRTY_INFORMATION != 0 { + access |= smb2.ACCESS_SYSTEM_SECURITY + } + + for off := 0; off < len(paths); { + remaining := len(paths) - off + + // Loan credits: 3 per file (CREATE + QUERY_INFO + CLOSE). + // Compute in int to avoid uint16 overflow on large directories, + // then clamp to the credit balance capacity. + wanted := min(remaining*3, cap(tc.account.balance)) + + granted, _, err := tc.account.loan(ctx, uint16(wanted)) + if err != nil { + return nil, err + } + + batchSize := int(granted / 3) + if batchSize == 0 { + tc.chargeCredit(granted) + return nil, &InternalError{"insufficient credits for compound request"} + } + if batchSize > remaining { + batchSize = remaining + } + + // Return excess credits. + if excess := granted - uint16(batchSize*3); excess > 0 { + tc.chargeCredit(excess) + } + + err = tc.sendSecurityBatch(paths[off:off+batchSize], results[off:off+batchSize], access, securityInfo, mapping, ctx) + if err != nil { + return nil, err + } + + off += batchSize + } + + return results, nil +} + +// sendSecurityBatch sends one compound batch and populates results. +func (tc *treeConn) sendSecurityBatch( + paths []string, + results []compoundSecurityResult, + access, securityInfo uint32, + mapping utf16le.MapChars, + ctx context.Context, +) error { + n := len(paths) + entries := make([]compoundEntry, n*3) + + for i, path := range paths { + base := i * 3 + + // CREATE — first in each related triplet. + entries[base] = compoundEntry{ + req: &smb2.CreateRequest{ + PacketHeader: smb2.PacketHeader{CreditCharge: 1}, + RequestedOplockLevel: smb2.SMB2_OPLOCK_LEVEL_NONE, + ImpersonationLevel: smb2.Impersonation, + DesiredAccess: access, + FileAttributes: smb2.FILE_ATTRIBUTE_NORMAL, + ShareAccess: smb2.FILE_SHARE_READ, + CreateDisposition: smb2.FILE_OPEN, + Name: path, + Mapping: mapping, + }, + tc: tc, + } + + // QUERY_INFO — related, uses sentinel FileId. + entries[base+1] = compoundEntry{ + req: &smb2.QueryInfoRequest{ + PacketHeader: smb2.PacketHeader{CreditCharge: 1}, + InfoType: smb2.SMB2_0_INFO_SECURITY, + FileInfoClass: 0, + OutputBufferLength: 64 * 1024, + AdditionalInformation: securityInfo, + FileId: sentinelFileId, + }, + tc: tc, + related: true, + } + + // CLOSE — related, uses sentinel FileId. + entries[base+2] = compoundEntry{ + req: &smb2.CloseRequest{ + PacketHeader: smb2.PacketHeader{CreditCharge: 1}, + FileId: sentinelFileId, + }, + tc: tc, + related: true, + } + } + + rrs, err := tc.sendCompound(ctx, entries) + if err != nil { + return err + } + + // Receive all responses. Each triplet: CREATE, QUERY_INFO, CLOSE. + for i := range paths { + base := i * 3 + + // CREATE response. + createPkt, createErr := tc.recv(rrs[base]) + if createErr != nil { + results[i].err = createErr + // Still drain QUERY_INFO and CLOSE responses. + tc.recv(rrs[base+1]) + tc.recv(rrs[base+2]) + continue + } + if _, createErr = accept(smb2.SMB2_CREATE, createPkt); createErr != nil { + results[i].err = createErr + tc.recv(rrs[base+1]) + tc.recv(rrs[base+2]) + continue + } + + // QUERY_INFO response — extract security descriptor. + qiPkt, qiErr := tc.recv(rrs[base+1]) + if qiErr != nil { + results[i].err = qiErr + tc.recv(rrs[base+2]) + continue + } + qiRes, qiErr := accept(smb2.SMB2_QUERY_INFO, qiPkt) + if qiErr != nil { + results[i].err = qiErr + tc.recv(rrs[base+2]) + continue + } + + r := smb2.QueryInfoResponseDecoder(qiRes) + if r.IsInvalid() { + results[i].err = &InvalidResponseError{"broken query info response format"} + tc.recv(rrs[base+2]) + continue + } + results[i].data = r.OutputBuffer() + + // CLOSE response — just drain it. + tc.recv(rrs[base+2]) + } + + return nil +} + type treeConn struct { *session treeId uint32