Skip to content

Commit 43c49f6

Browse files
author
Prashanth Govindarajan
authored
Improvements to the sort routine (dotnet#5776)
* Improvements to the sort routine * Fix unit test * Fold into existing API
1 parent 750956d commit 43c49f6

File tree

6 files changed

+92
-27
lines changed

6 files changed

+92
-27
lines changed

src/Microsoft.Data.Analysis/DataFrame.cs

+8-4
Original file line numberDiff line numberDiff line change
@@ -335,7 +335,7 @@ public DataFrame Sample(int numberOfRows)
335335

336336
int shuffleLowerLimit = 0;
337337
int shuffleUpperLimit = (int)Math.Min(Int32.MaxValue, Rows.Count);
338-
338+
339339
int[] shuffleArray = Enumerable.Range(0, shuffleUpperLimit).ToArray();
340340
Random rand = new Random();
341341
while (shuffleLowerLimit < numberOfRows)
@@ -349,7 +349,7 @@ public DataFrame Sample(int numberOfRows)
349349
ArraySegment<int> segment = new ArraySegment<int>(shuffleArray, 0, shuffleLowerLimit);
350350

351351
PrimitiveDataFrameColumn<int> indices = new PrimitiveDataFrameColumn<int>("indices", segment);
352-
352+
353353
return Clone(indices);
354354
}
355355

@@ -623,12 +623,16 @@ private void OnColumnsChanged()
623623
private DataFrame Sort(string columnName, bool isAscending)
624624
{
625625
DataFrameColumn column = Columns[columnName];
626-
DataFrameColumn sortIndices = column.GetAscendingSortIndices();
626+
PrimitiveDataFrameColumn<long> sortIndices = column.GetAscendingSortIndices(out Int64DataFrameColumn nullIndices);
627+
for (long i = 0; i < nullIndices.Length; i++)
628+
{
629+
sortIndices.Append(nullIndices[i]);
630+
}
627631
List<DataFrameColumn> newColumns = new List<DataFrameColumn>(Columns.Count);
628632
for (int i = 0; i < Columns.Count; i++)
629633
{
630634
DataFrameColumn oldColumn = Columns[i];
631-
DataFrameColumn newColumn = oldColumn.Clone(sortIndices, !isAscending, oldColumn.NullCount);
635+
DataFrameColumn newColumn = oldColumn.Clone(sortIndices, !isAscending);
632636
Debug.Assert(newColumn.NullCount == oldColumn.NullCount);
633637
newColumns.Add(newColumn);
634638
}

src/Microsoft.Data.Analysis/DataFrameColumn.cs

+6-2
Original file line numberDiff line numberDiff line change
@@ -199,7 +199,7 @@ public object this[long rowIndex]
199199
/// <param name="ascending"></param>
200200
public virtual DataFrameColumn Sort(bool ascending = true)
201201
{
202-
PrimitiveDataFrameColumn<long> sortIndices = GetAscendingSortIndices();
202+
PrimitiveDataFrameColumn<long> sortIndices = GetAscendingSortIndices(out Int64DataFrameColumn _);
203203
return Clone(sortIndices, !ascending, NullCount);
204204
}
205205

@@ -336,7 +336,11 @@ public virtual StringDataFrameColumn Info()
336336
/// </summary>
337337
public virtual DataFrameColumn Description() => throw new NotImplementedException();
338338

339-
internal virtual PrimitiveDataFrameColumn<long> GetAscendingSortIndices() => throw new NotImplementedException();
339+
/// <summary>
340+
/// Returns the indices of non-null values that, when applied, result in this column being sorted in ascending order. Also returns the indices of null values in <paramref name="nullIndices"/>.
341+
/// </summary>
342+
/// <param name="nullIndices">Indices of values that are <see langword="null"/>.</param>
343+
internal virtual PrimitiveDataFrameColumn<long> GetAscendingSortIndices(out Int64DataFrameColumn nullIndices) => throw new NotImplementedException();
340344

341345
internal delegate long GetBufferSortIndex(int bufferIndex, int sortIndex);
342346
internal delegate ValueTuple<T, int> GetValueAndBufferSortIndexAtBuffer<T>(int bufferIndex, int valueIndex);

src/Microsoft.Data.Analysis/PrimitiveDataFrameColumn.Sort.cs

+20-8
Original file line numberDiff line numberDiff line change
@@ -14,36 +14,46 @@ public partial class PrimitiveDataFrameColumn<T> : DataFrameColumn
1414
{
1515
public new PrimitiveDataFrameColumn<T> Sort(bool ascending = true)
1616
{
17-
PrimitiveDataFrameColumn<long> sortIndices = GetAscendingSortIndices();
17+
PrimitiveDataFrameColumn<long> sortIndices = GetAscendingSortIndices(out Int64DataFrameColumn _);
1818
return Clone(sortIndices, !ascending, NullCount);
1919
}
2020

21-
internal override PrimitiveDataFrameColumn<long> GetAscendingSortIndices()
21+
internal override PrimitiveDataFrameColumn<long> GetAscendingSortIndices(out Int64DataFrameColumn nullIndices)
2222
{
23-
// The return sortIndices contains only the non null indices.
24-
GetSortIndices(Comparer<T>.Default, out PrimitiveDataFrameColumn<long> sortIndices);
23+
Int64DataFrameColumn sortIndices = GetSortIndices(Comparer<T>.Default, out nullIndices);
2524
return sortIndices;
2625
}
2726

28-
private void GetSortIndices(IComparer<T> comparer, out PrimitiveDataFrameColumn<long> columnSortIndices)
27+
private Int64DataFrameColumn GetSortIndices(IComparer<T> comparer, out Int64DataFrameColumn columnNullIndices)
2928
{
3029
List<List<int>> bufferSortIndices = new List<List<int>>(_columnContainer.Buffers.Count);
30+
columnNullIndices = new Int64DataFrameColumn("NullIndices", NullCount);
31+
long nullIndicesSlot = 0;
3132
// Sort each buffer first
3233
for (int b = 0; b < _columnContainer.Buffers.Count; b++)
3334
{
3435
ReadOnlyDataFrameBuffer<T> buffer = _columnContainer.Buffers[b];
3536
ReadOnlySpan<byte> nullBitMapSpan = _columnContainer.NullBitMapBuffers[b].ReadOnlySpan;
3637
int[] sortIndices = new int[buffer.Length];
3738
for (int i = 0; i < buffer.Length; i++)
39+
{
3840
sortIndices[i] = i;
41+
}
3942
IntrospectiveSort(buffer.ReadOnlySpan, buffer.Length, sortIndices, comparer);
4043
// Bug fix: QuickSort is not stable. When PrimitiveDataFrameColumn has null values and default values, they move around
4144
List<int> nonNullSortIndices = new List<int>();
4245
for (int i = 0; i < sortIndices.Length; i++)
4346
{
44-
if (_columnContainer.IsValid(nullBitMapSpan, sortIndices[i]))
47+
int localSortIndex = sortIndices[i];
48+
if (_columnContainer.IsValid(nullBitMapSpan, localSortIndex))
49+
{
4550
nonNullSortIndices.Add(sortIndices[i]);
46-
51+
}
52+
else
53+
{
54+
columnNullIndices[nullIndicesSlot] = localSortIndex + b * _columnContainer.Buffers[0].Length;
55+
nullIndicesSlot++;
56+
}
4757
}
4858
bufferSortIndices.Add(nonNullSortIndices);
4959
}
@@ -90,11 +100,13 @@ ValueTuple<T, int> GetFirstNonNullValueAndBufferIndexStartingAtIndex(int bufferI
90100
heapOfValueAndListOfTupleOfSortAndBufferIndex.Add(valueAndBufferIndex.Item1, new List<ValueTuple<int, int>>() { (valueAndBufferIndex.Item2, i) });
91101
}
92102
}
93-
columnSortIndices = new PrimitiveDataFrameColumn<long>("SortIndices");
103+
Int64DataFrameColumn columnSortIndices = new Int64DataFrameColumn("SortIndices");
94104
GetBufferSortIndex getBufferSortIndex = new GetBufferSortIndex((int bufferIndex, int sortIndex) => (bufferSortIndices[bufferIndex][sortIndex]) + bufferIndex * bufferSortIndices[0].Count);
95105
GetValueAndBufferSortIndexAtBuffer<T> getValueAndBufferSortIndexAtBuffer = new GetValueAndBufferSortIndexAtBuffer<T>((int bufferIndex, int sortIndex) => GetFirstNonNullValueAndBufferIndexStartingAtIndex(bufferIndex, sortIndex));
96106
GetBufferLengthAtIndex getBufferLengthAtIndex = new GetBufferLengthAtIndex((int bufferIndex) => bufferSortIndices[bufferIndex].Count);
97107
PopulateColumnSortIndicesWithHeap(heapOfValueAndListOfTupleOfSortAndBufferIndex, columnSortIndices, getBufferSortIndex, getValueAndBufferSortIndexAtBuffer, getBufferLengthAtIndex);
108+
109+
return columnSortIndices;
98110
}
99111
}
100112
}

src/Microsoft.Data.Analysis/PrimitiveDataFrameColumn.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -225,7 +225,7 @@ public override double Median()
225225
// Not the most efficient implementation. Using a selection algorithm here would be O(n) instead of O(nLogn)
226226
if (Length == 0)
227227
return 0;
228-
PrimitiveDataFrameColumn<long> sortIndices = GetAscendingSortIndices();
228+
PrimitiveDataFrameColumn<long> sortIndices = GetAscendingSortIndices(out Int64DataFrameColumn _);
229229
long middle = sortIndices.Length / 2;
230230
double middleValue = (double)Convert.ChangeType(this[sortIndices[middle].Value].Value, typeof(double));
231231
if (Length % 2 == 0)

src/Microsoft.Data.Analysis/StringDataFrameColumn.cs

+13-5
Original file line numberDiff line numberDiff line change
@@ -171,25 +171,32 @@ public IEnumerator<string> GetEnumerator()
171171

172172
public new StringDataFrameColumn Sort(bool ascending = true)
173173
{
174-
PrimitiveDataFrameColumn<long> columnSortIndices = GetAscendingSortIndices();
174+
PrimitiveDataFrameColumn<long> columnSortIndices = GetAscendingSortIndices(out Int64DataFrameColumn _);
175175
return Clone(columnSortIndices, !ascending, NullCount);
176176
}
177177

178-
internal override PrimitiveDataFrameColumn<long> GetAscendingSortIndices()
178+
internal override PrimitiveDataFrameColumn<long> GetAscendingSortIndices(out Int64DataFrameColumn nullIndices)
179179
{
180-
GetSortIndices(Comparer<string>.Default, out PrimitiveDataFrameColumn<long> columnSortIndices);
180+
PrimitiveDataFrameColumn<long> columnSortIndices = GetSortIndices(Comparer<string>.Default, out nullIndices);
181181
return columnSortIndices;
182182
}
183183

184-
private void GetSortIndices(Comparer<string> comparer, out PrimitiveDataFrameColumn<long> columnSortIndices)
184+
private PrimitiveDataFrameColumn<long> GetSortIndices(Comparer<string> comparer, out Int64DataFrameColumn columnNullIndices)
185185
{
186186
List<int[]> bufferSortIndices = new List<int[]>(_stringBuffers.Count);
187+
columnNullIndices = new Int64DataFrameColumn("NullIndices", NullCount);
188+
long nullIndicesSlot = 0;
187189
foreach (List<string> buffer in _stringBuffers)
188190
{
189191
var sortIndices = new int[buffer.Count];
190192
for (int i = 0; i < buffer.Count; i++)
191193
{
192194
sortIndices[i] = i;
195+
if (buffer[i] == null)
196+
{
197+
columnNullIndices[nullIndicesSlot] = i + bufferSortIndices.Count * int.MaxValue;
198+
nullIndicesSlot++;
199+
}
193200
}
194201
// TODO: Refactor the sort routine to also work with IList?
195202
string[] array = buffer.ToArray();
@@ -227,11 +234,12 @@ ValueTuple<string, int> GetFirstNonNullValueStartingAtIndex(int stringBufferInde
227234
heapOfValueAndListOfTupleOfSortAndBufferIndex.Add(valueAndBufferSortIndex.Item1, new List<ValueTuple<int, int>>() { (valueAndBufferSortIndex.Item2, i) });
228235
}
229236
}
230-
columnSortIndices = new PrimitiveDataFrameColumn<long>("SortIndices");
237+
PrimitiveDataFrameColumn<long> columnSortIndices = new PrimitiveDataFrameColumn<long>("SortIndices");
231238
GetBufferSortIndex getBufferSortIndex = new GetBufferSortIndex((int bufferIndex, int sortIndex) => (bufferSortIndices[bufferIndex][sortIndex]) + bufferIndex * bufferSortIndices[0].Length);
232239
GetValueAndBufferSortIndexAtBuffer<string> getValueAtBuffer = new GetValueAndBufferSortIndexAtBuffer<string>((int bufferIndex, int sortIndex) => GetFirstNonNullValueStartingAtIndex(bufferIndex, sortIndex));
233240
GetBufferLengthAtIndex getBufferLengthAtIndex = new GetBufferLengthAtIndex((int bufferIndex) => bufferSortIndices[bufferIndex].Length);
234241
PopulateColumnSortIndicesWithHeap(heapOfValueAndListOfTupleOfSortAndBufferIndex, columnSortIndices, getBufferSortIndex, getValueAtBuffer, getBufferLengthAtIndex);
242+
return columnSortIndices;
235243
}
236244

237245
public new StringDataFrameColumn Clone(DataFrameColumn mapIndices, bool invertMapIndices, long numberOfNullsToAppend)

test/Microsoft.Data.Analysis.Tests/DataFrameTests.cs

+44-7
Original file line numberDiff line numberDiff line change
@@ -815,10 +815,10 @@ public void TestOrderBy()
815815

816816
// Sort by "Int" in descending order
817817
sortedDf = df.OrderByDescending("Int");
818-
Assert.Null(sortedDf.Columns["Int"][19]);
819-
Assert.Equal(-1, sortedDf.Columns["Int"][18]);
820-
Assert.Equal(100, sortedDf.Columns["Int"][1]);
821-
Assert.Equal(2000, sortedDf.Columns["Int"][0]);
818+
Assert.Null(sortedDf.Columns["Int"][0]);
819+
Assert.Equal(-1, sortedDf.Columns["Int"][19]);
820+
Assert.Equal(100, sortedDf.Columns["Int"][2]);
821+
Assert.Equal(2000, sortedDf.Columns["Int"][1]);
822822

823823
// Sort by "String" in ascending order
824824
sortedDf = df.OrderBy("String");
@@ -829,9 +829,9 @@ public void TestOrderBy()
829829

830830
// Sort by "String" in descending order
831831
sortedDf = df.OrderByDescending("String");
832-
Assert.Null(sortedDf.Columns["Int"][19]);
833-
Assert.Equal(8, sortedDf.Columns["Int"][1]);
834-
Assert.Equal(9, sortedDf.Columns["Int"][0]);
832+
Assert.Null(sortedDf.Columns["Int"][0]);
833+
Assert.Equal(8, sortedDf.Columns["Int"][2]);
834+
Assert.Equal(9, sortedDf.Columns["Int"][1]);
835835
}
836836

837837
[Fact]
@@ -920,6 +920,43 @@ public void TestPrimitiveColumnSort(int numberOfNulls)
920920
Assert.Null(sortedIntColumn[9]);
921921
}
922922

923+
[Fact]
924+
public void TestSortWithDifferentNullCountsInColumns()
925+
{
926+
DataFrame dataFrame = MakeDataFrameWithAllMutableColumnTypes(10);
927+
dataFrame["Int"][3] = null;
928+
dataFrame["String"][3] = null;
929+
DataFrame sorted = dataFrame.OrderBy("Int");
930+
void Verify(DataFrame sortedDataFrame)
931+
{
932+
Assert.Equal(10, sortedDataFrame.Rows.Count);
933+
DataFrameRow lastRow = sortedDataFrame.Rows[sortedDataFrame.Rows.Count - 1];
934+
DataFrameRow penultimateRow = sortedDataFrame.Rows[sortedDataFrame.Rows.Count - 2];
935+
foreach (object value in lastRow)
936+
{
937+
Assert.Null(value);
938+
}
939+
940+
for (int i = 0; i < sortedDataFrame.Columns.Count; i++)
941+
{
942+
string columnName = sortedDataFrame.Columns[i].Name;
943+
if (columnName != "String" && columnName != "Int")
944+
{
945+
Assert.Equal(dataFrame[columnName][3], penultimateRow[i]);
946+
}
947+
else if (columnName == "String" || columnName == "Int")
948+
{
949+
Assert.Null(penultimateRow[i]);
950+
}
951+
}
952+
}
953+
954+
Verify(sorted);
955+
956+
sorted = dataFrame.OrderBy("String");
957+
Verify(sorted);
958+
}
959+
923960
private void VerifyJoin(DataFrame join, DataFrame left, DataFrame right, JoinAlgorithm joinAlgorithm)
924961
{
925962
Int64DataFrameColumn mapIndices = new Int64DataFrameColumn("map", join.Rows.Count);

0 commit comments

Comments
 (0)