@@ -14,6 +14,7 @@ type batchQueryRunner struct {
14
14
q Query
15
15
oneToOneRels []Relationship
16
16
oneToManyRels []Relationship
17
+ throughRels []Relationship
17
18
db squirrel.DBProxy
18
19
builder squirrel.SelectBuilder
19
20
total int
@@ -29,6 +30,7 @@ func newBatchQueryRunner(schema Schema, db squirrel.DBProxy, q Query) *batchQuer
29
30
var (
30
31
oneToOneRels []Relationship
31
32
oneToManyRels []Relationship
33
+ throughRels []Relationship
32
34
)
33
35
34
36
for _ , rel := range q .getRelationships () {
@@ -37,6 +39,8 @@ func newBatchQueryRunner(schema Schema, db squirrel.DBProxy, q Query) *batchQuer
37
39
oneToOneRels = append (oneToOneRels , rel )
38
40
case OneToMany :
39
41
oneToManyRels = append (oneToManyRels , rel )
42
+ case Through :
43
+ throughRels = append (throughRels , rel )
40
44
}
41
45
}
42
46
@@ -46,6 +50,7 @@ func newBatchQueryRunner(schema Schema, db squirrel.DBProxy, q Query) *batchQuer
46
50
q : q ,
47
51
oneToOneRels : oneToOneRels ,
48
52
oneToManyRels : oneToManyRels ,
53
+ throughRels : throughRels ,
49
54
db : db ,
50
55
builder : builder ,
51
56
}
@@ -125,8 +130,14 @@ func (r *batchQueryRunner) processBatch(rows *sql.Rows) ([]Record, error) {
125
130
return nil , err
126
131
}
127
132
133
+ if len (records ) == 0 {
134
+ return nil , nil
135
+ }
136
+
128
137
var ids = make ([]interface {}, len (records ))
138
+ var identType Identifier
129
139
for i , r := range records {
140
+ identType = r .GetID ()
130
141
ids [i ] = r .GetID ().Raw ()
131
142
}
132
143
@@ -136,63 +147,142 @@ func (r *batchQueryRunner) processBatch(rows *sql.Rows) ([]Record, error) {
136
147
return nil , err
137
148
}
138
149
139
- for _ , r := range records {
140
- err := r . SetRelationship ( rel . Field , indexedResults [ r . GetID (). Raw ()])
141
- if err != nil {
142
- return nil , err
143
- }
150
+ err = setIndexedResults ( records , rel , indexedResults )
151
+ if err != nil {
152
+ return nil , err
153
+ }
154
+ }
144
155
145
- // If the relationship is partial, we can not ensure the results
146
- // in the field reflect the truth of the database.
147
- // In this case, the parent is marked as non-writable.
148
- if rel .Filter != nil {
149
- r .setWritable (false )
150
- }
156
+ for _ , rel := range r .throughRels {
157
+ indexedResults , err := r .getRecordThroughRelationships (ids , rel , identType )
158
+ if err != nil {
159
+ return nil , err
160
+ }
161
+
162
+ err = setIndexedResults (records , rel , indexedResults )
163
+ if err != nil {
164
+ return nil , err
151
165
}
152
166
}
153
167
154
168
return records , nil
155
169
}
156
170
171
+ func setIndexedResults (records []Record , rel Relationship , indexedResults indexedRecords ) error {
172
+ for _ , r := range records {
173
+ err := r .SetRelationship (rel .Field , indexedResults [r .GetID ().Raw ()])
174
+ if err != nil {
175
+ return err
176
+ }
177
+
178
+ // If the relationship is partial, we can not ensure the results
179
+ // in the field reflect the truth of the database.
180
+ // In this case, the parent is marked as non-writable.
181
+ if rel .Filter != nil {
182
+ r .setWritable (false )
183
+ }
184
+ }
185
+
186
+ return nil
187
+ }
188
+
157
189
type indexedRecords map [interface {}][]Record
158
190
159
191
func (r * batchQueryRunner ) getRecordRelationships (ids []interface {}, rel Relationship ) (indexedRecords , error ) {
160
192
fk , ok := r .schema .ForeignKey (rel .Field )
161
193
if ! ok {
162
- return nil , fmt .Errorf ("kallax: cannot find foreign key on field %s for table %s" , rel .Field , r .schema .Table ())
194
+ return nil , fmt .Errorf ("kallax: cannot find foreign key on field %s of table %s" , rel .Field , r .schema .Table ())
163
195
}
164
196
165
197
filter := In (fk , ids ... )
166
198
if rel .Filter != nil {
167
- And (rel .Filter , filter )
168
- } else {
169
- rel .Filter = filter
199
+ filter = And (rel .Filter , filter )
170
200
}
171
201
172
202
q := NewBaseQuery (rel .Schema )
173
- q .Where (rel . Filter )
203
+ q .Where (filter )
174
204
cols , builder := q .compile ()
175
205
rows , err := builder .RunWith (r .db ).Query ()
176
206
if err != nil {
177
207
return nil , err
178
208
}
179
209
210
+ return indexedResultsFromRows (rows , cols , rel .Schema , fk , nil )
211
+ }
212
+
213
+ func (r * batchQueryRunner ) getRecordThroughRelationships (ids []interface {}, rel Relationship , identType Identifier ) (indexedRecords , error ) {
214
+ lfk , rfk , ok := r .schema .ForeignKeys (rel .Field )
215
+ if ! ok {
216
+ return nil , fmt .Errorf ("kallax: cannot find foreign keys for through relationship on field %s of table %s" , rel .Field , r .schema .Table ())
217
+ }
218
+
219
+ filter := In (r .schema .ID (), ids ... )
220
+ if rel .Filter != nil {
221
+ filter = And (rel .Filter , filter )
222
+ }
223
+
224
+ if rel .IntermediateFilter != nil {
225
+ filter = And (rel .IntermediateFilter , filter )
226
+ }
227
+
228
+ q := NewBaseQuery (rel .Schema )
229
+ lschema := r .schema .WithAlias (rel .Schema .Alias ())
230
+ intSchema := rel .IntermediateSchema .WithAlias (rel .Schema .Alias ())
231
+ q .joinThrough (lschema , intSchema , rel .Schema , lfk , rfk )
232
+ q .Where (filter )
233
+ cols , builder := q .compile ()
234
+ // manually add the extra column to also select the parent id
235
+ builder = builder .Column (lschema .ID ().QualifiedName (lschema ))
236
+ rows , err := builder .RunWith (r .db ).Query ()
237
+ if err != nil {
238
+ return nil , err
239
+ }
240
+
241
+ // we need to pass a new pointer of the parent identifier type so the
242
+ // resultset can fill it and we can know to which record it belongs when
243
+ // indexing by parent id.
244
+ return indexedResultsFromRows (rows , cols , rel .Schema , rfk , identType .newPtr ())
245
+ }
246
+
247
+ // indexedResultsFromRows returns the results in the given rows indexed by the
248
+ // parent id. In the case of many to many relationships, the record odes not
249
+ // have a specific field with the ID of the parent to index by it,
250
+ // that's why parentIDPtr is passed for these cases. parentIDPtr is a pointer
251
+ // to an ID of the type required by the parent to be filled by the result set.
252
+ func indexedResultsFromRows (rows * sql.Rows , cols []string , schema Schema , fk SchemaField , parentIDPtr interface {}) (indexedRecords , error ) {
180
253
relRs := NewResultSet (rows , false , nil , cols ... )
181
254
var indexedResults = make (indexedRecords )
182
255
for relRs .Next () {
183
- rec , err := relRs .Get (rel .Schema )
184
- if err != nil {
185
- return nil , err
256
+ var (
257
+ rec Record
258
+ err error
259
+ )
260
+
261
+ if parentIDPtr != nil {
262
+ rec , err = relRs .customGet (schema , parentIDPtr )
263
+ } else {
264
+ rec , err = relRs .Get (schema )
186
265
}
187
266
188
- val , err := rec .Value (fk .String ())
189
267
if err != nil {
190
268
return nil , err
191
269
}
192
270
193
271
rec .setPersisted ()
194
272
rec .setWritable (true )
195
- id := val .(Identifier ).Raw ()
273
+
274
+ var id interface {}
275
+ if parentIDPtr != nil {
276
+ id = parentIDPtr .(Identifier ).Raw ()
277
+ } else {
278
+ val , err := rec .Value (fk .String ())
279
+ if err != nil {
280
+ return nil , err
281
+ }
282
+
283
+ id = val .(Identifier ).Raw ()
284
+ }
285
+
196
286
indexedResults [id ] = append (indexedResults [id ], rec )
197
287
}
198
288
0 commit comments