Skip to content

Commit 1dec030

Browse files
committedJan 29, 2018
Refactored Store to use BaseRunners instead of proxies
Query logger is now inherited for transactions (fixes src-d#254) Paves the way for src-d#256
1 parent 3162cdd commit 1dec030

File tree

3 files changed

+103
-58
lines changed

3 files changed

+103
-58
lines changed
 

‎batcher.go

+2-2
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ type batchQueryRunner struct {
1414
q Query
1515
oneToOneRels []Relationship
1616
oneToManyRels []Relationship
17-
db squirrel.DBProxy
17+
db squirrel.BaseRunner
1818
builder squirrel.SelectBuilder
1919
total int
2020
eof bool
@@ -24,7 +24,7 @@ type batchQueryRunner struct {
2424

2525
var errNoMoreRows = errors.New("kallax: there are no more rows in the result set")
2626

27-
func newBatchQueryRunner(schema Schema, db squirrel.DBProxy, q Query) *batchQueryRunner {
27+
func newBatchQueryRunner(schema Schema, db squirrel.BaseRunner, q Query) *batchQueryRunner {
2828
cols, builder := q.compile()
2929
var (
3030
oneToOneRels []Relationship

‎batcher_test.go

+3-3
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ func TestBatcherLimit(t *testing.T) {
5454
q.BatchSize(2)
5555
q.Limit(5)
5656
r.NoError(q.AddRelation(RelSchema, "rels", OneToMany, Eq(f("foo"), "1")))
57-
runner := newBatchQueryRunner(ModelSchema, store.proxy, q)
57+
runner := newBatchQueryRunner(ModelSchema, store.runner, q)
5858
rs := NewBatchingResultSet(runner)
5959

6060
var count int
@@ -91,7 +91,7 @@ func TestBatcherNoExtraQueryIfLessThanLimit(t *testing.T) {
9191
var queries int
9292
proxy := store.DebugWith(func(_ string, _ ...interface{}) {
9393
queries++
94-
}).proxy
94+
}).runner
9595
runner := newBatchQueryRunner(ModelSchema, proxy, q)
9696
rs := NewBatchingResultSet(runner)
9797

@@ -130,7 +130,7 @@ func TestBatcherNoExtraQueryIfLessThanBatchSize(t *testing.T) {
130130
var queries int
131131
proxy := store.DebugWith(func(_ string, _ ...interface{}) {
132132
queries++
133-
}).proxy
133+
}).runner
134134
runner := newBatchQueryRunner(ModelSchema, proxy, q)
135135
rs := NewBatchingResultSet(runner)
136136

‎store.go

+98-53
Original file line numberDiff line numberDiff line change
@@ -60,62 +60,87 @@ func StoreFrom(to, from GenericStorer) {
6060
// logs it.
6161
type LoggerFunc func(string, ...interface{})
6262

63-
// debugProxy is a database proxy that logs all SQL statements executed.
64-
type debugProxy struct {
63+
func defaultLogger(message string, args ...interface{}) {
64+
log.Printf("%s, args: %v", message, args)
65+
}
66+
67+
// basicLogger is a database runner that logs all SQL statements executed.
68+
type basicLogger struct {
6569
logger LoggerFunc
66-
proxy squirrel.DBProxy
70+
runner squirrel.BaseRunner
6771
}
6872

69-
func defaultLogger(message string, args ...interface{}) {
70-
log.Printf("%s, args: %v", message, args)
73+
// basicLogger is a database runner that logs all SQL statements executed.
74+
type proxyLogger struct {
75+
basicLogger
7176
}
7277

73-
func (p *debugProxy) Exec(query string, args ...interface{}) (sql.Result, error) {
78+
func (p *basicLogger) Exec(query string, args ...interface{}) (sql.Result, error) {
7479
p.logger(fmt.Sprintf("kallax: Exec: %s", query), args...)
75-
return p.proxy.Exec(query, args...)
80+
return p.runner.Exec(query, args...)
7681
}
7782

78-
func (p *debugProxy) Query(query string, args ...interface{}) (*sql.Rows, error) {
83+
func (p *basicLogger) Query(query string, args ...interface{}) (*sql.Rows, error) {
7984
p.logger(fmt.Sprintf("kallax: Query: %s", query), args...)
80-
return p.proxy.Query(query, args...)
85+
return p.runner.Query(query, args...)
8186
}
8287

83-
func (p *debugProxy) QueryRow(query string, args ...interface{}) squirrel.RowScanner {
84-
p.logger(fmt.Sprintf("kallax: QueryRow: %s", query), args...)
85-
return p.proxy.QueryRow(query, args...)
88+
func (p *proxyLogger) QueryRow(query string, args ...interface{}) squirrel.RowScanner {
89+
p.basicLogger.logger(fmt.Sprintf("kallax: QueryRow: %s", query), args...)
90+
if queryRower, ok := p.basicLogger.runner.(squirrel.QueryRower); ok {
91+
return queryRower.QueryRow(query, args...)
92+
} else {
93+
panic("Called proxyLogger with a runner which doesn't implement QueryRower")
94+
}
8695
}
8796

88-
func (p *debugProxy) Prepare(query string) (*sql.Stmt, error) {
89-
p.logger(fmt.Sprintf("kallax: Prepare: %s", query))
90-
return p.proxy.Prepare(query)
97+
func (p *proxyLogger) Prepare(query string) (*sql.Stmt, error) {
98+
// If chained runner is a proxy, run Prepare(). Otherwise, noop.
99+
if preparer, ok := p.basicLogger.runner.(squirrel.Preparer); ok {
100+
p.basicLogger.logger(fmt.Sprintf("kallax: Prepare: %s", query))
101+
return preparer.Prepare(query)
102+
} else {
103+
panic("Called proxyLogger with a runner which doesn't implement QueryRower")
104+
}
91105
}
92106

93107
// Store is a structure capable of retrieving records from a concrete table in
94108
// the database.
95109
type Store struct {
96-
builder squirrel.StatementBuilderType
97-
db *sql.DB
98-
proxy squirrel.DBProxy
110+
db interface {
111+
squirrel.BaseRunner
112+
squirrel.PreparerContext
113+
}
114+
runner squirrel.BaseRunner
115+
useCacher bool
116+
logger LoggerFunc
99117
}
100118

101119
// NewStore returns a new Store instance.
102120
func NewStore(db *sql.DB) *Store {
103-
proxy := squirrel.NewStmtCacher(db)
104-
builder := squirrel.StatementBuilder.PlaceholderFormat(squirrel.Dollar).RunWith(proxy)
105-
return &Store{
106-
db: db,
107-
proxy: proxy,
108-
builder: builder,
109-
}
121+
return (&Store{
122+
db: db,
123+
useCacher: true,
124+
}).init()
110125
}
111126

112-
func newStoreWithTransaction(tx *sql.Tx) *Store {
113-
proxy := squirrel.NewStmtCacher(tx)
114-
builder := squirrel.StatementBuilder.PlaceholderFormat(squirrel.Dollar).RunWith(proxy)
115-
return &Store{
116-
proxy: proxy,
117-
builder: builder,
127+
// init initializes the store runner with debugging or caching, and returns itself for chainability
128+
func (s *Store) init() *Store {
129+
s.runner = s.db
130+
131+
if s.useCacher {
132+
s.runner = squirrel.NewStmtCacher(s.db)
118133
}
134+
135+
if s.logger != nil && !s.useCacher {
136+
// Use BasicLogger as wrapper
137+
s.runner = &basicLogger{s.logger, s.runner}
138+
} else if s.logger != nil && s.useCacher {
139+
// We're using a proxy (cacher), so use proxyLogger instead
140+
s.runner = &proxyLogger{basicLogger{s.logger, s.runner}}
141+
}
142+
143+
return s
119144
}
120145

121146
// Debug returns a new store that will print all SQL statements to stdout using
@@ -127,12 +152,11 @@ func (s *Store) Debug() *Store {
127152
// DebugWith returns a new store that will print all SQL statements using the
128153
// given logger function.
129154
func (s *Store) DebugWith(logger LoggerFunc) *Store {
130-
proxy := &debugProxy{logger, s.proxy}
131-
return &Store{
132-
builder: s.builder.RunWith(proxy),
133-
db: s.db,
134-
proxy: proxy,
135-
}
155+
return (&Store{
156+
db: s.db,
157+
useCacher: s.useCacher,
158+
logger: logger,
159+
}).init()
136160
}
137161

138162
// Insert insert the given record in the table, returns error if no-new
@@ -192,9 +216,20 @@ func (s *Store) Insert(schema Schema, record Record) error {
192216
}
193217

194218
query.WriteString(fmt.Sprintf(" RETURNING %s", schema.ID().String()))
195-
err = s.proxy.QueryRow(query.String(), values...).Scan(pk)
219+
//err = s.runner.QueryRow(query.String(), values...).Scan(pk)
220+
rows, err := s.runner.Query(query.String(), values...)
221+
if err != nil {
222+
return err
223+
}
224+
if rows.Next() {
225+
err = rows.Scan(pk)
226+
rows.Close()
227+
if err != nil {
228+
return err
229+
}
230+
}
196231
} else {
197-
_, err = s.proxy.Exec(query.String(), values...)
232+
_, err = s.runner.Exec(query.String(), values...)
198233
}
199234

200235
if err != nil {
@@ -255,7 +290,7 @@ func (s *Store) Update(schema Schema, record Record, cols ...SchemaField) (int64
255290
query.WriteRune('=')
256291
query.WriteString(fmt.Sprintf("$%d", len(columnNames)+1))
257292

258-
result, err := s.proxy.Exec(query.String(), append(values, record.GetID())...)
293+
result, err := s.runner.Exec(query.String(), append(values, record.GetID())...)
259294
if err != nil {
260295
return 0, err
261296
}
@@ -300,7 +335,7 @@ func (s *Store) Delete(schema Schema, record Record) error {
300335
query.WriteString(schema.ID().String())
301336
query.WriteString("=$1")
302337

303-
_, err := s.proxy.Exec(query.String(), record.GetID())
338+
_, err := s.runner.Exec(query.String(), record.GetID())
304339
return err
305340
}
306341

@@ -309,7 +344,7 @@ func (s *Store) Delete(schema Schema, record Record) error {
309344
// WARNING: A result set created from a raw query can only be scanned using the
310345
// RawScan method of ResultSet, instead of Scan.
311346
func (s *Store) RawQuery(sql string, params ...interface{}) (ResultSet, error) {
312-
rows, err := s.proxy.Query(sql, params...)
347+
rows, err := s.runner.Query(sql, params...)
313348
if err != nil {
314349
return nil, err
315350
}
@@ -320,7 +355,7 @@ func (s *Store) RawQuery(sql string, params ...interface{}) (ResultSet, error) {
320355
// RawExec executes a raw SQL query with the given parameters and returns
321356
// the number of affected rows.
322357
func (s *Store) RawExec(sql string, params ...interface{}) (int64, error) {
323-
result, err := s.proxy.Exec(sql, params...)
358+
result, err := s.runner.Exec(sql, params...)
324359
if err != nil {
325360
return 0, err
326361
}
@@ -332,7 +367,7 @@ func (s *Store) RawExec(sql string, params ...interface{}) (int64, error) {
332367
func (s *Store) Find(q Query) (ResultSet, error) {
333368
rels := q.getRelationships()
334369
if containsRelationshipOfType(rels, OneToMany) {
335-
return NewBatchingResultSet(newBatchQueryRunner(q.Schema(), s.proxy, q)), nil
370+
return NewBatchingResultSet(newBatchQueryRunner(q.Schema(), s.runner, q)), nil
336371
}
337372

338373
columns, builder := q.compile()
@@ -344,7 +379,7 @@ func (s *Store) Find(q Query) (ResultSet, error) {
344379
builder = builder.Limit(limit)
345380
}
346381

347-
rows, err := builder.RunWith(s.proxy).Query()
382+
rows, err := builder.RunWith(s.runner).Query()
348383
if err != nil {
349384
return nil, err
350385
}
@@ -379,7 +414,7 @@ func (s *Store) Reload(schema Schema, record Record) error {
379414
q.Limit(1)
380415
columns, builder := q.compile()
381416

382-
rows, err := builder.RunWith(s.proxy).Query()
417+
rows, err := builder.RunWith(s.runner).Query()
383418
if err != nil {
384419
return err
385420
}
@@ -399,7 +434,7 @@ func (s *Store) Count(q Query) (count int64, err error) {
399434
_, queryBuilder := q.compile()
400435
builder := builder.Set(queryBuilder, "Columns", nil).(squirrel.SelectBuilder)
401436
err = builder.Column(fmt.Sprintf("COUNT(%s)", all.QualifiedName(q.Schema()))).
402-
RunWith(s.proxy).
437+
RunWith(s.runner).
403438
QueryRow().
404439
Scan(&count)
405440
return
@@ -423,16 +458,26 @@ func (s *Store) MustCount(q Query) int64 {
423458
// If a transaction is already opened in this store, instead of opening a new
424459
// one, the other will be reused.
425460
func (s *Store) Transaction(callback func(*Store) error) error {
426-
if s.db == nil {
461+
var tx *sql.Tx
462+
var err error
463+
if db, ok := s.db.(*sql.DB); ok {
464+
// db is *sql.DB, not *sql.Tx
465+
tx, err = db.Begin()
466+
if err != nil {
467+
return fmt.Errorf("kallax: can't open transaction: %s", err)
468+
}
469+
} else {
470+
// store is already holding a transaction
427471
return callback(s)
428472
}
429473

430-
tx, err := s.db.Begin()
431-
if err != nil {
432-
return fmt.Errorf("kallax: can't open transaction: %s", err)
433-
}
474+
txStore := (&Store{
475+
db: tx,
476+
logger: s.logger,
477+
useCacher: true,
478+
}).init()
434479

435-
if err := callback(newStoreWithTransaction(tx)); err != nil {
480+
if err := callback(txStore); err != nil {
436481
if err := tx.Rollback(); err != nil {
437482
return fmt.Errorf("kallax: unable to rollback transaction: %s", err)
438483
}

0 commit comments

Comments
 (0)
Please sign in to comment.