Skip to content

Commit 750956d

Browse files
author
Prashanth Govindarajan
authored
Improvements to the Merge routine (dotnet#5778)
* Handle nulls better in Merge * sq * sq * Add unit test
1 parent 19b2331 commit 750956d

File tree

6 files changed

+349
-120
lines changed

6 files changed

+349
-120
lines changed

src/Microsoft.Data.Analysis/ArrowStringDataFrameColumn.cs

+17-9
Original file line numberDiff line numberDiff line change
@@ -460,34 +460,42 @@ private ArrowStringDataFrameColumn Clone(PrimitiveDataFrameColumn<int> mapIndice
460460
/// <inheritdoc/>
461461
public override DataFrame ValueCounts()
462462
{
463-
Dictionary<string, ICollection<long>> groupedValues = GroupColumnValues<string>();
463+
Dictionary<string, ICollection<long>> groupedValues = GroupColumnValues<string>(out HashSet<long> _);
464464
return StringDataFrameColumn.ValueCountsImplementation(groupedValues);
465465
}
466466

467467
/// <inheritdoc/>
468468
public override GroupBy GroupBy(int columnIndex, DataFrame parent)
469469
{
470-
Dictionary<string, ICollection<long>> dictionary = GroupColumnValues<string>();
470+
Dictionary<string, ICollection<long>> dictionary = GroupColumnValues<string>(out HashSet<long> _);
471471
return new GroupBy<string>(parent, columnIndex, dictionary);
472472
}
473473

474474
/// <inheritdoc/>
475-
public override Dictionary<TKey, ICollection<long>> GroupColumnValues<TKey>()
475+
public override Dictionary<TKey, ICollection<long>> GroupColumnValues<TKey>(out HashSet<long> nullIndices)
476476
{
477477
if (typeof(TKey) == typeof(string))
478478
{
479+
nullIndices = new HashSet<long>();
479480
Dictionary<string, ICollection<long>> multimap = new Dictionary<string, ICollection<long>>(EqualityComparer<string>.Default);
480481
for (long i = 0; i < Length; i++)
481482
{
482-
string str = this[i] ?? "__null__";
483-
bool containsKey = multimap.TryGetValue(str, out ICollection<long> values);
484-
if (containsKey)
483+
string str = this[i];
484+
if (str != null)
485485
{
486-
values.Add(i);
486+
bool containsKey = multimap.TryGetValue(str, out ICollection<long> values);
487+
if (containsKey)
488+
{
489+
values.Add(i);
490+
}
491+
else
492+
{
493+
multimap.Add(str, new List<long>() { i });
494+
}
487495
}
488496
else
489497
{
490-
multimap.Add(str, new List<long>() { i });
498+
nullIndices.Add(i);
491499
}
492500
}
493501
return multimap as Dictionary<TKey, ICollection<long>>;
@@ -499,7 +507,7 @@ public override Dictionary<TKey, ICollection<long>> GroupColumnValues<TKey>()
499507
}
500508

501509
/// <inheritdoc/>
502-
public ArrowStringDataFrameColumn FillNulls(string value, bool inPlace = false)
510+
public ArrowStringDataFrameColumn FillNulls(string value, bool inPlace = false)
503511
{
504512
if (value == null)
505513
{

src/Microsoft.Data.Analysis/DataFrame.Join.cs

+91-92
Original file line numberDiff line numberDiff line change
@@ -168,82 +168,72 @@ public DataFrame Merge<TKey>(DataFrame other, string leftJoinColumn, string righ
168168
{
169169
// First hash other dataframe on the rightJoinColumn
170170
DataFrameColumn otherColumn = other.Columns[rightJoinColumn];
171-
Dictionary<TKey, ICollection<long>> multimap = otherColumn.GroupColumnValues<TKey>();
171+
Dictionary<TKey, ICollection<long>> multimap = otherColumn.GroupColumnValues<TKey>(out HashSet<long> otherColumnNullIndices);
172172

173173
// Go over the records in this dataframe and match with the dictionary
174174
DataFrameColumn thisColumn = Columns[leftJoinColumn];
175175

176176
for (long i = 0; i < thisColumn.Length; i++)
177177
{
178178
var thisColumnValue = thisColumn[i];
179-
TKey thisColumnValueOrDefault = (TKey)(thisColumnValue == null ? default(TKey) : thisColumnValue);
180-
if (multimap.TryGetValue(thisColumnValueOrDefault, out ICollection<long> rowNumbers))
179+
if (thisColumnValue != null)
181180
{
182-
foreach (long row in rowNumbers)
181+
if (multimap.TryGetValue((TKey)thisColumnValue, out ICollection<long> rowNumbers))
183182
{
184-
if (thisColumnValue == null)
183+
foreach (long row in rowNumbers)
185184
{
186-
// Match only with nulls in otherColumn
187-
if (otherColumn[row] == null)
188-
{
189-
leftRowIndices.Append(i);
190-
rightRowIndices.Append(row);
191-
}
192-
}
193-
else
194-
{
195-
// Cannot match nulls in otherColumn
196-
if (otherColumn[row] != null)
197-
{
198-
leftRowIndices.Append(i);
199-
rightRowIndices.Append(row);
200-
}
185+
leftRowIndices.Append(i);
186+
rightRowIndices.Append(row);
201187
}
202188
}
189+
else
190+
{
191+
leftRowIndices.Append(i);
192+
rightRowIndices.Append(null);
193+
}
203194
}
204195
else
205196
{
206-
leftRowIndices.Append(i);
207-
rightRowIndices.Append(null);
197+
foreach (long row in otherColumnNullIndices)
198+
{
199+
leftRowIndices.Append(i);
200+
rightRowIndices.Append(row);
201+
}
208202
}
209203
}
210204
}
211205
else if (joinAlgorithm == JoinAlgorithm.Right)
212206
{
213207
DataFrameColumn thisColumn = Columns[leftJoinColumn];
214-
Dictionary<TKey, ICollection<long>> multimap = thisColumn.GroupColumnValues<TKey>();
208+
Dictionary<TKey, ICollection<long>> multimap = thisColumn.GroupColumnValues<TKey>(out HashSet<long> thisColumnNullIndices);
215209

216210
DataFrameColumn otherColumn = other.Columns[rightJoinColumn];
217211
for (long i = 0; i < otherColumn.Length; i++)
218212
{
219213
var otherColumnValue = otherColumn[i];
220-
TKey otherColumnValueOrDefault = (TKey)(otherColumnValue == null ? default(TKey) : otherColumnValue);
221-
if (multimap.TryGetValue(otherColumnValueOrDefault, out ICollection<long> rowNumbers))
214+
if (otherColumnValue != null)
222215
{
223-
foreach (long row in rowNumbers)
216+
if (multimap.TryGetValue((TKey)otherColumnValue, out ICollection<long> rowNumbers))
224217
{
225-
if (otherColumnValue == null)
218+
foreach (long row in rowNumbers)
226219
{
227-
if (thisColumn[row] == null)
228-
{
229-
leftRowIndices.Append(row);
230-
rightRowIndices.Append(i);
231-
}
232-
}
233-
else
234-
{
235-
if (thisColumn[row] != null)
236-
{
237-
leftRowIndices.Append(row);
238-
rightRowIndices.Append(i);
239-
}
220+
leftRowIndices.Append(row);
221+
rightRowIndices.Append(i);
240222
}
241223
}
224+
else
225+
{
226+
leftRowIndices.Append(null);
227+
rightRowIndices.Append(i);
228+
}
242229
}
243230
else
244231
{
245-
leftRowIndices.Append(null);
246-
rightRowIndices.Append(i);
232+
foreach (long thisColumnNullIndex in thisColumnNullIndices)
233+
{
234+
leftRowIndices.Append(thisColumnNullIndex);
235+
rightRowIndices.Append(i);
236+
}
247237
}
248238
}
249239
}
@@ -253,97 +243,106 @@ public DataFrame Merge<TKey>(DataFrame other, string leftJoinColumn, string righ
253243
long leftRowCount = Rows.Count;
254244
long rightRowCount = other.Rows.Count;
255245

256-
var leftColumnIsSmaller = (leftRowCount <= rightRowCount);
246+
bool leftColumnIsSmaller = leftRowCount <= rightRowCount;
257247
DataFrameColumn hashColumn = leftColumnIsSmaller ? Columns[leftJoinColumn] : other.Columns[rightJoinColumn];
258248
DataFrameColumn otherColumn = ReferenceEquals(hashColumn, Columns[leftJoinColumn]) ? other.Columns[rightJoinColumn] : Columns[leftJoinColumn];
259-
Dictionary<TKey, ICollection<long>> multimap = hashColumn.GroupColumnValues<TKey>();
249+
Dictionary<TKey, ICollection<long>> multimap = hashColumn.GroupColumnValues<TKey>(out HashSet<long> smallerDataFrameColumnNullIndices);
260250

261251
for (long i = 0; i < otherColumn.Length; i++)
262252
{
263253
var otherColumnValue = otherColumn[i];
264-
TKey otherColumnValueOrDefault = (TKey)(otherColumnValue == null ? default(TKey) : otherColumnValue);
265-
if (multimap.TryGetValue(otherColumnValueOrDefault, out ICollection<long> rowNumbers))
254+
if (otherColumnValue != null)
266255
{
267-
foreach (long row in rowNumbers)
256+
if (multimap.TryGetValue((TKey)otherColumnValue, out ICollection<long> rowNumbers))
268257
{
269-
if (otherColumnValue == null)
270-
{
271-
if (hashColumn[row] == null)
272-
{
273-
leftRowIndices.Append(leftColumnIsSmaller ? row : i);
274-
rightRowIndices.Append(leftColumnIsSmaller ? i : row);
275-
}
276-
}
277-
else
258+
foreach (long row in rowNumbers)
278259
{
279-
if (hashColumn[row] != null)
280-
{
281-
leftRowIndices.Append(leftColumnIsSmaller ? row : i);
282-
rightRowIndices.Append(leftColumnIsSmaller ? i : row);
283-
}
260+
leftRowIndices.Append(leftColumnIsSmaller ? row : i);
261+
rightRowIndices.Append(leftColumnIsSmaller ? i : row);
284262
}
285263
}
286264
}
265+
else
266+
{
267+
foreach (long nullIndex in smallerDataFrameColumnNullIndices)
268+
{
269+
leftRowIndices.Append(leftColumnIsSmaller ? nullIndex : i);
270+
rightRowIndices.Append(leftColumnIsSmaller ? i : nullIndex);
271+
}
272+
}
287273
}
288274
}
289275
else if (joinAlgorithm == JoinAlgorithm.FullOuter)
290276
{
291277
DataFrameColumn otherColumn = other.Columns[rightJoinColumn];
292-
Dictionary<TKey, ICollection<long>> multimap = otherColumn.GroupColumnValues<TKey>();
278+
Dictionary<TKey, ICollection<long>> multimap = otherColumn.GroupColumnValues<TKey>(out HashSet<long> otherColumnNullIndices);
293279
Dictionary<TKey, long> intersection = new Dictionary<TKey, long>(EqualityComparer<TKey>.Default);
294280

295281
// Go over the records in this dataframe and match with the dictionary
296282
DataFrameColumn thisColumn = Columns[leftJoinColumn];
283+
Int64DataFrameColumn thisColumnNullIndices = new Int64DataFrameColumn("ThisColumnNullIndices");
297284

298285
for (long i = 0; i < thisColumn.Length; i++)
299286
{
300287
var thisColumnValue = thisColumn[i];
301-
TKey thisColumnValueOrDefault = (TKey)(thisColumnValue == null ? default(TKey) : thisColumnValue);
302-
if (multimap.TryGetValue(thisColumnValueOrDefault, out ICollection<long> rowNumbers))
288+
if (thisColumnValue != null)
303289
{
304-
foreach (long row in rowNumbers)
290+
if (multimap.TryGetValue((TKey)thisColumnValue, out ICollection<long> rowNumbers))
305291
{
306-
if (thisColumnValue == null)
307-
{
308-
// Has to match only with nulls in otherColumn
309-
if (otherColumn[row] == null)
310-
{
311-
leftRowIndices.Append(i);
312-
rightRowIndices.Append(row);
313-
if (!intersection.ContainsKey(thisColumnValueOrDefault))
314-
{
315-
intersection.Add(thisColumnValueOrDefault, rowNumber);
316-
}
317-
}
318-
}
319-
else
292+
foreach (long row in rowNumbers)
320293
{
321-
// Cannot match to nulls in otherColumn
322-
if (otherColumn[row] != null)
294+
leftRowIndices.Append(i);
295+
rightRowIndices.Append(row);
296+
if (!intersection.ContainsKey((TKey)thisColumnValue))
323297
{
324-
leftRowIndices.Append(i);
325-
rightRowIndices.Append(row);
326-
if (!intersection.ContainsKey(thisColumnValueOrDefault))
327-
{
328-
intersection.Add(thisColumnValueOrDefault, rowNumber);
329-
}
298+
intersection.Add((TKey)thisColumnValue, rowNumber);
330299
}
331300
}
332301
}
302+
else
303+
{
304+
leftRowIndices.Append(i);
305+
rightRowIndices.Append(null);
306+
}
333307
}
334308
else
335309
{
336-
leftRowIndices.Append(i);
337-
rightRowIndices.Append(null);
310+
thisColumnNullIndices.Append(i);
338311
}
339312
}
340313
for (long i = 0; i < otherColumn.Length; i++)
341314
{
342-
TKey value = (TKey)(otherColumn[i] ?? default(TKey));
343-
if (!intersection.ContainsKey(value))
315+
var value = otherColumn[i];
316+
if (value != null)
317+
{
318+
if (!intersection.ContainsKey((TKey)value))
319+
{
320+
leftRowIndices.Append(null);
321+
rightRowIndices.Append(i);
322+
}
323+
}
324+
}
325+
326+
// Now handle the null rows
327+
foreach (long? thisColumnNullIndex in thisColumnNullIndices)
328+
{
329+
foreach (long otherColumnNullIndex in otherColumnNullIndices)
330+
{
331+
leftRowIndices.Append(thisColumnNullIndex.Value);
332+
rightRowIndices.Append(otherColumnNullIndex);
333+
}
334+
if (otherColumnNullIndices.Count == 0)
335+
{
336+
leftRowIndices.Append(thisColumnNullIndex.Value);
337+
rightRowIndices.Append(null);
338+
}
339+
}
340+
if (thisColumnNullIndices.Length == 0)
341+
{
342+
foreach (long otherColumnNullIndex in otherColumnNullIndices)
344343
{
345344
leftRowIndices.Append(null);
346-
rightRowIndices.Append(i);
345+
rightRowIndices.Append(otherColumnNullIndex);
347346
}
348347
}
349348
}

src/Microsoft.Data.Analysis/DataFrameColumn.cs

+6-1
Original file line numberDiff line numberDiff line change
@@ -203,7 +203,12 @@ public virtual DataFrameColumn Sort(bool ascending = true)
203203
return Clone(sortIndices, !ascending, NullCount);
204204
}
205205

206-
public virtual Dictionary<TKey, ICollection<long>> GroupColumnValues<TKey>() => throw new NotImplementedException();
206+
/// <summary>
207+
/// Groups the rows of this column by their value.
208+
/// </summary>
209+
/// <typeparam name="TKey">The type of data held by this column</typeparam>
210+
/// <returns>A mapping of value(<typeparamref name="TKey"/>) to the indices containing this value</returns>
211+
public virtual Dictionary<TKey, ICollection<long>> GroupColumnValues<TKey>(out HashSet<long> nullIndices) => throw new NotImplementedException();
207212

208213
/// <summary>
209214
/// Returns a DataFrame containing counts of unique values

0 commit comments

Comments
 (0)