diff --git a/velox/expression/CastExpr-inl.h b/velox/expression/CastExpr-inl.h index d920918f3a2b..e7ddbd5ae56a 100644 --- a/velox/expression/CastExpr-inl.h +++ b/velox/expression/CastExpr-inl.h @@ -15,8 +15,6 @@ */ #pragma once -#include - #include "velox/common/base/CountBits.h" #include "velox/common/base/Exceptions.h" #include "velox/core/CoreTypeSystem.h" @@ -52,66 +50,6 @@ inline std::exception_ptr makeBadCastException( false)); } -/// @brief Convert the unscaled value of a decimal to varchar and write to raw -/// string buffer from start position. -/// @tparam T The type of input value. -/// @param unscaledValue The input unscaled value. -/// @param scale The scale of decimal. -/// @param maxVarcharSize The estimated max size of a varchar. -/// @param startPosition The start position to write from. -/// @return A string view. -template -StringView convertToStringView( - T unscaledValue, - int32_t scale, - int32_t maxVarcharSize, - char* const startPosition) { - char* writePosition = startPosition; - if (unscaledValue == 0) { - *writePosition++ = '0'; - if (scale > 0) { - *writePosition++ = '.'; - // Append leading zeros. - std::memset(writePosition, '0', scale); - writePosition += scale; - } - } else { - if (unscaledValue < 0) { - *writePosition++ = '-'; - unscaledValue = -unscaledValue; - } - auto [position, errorCode] = std::to_chars( - writePosition, - writePosition + maxVarcharSize, - unscaledValue / DecimalUtil::kPowersOfTen[scale]); - VELOX_DCHECK_EQ( - errorCode, - std::errc(), - "Failed to cast decimal to varchar: {}", - std::make_error_code(errorCode).message()); - writePosition = position; - - if (scale > 0) { - *writePosition++ = '.'; - uint128_t fraction = unscaledValue % DecimalUtil::kPowersOfTen[scale]; - // Append leading zeros. - int numLeadingZeros = std::max(scale - countDigits(fraction), 0); - std::memset(writePosition, '0', numLeadingZeros); - writePosition += numLeadingZeros; - // Append remaining fraction digits. - auto result = std::to_chars( - writePosition, writePosition + maxVarcharSize, fraction); - VELOX_DCHECK_EQ( - result.ec, - std::errc(), - "Failed to cast decimal to varchar: {}", - std::make_error_code(result.ec).message()); - writePosition = result.ptr; - } - } - return StringView(startPosition, writePosition - startPosition); -} - } // namespace namespace detail { @@ -632,24 +570,14 @@ VectorPtr CastExpr::applyDecimalToVarcharCast( const auto simpleInput = input.as>(); int precision = getDecimalPrecisionScale(*fromType).first; int scale = getDecimalPrecisionScale(*fromType).second; - // A varchar's size is estimated with unscaled value digits, dot, leading - // zero, and possible minus sign. - int32_t rowSize = precision + 1; - if (scale > 0) { - ++rowSize; // A dot. - } - if (precision == scale) { - ++rowSize; // Leading zero. - } - + auto rowSize = DecimalUtil::maxStringViewSize(precision, scale); auto flatResult = result->asFlatVector(); if (StringView::isInline(rowSize)) { char inlined[StringView::kInlineSize]; applyToSelectedNoThrowLocal(context, rows, result, [&](vector_size_t row) { - flatResult->setNoCopy( - row, - convertToStringView( - simpleInput->valueAt(row), scale, rowSize, inlined)); + auto actualSize = DecimalUtil::castToString( + simpleInput->valueAt(row), scale, rowSize, inlined); + flatResult->setNoCopy(row, StringView(inlined, actualSize)); }); return result; } @@ -659,13 +587,13 @@ VectorPtr CastExpr::applyDecimalToVarcharCast( char* rawBuffer = buffer->asMutable() + buffer->size(); applyToSelectedNoThrowLocal(context, rows, result, [&](vector_size_t row) { - auto stringView = convertToStringView( + auto actualSize = DecimalUtil::castToString( simpleInput->valueAt(row), scale, rowSize, rawBuffer); - flatResult->setNoCopy(row, stringView); - if (!stringView.isInline()) { + flatResult->setNoCopy(row, StringView(rawBuffer, actualSize)); + if (!StringView::isInline(actualSize)) { // If string view is inline, corresponding bytes on the raw string buffer // are not needed. - rawBuffer += stringView.size(); + rawBuffer += actualSize; } }); // Update the exact buffer size. diff --git a/velox/type/DecimalUtil.cpp b/velox/type/DecimalUtil.cpp index 3b48c3c98a72..f07b34a11528 100644 --- a/velox/type/DecimalUtil.cpp +++ b/velox/type/DecimalUtil.cpp @@ -114,4 +114,15 @@ void DecimalUtil::computeAverage( } } +int32_t DecimalUtil::maxStringViewSize(int precision, int scale) { + int32_t rowSize = precision + 1; // Number and symbol. + if (scale > 0) { + ++rowSize; // A dot. + } + if (precision == scale) { + ++rowSize; // Leading zero. + } + return rowSize; +} + } // namespace facebook::velox diff --git a/velox/type/DecimalUtil.h b/velox/type/DecimalUtil.h index d3a4ce8dd00e..dc935b94cdc9 100644 --- a/velox/type/DecimalUtil.h +++ b/velox/type/DecimalUtil.h @@ -16,6 +16,7 @@ #pragma once +#include #include #include "velox/common/base/CheckedArithmetic.h" #include "velox/common/base/CountBits.h" @@ -309,6 +310,71 @@ class DecimalUtil { return remainder * resultSign; } + /// Returns the max required size to convert the decimal of this precision and + /// scale to varchar. A varchar's size is estimated with unscaled value + /// digits, dot, leading zero, and possible minus sign. + static int32_t maxStringViewSize(int precision, int scale); + + /// @brief Convert the unscaled value of a decimal to string and write to raw + /// string buffer from start position. + /// @tparam T The type of input value. + /// @param unscaledValue The input unscaled value. + /// @param scale The scale of decimal. + /// @param maxSize The estimated max size of string. + /// @param startPosition The start position to write from. + /// @return The actual size of the string. + template + static size_t castToString( + T unscaledValue, + int32_t scale, + int32_t maxSize, + char* const startPosition) { + char* writePosition = startPosition; + if (unscaledValue == 0) { + *writePosition++ = '0'; + if (scale > 0) { + *writePosition++ = '.'; + // Append trailing zeros. + std::memset(writePosition, '0', scale); + writePosition += scale; + } + } else { + if (unscaledValue < 0) { + *writePosition++ = '-'; + unscaledValue = -unscaledValue; + } + auto [position, errorCode] = std::to_chars( + writePosition, + writePosition + maxSize, + unscaledValue / DecimalUtil::kPowersOfTen[scale]); + VELOX_DCHECK_EQ( + errorCode, + std::errc(), + "Failed to cast decimal to varchar: {}", + std::make_error_code(errorCode).message()); + writePosition = position; + + if (scale > 0) { + *writePosition++ = '.'; + uint128_t fraction = unscaledValue % DecimalUtil::kPowersOfTen[scale]; + // Append leading zeros. + int numLeadingZeros = std::max(scale - countDigits(fraction), 0); + std::memset(writePosition, '0', numLeadingZeros); + writePosition += numLeadingZeros; + // Append remaining fraction digits. + auto result = + std::to_chars(writePosition, writePosition + maxSize, fraction); + VELOX_DCHECK_EQ( + result.ec, + std::errc(), + "Failed to cast decimal to varchar: {}", + std::make_error_code(result.ec).message()); + writePosition = result.ptr; + } + } + return writePosition - startPosition; + } + /* * sum up and return overflow/underflow. */ diff --git a/velox/type/tests/DecimalTest.cpp b/velox/type/tests/DecimalTest.cpp index d87b1fdfd252..7b2bf2913908 100644 --- a/velox/type/tests/DecimalTest.cpp +++ b/velox/type/tests/DecimalTest.cpp @@ -106,6 +106,28 @@ void testToByteArray(int128_t value, int8_t* expected, int32_t size) { EXPECT_EQ(std::memcmp(expected, out, length), 0); } +template +void testcastToString( + T unscaleValue, + int precision, + int scale, + int maxStringSize, + const std::string& expected) { + char out[maxStringSize]; + auto actualSize = + DecimalUtil::castToString(unscaleValue, scale, maxStringSize, out); + EXPECT_EQ(expected.size(), actualSize); + EXPECT_EQ(std::memcmp(expected.data(), out, expected.size()), 0); +} + +void testMaxStringViewSize( + int precision, + int scale, + int expectedMaxStringSize) { + EXPECT_EQ( + DecimalUtil::maxStringViewSize(precision, scale), expectedMaxStringSize); +} + std::string zeros(uint32_t numZeros) { return std::string(numZeros, '0'); } @@ -490,5 +512,50 @@ TEST(DecimalTest, rescaleReal) { assertRescaleRealFail( INFINITY, DECIMAL(10, 2), "The input value should be finite."); } + +TEST(DecimalTest, maxStringViewSize) { + testMaxStringViewSize(10, 0, 11); + testMaxStringViewSize(10, 1, 12); + testMaxStringViewSize(10, 10, 13); +} + +TEST(DecimalTest, castToString) { + testcastToString(12, 10, 0, 11, "12"); + testcastToString(12, 10, 1, 12, "1.2"); + testcastToString(12, 10, 3, 12, "0.012"); + testcastToString(-12, 10, 3, 12, "-0.012"); + testcastToString(12, 5, 5, 8, "0.00012"); + testcastToString(-12, 5, 5, 8, "-0.00012"); + testcastToString(-12, 5, 5, 8, "-0.00012"); + testcastToString( + DecimalUtil::kShortDecimalMax, 18, 0, 19, std::string(18, '9')); + testcastToString( + DecimalUtil::kShortDecimalMin, 18, 0, 19, "-" + std::string(18, '9')); + + testcastToString( + HugeInt::parse("-18446744073709551616"), + 20, + 0, + 21, + "-18446744073709551616"); + + testcastToString( + HugeInt::parse("-18446744073709551616"), + 20, + 3, + 22, + "-18446744073709551.616"); + + testcastToString( + HugeInt::parse("-12345678901234567890"), + 20, + 20, + 23, + "-0.12345678901234567890"); + testcastToString( + DecimalUtil::kLongDecimalMax, 38, 0, 39, std::string(38, '9')); + testcastToString( + DecimalUtil::kLongDecimalMin, 38, 0, 39, "-" + std::string(38, '9')); +} } // namespace } // namespace facebook::velox