@@ -60,62 +60,87 @@ func StoreFrom(to, from GenericStorer) {
60
60
// logs it.
61
61
type LoggerFunc func (string , ... interface {})
62
62
63
- // debugProxy is a database proxy that logs all SQL statements executed.
64
- type debugProxy struct {
63
+ func defaultLogger (message string , args ... interface {}) {
64
+ log .Printf ("%s, args: %v" , message , args )
65
+ }
66
+
67
+ // basicLogger is a database runner that logs all SQL statements executed.
68
+ type basicLogger struct {
65
69
logger LoggerFunc
66
- proxy squirrel.DBProxy
70
+ runner squirrel.BaseRunner
67
71
}
68
72
69
- func defaultLogger (message string , args ... interface {}) {
70
- log .Printf ("%s, args: %v" , message , args )
73
+ // basicLogger is a database runner that logs all SQL statements executed.
74
+ type proxyLogger struct {
75
+ basicLogger
71
76
}
72
77
73
- func (p * debugProxy ) Exec (query string , args ... interface {}) (sql.Result , error ) {
78
+ func (p * basicLogger ) Exec (query string , args ... interface {}) (sql.Result , error ) {
74
79
p .logger (fmt .Sprintf ("kallax: Exec: %s" , query ), args ... )
75
- return p .proxy .Exec (query , args ... )
80
+ return p .runner .Exec (query , args ... )
76
81
}
77
82
78
- func (p * debugProxy ) Query (query string , args ... interface {}) (* sql.Rows , error ) {
83
+ func (p * basicLogger ) Query (query string , args ... interface {}) (* sql.Rows , error ) {
79
84
p .logger (fmt .Sprintf ("kallax: Query: %s" , query ), args ... )
80
- return p .proxy .Query (query , args ... )
85
+ return p .runner .Query (query , args ... )
81
86
}
82
87
83
- func (p * debugProxy ) QueryRow (query string , args ... interface {}) squirrel.RowScanner {
84
- p .logger (fmt .Sprintf ("kallax: QueryRow: %s" , query ), args ... )
85
- return p .proxy .QueryRow (query , args ... )
88
+ func (p * proxyLogger ) QueryRow (query string , args ... interface {}) squirrel.RowScanner {
89
+ p .basicLogger .logger (fmt .Sprintf ("kallax: QueryRow: %s" , query ), args ... )
90
+ if queryRower , ok := p .basicLogger .runner .(squirrel.QueryRower ); ok {
91
+ return queryRower .QueryRow (query , args ... )
92
+ } else {
93
+ panic ("Called proxyLogger with a runner which doesn't implement QueryRower" )
94
+ }
86
95
}
87
96
88
- func (p * debugProxy ) Prepare (query string ) (* sql.Stmt , error ) {
89
- p .logger (fmt .Sprintf ("kallax: Prepare: %s" , query ))
90
- return p .proxy .Prepare (query )
97
+ func (p * proxyLogger ) Prepare (query string ) (* sql.Stmt , error ) {
98
+ // If chained runner is a proxy, run Prepare(). Otherwise, noop.
99
+ if preparer , ok := p .basicLogger .runner .(squirrel.Preparer ); ok {
100
+ p .basicLogger .logger (fmt .Sprintf ("kallax: Prepare: %s" , query ))
101
+ return preparer .Prepare (query )
102
+ } else {
103
+ panic ("Called proxyLogger with a runner which doesn't implement QueryRower" )
104
+ }
91
105
}
92
106
93
107
// Store is a structure capable of retrieving records from a concrete table in
94
108
// the database.
95
109
type Store struct {
96
- builder squirrel.StatementBuilderType
97
- db * sql.DB
98
- proxy squirrel.DBProxy
110
+ db interface {
111
+ squirrel.BaseRunner
112
+ squirrel.PreparerContext
113
+ }
114
+ runner squirrel.BaseRunner
115
+ useCacher bool
116
+ logger LoggerFunc
99
117
}
100
118
101
119
// NewStore returns a new Store instance.
102
120
func NewStore (db * sql.DB ) * Store {
103
- proxy := squirrel .NewStmtCacher (db )
104
- builder := squirrel .StatementBuilder .PlaceholderFormat (squirrel .Dollar ).RunWith (proxy )
105
- return & Store {
106
- db : db ,
107
- proxy : proxy ,
108
- builder : builder ,
109
- }
121
+ return (& Store {
122
+ db : db ,
123
+ useCacher : true ,
124
+ }).init ()
110
125
}
111
126
112
- func newStoreWithTransaction ( tx * sql. Tx ) * Store {
113
- proxy := squirrel . NewStmtCacher ( tx )
114
- builder := squirrel . StatementBuilder . PlaceholderFormat ( squirrel . Dollar ). RunWith ( proxy )
115
- return & Store {
116
- proxy : proxy ,
117
- builder : builder ,
127
+ // init initializes the store runner with debugging or caching, and returns itself for chainability
128
+ func ( s * Store ) init () * Store {
129
+ s . runner = s . db
130
+
131
+ if s . useCacher {
132
+ s . runner = squirrel . NewStmtCacher ( s . db )
118
133
}
134
+
135
+ if s .logger != nil && ! s .useCacher {
136
+ // Use BasicLogger as wrapper
137
+ s .runner = & basicLogger {s .logger , s .runner }
138
+ } else if s .logger != nil && s .useCacher {
139
+ // We're using a proxy (cacher), so use proxyLogger instead
140
+ s .runner = & proxyLogger {basicLogger {s .logger , s .runner }}
141
+ }
142
+
143
+ return s
119
144
}
120
145
121
146
// Debug returns a new store that will print all SQL statements to stdout using
@@ -127,12 +152,11 @@ func (s *Store) Debug() *Store {
127
152
// DebugWith returns a new store that will print all SQL statements using the
128
153
// given logger function.
129
154
func (s * Store ) DebugWith (logger LoggerFunc ) * Store {
130
- proxy := & debugProxy {logger , s .proxy }
131
- return & Store {
132
- builder : s .builder .RunWith (proxy ),
133
- db : s .db ,
134
- proxy : proxy ,
135
- }
155
+ return (& Store {
156
+ db : s .db ,
157
+ useCacher : s .useCacher ,
158
+ logger : logger ,
159
+ }).init ()
136
160
}
137
161
138
162
// Insert insert the given record in the table, returns error if no-new
@@ -192,9 +216,20 @@ func (s *Store) Insert(schema Schema, record Record) error {
192
216
}
193
217
194
218
query .WriteString (fmt .Sprintf (" RETURNING %s" , schema .ID ().String ()))
195
- err = s .proxy .QueryRow (query .String (), values ... ).Scan (pk )
219
+ //err = s.runner.QueryRow(query.String(), values...).Scan(pk)
220
+ rows , err := s .runner .Query (query .String (), values ... )
221
+ if err != nil {
222
+ return err
223
+ }
224
+ if rows .Next () {
225
+ err = rows .Scan (pk )
226
+ rows .Close ()
227
+ if err != nil {
228
+ return err
229
+ }
230
+ }
196
231
} else {
197
- _ , err = s .proxy .Exec (query .String (), values ... )
232
+ _ , err = s .runner .Exec (query .String (), values ... )
198
233
}
199
234
200
235
if err != nil {
@@ -255,7 +290,7 @@ func (s *Store) Update(schema Schema, record Record, cols ...SchemaField) (int64
255
290
query .WriteRune ('=' )
256
291
query .WriteString (fmt .Sprintf ("$%d" , len (columnNames )+ 1 ))
257
292
258
- result , err := s .proxy .Exec (query .String (), append (values , record .GetID ())... )
293
+ result , err := s .runner .Exec (query .String (), append (values , record .GetID ())... )
259
294
if err != nil {
260
295
return 0 , err
261
296
}
@@ -300,7 +335,7 @@ func (s *Store) Delete(schema Schema, record Record) error {
300
335
query .WriteString (schema .ID ().String ())
301
336
query .WriteString ("=$1" )
302
337
303
- _ , err := s .proxy .Exec (query .String (), record .GetID ())
338
+ _ , err := s .runner .Exec (query .String (), record .GetID ())
304
339
return err
305
340
}
306
341
@@ -309,7 +344,7 @@ func (s *Store) Delete(schema Schema, record Record) error {
309
344
// WARNING: A result set created from a raw query can only be scanned using the
310
345
// RawScan method of ResultSet, instead of Scan.
311
346
func (s * Store ) RawQuery (sql string , params ... interface {}) (ResultSet , error ) {
312
- rows , err := s .proxy .Query (sql , params ... )
347
+ rows , err := s .runner .Query (sql , params ... )
313
348
if err != nil {
314
349
return nil , err
315
350
}
@@ -320,7 +355,7 @@ func (s *Store) RawQuery(sql string, params ...interface{}) (ResultSet, error) {
320
355
// RawExec executes a raw SQL query with the given parameters and returns
321
356
// the number of affected rows.
322
357
func (s * Store ) RawExec (sql string , params ... interface {}) (int64 , error ) {
323
- result , err := s .proxy .Exec (sql , params ... )
358
+ result , err := s .runner .Exec (sql , params ... )
324
359
if err != nil {
325
360
return 0 , err
326
361
}
@@ -332,7 +367,7 @@ func (s *Store) RawExec(sql string, params ...interface{}) (int64, error) {
332
367
func (s * Store ) Find (q Query ) (ResultSet , error ) {
333
368
rels := q .getRelationships ()
334
369
if containsRelationshipOfType (rels , OneToMany ) {
335
- return NewBatchingResultSet (newBatchQueryRunner (q .Schema (), s .proxy , q )), nil
370
+ return NewBatchingResultSet (newBatchQueryRunner (q .Schema (), s .runner , q )), nil
336
371
}
337
372
338
373
columns , builder := q .compile ()
@@ -344,7 +379,7 @@ func (s *Store) Find(q Query) (ResultSet, error) {
344
379
builder = builder .Limit (limit )
345
380
}
346
381
347
- rows , err := builder .RunWith (s .proxy ).Query ()
382
+ rows , err := builder .RunWith (s .runner ).Query ()
348
383
if err != nil {
349
384
return nil , err
350
385
}
@@ -379,7 +414,7 @@ func (s *Store) Reload(schema Schema, record Record) error {
379
414
q .Limit (1 )
380
415
columns , builder := q .compile ()
381
416
382
- rows , err := builder .RunWith (s .proxy ).Query ()
417
+ rows , err := builder .RunWith (s .runner ).Query ()
383
418
if err != nil {
384
419
return err
385
420
}
@@ -399,7 +434,7 @@ func (s *Store) Count(q Query) (count int64, err error) {
399
434
_ , queryBuilder := q .compile ()
400
435
builder := builder .Set (queryBuilder , "Columns" , nil ).(squirrel.SelectBuilder )
401
436
err = builder .Column (fmt .Sprintf ("COUNT(%s)" , all .QualifiedName (q .Schema ()))).
402
- RunWith (s .proxy ).
437
+ RunWith (s .runner ).
403
438
QueryRow ().
404
439
Scan (& count )
405
440
return
@@ -423,16 +458,26 @@ func (s *Store) MustCount(q Query) int64 {
423
458
// If a transaction is already opened in this store, instead of opening a new
424
459
// one, the other will be reused.
425
460
func (s * Store ) Transaction (callback func (* Store ) error ) error {
426
- if s .db == nil {
461
+ var tx * sql.Tx
462
+ var err error
463
+ if db , ok := s .db .(* sql.DB ); ok {
464
+ // db is *sql.DB, not *sql.Tx
465
+ tx , err = db .Begin ()
466
+ if err != nil {
467
+ return fmt .Errorf ("kallax: can't open transaction: %s" , err )
468
+ }
469
+ } else {
470
+ // store is already holding a transaction
427
471
return callback (s )
428
472
}
429
473
430
- tx , err := s .db .Begin ()
431
- if err != nil {
432
- return fmt .Errorf ("kallax: can't open transaction: %s" , err )
433
- }
474
+ txStore := (& Store {
475
+ db : tx ,
476
+ logger : s .logger ,
477
+ useCacher : true ,
478
+ }).init ()
434
479
435
- if err := callback (newStoreWithTransaction ( tx ) ); err != nil {
480
+ if err := callback (txStore ); err != nil {
436
481
if err := tx .Rollback (); err != nil {
437
482
return fmt .Errorf ("kallax: unable to rollback transaction: %s" , err )
438
483
}
0 commit comments