diff --git a/sqlite3.go b/sqlite3.go index ce985ec8..20ac7b78 100644 --- a/sqlite3.go +++ b/sqlite3.go @@ -399,6 +399,7 @@ type SQLiteRows struct { cls bool closed bool ctx context.Context // no better alternative to pass context into Next() method + resultCh chan error } type functionInfo struct { @@ -2172,24 +2173,29 @@ func (rc *SQLiteRows) Next(dest []driver.Value) error { return io.EOF } - if rc.ctx.Done() == nil { + done := rc.ctx.Done() + if done == nil { return rc.nextSyncLocked(dest) } - resultCh := make(chan error) - defer close(resultCh) + if err := rc.ctx.Err(); err != nil { + return err // Fast check if the channel is closed + } + if rc.resultCh == nil { + rc.resultCh = make(chan error) + } go func() { - resultCh <- rc.nextSyncLocked(dest) + rc.resultCh <- rc.nextSyncLocked(dest) }() select { - case err := <-resultCh: + case err := <-rc.resultCh: return err - case <-rc.ctx.Done(): + case <-done: select { - case <-resultCh: // no need to interrupt + case <-rc.resultCh: // no need to interrupt default: // this is still racy and can be no-op if executed between sqlite3_* calls in nextSyncLocked. C.sqlite3_interrupt(rc.s.c.db) - <-resultCh // ensure goroutine completed + <-rc.resultCh // ensure goroutine completed } return rc.ctx.Err() } diff --git a/sqlite3_go18_test.go b/sqlite3_go18_test.go index eec7479d..3ac1e754 100644 --- a/sqlite3_go18_test.go +++ b/sqlite3_go18_test.go @@ -11,6 +11,7 @@ package sqlite3 import ( "context" "database/sql" + "errors" "fmt" "io/ioutil" "math/rand" @@ -268,6 +269,141 @@ func TestQueryRowContextCancelParallel(t *testing.T) { } } +// Test that we can successfully interrupt a long running query when +// the context is canceled. The previous two QueryRowContext tests +// only test that we handle a previously cancelled context and thus +// do not call sqlite3_interrupt. +func TestQueryRowContextCancelInterrupt(t *testing.T) { + db, err := sql.Open("sqlite3", ":memory:") + if err != nil { + t.Fatal(err) + } + defer db.Close() + + const createTableStmt = ` + CREATE TABLE timestamps ( + ts TIMESTAMP NOT NULL + );` + if _, err := db.Exec(createTableStmt); err != nil { + t.Fatal(err) + } + + stmt, err := db.Prepare(`INSERT INTO timestamps VALUES (?);`) + if err != nil { + t.Fatal(err) + } + defer stmt.Close() + + // Computationally expensive query that consumes many rows. This is needed + // to test cancellation because queries are not interrupted immediately. + // Instead, queries are only halted at certain checkpoints where the + // sqlite3.isInterrupted is checked and true. + queryStmt := ` + SELECT + SUM(unixepoch(datetime(ts + 10, 'unixepoch', 'localtime'))) AS c1, + SUM(unixepoch(datetime(ts + 20, 'unixepoch', 'localtime'))) AS c2, + SUM(unixepoch(datetime(ts + 30, 'unixepoch', 'localtime'))) AS c3, + SUM(unixepoch(datetime(ts + 40, 'unixepoch', 'localtime'))) AS c4 + FROM + timestamps + WHERE datetime(ts, 'unixepoch', 'localtime') + LIKE + ?;` + + query := func(t *testing.T, timeout time.Duration) (int, error) { + // Create a complicated pattern to match timestamps + const pattern = "%2%0%2%4%-%-%:%:%" + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + rows, err := db.QueryContext(ctx, queryStmt, pattern) + if err != nil { + return 0, err + } + var count int + for rows.Next() { + var n int64 + if err := rows.Scan(&n, &n, &n, &n); err != nil { + return count, err + } + count++ + } + return count, rows.Err() + } + + average := func(n int, fn func()) time.Duration { + start := time.Now() + for i := 0; i < n; i++ { + fn() + } + return time.Since(start) / time.Duration(n) + } + + createRows := func(n int) { + t.Logf("Creating %d rows", n) + if _, err := db.Exec(`DELETE FROM timestamps; VACUUM;`); err != nil { + t.Fatal(err) + } + ts := time.Date(2024, 6, 6, 8, 9, 10, 12345, time.UTC).Unix() + rr := rand.New(rand.NewSource(1234)) + for i := 0; i < n; i++ { + if _, err := stmt.Exec(ts + rr.Int63n(10_000) - 5_000); err != nil { + t.Fatal(err) + } + } + } + + const TargetRuntime = 200 * time.Millisecond + const N = 5_000 // Number of rows to insert at a time + + // Create enough rows that the query takes ~200ms to run. + start := time.Now() + createRows(N) + baseAvg := average(4, func() { + if _, err := query(t, time.Hour); err != nil { + t.Fatal(err) + } + }) + t.Log("Base average:", baseAvg) + rowCount := N * (int(TargetRuntime/baseAvg) + 1) + createRows(rowCount) + t.Log("Table setup time:", time.Since(start)) + + // Set the timeout to 1/10 of the average query time. + avg := average(2, func() { + n, err := query(t, time.Hour) + if err != nil { + t.Fatal(err) + } + if n == 0 { + t.Fatal("scanned zero rows") + } + }) + // Guard against the timeout being too short to reliably test. + if avg < TargetRuntime/2 { + t.Fatalf("Average query runtime should be around %s got: %s ", + TargetRuntime, avg) + } + timeout := (avg / 10).Round(100 * time.Microsecond) + t.Logf("Average: %s Timeout: %s", avg, timeout) + + for i := 0; i < 10; i++ { + tt := time.Now() + n, err := query(t, timeout) + if !errors.Is(err, context.DeadlineExceeded) { + fn := t.Errorf + if err != nil { + fn = t.Fatalf + } + fn("expected error %v got %v", context.DeadlineExceeded, err) + } + d := time.Since(tt) + t.Logf("%d: rows: %d duration: %s", i, n, d) + if d > timeout*4 { + t.Errorf("query was cancelled after %s but did not abort until: %s", timeout, d) + } + } +} + func TestExecCancel(t *testing.T) { db, err := sql.Open("sqlite3", ":memory:") if err != nil { diff --git a/sqlite3_test.go b/sqlite3_test.go index 63c939d3..1eb42d2e 100644 --- a/sqlite3_test.go +++ b/sqlite3_test.go @@ -10,6 +10,7 @@ package sqlite3 import ( "bytes" + "context" "database/sql" "database/sql/driver" "errors" @@ -2030,7 +2031,7 @@ func BenchmarkCustomFunctions(b *testing.B) { } func TestSuite(t *testing.T) { - initializeTestDB(t) + initializeTestDB(t, false) defer freeTestDB() for _, test := range tests { @@ -2039,7 +2040,7 @@ func TestSuite(t *testing.T) { } func BenchmarkSuite(b *testing.B) { - initializeTestDB(b) + initializeTestDB(b, true) defer freeTestDB() for _, benchmark := range benchmarks { @@ -2068,8 +2069,13 @@ type TestDB struct { var db *TestDB -func initializeTestDB(t testing.TB) { - tempFilename := TempFilename(t) +func initializeTestDB(t testing.TB, memory bool) { + var tempFilename string + if memory { + tempFilename = ":memory:" + } else { + tempFilename = TempFilename(t) + } d, err := sql.Open("sqlite3", tempFilename+"?_busy_timeout=99999") if err != nil { os.Remove(tempFilename) @@ -2084,9 +2090,11 @@ func freeTestDB() { if err != nil { panic(err) } - err = os.Remove(db.tempFilename) - if err != nil { - panic(err) + if db.tempFilename != "" && db.tempFilename != ":memory:" { + err := os.Remove(db.tempFilename) + if err != nil { + panic(err) + } } } @@ -2107,6 +2115,7 @@ var tests = []testing.InternalTest{ var benchmarks = []testing.InternalBenchmark{ {Name: "BenchmarkExec", F: benchmarkExec}, {Name: "BenchmarkQuery", F: benchmarkQuery}, + {Name: "BenchmarkQueryContext", F: benchmarkQueryContext}, {Name: "BenchmarkParams", F: benchmarkParams}, {Name: "BenchmarkStmt", F: benchmarkStmt}, {Name: "BenchmarkRows", F: benchmarkRows}, @@ -2479,6 +2488,65 @@ func benchmarkQuery(b *testing.B) { } } +// benchmarkQueryContext is benchmark for QueryContext +func benchmarkQueryContext(b *testing.B) { + const createTableStmt = ` + CREATE TABLE IF NOT EXISTS query_context( + id INTEGER PRIMARY KEY + ); + DELETE FROM query_context; + VACUUM;` + test := func(ctx context.Context, b *testing.B) { + if _, err := db.Exec(createTableStmt); err != nil { + b.Fatal(err) + } + for i := 0; i < 10; i++ { + _, err := db.Exec("INSERT INTO query_context VALUES (?);", int64(i)) + if err != nil { + db.Fatal(err) + } + } + stmt, err := db.PrepareContext(ctx, `SELECT id FROM query_context;`) + if err != nil { + b.Fatal(err) + } + b.Cleanup(func() { stmt.Close() }) + + var n int + for i := 0; i < b.N; i++ { + rows, err := stmt.QueryContext(ctx) + if err != nil { + b.Fatal(err) + } + for rows.Next() { + if err := rows.Scan(&n); err != nil { + b.Fatal(err) + } + } + if err := rows.Err(); err != nil { + b.Fatal(err) + } + } + } + + // When the context does not have a Done channel we should use + // the fast path that directly handles the query instead of + // handling it in a goroutine. This benchmark also serves to + // highlight the performance impact of using a cancelable + // context. + b.Run("Background", func(b *testing.B) { + test(context.Background(), b) + }) + + // Benchmark a query with a context that can be canceled. This + // requires using a goroutine and is thus much slower. + b.Run("WithCancel", func(b *testing.B) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + test(ctx, b) + }) +} + // benchmarkParams is benchmark for params func benchmarkParams(b *testing.B) { for i := 0; i < b.N; i++ {