-
-
Notifications
You must be signed in to change notification settings - Fork 1.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix exponential memory allocation in Exec and improve performance #1296
base: master
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -137,6 +137,61 @@ _sqlite3_prepare_v2_internal(sqlite3 *db, const char *zSql, int nBytes, sqlite3_ | |
} | ||
#endif | ||
|
||
static int _sqlite3_prepare_v2(sqlite3 *db, const char *zSql, int nBytes, sqlite3_stmt **ppStmt, int *oBytes) { | ||
const char *tail; | ||
int rv = _sqlite3_prepare_v2_internal(db, zSql, nBytes, ppStmt, &tail); | ||
if (rv != SQLITE_OK) { | ||
return rv; | ||
} | ||
if (tail) { | ||
// Set oBytes to the number of bytes consumed instead of using the | ||
// **pzTail out param since that requires storing a Go pointer in | ||
// a C pointer, which is not allowed by CGO and will cause | ||
// runtime.cgoCheckPointer to fail. | ||
*oBytes = tail - zSql; | ||
} else { | ||
// NB: this should not happen, but if it does advance oBytes to the | ||
// end of the string so that we do not loop infinitely. | ||
*oBytes = nBytes; | ||
} | ||
return SQLITE_OK; | ||
} | ||
|
||
// _sqlite3_exec_no_args executes all of the statements in zSql. None of the | ||
// statements are allowed to have positional arguments. | ||
int _sqlite3_exec_no_args(sqlite3 *db, const char *zSql, int nBytes, int64_t *rowid, int64_t *changes) { | ||
while (*zSql && nBytes > 0) { | ||
sqlite3_stmt *stmt; | ||
const char *tail; | ||
int rv = sqlite3_prepare_v2(db, zSql, nBytes, &stmt, &tail); | ||
if (rv != SQLITE_OK) { | ||
return rv; | ||
} | ||
|
||
// Process statement | ||
do { | ||
rv = _sqlite3_step_internal(stmt); | ||
} while (rv == SQLITE_ROW); | ||
|
||
// Only record the number of changes made by the last statement. | ||
*changes = sqlite3_changes64(db); | ||
*rowid = sqlite3_last_insert_rowid(db); | ||
|
||
if (rv != SQLITE_OK && rv != SQLITE_DONE) { | ||
sqlite3_finalize(stmt); | ||
return rv; | ||
} | ||
rv = sqlite3_finalize(stmt); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
This implies that there is no need to inspect the return value here, and it can be done unconditionally (before checking |
||
if (rv != SQLITE_OK) { | ||
return rv; | ||
} | ||
|
||
nBytes -= tail - zSql; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Here you are assuming |
||
zSql = tail; | ||
} | ||
return SQLITE_OK; | ||
} | ||
|
||
void _sqlite3_result_text(sqlite3_context* ctx, const char* s) { | ||
sqlite3_result_text(ctx, s, -1, &free); | ||
} | ||
|
@@ -858,54 +913,119 @@ func (c *SQLiteConn) Exec(query string, args []driver.Value) (driver.Result, err | |
} | ||
|
||
func (c *SQLiteConn) exec(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) { | ||
start := 0 | ||
// Trim the query. This is mostly important for getting rid | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why? |
||
// of any trailing space. | ||
query = strings.TrimSpace(query) | ||
if len(args) > 0 { | ||
return c.execArgs(ctx, query, args) | ||
} | ||
return c.execNoArgs(ctx, query) | ||
} | ||
|
||
func (c *SQLiteConn) execArgs(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) { | ||
var ( | ||
stmtArgs []driver.NamedValue | ||
start int | ||
s SQLiteStmt // escapes to the heap so reuse it | ||
sz C.int // number of query bytes consumed: escapes to the heap | ||
) | ||
for { | ||
s, err := c.prepare(ctx, query) | ||
if err != nil { | ||
return nil, err | ||
s = SQLiteStmt{c: c} // reset | ||
sz = 0 | ||
rv := C._sqlite3_prepare_v2(c.db, (*C.char)(unsafe.Pointer(stringData(query))), | ||
C.int(len(query)), &s.s, &sz) | ||
if rv != C.SQLITE_OK { | ||
return nil, c.lastError() | ||
} | ||
query = strings.TrimSpace(query[sz:]) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why |
||
|
||
var res driver.Result | ||
if s.(*SQLiteStmt).s != nil { | ||
stmtArgs := make([]driver.NamedValue, 0, len(args)) | ||
if s.s != nil { | ||
na := s.NumInput() | ||
if len(args)-start < na { | ||
s.Close() | ||
s.finalize() | ||
return nil, fmt.Errorf("not enough args to execute query: want %d got %d", na, len(args)) | ||
} | ||
// consume the number of arguments used in the current | ||
// statement and append all named arguments not | ||
// contained therein | ||
if len(args[start:start+na]) > 0 { | ||
stmtArgs = append(stmtArgs, args[start:start+na]...) | ||
for i := range args { | ||
if (i < start || i >= na) && args[i].Name != "" { | ||
stmtArgs = append(stmtArgs, args[i]) | ||
} | ||
} | ||
for i := range stmtArgs { | ||
stmtArgs[i].Ordinal = i + 1 | ||
stmtArgs = append(stmtArgs[:0], args[start:start+na]...) | ||
for i := range args { | ||
if (i < start || i >= na) && args[i].Name != "" { | ||
stmtArgs = append(stmtArgs, args[i]) | ||
} | ||
} | ||
res, err = s.(*SQLiteStmt).exec(ctx, stmtArgs) | ||
for i := range stmtArgs { | ||
stmtArgs[i].Ordinal = i + 1 | ||
} | ||
var err error | ||
res, err = s.exec(ctx, stmtArgs) | ||
if err != nil && err != driver.ErrSkip { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What does it mean if |
||
s.Close() | ||
s.finalize() | ||
return nil, err | ||
} | ||
start += na | ||
} | ||
tail := s.(*SQLiteStmt).t | ||
s.Close() | ||
if tail == "" { | ||
s.finalize() | ||
if len(query) == 0 { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think there is an important nuance in play here. There are two ways to know you are done. The first is that However, the flaw in the current code is that it assumes that just checking whether (Please also add a test case for the trailing comment situation.) |
||
if res == nil { | ||
// https://github.com/mattn/go-sqlite3/issues/963 | ||
res = &SQLiteResult{0, 0} | ||
} | ||
return res, nil | ||
} | ||
query = tail | ||
} | ||
} | ||
|
||
// execNoArgsSync processes every SQL statement in query. All processing occurs | ||
// in C code, which reduces the overhead of CGO calls. | ||
func (c *SQLiteConn) execNoArgsSync(query string) (_ driver.Result, err error) { | ||
var rowid, changes C.int64_t | ||
rv := C._sqlite3_exec_no_args(c.db, (*C.char)(unsafe.Pointer(stringData(query))), | ||
C.int(len(query)), &rowid, &changes) | ||
if rv != C.SQLITE_OK { | ||
err = c.lastError() | ||
} | ||
return &SQLiteResult{id: int64(rowid), changes: int64(changes)}, err | ||
} | ||
|
||
func (c *SQLiteConn) execNoArgs(ctx context.Context, query string) (driver.Result, error) { | ||
done := ctx.Done() | ||
if done == nil { | ||
return c.execNoArgsSync(query) | ||
} | ||
|
||
// Fast check if the Context is cancelled | ||
if err := ctx.Err(); err != nil { | ||
return nil, err | ||
} | ||
|
||
ch := make(chan struct{}) | ||
defer close(ch) | ||
go func() { | ||
select { | ||
case <-done: | ||
C.sqlite3_interrupt(c.db) | ||
// Wait until signaled. We need to ensure that this goroutine | ||
// will not call interrupt after this method returns, which is | ||
// why we can't check if only done is closed when waiting below. | ||
<-ch | ||
case <-ch: | ||
} | ||
}() | ||
|
||
res, err := c.execNoArgsSync(query) | ||
|
||
// Stop the goroutine and make sure we're at a point where | ||
// sqlite3_interrupt cannot be called again. | ||
ch <- struct{}{} | ||
|
||
if isInterruptErr(err) { | ||
err = ctx.Err() | ||
} | ||
return res, err | ||
} | ||
|
||
// Query implements Queryer. | ||
func (c *SQLiteConn) Query(query string, args []driver.Value) (driver.Rows, error) { | ||
list := make([]driver.NamedValue, len(args)) | ||
|
@@ -1914,6 +2034,13 @@ func (s *SQLiteStmt) Close() error { | |
return nil | ||
} | ||
|
||
func (s *SQLiteStmt) finalize() { | ||
if s.s != nil { | ||
C.sqlite3_finalize(s.s) | ||
s.s = nil | ||
} | ||
} | ||
|
||
// NumInput return a number of parameters. | ||
func (s *SQLiteStmt) NumInput() int { | ||
return int(C.sqlite3_bind_parameter_count(s.s)) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -10,6 +10,7 @@ package sqlite3 | |
|
||
import ( | ||
"bytes" | ||
"context" | ||
"database/sql" | ||
"database/sql/driver" | ||
"errors" | ||
|
@@ -1090,6 +1091,67 @@ func TestExecer(t *testing.T) { | |
} | ||
} | ||
|
||
func TestExecDriverResult(t *testing.T) { | ||
setup := func(t *testing.T) *sql.DB { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is only called once - can we just merge it into |
||
db, err := sql.Open("sqlite3", t.TempDir()+"/test.sqlite3") | ||
if err != nil { | ||
t.Fatal("Failed to open database:", err) | ||
} | ||
if _, err := db.Exec(`CREATE TABLE foo (id INTEGER PRIMARY KEY);`); err != nil { | ||
t.Fatal(err) | ||
} | ||
t.Cleanup(func() { db.Close() }) | ||
return db | ||
} | ||
|
||
test := func(t *testing.T, execStmt string, args ...any) { | ||
db := setup(t) | ||
res, err := db.Exec(execStmt, args...) | ||
if err != nil { | ||
t.Fatal(err) | ||
} | ||
rows, err := res.RowsAffected() | ||
if err != nil { | ||
t.Fatal(err) | ||
} | ||
// We only return the changes from the last statement. | ||
if rows != 1 { | ||
t.Errorf("RowsAffected got: %d want: %d", rows, 1) | ||
} | ||
id, err := res.LastInsertId() | ||
if err != nil { | ||
t.Fatal(err) | ||
} | ||
if id != 3 { | ||
t.Errorf("LastInsertId got: %d want: %d", id, 3) | ||
} | ||
var count int64 | ||
err = db.QueryRow(`SELECT COUNT(*) FROM foo WHERE id IN (1, 2, 3);`).Scan(&count) | ||
if err != nil { | ||
t.Fatal(err) | ||
} | ||
if count != 3 { | ||
t.Errorf("Expected count to be %d got: %d", 3, count) | ||
} | ||
} | ||
|
||
t.Run("NoArgs", func(t *testing.T) { | ||
const stmt = ` | ||
INSERT INTO foo(id) VALUES(1); | ||
INSERT INTO foo(id) VALUES(2); | ||
INSERT INTO foo(id) VALUES(3);` | ||
test(t, stmt) | ||
}) | ||
|
||
t.Run("WithArgs", func(t *testing.T) { | ||
const stmt = ` | ||
INSERT INTO foo(id) VALUES(?); | ||
INSERT INTO foo(id) VALUES(?); | ||
INSERT INTO foo(id) VALUES(?);` | ||
test(t, stmt, 1, 2, 3) | ||
}) | ||
} | ||
|
||
func TestQueryer(t *testing.T) { | ||
tempFilename := TempFilename(t) | ||
defer os.Remove(tempFilename) | ||
|
@@ -2106,6 +2168,10 @@ var tests = []testing.InternalTest{ | |
|
||
var benchmarks = []testing.InternalBenchmark{ | ||
{Name: "BenchmarkExec", F: benchmarkExec}, | ||
{Name: "BenchmarkExecContext", F: benchmarkExecContext}, | ||
{Name: "BenchmarkExecStep", F: benchmarkExecStep}, | ||
{Name: "BenchmarkExecContextStep", F: benchmarkExecContextStep}, | ||
{Name: "BenchmarkExecTx", F: benchmarkExecTx}, | ||
{Name: "BenchmarkQuery", F: benchmarkQuery}, | ||
{Name: "BenchmarkParams", F: benchmarkParams}, | ||
{Name: "BenchmarkStmt", F: benchmarkStmt}, | ||
|
@@ -2459,13 +2525,78 @@ func testExecEmptyQuery(t *testing.T) { | |
|
||
// benchmarkExec is benchmark for exec | ||
func benchmarkExec(b *testing.B) { | ||
b.Run("Params", func(b *testing.B) { | ||
for i := 0; i < b.N; i++ { | ||
if _, err := db.Exec("select ?;", int64(1)); err != nil { | ||
panic(err) | ||
} | ||
} | ||
}) | ||
b.Run("NoParams", func(b *testing.B) { | ||
for i := 0; i < b.N; i++ { | ||
if _, err := db.Exec("select 1;"); err != nil { | ||
panic(err) | ||
} | ||
} | ||
}) | ||
} | ||
|
||
func benchmarkExecContext(b *testing.B) { | ||
b.Run("Params", func(b *testing.B) { | ||
ctx, cancel := context.WithCancel(context.Background()) | ||
defer cancel() | ||
for i := 0; i < b.N; i++ { | ||
if _, err := db.ExecContext(ctx, "select ?;", int64(1)); err != nil { | ||
panic(err) | ||
} | ||
} | ||
}) | ||
b.Run("NoParams", func(b *testing.B) { | ||
ctx, cancel := context.WithCancel(context.Background()) | ||
defer cancel() | ||
for i := 0; i < b.N; i++ { | ||
if _, err := db.ExecContext(ctx, "select 1;"); err != nil { | ||
panic(err) | ||
} | ||
} | ||
}) | ||
} | ||
|
||
func benchmarkExecTx(b *testing.B) { | ||
for i := 0; i < b.N; i++ { | ||
if _, err := db.Exec("select 1"); err != nil { | ||
tx, err := db.Begin() | ||
if err != nil { | ||
panic(err) | ||
} | ||
if _, err := tx.Exec("select 1;"); err != nil { | ||
panic(err) | ||
} | ||
if err := tx.Commit(); err != nil { | ||
panic(err) | ||
} | ||
} | ||
} | ||
|
||
var largeSelectStmt = strings.Repeat("select 1;\n", 1_000) | ||
|
||
func benchmarkExecStep(b *testing.B) { | ||
for n := 0; n < b.N; n++ { | ||
if _, err := db.Exec(largeSelectStmt); err != nil { | ||
b.Fatal(err) | ||
} | ||
} | ||
} | ||
|
||
func benchmarkExecContextStep(b *testing.B) { | ||
ctx, cancel := context.WithCancel(context.Background()) | ||
defer cancel() | ||
for n := 0; n < b.N; n++ { | ||
if _, err := db.ExecContext(ctx, largeSelectStmt); err != nil { | ||
b.Fatal(err) | ||
} | ||
} | ||
} | ||
|
||
// benchmarkQuery is benchmark for query | ||
func benchmarkQuery(b *testing.B) { | ||
for i := 0; i < b.N; i++ { | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we really do this if there was an error?