Skip to content

Commit c2349a3

Browse files
Miguel Molinaerizocosmico
Miguel Molina
authored andcommitted
implement logic to handle through relationships
1 parent 253a381 commit c2349a3

9 files changed

+585
-58
lines changed

batcher.go

+111-21
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ type batchQueryRunner struct {
1414
q Query
1515
oneToOneRels []Relationship
1616
oneToManyRels []Relationship
17+
throughRels []Relationship
1718
db squirrel.DBProxy
1819
builder squirrel.SelectBuilder
1920
total int
@@ -29,6 +30,7 @@ func newBatchQueryRunner(schema Schema, db squirrel.DBProxy, q Query) *batchQuer
2930
var (
3031
oneToOneRels []Relationship
3132
oneToManyRels []Relationship
33+
throughRels []Relationship
3234
)
3335

3436
for _, rel := range q.getRelationships() {
@@ -37,6 +39,8 @@ func newBatchQueryRunner(schema Schema, db squirrel.DBProxy, q Query) *batchQuer
3739
oneToOneRels = append(oneToOneRels, rel)
3840
case OneToMany:
3941
oneToManyRels = append(oneToManyRels, rel)
42+
case Through:
43+
throughRels = append(throughRels, rel)
4044
}
4145
}
4246

@@ -46,6 +50,7 @@ func newBatchQueryRunner(schema Schema, db squirrel.DBProxy, q Query) *batchQuer
4650
q: q,
4751
oneToOneRels: oneToOneRels,
4852
oneToManyRels: oneToManyRels,
53+
throughRels: throughRels,
4954
db: db,
5055
builder: builder,
5156
}
@@ -125,8 +130,14 @@ func (r *batchQueryRunner) processBatch(rows *sql.Rows) ([]Record, error) {
125130
return nil, err
126131
}
127132

133+
if len(records) == 0 {
134+
return nil, nil
135+
}
136+
128137
var ids = make([]interface{}, len(records))
138+
var identType Identifier
129139
for i, r := range records {
140+
identType = r.GetID()
130141
ids[i] = r.GetID().Raw()
131142
}
132143

@@ -136,63 +147,142 @@ func (r *batchQueryRunner) processBatch(rows *sql.Rows) ([]Record, error) {
136147
return nil, err
137148
}
138149

139-
for _, r := range records {
140-
err := r.SetRelationship(rel.Field, indexedResults[r.GetID().Raw()])
141-
if err != nil {
142-
return nil, err
143-
}
150+
err = setIndexedResults(records, rel, indexedResults)
151+
if err != nil {
152+
return nil, err
153+
}
154+
}
144155

145-
// If the relationship is partial, we can not ensure the results
146-
// in the field reflect the truth of the database.
147-
// In this case, the parent is marked as non-writable.
148-
if rel.Filter != nil {
149-
r.setWritable(false)
150-
}
156+
for _, rel := range r.throughRels {
157+
indexedResults, err := r.getRecordThroughRelationships(ids, rel, identType)
158+
if err != nil {
159+
return nil, err
160+
}
161+
162+
err = setIndexedResults(records, rel, indexedResults)
163+
if err != nil {
164+
return nil, err
151165
}
152166
}
153167

154168
return records, nil
155169
}
156170

171+
func setIndexedResults(records []Record, rel Relationship, indexedResults indexedRecords) error {
172+
for _, r := range records {
173+
err := r.SetRelationship(rel.Field, indexedResults[r.GetID().Raw()])
174+
if err != nil {
175+
return err
176+
}
177+
178+
// If the relationship is partial, we can not ensure the results
179+
// in the field reflect the truth of the database.
180+
// In this case, the parent is marked as non-writable.
181+
if rel.Filter != nil {
182+
r.setWritable(false)
183+
}
184+
}
185+
186+
return nil
187+
}
188+
157189
type indexedRecords map[interface{}][]Record
158190

159191
func (r *batchQueryRunner) getRecordRelationships(ids []interface{}, rel Relationship) (indexedRecords, error) {
160192
fk, ok := r.schema.ForeignKey(rel.Field)
161193
if !ok {
162-
return nil, fmt.Errorf("kallax: cannot find foreign key on field %s for table %s", rel.Field, r.schema.Table())
194+
return nil, fmt.Errorf("kallax: cannot find foreign key on field %s of table %s", rel.Field, r.schema.Table())
163195
}
164196

165197
filter := In(fk, ids...)
166198
if rel.Filter != nil {
167-
And(rel.Filter, filter)
168-
} else {
169-
rel.Filter = filter
199+
filter = And(rel.Filter, filter)
170200
}
171201

172202
q := NewBaseQuery(rel.Schema)
173-
q.Where(rel.Filter)
203+
q.Where(filter)
174204
cols, builder := q.compile()
175205
rows, err := builder.RunWith(r.db).Query()
176206
if err != nil {
177207
return nil, err
178208
}
179209

210+
return indexedResultsFromRows(rows, cols, rel.Schema, fk, nil)
211+
}
212+
213+
func (r *batchQueryRunner) getRecordThroughRelationships(ids []interface{}, rel Relationship, identType Identifier) (indexedRecords, error) {
214+
lfk, rfk, ok := r.schema.ForeignKeys(rel.Field)
215+
if !ok {
216+
return nil, fmt.Errorf("kallax: cannot find foreign keys for through relationship on field %s of table %s", rel.Field, r.schema.Table())
217+
}
218+
219+
filter := In(r.schema.ID(), ids...)
220+
if rel.Filter != nil {
221+
filter = And(rel.Filter, filter)
222+
}
223+
224+
if rel.IntermediateFilter != nil {
225+
filter = And(rel.IntermediateFilter, filter)
226+
}
227+
228+
q := NewBaseQuery(rel.Schema)
229+
lschema := r.schema.WithAlias(rel.Schema.Alias())
230+
intSchema := rel.IntermediateSchema.WithAlias(rel.Schema.Alias())
231+
q.joinThrough(lschema, intSchema, rel.Schema, lfk, rfk)
232+
q.Where(filter)
233+
cols, builder := q.compile()
234+
// manually add the extra column to also select the parent id
235+
builder = builder.Column(lschema.ID().QualifiedName(lschema))
236+
rows, err := builder.RunWith(r.db).Query()
237+
if err != nil {
238+
return nil, err
239+
}
240+
241+
// we need to pass a new pointer of the parent identifier type so the
242+
// resultset can fill it and we can know to which record it belongs when
243+
// indexing by parent id.
244+
return indexedResultsFromRows(rows, cols, rel.Schema, rfk, identType.newPtr())
245+
}
246+
247+
// indexedResultsFromRows returns the results in the given rows indexed by the
248+
// parent id. In the case of many to many relationships, the record odes not
249+
// have a specific field with the ID of the parent to index by it,
250+
// that's why parentIDPtr is passed for these cases. parentIDPtr is a pointer
251+
// to an ID of the type required by the parent to be filled by the result set.
252+
func indexedResultsFromRows(rows *sql.Rows, cols []string, schema Schema, fk SchemaField, parentIDPtr interface{}) (indexedRecords, error) {
180253
relRs := NewResultSet(rows, false, nil, cols...)
181254
var indexedResults = make(indexedRecords)
182255
for relRs.Next() {
183-
rec, err := relRs.Get(rel.Schema)
184-
if err != nil {
185-
return nil, err
256+
var (
257+
rec Record
258+
err error
259+
)
260+
261+
if parentIDPtr != nil {
262+
rec, err = relRs.customGet(schema, parentIDPtr)
263+
} else {
264+
rec, err = relRs.Get(schema)
186265
}
187266

188-
val, err := rec.Value(fk.String())
189267
if err != nil {
190268
return nil, err
191269
}
192270

193271
rec.setPersisted()
194272
rec.setWritable(true)
195-
id := val.(Identifier).Raw()
273+
274+
var id interface{}
275+
if parentIDPtr != nil {
276+
id = parentIDPtr.(Identifier).Raw()
277+
} else {
278+
val, err := rec.Value(fk.String())
279+
if err != nil {
280+
return nil, err
281+
}
282+
283+
id = val.(Identifier).Raw()
284+
}
285+
196286
indexedResults[id] = append(indexedResults[id], rec)
197287
}
198288

0 commit comments

Comments
 (0)