diff --git a/rowsaffected_test.go b/rowsaffected_test.go new file mode 100644 index 00000000..4e2e33cc --- /dev/null +++ b/rowsaffected_test.go @@ -0,0 +1,155 @@ +package mssql + +import ( + "context" + "fmt" + "math/rand" + "testing" +) + +// TestRowsAffectedWithTrigger verifies that RowsAffected returns the correct +// count for the outermost DML statement even when AFTER triggers fire and +// produce their own intermediate DONEINPROC row counts. See #204. +func TestRowsAffectedWithTrigger(t *testing.T) { + conn, logger := open(t) + defer conn.Close() + defer logger.StopLogging() + + ctx := context.Background() + + suffix := fmt.Sprintf("%d", rand.Intn(999999)) + tbl := "test_ra204_" + suffix + audit := "test_ra204_audit_" + suffix + trg := "tr_ra204_" + suffix + + cleanup := func() { + conn.ExecContext(ctx, "drop trigger if exists "+trg) + conn.ExecContext(ctx, "drop table if exists "+audit) + conn.ExecContext(ctx, "drop table if exists "+tbl) + } + cleanup() + defer cleanup() + + _, err := conn.ExecContext(ctx, "create table "+tbl+" (id int primary key, value nvarchar(100))") + if err != nil { + t.Fatal("create table failed:", err) + } + + _, err = conn.ExecContext(ctx, "insert into "+tbl+" values (1, 'old'), (2, 'old'), (3, 'old')") + if err != nil { + t.Fatal("insert failed:", err) + } + + // Scenario 1: Basic update without trigger + result, err := conn.ExecContext(ctx, "update "+tbl+" set value = 'test' where id = 1") + if err != nil { + t.Fatal("update failed:", err) + } + rowsAffected, _ := result.RowsAffected() + if rowsAffected != 1 { + t.Errorf("basic update: expected RowsAffected=1, got %d", rowsAffected) + } + + // Create audit table and trigger without NOCOUNT + _, err = conn.ExecContext(ctx, "create table "+audit+" (id int identity, action nvarchar(50))") + if err != nil { + t.Fatal("create audit table failed:", err) + } + _, err = conn.ExecContext(ctx, "create trigger "+trg+" on "+tbl+" after update as begin insert into "+audit+" (action) select 'updated' from inserted end") + if err != nil { + t.Fatal("create trigger failed:", err) + } + + // Scenario 2: Update with trigger (no NOCOUNT) - trigger produces extra DONEINPROC + result, err = conn.ExecContext(ctx, "update "+tbl+" set value = 'triggered' where id = 1") + if err != nil { + t.Fatal("triggered update failed:", err) + } + rowsAffected, _ = result.RowsAffected() + if rowsAffected != 1 { + t.Errorf("trigger without NOCOUNT: expected RowsAffected=1, got %d", rowsAffected) + } + + // Scenario 2b: Bulk update all rows with no-NOCOUNT trigger active. + // With the old += bug, the trigger's DONEINPROC (3 rows) plus the + // outer DONE (3 rows) would yield 6 instead of 3. + result, err = conn.ExecContext(ctx, "update "+tbl+" set value = 'bulk'") + if err != nil { + t.Fatal("bulk update without NOCOUNT failed:", err) + } + rowsAffected, _ = result.RowsAffected() + if rowsAffected != 3 { + t.Errorf("bulk update without NOCOUNT: expected RowsAffected=3, got %d", rowsAffected) + } + + // Scenario 3: Recreate trigger with NOCOUNT + conn.ExecContext(ctx, "drop trigger "+trg) + _, err = conn.ExecContext(ctx, "create trigger "+trg+" on "+tbl+" after update as begin set nocount on; insert into "+audit+" (action) select 'updated' from inserted end") + if err != nil { + t.Fatal("create nocount trigger failed:", err) + } + + result, err = conn.ExecContext(ctx, "update "+tbl+" set value = 'nocount' where id = 1") + if err != nil { + t.Fatal("nocount triggered update failed:", err) + } + rowsAffected, _ = result.RowsAffected() + if rowsAffected != 1 { + t.Errorf("trigger with NOCOUNT: expected RowsAffected=1, got %d", rowsAffected) + } + + // Scenario 4: Update all rows with trigger + result, err = conn.ExecContext(ctx, "update "+tbl+" set value = 'all'") + if err != nil { + t.Fatal("bulk update failed:", err) + } + rowsAffected, _ = result.RowsAffected() + if rowsAffected != 3 { + t.Errorf("bulk update with trigger: expected RowsAffected=3, got %d", rowsAffected) + } +} + +// TestMultiStatementBatchRowsAffected verifies that RowsAffected returns the +// last statement's count for a multi-statement batch. The assignment (=) +// semantics in doneStruct processing mean each DONE token replaces (not +// accumulates) the row count, so the final DONE is authoritative. +func TestMultiStatementBatchRowsAffected(t *testing.T) { + conn, logger := open(t) + defer conn.Close() + defer logger.StopLogging() + + ctx := context.Background() + + suffix := fmt.Sprintf("%d", rand.Intn(999999)) + tbl := "test_msbatch_" + suffix + + cleanup := func() { + conn.ExecContext(ctx, "drop table if exists "+tbl) + } + cleanup() + defer cleanup() + + _, err := conn.ExecContext(ctx, "create table "+tbl+" (id int primary key, value nvarchar(100))") + if err != nil { + t.Fatal("create table failed:", err) + } + + // Seed 5 rows. + _, err = conn.ExecContext(ctx, "insert into "+tbl+" values (1,'a'),(2,'b'),(3,'c'),(4,'d'),(5,'e')") + if err != nil { + t.Fatal("insert failed:", err) + } + + // Multi-statement batch: first UPDATE touches 3 rows, second touches 2. + // RowsAffected should be 2 (the last statement), not 5 (the sum). + batch := "update " + tbl + " set value='x' where id <= 3; " + + "update " + tbl + " set value='y' where id > 3" + result, err := conn.ExecContext(ctx, batch) + if err != nil { + t.Fatal("multi-statement batch failed:", err) + } + rowsAffected, _ := result.RowsAffected() + if rowsAffected != 2 { + t.Errorf("multi-statement batch: expected RowsAffected=2 (last stmt), got %d", rowsAffected) + } +} diff --git a/token.go b/token.go index 834cb265..546a222d 100644 --- a/token.go +++ b/token.go @@ -1238,11 +1238,20 @@ func (t *tokenProcessor) iterateResponse() error { t.lastRow = token case doneInProcStruct: if token.Status&doneCount != 0 { - t.rowCount += int64(token.RowCount) + // Assignment (not +=) mirrors the doneStruct logic. + // For RPC/sp_executesql with triggers, both trigger + // and outer statement counts arrive as DONEINPROC; + // assignment ensures the last one wins (#204). + t.rowCount = int64(token.RowCount) } case doneStruct: if token.Status&doneCount != 0 { - t.rowCount += int64(token.RowCount) + // Assignment (not +=) so the final DONE token's count + // is authoritative. Prevents double-counting when AFTER + // triggers fire without SET NOCOUNT ON (#204). + // For multi-statement batches, this means RowsAffected() + // returns the last statement's count, not the sum. + t.rowCount = int64(token.RowCount) } if token.isError() && t.firstError == nil { t.firstError = token.getError() diff --git a/token_test.go b/token_test.go index 49c63c5d..8f32ce2c 100644 --- a/token_test.go +++ b/token_test.go @@ -467,3 +467,74 @@ func TestNextToken_CancelDrainClosedChannelStartsSecondResponse(t *testing.T) { t.Fatal("expected attention packet to be written") } } + +// TestRowCountAssignmentNotAccumulation verifies that iterateResponse uses +// assignment (=) for doneStruct row counts, not accumulation (+=). This +// prevents double-counting when AFTER triggers fire without SET NOCOUNT ON. +func TestRowCountAssignmentNotAccumulation(t *testing.T) { + tokChan := make(chan tokenStruct, 10) + tp := &tokenProcessor{ + tokChan: tokChan, + ctx: context.Background(), + sess: &tdsSession{}, + } + + // Simulate trigger scenario: DONEINPROC from trigger (1 row), then + // final DONE from the outer UPDATE (also 1 row). With +=, rowCount + // would be 2. With =, it should be 1. + tokChan <- doneInProcStruct{Status: doneCount, RowCount: 1} + tokChan <- doneStruct{Status: doneFinal | doneCount, RowCount: 1} + close(tokChan) + + err := tp.iterateResponse() + assert.NoError(t, err) + assert.Equal(t, int64(1), tp.rowCount, + "rowCount should be 1 (assigned), not 2 (accumulated)") +} + +// TestRowCountMultiStatement verifies that for multi-statement batches, the +// final DONE token's count wins (last statement), not the sum. +func TestRowCountMultiStatement(t *testing.T) { + tokChan := make(chan tokenStruct, 10) + tp := &tokenProcessor{ + tokChan: tokChan, + ctx: context.Background(), + sess: &tdsSession{}, + } + + // First statement: DONE with 3 rows + tokChan <- doneStruct{Status: doneCount, RowCount: 3} + // Second statement: DONE with 2 rows (final) + tokChan <- doneStruct{Status: doneFinal | doneCount, RowCount: 2} + close(tokChan) + + err := tp.iterateResponse() + assert.NoError(t, err) + assert.Equal(t, int64(2), tp.rowCount, + "rowCount should be 2 (last statement), not 5 (sum)") +} + +// TestRowCountDoneInProcOnlyRPCPath verifies that the RPC/sp_executesql path +// (DONEINPROC-only) does not double-count when triggers fire. In this path, +// both trigger and outer statement counts arrive as DONEINPROC tokens, so +// assignment (=) must be used instead of accumulation (+=). +func TestRowCountDoneInProcOnlyRPCPath(t *testing.T) { + tokChan := make(chan tokenStruct, 10) + tp := &tokenProcessor{ + tokChan: tokChan, + ctx: context.Background(), + sess: &tdsSession{}, + } + + // Simulate RPC with trigger: trigger's INSERT (1 row), outer UPDATE + // (1 row), then DONEPROC without doneCount (common for sp_executesql). + tokChan <- doneInProcStruct{Status: doneCount, RowCount: 1} // trigger + tokChan <- doneInProcStruct{Status: doneCount, RowCount: 1} // outer stmt + tokChan <- doneStruct{Status: doneFinal} // DONEPROC, no count + close(tokChan) + + err := tp.iterateResponse() + assert.NoError(t, err) + assert.Equal(t, int64(1), tp.rowCount, + "rowCount should be 1 (last DONEINPROC), not 2 (accumulated)") +}