From a5e3ff79b1fa9881f9f820931573d3a1d7fb284f Mon Sep 17 00:00:00 2001 From: Jacob Khaliqi Date: Thu, 13 Mar 2025 18:16:51 -0700 Subject: [PATCH] Fix array_top_n not outputting null --- .../presto/operator/scalar/sql/ArraySqlFunctions.java | 4 ++-- .../operator/scalar/sql/TestArraySqlFunctions.java | 11 ++++++++++- 2 files changed, 12 insertions(+), 3 deletions(-) diff --git a/presto-main/src/main/java/com/facebook/presto/operator/scalar/sql/ArraySqlFunctions.java b/presto-main/src/main/java/com/facebook/presto/operator/scalar/sql/ArraySqlFunctions.java index 2fc1218ea2408..904533b00d9e9 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/scalar/sql/ArraySqlFunctions.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/scalar/sql/ArraySqlFunctions.java @@ -163,7 +163,7 @@ public static String removeNulls() @SqlParameters({@SqlParameter(name = "input", type = "array(T)"), @SqlParameter(name = "n", type = "int")}) @SqlType("array") public static String arrayTopN() - { return "RETURN IF(n < 0, fail('Parameter n: ' || cast(n as varchar) || ' to ARRAY_TOP_N is negative'), SLICE(ARRAY_SORT_DESC(input), 1, n))"; } + { return "RETURN IF(n < 0, NULL, SLICE(ARRAY_SORT_DESC(input), 1, n))"; } @SqlInvokedScalarFunction(value = "array_top_n", deterministic = true, calledOnNullInput = true) @Description("Returns the top N values of the given map sorted using the provided lambda comparator.") @@ -172,6 +172,6 @@ public static String arrayTopN() @SqlType("array") public static String arrayTopNComparator() { - return "RETURN IF(n < 0, fail('Parameter n: ' || cast(n as varchar) || ' to ARRAY_TOP_N is negative'), SLICE(REVERSE(ARRAY_SORT(input, f)), 1, n))"; + return "RETURN IF(n < 0, NULL, SLICE(REVERSE(ARRAY_SORT(input, f)), 1, n))"; } } diff --git a/presto-main/src/test/java/com/facebook/presto/operator/scalar/sql/TestArraySqlFunctions.java b/presto-main/src/test/java/com/facebook/presto/operator/scalar/sql/TestArraySqlFunctions.java index 2e0e865f13be4..084a5dbf31c42 100644 --- a/presto-main/src/test/java/com/facebook/presto/operator/scalar/sql/TestArraySqlFunctions.java +++ b/presto-main/src/test/java/com/facebook/presto/operator/scalar/sql/TestArraySqlFunctions.java @@ -393,7 +393,6 @@ public void testArrayTopNEdgeAndErrorCase() // Test exceptions assertInvalidFunction("ARRAY_TOP_N(ARRAY [ROW('a', 1), ROW('a', null), null, ROW('a', 0)], 2)", StandardErrorCode.INVALID_FUNCTION_ARGUMENT); assertInvalidFunction("ARRAY_TOP_N(ARRAY [MAP(ARRAY['foo', 'bar'], ARRAY[1, 2]), MAP(ARRAY['foo', 'bar'], ARRAY[0, 3])], 2)", SemanticErrorCode.FUNCTION_NOT_FOUND); - assertInvalidFunction("ARRAY_TOP_N(ARRAY ['a', 'a', 'd', 'a', 'a', 'a'], -1)", StandardErrorCode.GENERIC_USER_ERROR, "Parameter n: -1 to ARRAY_TOP_N is negative"); // Test edge cases assertFunction("ARRAY_TOP_N(ARRAY [null, null], 3)", new ArrayType(UNKNOWN), asList(null, null)); @@ -401,4 +400,14 @@ public void testArrayTopNEdgeAndErrorCase() assertFunction("ARRAY_TOP_N(ARRAY [], 3)", new ArrayType(UNKNOWN), emptyList()); assertFunction("ARRAY_TOP_N(ARRAY [1, 4], 3)", new ArrayType(INTEGER), ImmutableList.of(4, 1)); } + + @Test + public void testArrayTopNNegativeParameter() + { + assertFunction("ARRAY_TOP_N(ARRAY ['a', 'a', 'd', 'a', 'a', 'a'], -1)", new ArrayType(createVarcharType(1)), null); + assertFunction("ARRAY_TOP_N(ARRAY [1,2,3,4,5,6], -5)", new ArrayType(INTEGER), null); + assertFunction("ARRAY_TOP_N(ARRAY [DOUBLE '1.0', 100, 2, DOUBLE '5.0', DOUBLE '3.0'], -3)", new ArrayType(DOUBLE), null); + assertFunction("ARRAY_TOP_N(ARRAY [true, true, false, true, false], -4)", new ArrayType(BOOLEAN), null); + assertFunction("ARRAY_TOP_N(ARRAY [null, null], -3)", new ArrayType(UNKNOWN), null); + } }