Skip to content

Commit ff6a6cb

Browse files
refactor: rewrite some UDFs to DataFusion style (part 3)
Signed-off-by: luofucong <[email protected]>
1 parent cbe0cf4 commit ff6a6cb

File tree

12 files changed

+513
-398
lines changed

12 files changed

+513
-398
lines changed

src/common/function/src/function.rs

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,18 @@
1212
// See the License for the specific language governing permissions and
1313
// limitations under the License.
1414

15+
use std::any::Any;
1516
use std::fmt;
17+
use std::fmt::{Debug, Formatter};
1618
use std::sync::Arc;
1719

1820
use common_error::ext::{BoxedError, PlainError};
1921
use common_error::status_code::StatusCode;
2022
use common_query::error::{ExecuteSnafu, Result};
2123
use datafusion::arrow::datatypes::DataType;
2224
use datafusion::logical_expr::ColumnarValue;
25+
use datafusion_common::DataFusionError;
26+
use datafusion_common::config::{ConfigEntry, ConfigExtension, ExtensionOptions};
2327
use datafusion_expr::{ScalarFunctionArgs, Signature};
2428
use datatypes::vectors::VectorRef;
2529
use session::context::{QueryContextBuilder, QueryContextRef};
@@ -60,6 +64,42 @@ impl Default for FunctionContext {
6064
}
6165
}
6266

67+
impl ExtensionOptions for FunctionContext {
68+
fn as_any(&self) -> &dyn Any {
69+
self
70+
}
71+
72+
fn as_any_mut(&mut self) -> &mut dyn Any {
73+
self
74+
}
75+
76+
fn cloned(&self) -> Box<dyn ExtensionOptions> {
77+
Box::new(self.clone())
78+
}
79+
80+
fn set(&mut self, _: &str, _: &str) -> datafusion_common::Result<()> {
81+
Err(DataFusionError::NotImplemented(
82+
"set options for `FunctionContext`".to_string(),
83+
))
84+
}
85+
86+
fn entries(&self) -> Vec<ConfigEntry> {
87+
vec![]
88+
}
89+
}
90+
91+
impl Debug for FunctionContext {
92+
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
93+
f.debug_struct("FunctionContext")
94+
.field("query_ctx", &self.query_ctx)
95+
.finish()
96+
}
97+
}
98+
99+
impl ConfigExtension for FunctionContext {
100+
const PREFIX: &'static str = "FunctionContext";
101+
}
102+
63103
/// Scalar function trait, modified from databend to adapt datafusion
64104
/// TODO(dennis): optimize function by it's features such as monotonicity etc.
65105
pub trait Function: fmt::Display + Sync + Send {

src/common/function/src/scalars/expression/is_null.rs

Lines changed: 27 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -16,17 +16,13 @@ use std::fmt;
1616
use std::fmt::Display;
1717
use std::sync::Arc;
1818

19-
use common_query::error;
20-
use common_query::error::{ArrowComputeSnafu, InvalidFuncArgsSnafu, Result};
21-
use datafusion::arrow::array::ArrayRef;
19+
use common_query::error::Result;
2220
use datafusion::arrow::compute::is_null;
2321
use datafusion::arrow::datatypes::DataType;
24-
use datafusion_expr::{Signature, Volatility};
25-
use datatypes::prelude::VectorRef;
26-
use datatypes::vectors::Helper;
27-
use snafu::{ResultExt, ensure};
22+
use datafusion_common::utils;
23+
use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, Signature, Volatility};
2824

29-
use crate::function::{Function, FunctionContext};
25+
use crate::function::Function;
3026

3127
const NAME: &str = "isnull";
3228

@@ -53,35 +49,25 @@ impl Function for IsNullFunction {
5349
Signature::any(1, Volatility::Immutable)
5450
}
5551

56-
fn eval(
52+
fn invoke_with_args(
5753
&self,
58-
_func_ctx: &FunctionContext,
59-
columns: &[VectorRef],
60-
) -> common_query::error::Result<VectorRef> {
61-
ensure!(
62-
columns.len() == 1,
63-
InvalidFuncArgsSnafu {
64-
err_msg: format!(
65-
"The length of the args is not correct, expect exactly one, have: {}",
66-
columns.len()
67-
),
68-
}
69-
);
70-
let values = &columns[0];
71-
let arrow_array = &values.to_arrow_array();
72-
let result = is_null(arrow_array).context(ArrowComputeSnafu)?;
54+
args: ScalarFunctionArgs,
55+
) -> datafusion_common::Result<ColumnarValue> {
56+
let args = ColumnarValue::values_to_arrays(&args.args)?;
57+
let [arg0] = utils::take_function_args(self.name(), args)?;
58+
let result = is_null(&arg0)?;
7359

74-
Helper::try_into_vector(Arc::new(result) as ArrayRef).context(error::FromArrowArraySnafu)
60+
Ok(ColumnarValue::Array(Arc::new(result)))
7561
}
7662
}
7763

7864
#[cfg(test)]
7965
mod tests {
8066
use std::sync::Arc;
8167

68+
use arrow_schema::Field;
69+
use datafusion_common::arrow::array::{AsArray, BooleanArray, Float32Array};
8270
use datafusion_expr::TypeSignature;
83-
use datatypes::scalars::ScalarVector;
84-
use datatypes::vectors::{BooleanVector, Float32Vector};
8571

8672
use super::*;
8773
#[test]
@@ -98,9 +84,20 @@ mod tests {
9884
);
9985
let values = vec![None, Some(3.0), None];
10086

101-
let args: Vec<VectorRef> = vec![Arc::new(Float32Vector::from(values))];
102-
let vector = is_null.eval(&FunctionContext::default(), &args).unwrap();
103-
let expect: VectorRef = Arc::new(BooleanVector::from_vec(vec![true, false, true]));
87+
let result = is_null
88+
.invoke_with_args(ScalarFunctionArgs {
89+
args: vec![ColumnarValue::Array(Arc::new(Float32Array::from(values)))],
90+
arg_fields: vec![],
91+
number_rows: 3,
92+
return_field: Arc::new(Field::new("", DataType::Boolean, false)),
93+
config_options: Arc::new(Default::default()),
94+
})
95+
.unwrap();
96+
let ColumnarValue::Array(result) = result else {
97+
unreachable!()
98+
};
99+
let vector = result.as_boolean();
100+
let expect = &BooleanArray::from(vec![true, false, true]);
104101
assert_eq!(expect, vector);
105102
}
106103
}

src/common/function/src/scalars/geo/helpers.rs

Lines changed: 6 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -15,63 +15,14 @@
1515
use datafusion::arrow::array::{ArrayRef, ArrowPrimitiveType};
1616
use datafusion::arrow::compute;
1717

18-
macro_rules! ensure_columns_len {
19-
($columns:ident) => {
20-
snafu::ensure!(
21-
$columns.windows(2).all(|c| c[0].len() == c[1].len()),
22-
common_query::error::InvalidFuncArgsSnafu {
23-
err_msg: "The length of input columns are in different size"
24-
}
25-
)
26-
};
27-
($column_a:ident, $column_b:ident, $($column_n:ident),*) => {
28-
snafu::ensure!(
29-
{
30-
let mut result = $column_a.len() == $column_b.len();
31-
$(
32-
result = result && ($column_a.len() == $column_n.len());
33-
)*
34-
result
35-
}
36-
common_query::error::InvalidFuncArgsSnafu {
37-
err_msg: "The length of input columns are in different size"
38-
}
39-
)
40-
};
41-
}
42-
43-
pub(crate) use ensure_columns_len;
44-
45-
macro_rules! ensure_columns_n {
46-
($columns:ident, $n:literal) => {
47-
snafu::ensure!(
48-
$columns.len() == $n,
49-
common_query::error::InvalidFuncArgsSnafu {
50-
err_msg: format!(
51-
"The length of arguments is not correct, expect {}, provided : {}",
52-
stringify!($n),
53-
$columns.len()
54-
),
55-
}
56-
);
57-
58-
if $n > 1 {
59-
ensure_columns_len!($columns);
60-
}
61-
};
62-
}
63-
64-
pub(crate) use ensure_columns_n;
65-
6618
macro_rules! ensure_and_coerce {
6719
($compare:expr, $coerce:expr) => {{
68-
snafu::ensure!(
69-
$compare,
70-
common_query::error::InvalidFuncArgsSnafu {
71-
err_msg: "Argument was outside of acceptable range "
72-
}
73-
);
74-
Ok($coerce)
20+
if !$compare {
21+
return Err(datafusion_common::DataFusionError::Execution(
22+
"argument out of valid range".to_string(),
23+
));
24+
}
25+
Ok(Some($coerce))
7526
}};
7627
}
7728

src/common/function/src/scalars/geo/measure.rs

Lines changed: 55 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -12,21 +12,23 @@
1212
// See the License for the specific language governing permissions and
1313
// limitations under the License.
1414

15+
use std::sync::Arc;
16+
1517
use common_error::ext::{BoxedError, PlainError};
1618
use common_error::status_code::StatusCode;
1719
use common_query::error::{self, Result};
18-
use datafusion::arrow::datatypes::DataType;
19-
use datafusion_expr::{Signature, Volatility};
20-
use datatypes::scalars::ScalarVectorBuilder;
21-
use datatypes::vectors::{Float64VectorBuilder, MutableVector, VectorRef};
20+
use datafusion_common::arrow::array::{Array, AsArray, Float64Builder};
21+
use datafusion_common::arrow::compute;
22+
use datafusion_common::arrow::datatypes::DataType;
23+
use datafusion_common::utils;
24+
use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, Signature, Volatility};
2225
use derive_more::Display;
2326
use geo::algorithm::line_measures::metric_spaces::Euclidean;
2427
use geo::{Area, Distance, Haversine};
2528
use geo_types::Geometry;
2629
use snafu::ResultExt;
2730

28-
use crate::function::{Function, FunctionContext};
29-
use crate::scalars::geo::helpers::{ensure_columns_len, ensure_columns_n};
31+
use crate::function::Function;
3032
use crate::scalars::geo::wkt::parse_wkt;
3133

3234
/// Return WGS84(SRID: 4326) euclidean distance between two geometry object, in degree
@@ -47,33 +49,39 @@ impl Function for STDistance {
4749
Signature::string(2, Volatility::Stable)
4850
}
4951

50-
fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
51-
ensure_columns_n!(columns, 2);
52+
fn invoke_with_args(
53+
&self,
54+
args: ScalarFunctionArgs,
55+
) -> datafusion_common::Result<ColumnarValue> {
56+
let args = ColumnarValue::values_to_arrays(&args.args)?;
57+
let [arg0, arg1] = utils::take_function_args(self.name(), args)?;
5258

53-
let wkt_this_vec = &columns[0];
54-
let wkt_that_vec = &columns[1];
59+
let arg0 = compute::cast(&arg0, &DataType::Utf8View)?;
60+
let wkt_this_vec = arg0.as_string_view();
61+
let arg1 = compute::cast(&arg1, &DataType::Utf8View)?;
62+
let wkt_that_vec = arg1.as_string_view();
5563

5664
let size = wkt_this_vec.len();
57-
let mut results = Float64VectorBuilder::with_capacity(size);
65+
let mut builder = Float64Builder::with_capacity(size);
5866

5967
for i in 0..size {
60-
let wkt_this = wkt_this_vec.get(i).as_string();
61-
let wkt_that = wkt_that_vec.get(i).as_string();
68+
let wkt_this = wkt_this_vec.is_valid(i).then(|| wkt_this_vec.value(i));
69+
let wkt_that = wkt_that_vec.is_valid(i).then(|| wkt_that_vec.value(i));
6270

6371
let result = match (wkt_this, wkt_that) {
6472
(Some(wkt_this), Some(wkt_that)) => {
65-
let geom_this = parse_wkt(&wkt_this)?;
66-
let geom_that = parse_wkt(&wkt_that)?;
73+
let geom_this = parse_wkt(wkt_this)?;
74+
let geom_that = parse_wkt(wkt_that)?;
6775

6876
Some(Euclidean::distance(&geom_this, &geom_that))
6977
}
7078
_ => None,
7179
};
7280

73-
results.push(result);
81+
builder.append_option(result);
7482
}
7583

76-
Ok(results.to_vector())
84+
Ok(ColumnarValue::Array(Arc::new(builder.finish())))
7785
}
7886
}
7987

@@ -95,23 +103,29 @@ impl Function for STDistanceSphere {
95103
Signature::string(2, Volatility::Stable)
96104
}
97105

98-
fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
99-
ensure_columns_n!(columns, 2);
106+
fn invoke_with_args(
107+
&self,
108+
args: ScalarFunctionArgs,
109+
) -> datafusion_common::Result<ColumnarValue> {
110+
let args = ColumnarValue::values_to_arrays(&args.args)?;
111+
let [arg0, arg1] = utils::take_function_args(self.name(), args)?;
100112

101-
let wkt_this_vec = &columns[0];
102-
let wkt_that_vec = &columns[1];
113+
let arg0 = compute::cast(&arg0, &DataType::Utf8View)?;
114+
let wkt_this_vec = arg0.as_string_view();
115+
let arg1 = compute::cast(&arg1, &DataType::Utf8View)?;
116+
let wkt_that_vec = arg1.as_string_view();
103117

104118
let size = wkt_this_vec.len();
105-
let mut results = Float64VectorBuilder::with_capacity(size);
119+
let mut builder = Float64Builder::with_capacity(size);
106120

107121
for i in 0..size {
108-
let wkt_this = wkt_this_vec.get(i).as_string();
109-
let wkt_that = wkt_that_vec.get(i).as_string();
122+
let wkt_this = wkt_this_vec.is_valid(i).then(|| wkt_this_vec.value(i));
123+
let wkt_that = wkt_that_vec.is_valid(i).then(|| wkt_that_vec.value(i));
110124

111125
let result = match (wkt_this, wkt_that) {
112126
(Some(wkt_this), Some(wkt_that)) => {
113-
let geom_this = parse_wkt(&wkt_this)?;
114-
let geom_that = parse_wkt(&wkt_that)?;
127+
let geom_this = parse_wkt(wkt_this)?;
128+
let geom_that = parse_wkt(wkt_that)?;
115129

116130
match (geom_this, geom_that) {
117131
(Geometry::Point(this), Geometry::Point(that)) => {
@@ -128,10 +142,10 @@ impl Function for STDistanceSphere {
128142
_ => None,
129143
};
130144

131-
results.push(result);
145+
builder.append_option(result);
132146
}
133147

134-
Ok(results.to_vector())
148+
Ok(ColumnarValue::Array(Arc::new(builder.finish())))
135149
}
136150
}
137151

@@ -153,27 +167,32 @@ impl Function for STArea {
153167
Signature::string(1, Volatility::Stable)
154168
}
155169

156-
fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
157-
ensure_columns_n!(columns, 1);
170+
fn invoke_with_args(
171+
&self,
172+
args: ScalarFunctionArgs,
173+
) -> datafusion_common::Result<ColumnarValue> {
174+
let args = ColumnarValue::values_to_arrays(&args.args)?;
175+
let [arg0] = utils::take_function_args(self.name(), args)?;
158176

159-
let wkt_vec = &columns[0];
177+
let arg0 = compute::cast(&arg0, &DataType::Utf8View)?;
178+
let wkt_vec = arg0.as_string_view();
160179

161180
let size = wkt_vec.len();
162-
let mut results = Float64VectorBuilder::with_capacity(size);
181+
let mut builder = Float64Builder::with_capacity(size);
163182

164183
for i in 0..size {
165-
let wkt = wkt_vec.get(i).as_string();
184+
let wkt = wkt_vec.is_valid(i).then(|| wkt_vec.value(i));
166185

167186
let result = if let Some(wkt) = wkt {
168-
let geom = parse_wkt(&wkt)?;
187+
let geom = parse_wkt(wkt)?;
169188
Some(geom.unsigned_area())
170189
} else {
171190
None
172191
};
173192

174-
results.push(result);
193+
builder.append_option(result);
175194
}
176195

177-
Ok(results.to_vector())
196+
Ok(ColumnarValue::Array(Arc::new(builder.finish())))
178197
}
179198
}

0 commit comments

Comments
 (0)