diff --git a/sqlite3.go b/sqlite3.go index ed2a9e2a..91e428c8 100644 --- a/sqlite3.go +++ b/sqlite3.go @@ -135,6 +135,7 @@ _sqlite3_prepare_v2_internal(sqlite3 *db, const char *zSql, int nBytes, sqlite3_ { return sqlite3_prepare_v2(db, zSql, nBytes, ppStmt, pzTail); } + #endif void _sqlite3_result_text(sqlite3_context* ctx, const char* s) { @@ -858,25 +859,34 @@ func (c *SQLiteConn) Exec(query string, args []driver.Value) (driver.Result, err } func (c *SQLiteConn) exec(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) { + pquery := C.CString(query) + op := pquery // original pointer + defer C.free(unsafe.Pointer(op)) + + var stmtArgs []driver.NamedValue + var tail *C.char + s := new(SQLiteStmt) // escapes to the heap so reuse it + defer s.finalize() start := 0 for { - s, err := c.prepare(ctx, query) - if err != nil { - return nil, err + *s = SQLiteStmt{c: c} // reset + rv := C._sqlite3_prepare_v2_internal(c.db, pquery, C.int(-1), &s.s, &tail) + if rv != C.SQLITE_OK { + return nil, c.lastError() } + var res driver.Result - if s.(*SQLiteStmt).s != nil { - stmtArgs := make([]driver.NamedValue, 0, len(args)) + if s.s != nil { na := s.NumInput() if len(args)-start < na { - s.Close() + s.finalize() return nil, fmt.Errorf("not enough args to execute query: want %d got %d", na, len(args)) } // consume the number of arguments used in the current // statement and append all named arguments not // contained therein if len(args[start:start+na]) > 0 { - stmtArgs = append(stmtArgs, args[start:start+na]...) + stmtArgs = append(stmtArgs[:0], args[start:start+na]...) for i := range args { if (i < start || i >= na) && args[i].Name != "" { stmtArgs = append(stmtArgs, args[i]) @@ -886,23 +896,23 @@ func (c *SQLiteConn) exec(ctx context.Context, query string, args []driver.Named stmtArgs[i].Ordinal = i + 1 } } - res, err = s.(*SQLiteStmt).exec(ctx, stmtArgs) + var err error + res, err = s.exec(ctx, stmtArgs) if err != nil && err != driver.ErrSkip { - s.Close() + s.finalize() return nil, err } start += na } - tail := s.(*SQLiteStmt).t - s.Close() - if tail == "" { + s.finalize() + if tail == nil || *tail == '\000' { if res == nil { // https://github.com/mattn/go-sqlite3/issues/963 res = &SQLiteResult{0, 0} } return res, nil } - query = tail + pquery = tail } } @@ -919,44 +929,48 @@ func (c *SQLiteConn) Query(query string, args []driver.Value) (driver.Rows, erro } func (c *SQLiteConn) query(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) { - start := 0 - for { - stmtArgs := make([]driver.NamedValue, 0, len(args)) - s, err := c.prepare(ctx, query) - if err != nil { - return nil, err - } - s.(*SQLiteStmt).cls = true - na := s.NumInput() - if len(args)-start < na { - return nil, fmt.Errorf("not enough args to execute query: want %d got %d", na, len(args)-start) - } - // consume the number of arguments used in the current - // statement and append all named arguments not contained - // therein - stmtArgs = append(stmtArgs, args[start:start+na]...) - for i := range args { - if (i < start || i >= na) && args[i].Name != "" { - stmtArgs = append(stmtArgs, args[i]) - } - } - for i := range stmtArgs { - stmtArgs[i].Ordinal = i + 1 + pquery := C.CString(query) + op := pquery // original pointer + defer C.free(unsafe.Pointer(op)) + + var tail *C.char + s := &SQLiteStmt{c: c, cls: true} + rv := C._sqlite3_prepare_v2_internal(c.db, pquery, C.int(-1), &s.s, &tail) + if rv != C.SQLITE_OK { + return nil, c.lastError() + } + if s.s == nil { + return &SQLiteRows{s: &SQLiteStmt{closed: true}, closed: true}, nil + } + na := s.NumInput() + if n := len(args); n != na { + s.finalize() + if n < na { + return nil, fmt.Errorf("not enough args to execute query: want %d got %d", na, len(args)) } - rows, err := s.(*SQLiteStmt).query(ctx, stmtArgs) - if err != nil && err != driver.ErrSkip { - s.Close() - return rows, err + return nil, fmt.Errorf("too many args to execute query: want %d got %d", na, len(args)) + } + rows, err := s.query(ctx, args) + if err != nil && err != driver.ErrSkip { + s.finalize() + return rows, err + } + + // Consume the rest of the query + for pquery = tail; pquery != nil && *pquery != 0; pquery = tail { + var stmt *C.sqlite3_stmt + rv := C._sqlite3_prepare_v2_internal(c.db, pquery, C.int(-1), &stmt, &tail) + if rv != C.SQLITE_OK { + rows.Close() + return nil, c.lastError() } - start += na - tail := s.(*SQLiteStmt).t - if tail == "" { - return rows, nil + if stmt != nil { + rows.Close() + return nil, errors.New("cannot insert multiple commands into a prepared statement") // WARN } - rows.Close() - s.Close() - query = tail } + + return rows, nil } // Begin transaction. @@ -1818,8 +1832,11 @@ func (c *SQLiteConn) prepare(ctx context.Context, query string) (driver.Stmt, er return nil, c.lastError() } var t string - if tail != nil && *tail != '\000' { - t = strings.TrimSpace(C.GoString(tail)) + if tail != nil && *tail != 0 { + n := int(uintptr(unsafe.Pointer(tail))) - int(uintptr(unsafe.Pointer(pquery))) + if 0 <= n && n < len(query) { + t = strings.TrimSpace(query[n:]) + } } ss := &SQLiteStmt{c: c, s: s, t: t} runtime.SetFinalizer(ss, (*SQLiteStmt).Close) @@ -1913,6 +1930,13 @@ func (s *SQLiteStmt) Close() error { return nil } +func (s *SQLiteStmt) finalize() { + if s.s != nil { + C.sqlite3_finalize(s.s) + s.s = nil + } +} + // NumInput return a number of parameters. func (s *SQLiteStmt) NumInput() int { return int(C.sqlite3_bind_parameter_count(s.s)) @@ -2000,7 +2024,7 @@ func (s *SQLiteStmt) Query(args []driver.Value) (driver.Rows, error) { return s.query(context.Background(), list) } -func (s *SQLiteStmt) query(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) { +func (s *SQLiteStmt) query(ctx context.Context, args []driver.NamedValue) (*SQLiteRows, error) { if err := s.bind(args); err != nil { return nil, err } diff --git a/sqlite3_test.go b/sqlite3_test.go index 63c939d3..089cf21a 100644 --- a/sqlite3_test.go +++ b/sqlite3_test.go @@ -18,6 +18,7 @@ import ( "math/rand" "net/url" "os" + "path/filepath" "reflect" "regexp" "runtime" @@ -1080,67 +1081,158 @@ func TestExecer(t *testing.T) { defer db.Close() _, err = db.Exec(` - create table foo (id integer); -- one comment - insert into foo(id) values(?); - insert into foo(id) values(?); - insert into foo(id) values(?); -- another comment + CREATE TABLE foo (id INTEGER); -- one comment + INSERT INTO foo(id) VALUES(?); + INSERT INTO foo(id) VALUES(?); + INSERT INTO foo(id) VALUES(?); -- another comment `, 1, 2, 3) if err != nil { t.Error("Failed to call db.Exec:", err) } } -func TestQueryer(t *testing.T) { - tempFilename := TempFilename(t) - defer os.Remove(tempFilename) - db, err := sql.Open("sqlite3", tempFilename) +func testQuery(t *testing.T, seed bool, test func(t *testing.T, db *sql.DB)) { + db, err := sql.Open("sqlite3", filepath.Join(t.TempDir(), "test.sqlite3")) if err != nil { t.Fatal("Failed to open database:", err) } defer db.Close() - _, err = db.Exec(` - create table foo (id integer); - `) - if err != nil { - t.Error("Failed to call db.Query:", err) + if seed { + if _, err := db.Exec(`create table foo (id integer);`); err != nil { + t.Fatal(err) + } + _, err := db.Exec(` + INSERT INTO foo(id) VALUES(?); + INSERT INTO foo(id) VALUES(?); + INSERT INTO foo(id) VALUES(?); + `, 3, 2, 1) + if err != nil { + t.Fatal(err) + } } - _, err = db.Exec(` - insert into foo(id) values(?); - insert into foo(id) values(?); - insert into foo(id) values(?); - `, 3, 2, 1) - if err != nil { - t.Error("Failed to call db.Exec:", err) - } - rows, err := db.Query(` - select id from foo order by id; - `) - if err != nil { - t.Error("Failed to call db.Query:", err) - } - defer rows.Close() - n := 0 - for rows.Next() { - var id int - err = rows.Scan(&id) + // Capture panic so tests can continue + defer func() { + if e := recover(); e != nil { + buf := make([]byte, 32*1024) + n := runtime.Stack(buf, false) + t.Fatalf("\npanic: %v\n\n%s\n", e, buf[:n]) + } + }() + test(t, db) +} + +func testQueryValues(t *testing.T, query string, args ...interface{}) []interface{} { + var values []interface{} + testQuery(t, true, func(t *testing.T, db *sql.DB) { + rows, err := db.Query(query, args...) if err != nil { - t.Error("Failed to db.Query:", err) + t.Fatal(err) } - if id != n+1 { - t.Error("Failed to db.Query: not matched results") + if rows == nil { + t.Fatal("nil rows") } - n = n + 1 + for i := 0; rows.Next(); i++ { + if i > 1_000 { + t.Fatal("To many iterations of rows.Next():", i) + } + var v interface{} + if err := rows.Scan(&v); err != nil { + t.Fatal(err) + } + values = append(values, v) + } + if err := rows.Err(); err != nil { + t.Fatal(err) + } + if err := rows.Close(); err != nil { + t.Fatal(err) + } + }) + return values +} + +func TestQuery(t *testing.T) { + queries := []struct { + query string + args []interface{} + }{ + {"SELECT id FROM foo ORDER BY id;", nil}, + {"SELECT id FROM foo WHERE id != ? ORDER BY id;", []interface{}{4}}, + {"SELECT id FROM foo WHERE id IN (?, ?, ?) ORDER BY id;", []interface{}{1, 2, 3}}, + + // Comments + {"SELECT id FROM foo ORDER BY id; -- comment", nil}, + {"SELECT id FROM foo ORDER BY id;\n -- comment\n", nil}, + { + `-- FOO + SELECT id FROM foo ORDER BY id; -- BAR + /* BAZ */`, + nil, + }, } - if err := rows.Err(); err != nil { - t.Errorf("Post-scan failed: %v\n", err) + want := []interface{}{ + int64(1), + int64(2), + int64(3), + } + for _, q := range queries { + t.Run("", func(t *testing.T) { + got := testQueryValues(t, q.query, q.args...) + if !reflect.DeepEqual(got, want) { + t.Fatalf("Query(%q, %v) = %v; want: %v", q.query, q.args, got, want) + } + }) } - if n != 3 { - t.Errorf("Expected 3 rows but retrieved %v", n) +} + +func TestQueryNoSQL(t *testing.T) { + got := testQueryValues(t, "") + if got != nil { + t.Fatalf("Query(%q, %v) = %v; want: %v", "", nil, got, nil) } } +func testQueryError(t *testing.T, query string, args ...interface{}) { + testQuery(t, true, func(t *testing.T, db *sql.DB) { + rows, err := db.Query(query, args...) + if err == nil { + t.Error("Expected an error got:", err) + } + if rows != nil { + t.Error("Returned rows should be nil on error!") + // Attempt to iterate over rows to make sure they don't panic. + for i := 0; rows.Next(); i++ { + if i > 1_000 { + t.Fatal("To many iterations of rows.Next():", i) + } + } + if err := rows.Err(); err != nil { + t.Error(err) + } + rows.Close() + } + }) +} + +func TestQueryNotEnoughArgs(t *testing.T) { + testQueryError(t, "SELECT FROM foo WHERE id = ? OR id = ?);", 1) +} + +func TestQueryTooManyArgs(t *testing.T) { + // TODO: test error message / kind + testQueryError(t, "SELECT FROM foo WHERE id = ?);", 1, 2) +} + +func TestQueryMultipleStatements(t *testing.T) { + testQueryError(t, "SELECT 1; SELECT 2;") +} + +func TestQueryInvalidTable(t *testing.T) { + testQueryError(t, "SELECT COUNT(*) FROM does_not_exist;") +} + func TestStress(t *testing.T) { tempFilename := TempFilename(t) defer os.Remove(tempFilename) @@ -2111,6 +2203,7 @@ var benchmarks = []testing.InternalBenchmark{ {Name: "BenchmarkStmt", F: benchmarkStmt}, {Name: "BenchmarkRows", F: benchmarkRows}, {Name: "BenchmarkStmtRows", F: benchmarkStmtRows}, + {Name: "BenchmarkExecStep", F: benchmarkExecStep}, } func (db *TestDB) mustExec(sql string, args ...any) sql.Result { @@ -2568,3 +2661,13 @@ func benchmarkStmtRows(b *testing.B) { } } } + +var largeSelectStmt = strings.Repeat("select 1;\n", 1_000) + +func benchmarkExecStep(b *testing.B) { + for n := 0; n < b.N; n++ { + if _, err := db.Exec(largeSelectStmt); err != nil { + b.Fatal(err) + } + } +}