Skip to content

Commit bbd54d4

Browse files
committed
add rounding logic and scale zero fix
1 parent 296e0fd commit bbd54d4

File tree

3 files changed

+40
-13
lines changed

3 files changed

+40
-13
lines changed

arrow-cast/src/cast/decimal.rs

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -603,21 +603,18 @@ mod tests {
603603
12300000_i128
604604
);
605605

606-
// `parse_decimal` does not handle scale=0 correctly. will enable it as part of code change PR.
607-
// assert_eq!(parse_decimal::<Decimal128Type>("123.45", 38, 0)?, 123_i128);
606+
assert_eq!(parse_decimal::<Decimal128Type>("123.45", 38, 0)?, 123_i128);
608607
assert_eq!(
609608
parse_decimal::<Decimal128Type>("123.45", 38, 5)?,
610609
12345000_i128
611610
);
612-
613-
//scale = 0 is not handled correctly in parse_decimal, next PR will fix it and enable this.
614-
/*assert_eq!(
611+
assert_eq!(
615612
parse_decimal::<Decimal128Type>("123.4567891", 38, 0)?,
616613
123_i128
617-
);*/
614+
);
618615
assert_eq!(
619616
parse_decimal::<Decimal128Type>("123.4567891", 38, 5)?,
620-
12345678_i128
617+
12345679_i128
621618
);
622619
Ok(())
623620
}

arrow-cast/src/cast/mod.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8453,15 +8453,15 @@ mod tests {
84538453
38,
84548454
2,
84558455
),
8456-
"0.12"
8456+
"0.13"
84578457
);
84588458
assert_eq!(
84598459
Decimal128Type::format_decimal(
84608460
parse_decimal::<Decimal128Type>(".1265", 38, 2).unwrap(),
84618461
38,
84628462
2,
84638463
),
8464-
"0.12"
8464+
"0.13"
84658465
);
84668466

84678467
assert_eq!(
@@ -8502,7 +8502,7 @@ mod tests {
85028502
38,
85038503
3,
85048504
),
8505-
"0.126"
8505+
"0.127"
85068506
);
85078507
}
85088508

arrow-cast/src/parse.rs

Lines changed: 33 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -850,7 +850,16 @@ fn parse_e_notation<T: DecimalType>(
850850
}
851851

852852
if exp < 0 {
853-
result = result.div_wrapping(base.pow_wrapping(-exp as _));
853+
let result_with_scale = result.div_wrapping(base.pow_wrapping(-exp as _));
854+
let result_with_one_scale_up =
855+
result.div_wrapping(base.pow_wrapping(-exp.add_wrapping(1) as _));
856+
let rounding_digit =
857+
result_with_one_scale_up.sub_wrapping(result_with_scale.mul_wrapping(base));
858+
if rounding_digit >= T::Native::usize_as(5) {
859+
result = result_with_scale.add_wrapping(T::Native::usize_as(1));
860+
} else {
861+
result = result_with_scale;
862+
}
854863
} else {
855864
result = result.mul_wrapping(base.pow_wrapping(exp as _));
856865
}
@@ -868,6 +877,7 @@ pub fn parse_decimal<T: DecimalType>(
868877
let mut result = T::Native::usize_as(0);
869878
let mut fractionals: i8 = 0;
870879
let mut digits: u8 = 0;
880+
let mut rounding_digit = -1; // to store digit after the scale for rounding
871881
let base = T::Native::usize_as(10);
872882

873883
let bs = s.as_bytes();
@@ -897,6 +907,13 @@ pub fn parse_decimal<T: DecimalType>(
897907
// Ignore leading zeros.
898908
continue;
899909
}
910+
if fractionals == scale && scale != 0 {
911+
// Capture the rounding digit once
912+
if rounding_digit < 0 {
913+
rounding_digit = (b - b'0') as i8;
914+
}
915+
continue;
916+
}
900917
digits += 1;
901918
result = result.mul_wrapping(base);
902919
result = result.add_wrapping(T::Native::usize_as((b - b'0') as usize));
@@ -925,11 +942,17 @@ pub fn parse_decimal<T: DecimalType>(
925942
"can't parse the string value {s} to decimal"
926943
)));
927944
}
928-
if fractionals == scale && scale != 0 {
945+
if fractionals == scale {
946+
// Capture the rounding digit once
947+
if rounding_digit < 0 {
948+
rounding_digit = (b - b'0') as i8;
949+
}
929950
// We have processed all the digits that we need. All that
930951
// is left is to validate that the rest of the string contains
931952
// valid digits.
932-
continue;
953+
if scale != 0 {
954+
continue;
955+
}
933956
}
934957
fractionals += 1;
935958
digits += 1;
@@ -986,6 +1009,13 @@ pub fn parse_decimal<T: DecimalType>(
9861009
"parse decimal overflow ({s})"
9871010
)));
9881011
}
1012+
if scale == 0 {
1013+
result = result.div_wrapping(base.pow_wrapping(fractionals as u32))
1014+
}
1015+
//add one if >=5
1016+
if rounding_digit >= 5 {
1017+
result = result.add_wrapping(T::Native::usize_as(1));
1018+
}
9891019
}
9901020

9911021
Ok(if negative {

0 commit comments

Comments
 (0)