1
+ // Licensed to the Apache Software Foundation (ASF) under one
2
+ // or more contributor license agreements. See the NOTICE file
3
+ // distributed with this work for additional information
4
+ // regarding copyright ownership. The ASF licenses this file
5
+ // to you under the Apache License, Version 2.0 (the
6
+ // "License"); you may not use this file except in compliance
7
+ // with the License. You may obtain a copy of the License at
8
+ //
9
+ // http://www.apache.org/licenses/LICENSE-2.0
10
+ //
11
+ // Unless required by applicable law or agreed to in writing,
12
+ // software distributed under the License is distributed on an
13
+ // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14
+ // KIND, either express or implied. See the License for the
15
+ // specific language governing permissions and limitations
16
+ // under the License.
17
+
18
+ //! Utilities for casting scalar literals to different data types
19
+ //!
20
+ //! This module contains functions for casting ScalarValue literals
21
+ //! to different data types, originally extracted from the optimizer's
22
+ //! unwrap_cast module to be shared between logical and physical layers.
23
+
24
+ use std:: cmp:: Ordering ;
25
+
26
+ use arrow:: datatypes:: {
27
+ DataType , TimeUnit , MAX_DECIMAL128_FOR_EACH_PRECISION ,
28
+ MIN_DECIMAL128_FOR_EACH_PRECISION ,
29
+ } ;
30
+ use arrow:: temporal_conversions:: { MICROSECONDS , MILLISECONDS , NANOSECONDS } ;
31
+
32
+ use crate :: ScalarValue ;
33
+
34
+ /// Convert a literal value from one data type to another
35
+ pub fn try_cast_literal_to_type (
36
+ lit_value : & ScalarValue ,
37
+ target_type : & DataType ,
38
+ ) -> Option < ScalarValue > {
39
+ let lit_data_type = lit_value. data_type ( ) ;
40
+ if !is_supported_type ( & lit_data_type) || !is_supported_type ( target_type) {
41
+ return None ;
42
+ }
43
+ if lit_value. is_null ( ) {
44
+ // null value can be cast to any type of null value
45
+ return ScalarValue :: try_from ( target_type) . ok ( ) ;
46
+ }
47
+ try_cast_numeric_literal ( lit_value, target_type)
48
+ . or_else ( || try_cast_string_literal ( lit_value, target_type) )
49
+ . or_else ( || try_cast_dictionary ( lit_value, target_type) )
50
+ . or_else ( || try_cast_binary ( lit_value, target_type) )
51
+ }
52
+
53
+ /// Returns true if unwrap_cast_in_comparison supports this data type
54
+ pub fn is_supported_type ( data_type : & DataType ) -> bool {
55
+ is_supported_numeric_type ( data_type)
56
+ || is_supported_string_type ( data_type)
57
+ || is_supported_dictionary_type ( data_type)
58
+ || is_supported_binary_type ( data_type)
59
+ }
60
+
61
+ /// Returns true if unwrap_cast_in_comparison support this numeric type
62
+ pub fn is_supported_numeric_type ( data_type : & DataType ) -> bool {
63
+ matches ! (
64
+ data_type,
65
+ DataType :: UInt8
66
+ | DataType :: UInt16
67
+ | DataType :: UInt32
68
+ | DataType :: UInt64
69
+ | DataType :: Int8
70
+ | DataType :: Int16
71
+ | DataType :: Int32
72
+ | DataType :: Int64
73
+ | DataType :: Decimal128 ( _, _)
74
+ | DataType :: Timestamp ( _, _)
75
+ )
76
+ }
77
+
78
+ /// Returns true if unwrap_cast_in_comparison supports casting this value as a string
79
+ pub fn is_supported_string_type ( data_type : & DataType ) -> bool {
80
+ matches ! (
81
+ data_type,
82
+ DataType :: Utf8 | DataType :: LargeUtf8 | DataType :: Utf8View
83
+ )
84
+ }
85
+
86
+ /// Returns true if unwrap_cast_in_comparison supports casting this value as a dictionary
87
+ pub fn is_supported_dictionary_type ( data_type : & DataType ) -> bool {
88
+ matches ! ( data_type,
89
+ DataType :: Dictionary ( _, inner) if is_supported_type( inner) )
90
+ }
91
+
92
+ pub fn is_supported_binary_type ( data_type : & DataType ) -> bool {
93
+ matches ! ( data_type, DataType :: Binary | DataType :: FixedSizeBinary ( _) )
94
+ }
95
+
96
+ /// Convert a numeric value from one numeric data type to another
97
+ pub fn try_cast_numeric_literal (
98
+ lit_value : & ScalarValue ,
99
+ target_type : & DataType ,
100
+ ) -> Option < ScalarValue > {
101
+ let lit_data_type = lit_value. data_type ( ) ;
102
+ if !is_supported_numeric_type ( & lit_data_type)
103
+ || !is_supported_numeric_type ( target_type)
104
+ {
105
+ return None ;
106
+ }
107
+
108
+ let mul = match target_type {
109
+ DataType :: UInt8
110
+ | DataType :: UInt16
111
+ | DataType :: UInt32
112
+ | DataType :: UInt64
113
+ | DataType :: Int8
114
+ | DataType :: Int16
115
+ | DataType :: Int32
116
+ | DataType :: Int64 => 1_i128 ,
117
+ DataType :: Timestamp ( _, _) => 1_i128 ,
118
+ DataType :: Decimal128 ( _, scale) => 10_i128 . pow ( * scale as u32 ) ,
119
+ _ => return None ,
120
+ } ;
121
+ let ( target_min, target_max) = match target_type {
122
+ DataType :: UInt8 => ( u8:: MIN as i128 , u8:: MAX as i128 ) ,
123
+ DataType :: UInt16 => ( u16:: MIN as i128 , u16:: MAX as i128 ) ,
124
+ DataType :: UInt32 => ( u32:: MIN as i128 , u32:: MAX as i128 ) ,
125
+ DataType :: UInt64 => ( u64:: MIN as i128 , u64:: MAX as i128 ) ,
126
+ DataType :: Int8 => ( i8:: MIN as i128 , i8:: MAX as i128 ) ,
127
+ DataType :: Int16 => ( i16:: MIN as i128 , i16:: MAX as i128 ) ,
128
+ DataType :: Int32 => ( i32:: MIN as i128 , i32:: MAX as i128 ) ,
129
+ DataType :: Int64 => ( i64:: MIN as i128 , i64:: MAX as i128 ) ,
130
+ DataType :: Timestamp ( _, _) => ( i64:: MIN as i128 , i64:: MAX as i128 ) ,
131
+ DataType :: Decimal128 ( precision, _) => (
132
+ // Different precision for decimal128 can store different range of value.
133
+ // For example, the precision is 3, the max of value is `999` and the min
134
+ // value is `-999`
135
+ MIN_DECIMAL128_FOR_EACH_PRECISION [ * precision as usize ] ,
136
+ MAX_DECIMAL128_FOR_EACH_PRECISION [ * precision as usize ] ,
137
+ ) ,
138
+ _ => return None ,
139
+ } ;
140
+ let lit_value_target_type = match lit_value {
141
+ ScalarValue :: Int8 ( Some ( v) ) => ( * v as i128 ) . checked_mul ( mul) ,
142
+ ScalarValue :: Int16 ( Some ( v) ) => ( * v as i128 ) . checked_mul ( mul) ,
143
+ ScalarValue :: Int32 ( Some ( v) ) => ( * v as i128 ) . checked_mul ( mul) ,
144
+ ScalarValue :: Int64 ( Some ( v) ) => ( * v as i128 ) . checked_mul ( mul) ,
145
+ ScalarValue :: UInt8 ( Some ( v) ) => ( * v as i128 ) . checked_mul ( mul) ,
146
+ ScalarValue :: UInt16 ( Some ( v) ) => ( * v as i128 ) . checked_mul ( mul) ,
147
+ ScalarValue :: UInt32 ( Some ( v) ) => ( * v as i128 ) . checked_mul ( mul) ,
148
+ ScalarValue :: UInt64 ( Some ( v) ) => ( * v as i128 ) . checked_mul ( mul) ,
149
+ ScalarValue :: TimestampSecond ( Some ( v) , _) => ( * v as i128 ) . checked_mul ( mul) ,
150
+ ScalarValue :: TimestampMillisecond ( Some ( v) , _) => ( * v as i128 ) . checked_mul ( mul) ,
151
+ ScalarValue :: TimestampMicrosecond ( Some ( v) , _) => ( * v as i128 ) . checked_mul ( mul) ,
152
+ ScalarValue :: TimestampNanosecond ( Some ( v) , _) => ( * v as i128 ) . checked_mul ( mul) ,
153
+ ScalarValue :: Decimal128 ( Some ( v) , _, scale) => {
154
+ let lit_scale_mul = 10_i128 . pow ( * scale as u32 ) ;
155
+ if mul >= lit_scale_mul {
156
+ // Example:
157
+ // lit is decimal(123,3,2)
158
+ // target type is decimal(5,3)
159
+ // the lit can be converted to the decimal(1230,5,3)
160
+ ( * v) . checked_mul ( mul / lit_scale_mul)
161
+ } else if ( * v) % ( lit_scale_mul / mul) == 0 {
162
+ // Example:
163
+ // lit is decimal(123000,10,3)
164
+ // target type is int32: the lit can be converted to INT32(123)
165
+ // target type is decimal(10,2): the lit can be converted to decimal(12300,10,2)
166
+ Some ( * v / ( lit_scale_mul / mul) )
167
+ } else {
168
+ // can't convert the lit decimal to the target data type
169
+ None
170
+ }
171
+ }
172
+ _ => None ,
173
+ } ;
174
+
175
+ match lit_value_target_type {
176
+ None => None ,
177
+ Some ( value) => {
178
+ if value >= target_min && value <= target_max {
179
+ // the value casted from lit to the target type is in the range of target type.
180
+ // return the target type of scalar value
181
+ let result_scalar = match target_type {
182
+ DataType :: Int8 => ScalarValue :: Int8 ( Some ( value as i8 ) ) ,
183
+ DataType :: Int16 => ScalarValue :: Int16 ( Some ( value as i16 ) ) ,
184
+ DataType :: Int32 => ScalarValue :: Int32 ( Some ( value as i32 ) ) ,
185
+ DataType :: Int64 => ScalarValue :: Int64 ( Some ( value as i64 ) ) ,
186
+ DataType :: UInt8 => ScalarValue :: UInt8 ( Some ( value as u8 ) ) ,
187
+ DataType :: UInt16 => ScalarValue :: UInt16 ( Some ( value as u16 ) ) ,
188
+ DataType :: UInt32 => ScalarValue :: UInt32 ( Some ( value as u32 ) ) ,
189
+ DataType :: UInt64 => ScalarValue :: UInt64 ( Some ( value as u64 ) ) ,
190
+ DataType :: Timestamp ( TimeUnit :: Second , tz) => {
191
+ let value = cast_between_timestamp (
192
+ & lit_data_type,
193
+ & DataType :: Timestamp ( TimeUnit :: Second , tz. clone ( ) ) ,
194
+ value,
195
+ ) ;
196
+ ScalarValue :: TimestampSecond ( value, tz. clone ( ) )
197
+ }
198
+ DataType :: Timestamp ( TimeUnit :: Millisecond , tz) => {
199
+ let value = cast_between_timestamp (
200
+ & lit_data_type,
201
+ & DataType :: Timestamp ( TimeUnit :: Millisecond , tz. clone ( ) ) ,
202
+ value,
203
+ ) ;
204
+ ScalarValue :: TimestampMillisecond ( value, tz. clone ( ) )
205
+ }
206
+ DataType :: Timestamp ( TimeUnit :: Microsecond , tz) => {
207
+ let value = cast_between_timestamp (
208
+ & lit_data_type,
209
+ & DataType :: Timestamp ( TimeUnit :: Microsecond , tz. clone ( ) ) ,
210
+ value,
211
+ ) ;
212
+ ScalarValue :: TimestampMicrosecond ( value, tz. clone ( ) )
213
+ }
214
+ DataType :: Timestamp ( TimeUnit :: Nanosecond , tz) => {
215
+ let value = cast_between_timestamp (
216
+ & lit_data_type,
217
+ & DataType :: Timestamp ( TimeUnit :: Nanosecond , tz. clone ( ) ) ,
218
+ value,
219
+ ) ;
220
+ ScalarValue :: TimestampNanosecond ( value, tz. clone ( ) )
221
+ }
222
+ DataType :: Decimal128 ( p, s) => {
223
+ ScalarValue :: Decimal128 ( Some ( value) , * p, * s)
224
+ }
225
+ _ => {
226
+ return None ;
227
+ }
228
+ } ;
229
+ Some ( result_scalar)
230
+ } else {
231
+ None
232
+ }
233
+ }
234
+ }
235
+ }
236
+
237
+ pub fn try_cast_string_literal (
238
+ lit_value : & ScalarValue ,
239
+ target_type : & DataType ,
240
+ ) -> Option < ScalarValue > {
241
+ let string_value = lit_value. try_as_str ( ) ?. map ( |s| s. to_string ( ) ) ;
242
+ let scalar_value = match target_type {
243
+ DataType :: Utf8 => ScalarValue :: Utf8 ( string_value) ,
244
+ DataType :: LargeUtf8 => ScalarValue :: LargeUtf8 ( string_value) ,
245
+ DataType :: Utf8View => ScalarValue :: Utf8View ( string_value) ,
246
+ _ => return None ,
247
+ } ;
248
+ Some ( scalar_value)
249
+ }
250
+
251
+ /// Attempt to cast to/from a dictionary type by wrapping/unwrapping the dictionary
252
+ pub fn try_cast_dictionary (
253
+ lit_value : & ScalarValue ,
254
+ target_type : & DataType ,
255
+ ) -> Option < ScalarValue > {
256
+ let lit_value_type = lit_value. data_type ( ) ;
257
+ let result_scalar = match ( lit_value, target_type) {
258
+ // Unwrap dictionary when inner type matches target type
259
+ ( ScalarValue :: Dictionary ( _, inner_value) , _)
260
+ if inner_value. data_type ( ) == * target_type =>
261
+ {
262
+ ( * * inner_value) . clone ( )
263
+ }
264
+ // Wrap type when target type is dictionary
265
+ ( _, DataType :: Dictionary ( index_type, inner_type) )
266
+ if * * inner_type == lit_value_type =>
267
+ {
268
+ ScalarValue :: Dictionary ( index_type. clone ( ) , Box :: new ( lit_value. clone ( ) ) )
269
+ }
270
+ _ => {
271
+ return None ;
272
+ }
273
+ } ;
274
+ Some ( result_scalar)
275
+ }
276
+
277
+ /// Cast a timestamp value from one unit to another
278
+ pub fn cast_between_timestamp ( from : & DataType , to : & DataType , value : i128 ) -> Option < i64 > {
279
+ let value = value as i64 ;
280
+ let from_scale = match from {
281
+ DataType :: Timestamp ( TimeUnit :: Second , _) => 1 ,
282
+ DataType :: Timestamp ( TimeUnit :: Millisecond , _) => MILLISECONDS ,
283
+ DataType :: Timestamp ( TimeUnit :: Microsecond , _) => MICROSECONDS ,
284
+ DataType :: Timestamp ( TimeUnit :: Nanosecond , _) => NANOSECONDS ,
285
+ _ => return Some ( value) ,
286
+ } ;
287
+
288
+ let to_scale = match to {
289
+ DataType :: Timestamp ( TimeUnit :: Second , _) => 1 ,
290
+ DataType :: Timestamp ( TimeUnit :: Millisecond , _) => MILLISECONDS ,
291
+ DataType :: Timestamp ( TimeUnit :: Microsecond , _) => MICROSECONDS ,
292
+ DataType :: Timestamp ( TimeUnit :: Nanosecond , _) => NANOSECONDS ,
293
+ _ => return Some ( value) ,
294
+ } ;
295
+
296
+ match from_scale. cmp ( & to_scale) {
297
+ Ordering :: Less => value. checked_mul ( to_scale / from_scale) ,
298
+ Ordering :: Greater => Some ( value / ( from_scale / to_scale) ) ,
299
+ Ordering :: Equal => Some ( value) ,
300
+ }
301
+ }
302
+
303
+ pub fn try_cast_binary (
304
+ lit_value : & ScalarValue ,
305
+ target_type : & DataType ,
306
+ ) -> Option < ScalarValue > {
307
+ match ( lit_value, target_type) {
308
+ ( ScalarValue :: Binary ( Some ( v) ) , DataType :: FixedSizeBinary ( n) )
309
+ if v. len ( ) == * n as usize =>
310
+ {
311
+ Some ( ScalarValue :: FixedSizeBinary ( * n, Some ( v. clone ( ) ) ) )
312
+ }
313
+ _ => None ,
314
+ }
315
+ }
0 commit comments