Skip to content

Commit c816a29

Browse files
committed
sqlite3: handle trailing comments and multiple SQL statements in Queries
This commit fixes *SQLiteConn.Query to properly handle trailing comments after a SQL query statement. Previously, trailing comments could lead to an infinite loop. It also changes Query to error if the provided SQL statement contains multiple queries ("SELECT 1; SELECT 2;") - previously only the last query was executed ("SELECT 1; SELECT 2;" would yield only 2). This may be a breaking change as previously: Query consumed all of its args - despite only using the last query (Query now only uses the args required to satisfy the first query and errors if there is a mismatch); Query used only the last query and there may be code using this library that depends on this behavior. Personally, I believe the behavior introduced by this commit is correct and any code relying on the prior undocumented behavior incorrect, but it could still be a break.
1 parent 5880fdc commit c816a29

File tree

2 files changed

+168
-106
lines changed

2 files changed

+168
-106
lines changed

sqlite3.go

Lines changed: 36 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@ package sqlite3
3131
#endif
3232
#include <stdlib.h>
3333
#include <string.h>
34-
#include <ctype.h>
3534
3635
#ifdef __CYGWIN__
3736
# include <errno.h>
@@ -80,16 +79,6 @@ _sqlite3_exec(sqlite3* db, const char* pcmd, long long* rowid, long long* change
8079
return rv;
8180
}
8281
83-
static const char *
84-
_trim_leading_spaces(const char *str) {
85-
if (str) {
86-
while (isspace(*str)) {
87-
str++;
88-
}
89-
}
90-
return str;
91-
}
92-
9382
#ifdef SQLITE_ENABLE_UNLOCK_NOTIFY
9483
extern int _sqlite3_step_blocking(sqlite3_stmt *stmt);
9584
extern int _sqlite3_step_row_blocking(sqlite3_stmt* stmt, long long* rowid, long long* changes);
@@ -110,11 +99,7 @@ _sqlite3_step_row_internal(sqlite3_stmt* stmt, long long* rowid, long long* chan
11099
static int
111100
_sqlite3_prepare_v2_internal(sqlite3 *db, const char *zSql, int nBytes, sqlite3_stmt **ppStmt, const char **pzTail)
112101
{
113-
int rv = _sqlite3_prepare_v2_blocking(db, zSql, nBytes, ppStmt, pzTail);
114-
if (pzTail) {
115-
*pzTail = _trim_leading_spaces(*pzTail);
116-
}
117-
return rv;
102+
return _sqlite3_prepare_v2_blocking(db, zSql, nBytes, ppStmt, pzTail);
118103
}
119104
120105
#else
@@ -137,12 +122,9 @@ _sqlite3_step_row_internal(sqlite3_stmt* stmt, long long* rowid, long long* chan
137122
static int
138123
_sqlite3_prepare_v2_internal(sqlite3 *db, const char *zSql, int nBytes, sqlite3_stmt **ppStmt, const char **pzTail)
139124
{
140-
int rv = sqlite3_prepare_v2(db, zSql, nBytes, ppStmt, pzTail);
141-
if (pzTail) {
142-
*pzTail = _trim_leading_spaces(*pzTail);
143-
}
144-
return rv;
125+
return sqlite3_prepare_v2(db, zSql, nBytes, ppStmt, pzTail);
145126
}
127+
146128
#endif
147129
148130
void _sqlite3_result_text(sqlite3_context* ctx, const char* s) {
@@ -938,46 +920,44 @@ func (c *SQLiteConn) query(ctx context.Context, query string, args []driver.Name
938920
op := pquery // original pointer
939921
defer C.free(unsafe.Pointer(op))
940922

941-
var stmtArgs []driver.NamedValue
942923
var tail *C.char
943-
s := new(SQLiteStmt) // escapes to the heap so reuse it
944-
start := 0
945-
for {
946-
*s = SQLiteStmt{c: c, cls: true} // reset
947-
rv := C._sqlite3_prepare_v2_internal(c.db, pquery, C.int(-1), &s.s, &tail)
948-
if rv != C.SQLITE_OK {
949-
return nil, c.lastError()
924+
s := &SQLiteStmt{c: c, cls: true}
925+
rv := C._sqlite3_prepare_v2_internal(c.db, pquery, C.int(-1), &s.s, &tail)
926+
if rv != C.SQLITE_OK {
927+
return nil, c.lastError()
928+
}
929+
if s.s == nil {
930+
return &SQLiteRows{s: &SQLiteStmt{closed: true}, closed: true}, nil
931+
}
932+
na := s.NumInput()
933+
if n := len(args); n != na {
934+
s.finalize()
935+
if n < na {
936+
return nil, fmt.Errorf("not enough args to execute query: want %d got %d", na, len(args))
950937
}
938+
return nil, fmt.Errorf("too many args to execute query: want %d got %d", na, len(args))
939+
}
940+
rows, err := s.query(ctx, args)
941+
if err != nil && err != driver.ErrSkip {
942+
s.finalize() // WARN
943+
return rows, err
944+
}
951945

952-
na := s.NumInput()
953-
if len(args)-start < na {
954-
return nil, fmt.Errorf("not enough args to execute query: want %d got %d", na, len(args)-start)
955-
}
956-
// consume the number of arguments used in the current
957-
// statement and append all named arguments not contained
958-
// therein
959-
stmtArgs = append(stmtArgs[:0], args[start:start+na]...)
960-
for i := range args {
961-
if (i < start || i >= na) && args[i].Name != "" {
962-
stmtArgs = append(stmtArgs, args[i])
963-
}
964-
}
965-
for i := range stmtArgs {
966-
stmtArgs[i].Ordinal = i + 1
967-
}
968-
rows, err := s.query(ctx, stmtArgs)
969-
if err != nil && err != driver.ErrSkip {
970-
s.finalize()
971-
return rows, err
946+
// Consume the rest of the query
947+
for pquery = tail; pquery != nil && *pquery != 0; pquery = tail {
948+
var stmt *C.sqlite3_stmt
949+
rv := C._sqlite3_prepare_v2_internal(c.db, pquery, C.int(-1), &stmt, &tail)
950+
if rv != C.SQLITE_OK {
951+
rows.Close()
952+
return nil, c.lastError()
972953
}
973-
start += na
974-
if tail == nil || *tail == '\000' {
975-
return rows, nil
954+
if stmt != nil {
955+
rows.Close()
956+
return nil, errors.New("cannot insert multiple commands into a prepared statement") // WARN
976957
}
977-
rows.Close()
978-
s.finalize()
979-
pquery = tail
980958
}
959+
960+
return rows, nil
981961
}
982962

983963
// Begin transaction.
@@ -2029,7 +2009,7 @@ func (s *SQLiteStmt) Query(args []driver.Value) (driver.Rows, error) {
20292009
return s.query(context.Background(), list)
20302010
}
20312011

2032-
func (s *SQLiteStmt) query(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) {
2012+
func (s *SQLiteStmt) query(ctx context.Context, args []driver.NamedValue) (*SQLiteRows, error) {
20332013
if err := s.bind(args); err != nil {
20342014
return nil, err
20352015
}

sqlite3_test.go

Lines changed: 132 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ import (
1818
"math/rand"
1919
"net/url"
2020
"os"
21+
"path/filepath"
2122
"reflect"
2223
"regexp"
2324
"runtime"
@@ -1080,67 +1081,158 @@ func TestExecer(t *testing.T) {
10801081
defer db.Close()
10811082

10821083
_, err = db.Exec(`
1083-
create table foo (id integer); -- one comment
1084-
insert into foo(id) values(?);
1085-
insert into foo(id) values(?);
1086-
insert into foo(id) values(?); -- another comment
1084+
CREATE TABLE foo (id INTEGER); -- one comment
1085+
INSERT INTO foo(id) VALUES(?);
1086+
INSERT INTO foo(id) VALUES(?);
1087+
INSERT INTO foo(id) VALUES(?); -- another comment
10871088
`, 1, 2, 3)
10881089
if err != nil {
10891090
t.Error("Failed to call db.Exec:", err)
10901091
}
10911092
}
10921093

1093-
func TestQueryer(t *testing.T) {
1094-
tempFilename := TempFilename(t)
1095-
defer os.Remove(tempFilename)
1096-
db, err := sql.Open("sqlite3", tempFilename)
1094+
func testQuery(t *testing.T, seed bool, test func(t *testing.T, db *sql.DB)) {
1095+
db, err := sql.Open("sqlite3", filepath.Join(t.TempDir(), "test.sqlite3"))
10971096
if err != nil {
10981097
t.Fatal("Failed to open database:", err)
10991098
}
11001099
defer db.Close()
11011100

1102-
_, err = db.Exec(`
1103-
create table foo (id integer);
1104-
`)
1105-
if err != nil {
1106-
t.Error("Failed to call db.Query:", err)
1101+
if seed {
1102+
if _, err := db.Exec(`create table foo (id integer);`); err != nil {
1103+
t.Fatal(err)
1104+
}
1105+
_, err := db.Exec(`
1106+
INSERT INTO foo(id) VALUES(?);
1107+
INSERT INTO foo(id) VALUES(?);
1108+
INSERT INTO foo(id) VALUES(?);
1109+
`, 3, 2, 1)
1110+
if err != nil {
1111+
t.Fatal(err)
1112+
}
11071113
}
11081114

1109-
_, err = db.Exec(`
1110-
insert into foo(id) values(?);
1111-
insert into foo(id) values(?);
1112-
insert into foo(id) values(?);
1113-
`, 3, 2, 1)
1114-
if err != nil {
1115-
t.Error("Failed to call db.Exec:", err)
1116-
}
1117-
rows, err := db.Query(`
1118-
select id from foo order by id;
1119-
`)
1120-
if err != nil {
1121-
t.Error("Failed to call db.Query:", err)
1122-
}
1123-
defer rows.Close()
1124-
n := 0
1125-
for rows.Next() {
1126-
var id int
1127-
err = rows.Scan(&id)
1115+
// Capture panic so tests can continue
1116+
defer func() {
1117+
if e := recover(); e != nil {
1118+
buf := make([]byte, 32*1024)
1119+
n := runtime.Stack(buf, false)
1120+
t.Fatalf("\npanic: %v\n\n%s\n", e, buf[:n])
1121+
}
1122+
}()
1123+
test(t, db)
1124+
}
1125+
1126+
func testQueryValues(t *testing.T, query string, args ...interface{}) []interface{} {
1127+
var values []interface{}
1128+
testQuery(t, true, func(t *testing.T, db *sql.DB) {
1129+
rows, err := db.Query(query, args...)
11281130
if err != nil {
1129-
t.Error("Failed to db.Query:", err)
1131+
t.Fatal(err)
11301132
}
1131-
if id != n + 1 {
1132-
t.Error("Failed to db.Query: not matched results")
1133+
if rows == nil {
1134+
t.Fatal("nil rows")
11331135
}
1134-
n = n + 1
1136+
for i := 0; rows.Next(); i++ {
1137+
if i > 1_000 {
1138+
t.Fatal("To many iterations of rows.Next():", i)
1139+
}
1140+
var v interface{}
1141+
if err := rows.Scan(&v); err != nil {
1142+
t.Fatal(err)
1143+
}
1144+
values = append(values, v)
1145+
}
1146+
if err := rows.Err(); err != nil {
1147+
t.Fatal(err)
1148+
}
1149+
if err := rows.Close(); err != nil {
1150+
t.Fatal(err)
1151+
}
1152+
})
1153+
return values
1154+
}
1155+
1156+
func TestQuery(t *testing.T) {
1157+
queries := []struct {
1158+
query string
1159+
args []interface{}
1160+
}{
1161+
{"SELECT id FROM foo ORDER BY id;", nil},
1162+
{"SELECT id FROM foo WHERE id != ? ORDER BY id;", []interface{}{4}},
1163+
{"SELECT id FROM foo WHERE id IN (?, ?, ?) ORDER BY id;", []interface{}{1, 2, 3}},
1164+
1165+
// Comments
1166+
{"SELECT id FROM foo ORDER BY id; -- comment", nil},
1167+
{"SELECT id FROM foo ORDER BY id;\n -- comment\n", nil},
1168+
{
1169+
`-- FOO
1170+
SELECT id FROM foo ORDER BY id; -- BAR
1171+
/* BAZ */`,
1172+
nil,
1173+
},
11351174
}
1136-
if err := rows.Err(); err != nil {
1137-
t.Errorf("Post-scan failed: %v\n", err)
1175+
want := []interface{}{
1176+
int64(1),
1177+
int64(2),
1178+
int64(3),
1179+
}
1180+
for _, q := range queries {
1181+
t.Run("", func(t *testing.T) {
1182+
got := testQueryValues(t, q.query, q.args...)
1183+
if !reflect.DeepEqual(got, want) {
1184+
t.Fatalf("Query(%q, %v) = %v; want: %v", q.query, q.args, got, want)
1185+
}
1186+
})
11381187
}
1139-
if n != 3 {
1140-
t.Errorf("Expected 3 rows but retrieved %v", n)
1188+
}
1189+
1190+
func TestQueryNoSQL(t *testing.T) {
1191+
got := testQueryValues(t, "")
1192+
if got != nil {
1193+
t.Fatalf("Query(%q, %v) = %v; want: %v", "", nil, got, nil)
11411194
}
11421195
}
11431196

1197+
func testQueryError(t *testing.T, query string, args ...interface{}) {
1198+
testQuery(t, true, func(t *testing.T, db *sql.DB) {
1199+
rows, err := db.Query(query, args...)
1200+
if err == nil {
1201+
t.Error("Expected an error got:", err)
1202+
}
1203+
if rows != nil {
1204+
t.Error("Returned rows should be nil on error!")
1205+
// Attempt to iterate over rows to make sure they don't panic.
1206+
for i := 0; rows.Next(); i++ {
1207+
if i > 1_000 {
1208+
t.Fatal("To many iterations of rows.Next():", i)
1209+
}
1210+
}
1211+
if err := rows.Err(); err != nil {
1212+
t.Error(err)
1213+
}
1214+
rows.Close()
1215+
}
1216+
})
1217+
}
1218+
1219+
func TestQueryNotEnoughArgs(t *testing.T) {
1220+
testQueryError(t, "SELECT FROM foo WHERE id = ? OR id = ?);", 1)
1221+
}
1222+
1223+
func TestQueryTooManyArgs(t *testing.T) {
1224+
// TODO: test error message / kind
1225+
testQueryError(t, "SELECT FROM foo WHERE id = ?);", 1, 2)
1226+
}
1227+
1228+
func TestQueryMultipleStatements(t *testing.T) {
1229+
testQueryError(t, "SELECT 1; SELECT 2;")
1230+
}
1231+
1232+
func TestQueryInvalidTable(t *testing.T) {
1233+
testQueryError(t, "SELECT COUNT(*) FROM does_not_exist;")
1234+
}
1235+
11441236
func TestStress(t *testing.T) {
11451237
tempFilename := TempFilename(t)
11461238
defer os.Remove(tempFilename)
@@ -2112,7 +2204,6 @@ var benchmarks = []testing.InternalBenchmark{
21122204
{Name: "BenchmarkRows", F: benchmarkRows},
21132205
{Name: "BenchmarkStmtRows", F: benchmarkStmtRows},
21142206
{Name: "BenchmarkExecStep", F: benchmarkExecStep},
2115-
{Name: "BenchmarkQueryStep", F: benchmarkQueryStep},
21162207
}
21172208

21182209
func (db *TestDB) mustExec(sql string, args ...interface{}) sql.Result {
@@ -2580,12 +2671,3 @@ func benchmarkExecStep(b *testing.B) {
25802671
}
25812672
}
25822673
}
2583-
2584-
func benchmarkQueryStep(b *testing.B) {
2585-
var i int
2586-
for n := 0; n < b.N; n++ {
2587-
if err := db.QueryRow(largeSelectStmt).Scan(&i); err != nil {
2588-
b.Fatal(err)
2589-
}
2590-
}
2591-
}

0 commit comments

Comments
 (0)