Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
155 changes: 155 additions & 0 deletions rowsaffected_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
package mssql

import (
"context"
"fmt"
"math/rand"
"testing"
)

// TestRowsAffectedWithTrigger verifies that RowsAffected returns the correct
// count for the outermost DML statement even when AFTER triggers fire and
// produce their own intermediate DONEINPROC row counts. See #204.
func TestRowsAffectedWithTrigger(t *testing.T) {
conn, logger := open(t)
defer conn.Close()
defer logger.StopLogging()

ctx := context.Background()

suffix := fmt.Sprintf("%d", rand.Intn(999999))
tbl := "test_ra204_" + suffix
audit := "test_ra204_audit_" + suffix
trg := "tr_ra204_" + suffix

cleanup := func() {
conn.ExecContext(ctx, "drop trigger if exists "+trg)
conn.ExecContext(ctx, "drop table if exists "+audit)
conn.ExecContext(ctx, "drop table if exists "+tbl)
Comment thread
dlevy-msft-sql marked this conversation as resolved.
}
cleanup()
defer cleanup()

_, err := conn.ExecContext(ctx, "create table "+tbl+" (id int primary key, value nvarchar(100))")
if err != nil {
t.Fatal("create table failed:", err)
}

_, err = conn.ExecContext(ctx, "insert into "+tbl+" values (1, 'old'), (2, 'old'), (3, 'old')")
if err != nil {
t.Fatal("insert failed:", err)
}

// Scenario 1: Basic update without trigger
result, err := conn.ExecContext(ctx, "update "+tbl+" set value = 'test' where id = 1")
if err != nil {
t.Fatal("update failed:", err)
}
rowsAffected, _ := result.RowsAffected()
if rowsAffected != 1 {
t.Errorf("basic update: expected RowsAffected=1, got %d", rowsAffected)
}

// Create audit table and trigger without NOCOUNT
_, err = conn.ExecContext(ctx, "create table "+audit+" (id int identity, action nvarchar(50))")
if err != nil {
t.Fatal("create audit table failed:", err)
}
_, err = conn.ExecContext(ctx, "create trigger "+trg+" on "+tbl+" after update as begin insert into "+audit+" (action) select 'updated' from inserted end")
if err != nil {
t.Fatal("create trigger failed:", err)
}

// Scenario 2: Update with trigger (no NOCOUNT) - trigger produces extra DONEINPROC
result, err = conn.ExecContext(ctx, "update "+tbl+" set value = 'triggered' where id = 1")
if err != nil {
t.Fatal("triggered update failed:", err)
}
rowsAffected, _ = result.RowsAffected()
if rowsAffected != 1 {
t.Errorf("trigger without NOCOUNT: expected RowsAffected=1, got %d", rowsAffected)
}
Comment thread
dlevy-msft-sql marked this conversation as resolved.

// Scenario 2b: Bulk update all rows with no-NOCOUNT trigger active.
// With the old += bug, the trigger's DONEINPROC (3 rows) plus the
// outer DONE (3 rows) would yield 6 instead of 3.
result, err = conn.ExecContext(ctx, "update "+tbl+" set value = 'bulk'")
if err != nil {
t.Fatal("bulk update without NOCOUNT failed:", err)
}
rowsAffected, _ = result.RowsAffected()
if rowsAffected != 3 {
t.Errorf("bulk update without NOCOUNT: expected RowsAffected=3, got %d", rowsAffected)
}

// Scenario 3: Recreate trigger with NOCOUNT
conn.ExecContext(ctx, "drop trigger "+trg)
Comment thread
dlevy-msft-sql marked this conversation as resolved.
_, err = conn.ExecContext(ctx, "create trigger "+trg+" on "+tbl+" after update as begin set nocount on; insert into "+audit+" (action) select 'updated' from inserted end")
if err != nil {
t.Fatal("create nocount trigger failed:", err)
}

result, err = conn.ExecContext(ctx, "update "+tbl+" set value = 'nocount' where id = 1")
if err != nil {
t.Fatal("nocount triggered update failed:", err)
}
rowsAffected, _ = result.RowsAffected()
if rowsAffected != 1 {
t.Errorf("trigger with NOCOUNT: expected RowsAffected=1, got %d", rowsAffected)
}

// Scenario 4: Update all rows with trigger
result, err = conn.ExecContext(ctx, "update "+tbl+" set value = 'all'")
if err != nil {
t.Fatal("bulk update failed:", err)
}
Comment thread
dlevy-msft-sql marked this conversation as resolved.
rowsAffected, _ = result.RowsAffected()
if rowsAffected != 3 {
t.Errorf("bulk update with trigger: expected RowsAffected=3, got %d", rowsAffected)
}
}

// TestMultiStatementBatchRowsAffected verifies that RowsAffected returns the
// last statement's count for a multi-statement batch. The assignment (=)
// semantics in doneStruct processing mean each DONE token replaces (not
// accumulates) the row count, so the final DONE is authoritative.
func TestMultiStatementBatchRowsAffected(t *testing.T) {
conn, logger := open(t)
defer conn.Close()
defer logger.StopLogging()

ctx := context.Background()

suffix := fmt.Sprintf("%d", rand.Intn(999999))
tbl := "test_msbatch_" + suffix

cleanup := func() {
conn.ExecContext(ctx, "drop table if exists "+tbl)
}
cleanup()
defer cleanup()

_, err := conn.ExecContext(ctx, "create table "+tbl+" (id int primary key, value nvarchar(100))")
if err != nil {
t.Fatal("create table failed:", err)
}

// Seed 5 rows.
_, err = conn.ExecContext(ctx, "insert into "+tbl+" values (1,'a'),(2,'b'),(3,'c'),(4,'d'),(5,'e')")
if err != nil {
t.Fatal("insert failed:", err)
}

// Multi-statement batch: first UPDATE touches 3 rows, second touches 2.
// RowsAffected should be 2 (the last statement), not 5 (the sum).
batch := "update " + tbl + " set value='x' where id <= 3; " +
"update " + tbl + " set value='y' where id > 3"
result, err := conn.ExecContext(ctx, batch)
if err != nil {
t.Fatal("multi-statement batch failed:", err)
}
rowsAffected, _ := result.RowsAffected()
if rowsAffected != 2 {
t.Errorf("multi-statement batch: expected RowsAffected=2 (last stmt), got %d", rowsAffected)
}
}
13 changes: 11 additions & 2 deletions token.go
Original file line number Diff line number Diff line change
Expand Up @@ -1238,11 +1238,20 @@ func (t *tokenProcessor) iterateResponse() error {
t.lastRow = token
case doneInProcStruct:
if token.Status&doneCount != 0 {
t.rowCount += int64(token.RowCount)
// Assignment (not +=) mirrors the doneStruct logic.
// For RPC/sp_executesql with triggers, both trigger
// and outer statement counts arrive as DONEINPROC;
// assignment ensures the last one wins (#204).
t.rowCount = int64(token.RowCount)
}
case doneStruct:
if token.Status&doneCount != 0 {
t.rowCount += int64(token.RowCount)
// Assignment (not +=) so the final DONE token's count
// is authoritative. Prevents double-counting when AFTER
// triggers fire without SET NOCOUNT ON (#204).
// For multi-statement batches, this means RowsAffected()
// returns the last statement's count, not the sum.
t.rowCount = int64(token.RowCount)
Comment thread
dlevy-msft-sql marked this conversation as resolved.
}
if token.isError() && t.firstError == nil {
t.firstError = token.getError()
Expand Down
71 changes: 71 additions & 0 deletions token_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -467,3 +467,74 @@ func TestNextToken_CancelDrainClosedChannelStartsSecondResponse(t *testing.T) {
t.Fatal("expected attention packet to be written")
}
}

// TestRowCountAssignmentNotAccumulation verifies that iterateResponse uses
// assignment (=) for doneStruct row counts, not accumulation (+=). This
// prevents double-counting when AFTER triggers fire without SET NOCOUNT ON.
func TestRowCountAssignmentNotAccumulation(t *testing.T) {
tokChan := make(chan tokenStruct, 10)
tp := &tokenProcessor{
tokChan: tokChan,
ctx: context.Background(),
sess: &tdsSession{},
}

// Simulate trigger scenario: DONEINPROC from trigger (1 row), then
// final DONE from the outer UPDATE (also 1 row). With +=, rowCount
// would be 2. With =, it should be 1.
tokChan <- doneInProcStruct{Status: doneCount, RowCount: 1}
tokChan <- doneStruct{Status: doneFinal | doneCount, RowCount: 1}
close(tokChan)

err := tp.iterateResponse()
assert.NoError(t, err)
assert.Equal(t, int64(1), tp.rowCount,
"rowCount should be 1 (assigned), not 2 (accumulated)")
}

// TestRowCountMultiStatement verifies that for multi-statement batches, the
// final DONE token's count wins (last statement), not the sum.
func TestRowCountMultiStatement(t *testing.T) {
tokChan := make(chan tokenStruct, 10)
tp := &tokenProcessor{
tokChan: tokChan,
ctx: context.Background(),
sess: &tdsSession{},
}

// First statement: DONE with 3 rows
tokChan <- doneStruct{Status: doneCount, RowCount: 3}
// Second statement: DONE with 2 rows (final)
tokChan <- doneStruct{Status: doneFinal | doneCount, RowCount: 2}
close(tokChan)

err := tp.iterateResponse()
assert.NoError(t, err)
assert.Equal(t, int64(2), tp.rowCount,
"rowCount should be 2 (last statement), not 5 (sum)")
}

// TestRowCountDoneInProcOnlyRPCPath verifies that the RPC/sp_executesql path
// (DONEINPROC-only) does not double-count when triggers fire. In this path,
// both trigger and outer statement counts arrive as DONEINPROC tokens, so
// assignment (=) must be used instead of accumulation (+=).
func TestRowCountDoneInProcOnlyRPCPath(t *testing.T) {
tokChan := make(chan tokenStruct, 10)
tp := &tokenProcessor{
tokChan: tokChan,
ctx: context.Background(),
sess: &tdsSession{},
}

// Simulate RPC with trigger: trigger's INSERT (1 row), outer UPDATE
// (1 row), then DONEPROC without doneCount (common for sp_executesql).
tokChan <- doneInProcStruct{Status: doneCount, RowCount: 1} // trigger
tokChan <- doneInProcStruct{Status: doneCount, RowCount: 1} // outer stmt
tokChan <- doneStruct{Status: doneFinal} // DONEPROC, no count
close(tokChan)

err := tp.iterateResponse()
assert.NoError(t, err)
assert.Equal(t, int64(1), tp.rowCount,
"rowCount should be 1 (last DONEINPROC), not 2 (accumulated)")
}