Skip to content

Commit

Permalink
Fix GH#291: Ignore null values in DistinctCountAggregator (#298)
Browse files Browse the repository at this point in the history
Ref internal pr-629
  • Loading branch information
StrongestNumber9 authored Sep 2, 2024
1 parent ae2da24 commit a844b57
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 25 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,7 @@ public Node aggregateMethodDistinctCountEmitCatalyst(DPLParser.AggregateMethodDi
String resultColumnName = String.format("dc(%s)", colName);

// Use aggregator
Column col = new DistinctCountAggregator(colName).toColumn();
Column col = new DistinctCountAggregator(colName, catCtx.nullValue).toColumn();

rv = new ColumnNode(col.as(resultColumnName));

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,90 +46,76 @@

package com.teragrep.pth10.ast.commands.aggregate.UDAFs;

import com.teragrep.pth10.ast.NullValue;
import com.teragrep.pth10.ast.commands.aggregate.UDAFs.BufferClasses.CountBuffer;
import org.apache.spark.sql.Encoder;
import org.apache.spark.sql.Encoders;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.expressions.Aggregator;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.Serializable;

/**
* Aggregator for command dc()
*
* Aggregator types: IN=Row, BUF=CountBuffer, OUT=String
* Serializable
* @author eemhu
*
*/
public class DistinctCountAggregator extends Aggregator<Row, CountBuffer, Integer> implements Serializable {
private static final Logger LOGGER = LoggerFactory.getLogger(DistinctCountAggregator.class);

private static final long serialVersionUID = 1L;
private String colName = null;
private static final boolean debugEnabled = false;
private static final long serialVersionUID = 1L;
private final String colName;
private final NullValue nullValue;

/**
* Constructor used to feed in the column name
* @param colName Column name for source field
* */
public DistinctCountAggregator(String colName) {
public DistinctCountAggregator(String colName, NullValue nullValue) {
super();
this.colName = colName;
this.nullValue = nullValue;
}

/** Encoder for the buffer (class: Values)*/
@Override
public Encoder<CountBuffer> bufferEncoder() {
if (debugEnabled) LOGGER.info("Buffer encoder");

// TODO using kryo should speed this up
return Encoders.javaSerialization(CountBuffer.class);
}

/** Encoder for the output (String of all the values in column, lexicographically sorted)*/
@Override
public Encoder<Integer> outputEncoder() {
if (debugEnabled) LOGGER.info("Output encoder");

return Encoders.INT();
}

/** Initialization */
@Override
public CountBuffer zero() {
if (debugEnabled) LOGGER.info("zero");

return new CountBuffer();
}

/** Perform at the end of the aggregation */
@Override
public Integer finish(CountBuffer buffer) {
if (debugEnabled) LOGGER.info("finish");

return buffer.dc();
}

/** Merge two buffers into one */
@Override
public CountBuffer merge(CountBuffer buffer, CountBuffer buffer2) {
if (debugEnabled) LOGGER.info("merge");

buffer.mergeMap(buffer2.getMap());
return buffer;
}

/** Update array with new input value */
@Override
public CountBuffer reduce(CountBuffer buffer, Row input) {
if (debugEnabled) LOGGER.info("reduce");

String inputString = input.getAs(colName).toString();
buffer.add(inputString);

Object inputObject = input.getAs(colName);
if (inputObject != nullValue.value()) {
buffer.add(inputObject.toString());
}
return buffer;
}
}
15 changes: 15 additions & 0 deletions src/test/java/com/teragrep/pth10/statsTransformationTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -414,6 +414,21 @@ void statsTransform_AggDc_Test() {
assertEquals(Collections.singletonList("11"), destAsList);
});
}

// Test dc() with NULL data
@Test
@DisabledIfSystemProperty(named="skipSparkTest", matches="true")
void statsTransform_AggDc_NoData_Test() {
// rex4j is used to produce nulls here
streamingTestUtil.performDPLTest("| makeresults | eval raw=\"kissa@1\"| rex4j field=raw \"koira@(?<koira>\\d)\" | stats dc(koira)",
testFile,
ds -> {
assertEquals("[dc(koira)]", Arrays.toString(ds.columns()));

List<String> destAsList = ds.select("dc(koira)").collectAsList().stream().map(r -> r.getAs(0).toString()).collect(Collectors.toList());
assertEquals(Collections.singletonList("0"), destAsList);
});
}

// Test estdc()
@Test
Expand Down

0 comments on commit a844b57

Please sign in to comment.