Skip to content

Commit

Permalink
Improve Query performance and return an error if the query includes m…
Browse files Browse the repository at this point in the history
…ultiple statements

This commit changes Query to return an error if more than one SQL
statement is provided. Previously, this library would only execute the
last query statement. It also improves query construction performance by
~15%.

This is a breaking a change since existing programs may rely on the
broken mattn/go-sqlite3 implementation. That said, any program relying
on this is also broken / using sqlite3 incorrectly.

```
goos: darwin
goarch: arm64
pkg: github.com/charlievieth/go-sqlite3
cpu: Apple M4 Pro
                              │   x1.txt    │               x2.txt                │
                              │   sec/op    │   sec/op     vs base                │
Suite/BenchmarkQuery-14         2.255µ ± 1%   1.837µ ± 1%  -18.56% (p=0.000 n=10)
Suite/BenchmarkQuerySimple-14   1.322µ ± 9%   1.124µ ± 4%  -15.02% (p=0.000 n=10)
geomean                         1.727µ        1.436µ       -16.81%

                              │   x1.txt   │              x2.txt               │
                              │    B/op    │    B/op     vs base               │
Suite/BenchmarkQuery-14         664.0 ± 0%   656.0 ± 0%  -1.20% (p=0.000 n=10)
Suite/BenchmarkQuerySimple-14   472.0 ± 0%   456.0 ± 0%  -3.39% (p=0.000 n=10)
geomean                         559.8        546.9       -2.30%

                              │   x1.txt   │              x2.txt               │
                              │ allocs/op  │ allocs/op   vs base               │
Suite/BenchmarkQuery-14         23.00 ± 0%   22.00 ± 0%  -4.35% (p=0.000 n=10)
Suite/BenchmarkQuerySimple-14   14.00 ± 0%   13.00 ± 0%  -7.14% (p=0.000 n=10)
geomean                         17.94        16.91       -5.76%
```
  • Loading branch information
charlievieth committed Nov 28, 2024
1 parent 976152a commit 8ba2e97
Show file tree
Hide file tree
Showing 2 changed files with 253 additions and 36 deletions.
121 changes: 85 additions & 36 deletions sqlite3.go
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,59 @@ _sqlite3_prepare_v2_internal(sqlite3 *db, const char *zSql, int nBytes, sqlite3_
}
#endif
#define GO_SQLITE_MULTIPLE_QUERIES -1
// Our own implementation of ctype.h's isspace (for simplicity and to avoid
// whatever locale shenanigans are involved with the Libc's isspace).
static int _sqlite3_isspace(unsigned char c) {
return c == ' ' || c - '\t' < 5;
}
static int _sqlite3_prepare_query(sqlite3 *db, const char *zSql, int nBytes,
sqlite3_stmt **ppStmt, int *paramCount) {
const char *tail;
int rc = _sqlite3_prepare_v2_internal(db, zSql, nBytes, ppStmt, &tail);
if (rc != SQLITE_OK) {
return rc;
}
*paramCount = sqlite3_bind_parameter_count(*ppStmt);
// Check if the SQL query contains multiple statements.
// Trim leading space to handle queries with trailing whitespace.
// This can save us an additional call to sqlite3_prepare_v2.
const char *end = zSql + nBytes;
while (tail < end && _sqlite3_isspace(*tail)) {
tail++;
}
nBytes -= (tail - zSql);
// Attempt to parse the remaining SQL, if any.
if (nBytes > 0 && *tail) {
sqlite3_stmt *stmt;
rc = _sqlite3_prepare_v2_internal(db, tail, nBytes, &stmt, NULL);
if (rc != SQLITE_OK) {
// sqlite3 will return OK and a NULL statement if it was
goto error;
}
if (stmt != NULL) {
sqlite3_finalize(stmt);
rc = GO_SQLITE_MULTIPLE_QUERIES;
goto error;
}
}
// Ok, the SQL contained one valid statement.
return SQLITE_OK;
error:
if (*ppStmt) {
sqlite3_finalize(*ppStmt);
}
return rc;
}
static int _sqlite3_prepare_v2(sqlite3 *db, const char *zSql, int nBytes, sqlite3_stmt **ppStmt, int *oBytes) {
const char *tail = NULL;
int rv = _sqlite3_prepare_v2_internal(db, zSql, nBytes, ppStmt, &tail);
Expand Down Expand Up @@ -1123,46 +1176,42 @@ func (c *SQLiteConn) Query(query string, args []driver.Value) (driver.Rows, erro
return c.query(context.Background(), query, list)
}

var closedRows = &SQLiteRows{s: &SQLiteStmt{closed: true}, closed: true}

func (c *SQLiteConn) query(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) {
start := 0
for {
stmtArgs := make([]driver.NamedValue, 0, len(args))
s, err := c.prepare(ctx, query)
if err != nil {
return nil, err
}
s.(*SQLiteStmt).cls = true
na := s.NumInput()
if len(args)-start < na {
s.Close()
return nil, fmt.Errorf("not enough args to execute query: want %d got %d", na, len(args)-start)
}
// consume the number of arguments used in the current
// statement and append all named arguments not contained
// therein
stmtArgs = append(stmtArgs, args[start:start+na]...)
for i := range args {
if (i < start || i >= na) && args[i].Name != "" {
stmtArgs = append(stmtArgs, args[i])
}
}
for i := range stmtArgs {
stmtArgs[i].Ordinal = i + 1
}
rows, err := s.(*SQLiteStmt).query(ctx, stmtArgs)
if err != nil && err != driver.ErrSkip {
s.Close()
return rows, err
s := SQLiteStmt{c: c, cls: true}
p := stringData(query)
var paramCount C.int
rv := C._sqlite3_prepare_query(c.db, (*C.char)(unsafe.Pointer(p)), C.int(len(query)), &s.s, &paramCount)
if rv != C.SQLITE_OK {
if rv == C.GO_SQLITE_MULTIPLE_QUERIES {
return nil, errors.New("query contains multiple SQL statements")
}
start += na
tail := s.(*SQLiteStmt).t
if tail == "" {
return rows, nil
return nil, c.lastError()
}

// The sqlite3_stmt will be nil if the SQL was valid but did not
// contain a query. For now we're supporting this for the sake of
// backwards compatibility, but that may change in the future.
if s.s == nil {
return closedRows, nil
}

na := int(paramCount)
if n := len(args); n != na {
s.finalize()
if n < na {
return nil, fmt.Errorf("not enough args to execute query: want %d got %d", na, len(args))
}
rows.Close()
s.Close()
query = tail
return nil, fmt.Errorf("too many args to execute query: want %d got %d", na, len(args))
}

rows, err := s.query(ctx, args)
if err != nil && err != driver.ErrSkip {
s.finalize()
return rows, err
}
return rows, nil
}

// Begin transaction.
Expand Down
168 changes: 168 additions & 0 deletions sqlite3_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import (
"math/rand"
"net/url"
"os"
"path/filepath"
"reflect"
"regexp"
"runtime"
Expand Down Expand Up @@ -1203,6 +1204,163 @@ func TestQueryer(t *testing.T) {
}
}

func testQuery(t *testing.T, test func(t *testing.T, db *sql.DB)) {
db, err := sql.Open("sqlite3", filepath.Join(t.TempDir(), "test.sqlite3"))
if err != nil {
t.Fatal("Failed to open database:", err)
}
defer db.Close()

_, err = db.Exec(`
CREATE TABLE FOO (id INTEGER);
INSERT INTO foo(id) VALUES(?);
INSERT INTO foo(id) VALUES(?);
INSERT INTO foo(id) VALUES(?);
`, 3, 2, 1)
if err != nil {
t.Fatal(err)
}

// Capture panic so tests can continue
defer func() {
if e := recover(); e != nil {
buf := make([]byte, 32*1024)
n := runtime.Stack(buf, false)
t.Fatalf("\npanic: %v\n\n%s\n", e, buf[:n])
}
}()
test(t, db)
}

func testQueryValues(t *testing.T, query string, args ...interface{}) []interface{} {
var values []interface{}
testQuery(t, func(t *testing.T, db *sql.DB) {
rows, err := db.Query(query, args...)
if err != nil {
t.Fatal(err)
}
if rows == nil {
t.Fatal("nil rows")
}
for i := 0; rows.Next(); i++ {
if i > 1_000 {
t.Fatal("To many iterations of rows.Next():", i)
}
var v interface{}
if err := rows.Scan(&v); err != nil {
t.Fatal(err)
}
values = append(values, v)
}
if err := rows.Err(); err != nil {
t.Fatal(err)
}
if err := rows.Close(); err != nil {
t.Fatal(err)
}
})
return values
}

func TestQuery(t *testing.T) {
queries := []struct {
query string
args []interface{}
}{
{"SELECT id FROM foo ORDER BY id;", nil},
{"SELECT id FROM foo WHERE id != ? ORDER BY id;", []interface{}{4}},
{"SELECT id FROM foo WHERE id IN (?, ?, ?) ORDER BY id;", []interface{}{1, 2, 3}},

// Comments
{"SELECT id FROM foo ORDER BY id; -- comment", nil},
{"SELECT id FROM foo ORDER BY id -- comment", nil}, // Not terminated
{"SELECT id FROM foo ORDER BY id;\n -- comment\n", nil},
{
`-- FOO
SELECT id FROM foo ORDER BY id; -- BAR
/* BAZ */`,
nil,
},
}
want := []interface{}{
int64(1),
int64(2),
int64(3),
}
for _, q := range queries {
t.Run("", func(t *testing.T) {
got := testQueryValues(t, q.query, q.args...)
if !reflect.DeepEqual(got, want) {
t.Fatalf("Query(%q, %v) = %v; want: %v", q.query, q.args, got, want)
}
})
}
}

func TestQueryNoSQL(t *testing.T) {
got := testQueryValues(t, "")
if got != nil {
t.Fatalf("Query(%q, %v) = %v; want: %v", "", nil, got, nil)
}
}

func testQueryError(t *testing.T, query string, args ...interface{}) {
testQuery(t, func(t *testing.T, db *sql.DB) {
rows, err := db.Query(query, args...)
if err == nil {
t.Error("Expected an error got:", err)
}
if rows != nil {
t.Error("Returned rows should be nil on error!")
// Attempt to iterate over rows to make sure they don't panic.
for i := 0; rows.Next(); i++ {
if i > 1_000 {
t.Fatal("To many iterations of rows.Next():", i)
}
}
if err := rows.Err(); err != nil {
t.Error(err)
}
rows.Close()
}
})
}

func TestQueryNotEnoughArgs(t *testing.T) {
testQueryError(t, "SELECT FROM foo WHERE id = ? OR id = ?);", 1)
}

func TestQueryTooManyArgs(t *testing.T) {
// TODO: test error message / kind
testQueryError(t, "SELECT FROM foo WHERE id = ?);", 1, 2)
}

func TestQueryMultipleStatements(t *testing.T) {
testQueryError(t, "SELECT 1; SELECT 2;")
testQueryError(t, "SELECT 1; SELECT 2; SELECT 3;")
testQueryError(t, "SELECT 1; ; SELECT 2;") // Empty statement in between
testQueryError(t, "SELECT 1; FOOBAR 2;") // Error in second statement

// Test that multiple trailing semicolons (";;") are not an error
noError := func(t *testing.T, query string, args ...any) {
testQuery(t, func(t *testing.T, db *sql.DB) {
var n int64
if err := db.QueryRow(query, args...).Scan(&n); err != nil {
t.Fatal(err)
}
if n != 1 {
t.Fatalf("got: %d want: %d", n, 1)
}
})
}
noError(t, "SELECT 1; ;")
noError(t, "SELECT ?; ;", 1)
}

func TestQueryInvalidTable(t *testing.T) {
testQueryError(t, "SELECT COUNT(*) FROM does_not_exist;")
}

func TestStress(t *testing.T) {
tempFilename := TempFilename(t)
defer os.Remove(tempFilename)
Expand Down Expand Up @@ -2180,6 +2338,7 @@ var benchmarks = []testing.InternalBenchmark{
{Name: "BenchmarkExecContextStep", F: benchmarkExecContextStep},
{Name: "BenchmarkExecTx", F: benchmarkExecTx},
{Name: "BenchmarkQuery", F: benchmarkQuery},
{Name: "BenchmarkQuerySimple", F: benchmarkQuerySimple},
{Name: "BenchmarkQueryContext", F: benchmarkQueryContext},
{Name: "BenchmarkParams", F: benchmarkParams},
{Name: "BenchmarkStmt", F: benchmarkStmt},
Expand Down Expand Up @@ -2619,6 +2778,15 @@ func benchmarkQuery(b *testing.B) {
}
}

func benchmarkQuerySimple(b *testing.B) {
for i := 0; i < b.N; i++ {
var n int
if err := db.QueryRow("select 1;").Scan(&n); err != nil {
panic(err)
}
}
}

// benchmarkQueryContext is benchmark for QueryContext
func benchmarkQueryContext(b *testing.B) {
const createTableStmt = `
Expand Down

0 comments on commit 8ba2e97

Please sign in to comment.