diff --git a/sqlite3.go b/sqlite3.go index a628b02e..e6b8c166 100644 --- a/sqlite3.go +++ b/sqlite3.go @@ -137,6 +137,59 @@ _sqlite3_prepare_v2_internal(sqlite3 *db, const char *zSql, int nBytes, sqlite3_ } #endif +#define GO_SQLITE_MULTIPLE_QUERIES -1 + +// Our own implementation of ctype.h's isspace (for simplicity and to avoid +// whatever locale shenanigans are involved with the Libc's isspace). +static int _sqlite3_isspace(unsigned char c) { + return c == ' ' || c - '\t' < 5; +} + +static int _sqlite3_prepare_query(sqlite3 *db, const char *zSql, int nBytes, + sqlite3_stmt **ppStmt, int *paramCount) { + + const char *tail; + int rc = _sqlite3_prepare_v2_internal(db, zSql, nBytes, ppStmt, &tail); + if (rc != SQLITE_OK) { + return rc; + } + *paramCount = sqlite3_bind_parameter_count(*ppStmt); + + // Check if the SQL query contains multiple statements. + + // Trim leading space to handle queries with trailing whitespace. + // This can save us an additional call to sqlite3_prepare_v2. + const char *end = zSql + nBytes; + while (tail < end && _sqlite3_isspace(*tail)) { + tail++; + } + nBytes -= (tail - zSql); + + // Attempt to parse the remaining SQL, if any. + if (nBytes > 0 && *tail) { + sqlite3_stmt *stmt; + rc = _sqlite3_prepare_v2_internal(db, tail, nBytes, &stmt, NULL); + if (rc != SQLITE_OK) { + // sqlite3 will return OK and a NULL statement if it was + goto error; + } + if (stmt != NULL) { + sqlite3_finalize(stmt); + rc = GO_SQLITE_MULTIPLE_QUERIES; + goto error; + } + } + + // Ok, the SQL contained one valid statement. + return SQLITE_OK; + +error: + if (*ppStmt) { + sqlite3_finalize(*ppStmt); + } + return rc; +} + static int _sqlite3_prepare_v2(sqlite3 *db, const char *zSql, int nBytes, sqlite3_stmt **ppStmt, int *oBytes) { const char *tail = NULL; int rv = _sqlite3_prepare_v2_internal(db, zSql, nBytes, ppStmt, &tail); @@ -1123,46 +1176,42 @@ func (c *SQLiteConn) Query(query string, args []driver.Value) (driver.Rows, erro return c.query(context.Background(), query, list) } +var closedRows = &SQLiteRows{s: &SQLiteStmt{closed: true}, closed: true} + func (c *SQLiteConn) query(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) { - start := 0 - for { - stmtArgs := make([]driver.NamedValue, 0, len(args)) - s, err := c.prepare(ctx, query) - if err != nil { - return nil, err - } - s.(*SQLiteStmt).cls = true - na := s.NumInput() - if len(args)-start < na { - s.Close() - return nil, fmt.Errorf("not enough args to execute query: want %d got %d", na, len(args)-start) - } - // consume the number of arguments used in the current - // statement and append all named arguments not contained - // therein - stmtArgs = append(stmtArgs, args[start:start+na]...) - for i := range args { - if (i < start || i >= na) && args[i].Name != "" { - stmtArgs = append(stmtArgs, args[i]) - } - } - for i := range stmtArgs { - stmtArgs[i].Ordinal = i + 1 - } - rows, err := s.(*SQLiteStmt).query(ctx, stmtArgs) - if err != nil && err != driver.ErrSkip { - s.Close() - return rows, err + s := SQLiteStmt{c: c, cls: true} + p := stringData(query) + var paramCount C.int + rv := C._sqlite3_prepare_query(c.db, (*C.char)(unsafe.Pointer(p)), C.int(len(query)), &s.s, ¶mCount) + if rv != C.SQLITE_OK { + if rv == C.GO_SQLITE_MULTIPLE_QUERIES { + return nil, errors.New("query contains multiple SQL statements") } - start += na - tail := s.(*SQLiteStmt).t - if tail == "" { - return rows, nil + return nil, c.lastError() + } + + // The sqlite3_stmt will be nil if the SQL was valid but did not + // contain a query. For now we're supporting this for the sake of + // backwards compatibility, but that may change in the future. + if s.s == nil { + return closedRows, nil + } + + na := int(paramCount) + if n := len(args); n != na { + s.finalize() + if n < na { + return nil, fmt.Errorf("not enough args to execute query: want %d got %d", na, len(args)) } - rows.Close() - s.Close() - query = tail + return nil, fmt.Errorf("too many args to execute query: want %d got %d", na, len(args)) } + + rows, err := s.query(ctx, args) + if err != nil && err != driver.ErrSkip { + s.finalize() + return rows, err + } + return rows, nil } // Begin transaction. diff --git a/sqlite3_test.go b/sqlite3_test.go index 89682d41..73c3fd9a 100644 --- a/sqlite3_test.go +++ b/sqlite3_test.go @@ -19,6 +19,7 @@ import ( "math/rand" "net/url" "os" + "path/filepath" "reflect" "regexp" "runtime" @@ -1203,6 +1204,163 @@ func TestQueryer(t *testing.T) { } } +func testQuery(t *testing.T, test func(t *testing.T, db *sql.DB)) { + db, err := sql.Open("sqlite3", filepath.Join(t.TempDir(), "test.sqlite3")) + if err != nil { + t.Fatal("Failed to open database:", err) + } + defer db.Close() + + _, err = db.Exec(` + CREATE TABLE FOO (id INTEGER); + INSERT INTO foo(id) VALUES(?); + INSERT INTO foo(id) VALUES(?); + INSERT INTO foo(id) VALUES(?); + `, 3, 2, 1) + if err != nil { + t.Fatal(err) + } + + // Capture panic so tests can continue + defer func() { + if e := recover(); e != nil { + buf := make([]byte, 32*1024) + n := runtime.Stack(buf, false) + t.Fatalf("\npanic: %v\n\n%s\n", e, buf[:n]) + } + }() + test(t, db) +} + +func testQueryValues(t *testing.T, query string, args ...interface{}) []interface{} { + var values []interface{} + testQuery(t, func(t *testing.T, db *sql.DB) { + rows, err := db.Query(query, args...) + if err != nil { + t.Fatal(err) + } + if rows == nil { + t.Fatal("nil rows") + } + for i := 0; rows.Next(); i++ { + if i > 1_000 { + t.Fatal("To many iterations of rows.Next():", i) + } + var v interface{} + if err := rows.Scan(&v); err != nil { + t.Fatal(err) + } + values = append(values, v) + } + if err := rows.Err(); err != nil { + t.Fatal(err) + } + if err := rows.Close(); err != nil { + t.Fatal(err) + } + }) + return values +} + +func TestQuery(t *testing.T) { + queries := []struct { + query string + args []interface{} + }{ + {"SELECT id FROM foo ORDER BY id;", nil}, + {"SELECT id FROM foo WHERE id != ? ORDER BY id;", []interface{}{4}}, + {"SELECT id FROM foo WHERE id IN (?, ?, ?) ORDER BY id;", []interface{}{1, 2, 3}}, + + // Comments + {"SELECT id FROM foo ORDER BY id; -- comment", nil}, + {"SELECT id FROM foo ORDER BY id -- comment", nil}, // Not terminated + {"SELECT id FROM foo ORDER BY id;\n -- comment\n", nil}, + { + `-- FOO + SELECT id FROM foo ORDER BY id; -- BAR + /* BAZ */`, + nil, + }, + } + want := []interface{}{ + int64(1), + int64(2), + int64(3), + } + for _, q := range queries { + t.Run("", func(t *testing.T) { + got := testQueryValues(t, q.query, q.args...) + if !reflect.DeepEqual(got, want) { + t.Fatalf("Query(%q, %v) = %v; want: %v", q.query, q.args, got, want) + } + }) + } +} + +func TestQueryNoSQL(t *testing.T) { + got := testQueryValues(t, "") + if got != nil { + t.Fatalf("Query(%q, %v) = %v; want: %v", "", nil, got, nil) + } +} + +func testQueryError(t *testing.T, query string, args ...interface{}) { + testQuery(t, func(t *testing.T, db *sql.DB) { + rows, err := db.Query(query, args...) + if err == nil { + t.Error("Expected an error got:", err) + } + if rows != nil { + t.Error("Returned rows should be nil on error!") + // Attempt to iterate over rows to make sure they don't panic. + for i := 0; rows.Next(); i++ { + if i > 1_000 { + t.Fatal("To many iterations of rows.Next():", i) + } + } + if err := rows.Err(); err != nil { + t.Error(err) + } + rows.Close() + } + }) +} + +func TestQueryNotEnoughArgs(t *testing.T) { + testQueryError(t, "SELECT FROM foo WHERE id = ? OR id = ?);", 1) +} + +func TestQueryTooManyArgs(t *testing.T) { + // TODO: test error message / kind + testQueryError(t, "SELECT FROM foo WHERE id = ?);", 1, 2) +} + +func TestQueryMultipleStatements(t *testing.T) { + testQueryError(t, "SELECT 1; SELECT 2;") + testQueryError(t, "SELECT 1; SELECT 2; SELECT 3;") + testQueryError(t, "SELECT 1; ; SELECT 2;") // Empty statement in between + testQueryError(t, "SELECT 1; FOOBAR 2;") // Error in second statement + + // Test that multiple trailing semicolons (";;") are not an error + noError := func(t *testing.T, query string, args ...any) { + testQuery(t, func(t *testing.T, db *sql.DB) { + var n int64 + if err := db.QueryRow(query, args...).Scan(&n); err != nil { + t.Fatal(err) + } + if n != 1 { + t.Fatalf("got: %d want: %d", n, 1) + } + }) + } + noError(t, "SELECT 1; ;") + noError(t, "SELECT ?; ;", 1) +} + +func TestQueryInvalidTable(t *testing.T) { + testQueryError(t, "SELECT COUNT(*) FROM does_not_exist;") +} + func TestStress(t *testing.T) { tempFilename := TempFilename(t) defer os.Remove(tempFilename) @@ -2180,6 +2338,7 @@ var benchmarks = []testing.InternalBenchmark{ {Name: "BenchmarkExecContextStep", F: benchmarkExecContextStep}, {Name: "BenchmarkExecTx", F: benchmarkExecTx}, {Name: "BenchmarkQuery", F: benchmarkQuery}, + {Name: "BenchmarkQuerySimple", F: benchmarkQuerySimple}, {Name: "BenchmarkQueryContext", F: benchmarkQueryContext}, {Name: "BenchmarkParams", F: benchmarkParams}, {Name: "BenchmarkStmt", F: benchmarkStmt}, @@ -2619,6 +2778,15 @@ func benchmarkQuery(b *testing.B) { } } +func benchmarkQuerySimple(b *testing.B) { + for i := 0; i < b.N; i++ { + var n int + if err := db.QueryRow("select 1;").Scan(&n); err != nil { + panic(err) + } + } +} + // benchmarkQueryContext is benchmark for QueryContext func benchmarkQueryContext(b *testing.B) { const createTableStmt = `