Skip to content

Commit 4e2af3e

Browse files
committed
move to shared impl
1 parent 7a1bb68 commit 4e2af3e

File tree

4 files changed

+397
-800
lines changed

4 files changed

+397
-800
lines changed

datafusion/common/src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ pub mod parsers;
5151
pub mod pruning;
5252
pub mod rounding;
5353
pub mod scalar;
54+
pub mod scalar_literal_cast;
5455
pub mod spans;
5556
pub mod stats;
5657
pub mod test_util;
Lines changed: 315 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,315 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
//! Utilities for casting scalar literals to different data types
19+
//!
20+
//! This module contains functions for casting ScalarValue literals
21+
//! to different data types, originally extracted from the optimizer's
22+
//! unwrap_cast module to be shared between logical and physical layers.
23+
24+
use std::cmp::Ordering;
25+
26+
use arrow::datatypes::{
27+
DataType, TimeUnit, MAX_DECIMAL128_FOR_EACH_PRECISION,
28+
MIN_DECIMAL128_FOR_EACH_PRECISION,
29+
};
30+
use arrow::temporal_conversions::{MICROSECONDS, MILLISECONDS, NANOSECONDS};
31+
32+
use crate::ScalarValue;
33+
34+
/// Convert a literal value from one data type to another
35+
pub fn try_cast_literal_to_type(
36+
lit_value: &ScalarValue,
37+
target_type: &DataType,
38+
) -> Option<ScalarValue> {
39+
let lit_data_type = lit_value.data_type();
40+
if !is_supported_type(&lit_data_type) || !is_supported_type(target_type) {
41+
return None;
42+
}
43+
if lit_value.is_null() {
44+
// null value can be cast to any type of null value
45+
return ScalarValue::try_from(target_type).ok();
46+
}
47+
try_cast_numeric_literal(lit_value, target_type)
48+
.or_else(|| try_cast_string_literal(lit_value, target_type))
49+
.or_else(|| try_cast_dictionary(lit_value, target_type))
50+
.or_else(|| try_cast_binary(lit_value, target_type))
51+
}
52+
53+
/// Returns true if unwrap_cast_in_comparison supports this data type
54+
pub fn is_supported_type(data_type: &DataType) -> bool {
55+
is_supported_numeric_type(data_type)
56+
|| is_supported_string_type(data_type)
57+
|| is_supported_dictionary_type(data_type)
58+
|| is_supported_binary_type(data_type)
59+
}
60+
61+
/// Returns true if unwrap_cast_in_comparison support this numeric type
62+
pub fn is_supported_numeric_type(data_type: &DataType) -> bool {
63+
matches!(
64+
data_type,
65+
DataType::UInt8
66+
| DataType::UInt16
67+
| DataType::UInt32
68+
| DataType::UInt64
69+
| DataType::Int8
70+
| DataType::Int16
71+
| DataType::Int32
72+
| DataType::Int64
73+
| DataType::Decimal128(_, _)
74+
| DataType::Timestamp(_, _)
75+
)
76+
}
77+
78+
/// Returns true if unwrap_cast_in_comparison supports casting this value as a string
79+
pub fn is_supported_string_type(data_type: &DataType) -> bool {
80+
matches!(
81+
data_type,
82+
DataType::Utf8 | DataType::LargeUtf8 | DataType::Utf8View
83+
)
84+
}
85+
86+
/// Returns true if unwrap_cast_in_comparison supports casting this value as a dictionary
87+
pub fn is_supported_dictionary_type(data_type: &DataType) -> bool {
88+
matches!(data_type,
89+
DataType::Dictionary(_, inner) if is_supported_type(inner))
90+
}
91+
92+
pub fn is_supported_binary_type(data_type: &DataType) -> bool {
93+
matches!(data_type, DataType::Binary | DataType::FixedSizeBinary(_))
94+
}
95+
96+
/// Convert a numeric value from one numeric data type to another
97+
pub fn try_cast_numeric_literal(
98+
lit_value: &ScalarValue,
99+
target_type: &DataType,
100+
) -> Option<ScalarValue> {
101+
let lit_data_type = lit_value.data_type();
102+
if !is_supported_numeric_type(&lit_data_type)
103+
|| !is_supported_numeric_type(target_type)
104+
{
105+
return None;
106+
}
107+
108+
let mul = match target_type {
109+
DataType::UInt8
110+
| DataType::UInt16
111+
| DataType::UInt32
112+
| DataType::UInt64
113+
| DataType::Int8
114+
| DataType::Int16
115+
| DataType::Int32
116+
| DataType::Int64 => 1_i128,
117+
DataType::Timestamp(_, _) => 1_i128,
118+
DataType::Decimal128(_, scale) => 10_i128.pow(*scale as u32),
119+
_ => return None,
120+
};
121+
let (target_min, target_max) = match target_type {
122+
DataType::UInt8 => (u8::MIN as i128, u8::MAX as i128),
123+
DataType::UInt16 => (u16::MIN as i128, u16::MAX as i128),
124+
DataType::UInt32 => (u32::MIN as i128, u32::MAX as i128),
125+
DataType::UInt64 => (u64::MIN as i128, u64::MAX as i128),
126+
DataType::Int8 => (i8::MIN as i128, i8::MAX as i128),
127+
DataType::Int16 => (i16::MIN as i128, i16::MAX as i128),
128+
DataType::Int32 => (i32::MIN as i128, i32::MAX as i128),
129+
DataType::Int64 => (i64::MIN as i128, i64::MAX as i128),
130+
DataType::Timestamp(_, _) => (i64::MIN as i128, i64::MAX as i128),
131+
DataType::Decimal128(precision, _) => (
132+
// Different precision for decimal128 can store different range of value.
133+
// For example, the precision is 3, the max of value is `999` and the min
134+
// value is `-999`
135+
MIN_DECIMAL128_FOR_EACH_PRECISION[*precision as usize],
136+
MAX_DECIMAL128_FOR_EACH_PRECISION[*precision as usize],
137+
),
138+
_ => return None,
139+
};
140+
let lit_value_target_type = match lit_value {
141+
ScalarValue::Int8(Some(v)) => (*v as i128).checked_mul(mul),
142+
ScalarValue::Int16(Some(v)) => (*v as i128).checked_mul(mul),
143+
ScalarValue::Int32(Some(v)) => (*v as i128).checked_mul(mul),
144+
ScalarValue::Int64(Some(v)) => (*v as i128).checked_mul(mul),
145+
ScalarValue::UInt8(Some(v)) => (*v as i128).checked_mul(mul),
146+
ScalarValue::UInt16(Some(v)) => (*v as i128).checked_mul(mul),
147+
ScalarValue::UInt32(Some(v)) => (*v as i128).checked_mul(mul),
148+
ScalarValue::UInt64(Some(v)) => (*v as i128).checked_mul(mul),
149+
ScalarValue::TimestampSecond(Some(v), _) => (*v as i128).checked_mul(mul),
150+
ScalarValue::TimestampMillisecond(Some(v), _) => (*v as i128).checked_mul(mul),
151+
ScalarValue::TimestampMicrosecond(Some(v), _) => (*v as i128).checked_mul(mul),
152+
ScalarValue::TimestampNanosecond(Some(v), _) => (*v as i128).checked_mul(mul),
153+
ScalarValue::Decimal128(Some(v), _, scale) => {
154+
let lit_scale_mul = 10_i128.pow(*scale as u32);
155+
if mul >= lit_scale_mul {
156+
// Example:
157+
// lit is decimal(123,3,2)
158+
// target type is decimal(5,3)
159+
// the lit can be converted to the decimal(1230,5,3)
160+
(*v).checked_mul(mul / lit_scale_mul)
161+
} else if (*v) % (lit_scale_mul / mul) == 0 {
162+
// Example:
163+
// lit is decimal(123000,10,3)
164+
// target type is int32: the lit can be converted to INT32(123)
165+
// target type is decimal(10,2): the lit can be converted to decimal(12300,10,2)
166+
Some(*v / (lit_scale_mul / mul))
167+
} else {
168+
// can't convert the lit decimal to the target data type
169+
None
170+
}
171+
}
172+
_ => None,
173+
};
174+
175+
match lit_value_target_type {
176+
None => None,
177+
Some(value) => {
178+
if value >= target_min && value <= target_max {
179+
// the value casted from lit to the target type is in the range of target type.
180+
// return the target type of scalar value
181+
let result_scalar = match target_type {
182+
DataType::Int8 => ScalarValue::Int8(Some(value as i8)),
183+
DataType::Int16 => ScalarValue::Int16(Some(value as i16)),
184+
DataType::Int32 => ScalarValue::Int32(Some(value as i32)),
185+
DataType::Int64 => ScalarValue::Int64(Some(value as i64)),
186+
DataType::UInt8 => ScalarValue::UInt8(Some(value as u8)),
187+
DataType::UInt16 => ScalarValue::UInt16(Some(value as u16)),
188+
DataType::UInt32 => ScalarValue::UInt32(Some(value as u32)),
189+
DataType::UInt64 => ScalarValue::UInt64(Some(value as u64)),
190+
DataType::Timestamp(TimeUnit::Second, tz) => {
191+
let value = cast_between_timestamp(
192+
&lit_data_type,
193+
&DataType::Timestamp(TimeUnit::Second, tz.clone()),
194+
value,
195+
);
196+
ScalarValue::TimestampSecond(value, tz.clone())
197+
}
198+
DataType::Timestamp(TimeUnit::Millisecond, tz) => {
199+
let value = cast_between_timestamp(
200+
&lit_data_type,
201+
&DataType::Timestamp(TimeUnit::Millisecond, tz.clone()),
202+
value,
203+
);
204+
ScalarValue::TimestampMillisecond(value, tz.clone())
205+
}
206+
DataType::Timestamp(TimeUnit::Microsecond, tz) => {
207+
let value = cast_between_timestamp(
208+
&lit_data_type,
209+
&DataType::Timestamp(TimeUnit::Microsecond, tz.clone()),
210+
value,
211+
);
212+
ScalarValue::TimestampMicrosecond(value, tz.clone())
213+
}
214+
DataType::Timestamp(TimeUnit::Nanosecond, tz) => {
215+
let value = cast_between_timestamp(
216+
&lit_data_type,
217+
&DataType::Timestamp(TimeUnit::Nanosecond, tz.clone()),
218+
value,
219+
);
220+
ScalarValue::TimestampNanosecond(value, tz.clone())
221+
}
222+
DataType::Decimal128(p, s) => {
223+
ScalarValue::Decimal128(Some(value), *p, *s)
224+
}
225+
_ => {
226+
return None;
227+
}
228+
};
229+
Some(result_scalar)
230+
} else {
231+
None
232+
}
233+
}
234+
}
235+
}
236+
237+
pub fn try_cast_string_literal(
238+
lit_value: &ScalarValue,
239+
target_type: &DataType,
240+
) -> Option<ScalarValue> {
241+
let string_value = lit_value.try_as_str()?.map(|s| s.to_string());
242+
let scalar_value = match target_type {
243+
DataType::Utf8 => ScalarValue::Utf8(string_value),
244+
DataType::LargeUtf8 => ScalarValue::LargeUtf8(string_value),
245+
DataType::Utf8View => ScalarValue::Utf8View(string_value),
246+
_ => return None,
247+
};
248+
Some(scalar_value)
249+
}
250+
251+
/// Attempt to cast to/from a dictionary type by wrapping/unwrapping the dictionary
252+
pub fn try_cast_dictionary(
253+
lit_value: &ScalarValue,
254+
target_type: &DataType,
255+
) -> Option<ScalarValue> {
256+
let lit_value_type = lit_value.data_type();
257+
let result_scalar = match (lit_value, target_type) {
258+
// Unwrap dictionary when inner type matches target type
259+
(ScalarValue::Dictionary(_, inner_value), _)
260+
if inner_value.data_type() == *target_type =>
261+
{
262+
(**inner_value).clone()
263+
}
264+
// Wrap type when target type is dictionary
265+
(_, DataType::Dictionary(index_type, inner_type))
266+
if **inner_type == lit_value_type =>
267+
{
268+
ScalarValue::Dictionary(index_type.clone(), Box::new(lit_value.clone()))
269+
}
270+
_ => {
271+
return None;
272+
}
273+
};
274+
Some(result_scalar)
275+
}
276+
277+
/// Cast a timestamp value from one unit to another
278+
pub fn cast_between_timestamp(from: &DataType, to: &DataType, value: i128) -> Option<i64> {
279+
let value = value as i64;
280+
let from_scale = match from {
281+
DataType::Timestamp(TimeUnit::Second, _) => 1,
282+
DataType::Timestamp(TimeUnit::Millisecond, _) => MILLISECONDS,
283+
DataType::Timestamp(TimeUnit::Microsecond, _) => MICROSECONDS,
284+
DataType::Timestamp(TimeUnit::Nanosecond, _) => NANOSECONDS,
285+
_ => return Some(value),
286+
};
287+
288+
let to_scale = match to {
289+
DataType::Timestamp(TimeUnit::Second, _) => 1,
290+
DataType::Timestamp(TimeUnit::Millisecond, _) => MILLISECONDS,
291+
DataType::Timestamp(TimeUnit::Microsecond, _) => MICROSECONDS,
292+
DataType::Timestamp(TimeUnit::Nanosecond, _) => NANOSECONDS,
293+
_ => return Some(value),
294+
};
295+
296+
match from_scale.cmp(&to_scale) {
297+
Ordering::Less => value.checked_mul(to_scale / from_scale),
298+
Ordering::Greater => Some(value / (from_scale / to_scale)),
299+
Ordering::Equal => Some(value),
300+
}
301+
}
302+
303+
pub fn try_cast_binary(
304+
lit_value: &ScalarValue,
305+
target_type: &DataType,
306+
) -> Option<ScalarValue> {
307+
match (lit_value, target_type) {
308+
(ScalarValue::Binary(Some(v)), DataType::FixedSizeBinary(n))
309+
if v.len() == *n as usize =>
310+
{
311+
Some(ScalarValue::FixedSizeBinary(*n, Some(v.clone())))
312+
}
313+
_ => None,
314+
}
315+
}

0 commit comments

Comments
 (0)