Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 29 additions & 11 deletions src/common/function/src/aggrs/aggr_wrapper/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ use std::pin::Pin;
use std::sync::{Arc, Mutex};
use std::task::{Context, Poll};

use arrow::array::{ArrayRef, Float64Array, Int64Array, UInt64Array};
use arrow::record_batch::RecordBatch;
use arrow_schema::SchemaRef;
use datafusion::catalog::{Session, TableProvider};
Expand All @@ -32,10 +31,14 @@ use datafusion::physical_plan::execution_plan::{Boundedness, EmissionType};
use datafusion::physical_plan::{DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties};
use datafusion::physical_planner::{DefaultPhysicalPlanner, PhysicalPlanner};
use datafusion::prelude::SessionContext;
use datafusion_common::arrow::array::{ArrayRef, AsArray, Float64Array, Int64Array, UInt64Array};
use datafusion_common::arrow::datatypes::{Float64Type, UInt64Type};
use datafusion_common::{Column, TableReference};
use datafusion_expr::expr::AggregateFunction;
use datafusion_expr::sqlparser::ast::NullTreatment;
use datafusion_expr::{Aggregate, Expr, LogicalPlan, SortExpr, TableScan, lit};
use datafusion_expr::{
Aggregate, ColumnarValue, Expr, LogicalPlan, ScalarFunctionArgs, SortExpr, TableScan, lit,
};
use datafusion_physical_expr::aggregate::AggregateExprBuilder;
use datafusion_physical_expr::{EquivalenceProperties, Partitioning};
use datatypes::arrow_array::StringArray;
Expand Down Expand Up @@ -649,14 +652,20 @@ async fn test_udaf_correct_eval_result() {
expected_output: None,
expected_fn: Some(|arr| {
let percent = ScalarValue::Float64(Some(0.5)).to_array().unwrap();
let percent = datatypes::vectors::Helper::try_into_vector(percent).unwrap();
let state = datatypes::vectors::Helper::try_into_vector(arr).unwrap();
let percent = ColumnarValue::Array(percent);
let state = ColumnarValue::Array(arr);
let udd_calc = UddSketchCalcFunction;
let res = udd_calc
.eval(&Default::default(), &[percent, state])
.invoke_with_args(ScalarFunctionArgs {
args: vec![percent, state],
arg_fields: vec![],
number_rows: 1,
return_field: Arc::new(Field::new("x", DataType::Float64, false)),
config_options: Arc::new(Default::default()),
})
.unwrap();
let binding = res.to_arrow_array();
let res_arr = binding.as_any().downcast_ref::<Float64Array>().unwrap();
let binding = res.to_array(1).unwrap();
let res_arr = binding.as_primitive::<Float64Type>();
assert!(res_arr.len() == 1);
assert!((res_arr.value(0) - 2.856578984907706f64).abs() <= f64::EPSILON);
true
Expand All @@ -683,11 +692,20 @@ async fn test_udaf_correct_eval_result() {
]))],
expected_output: None,
expected_fn: Some(|arr| {
let state = datatypes::vectors::Helper::try_into_vector(arr).unwrap();
let number_rows = arr.len();
let state = ColumnarValue::Array(arr);
let hll_calc = HllCalcFunction;
let res = hll_calc.eval(&Default::default(), &[state]).unwrap();
let binding = res.to_arrow_array();
let res_arr = binding.as_any().downcast_ref::<UInt64Array>().unwrap();
let res = hll_calc
.invoke_with_args(ScalarFunctionArgs {
args: vec![state],
arg_fields: vec![],
number_rows,
return_field: Arc::new(Field::new("x", DataType::UInt64, false)),
config_options: Arc::new(Default::default()),
})
.unwrap();
let binding = res.to_array(1).unwrap();
let res_arr = binding.as_primitive::<UInt64Type>();
assert!(res_arr.len() == 1);
assert_eq!(res_arr.value(0), 3);
true
Expand Down
64 changes: 64 additions & 0 deletions src/common/function/src/function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,19 @@
// See the License for the specific language governing permissions and
// limitations under the License.

use std::any::Any;
use std::fmt;
use std::fmt::{Debug, Formatter};
use std::sync::Arc;

use common_error::ext::{BoxedError, PlainError};
use common_error::status_code::StatusCode;
use common_query::error::{ExecuteSnafu, Result};
use datafusion::arrow::datatypes::DataType;
use datafusion::logical_expr::ColumnarValue;
use datafusion_common::DataFusionError;
use datafusion_common::arrow::array::ArrayRef;
use datafusion_common::config::{ConfigEntry, ConfigExtension, ExtensionOptions};
use datafusion_expr::{ScalarFunctionArgs, Signature};
use datatypes::vectors::VectorRef;
use session::context::{QueryContextBuilder, QueryContextRef};
Expand Down Expand Up @@ -60,6 +65,42 @@ impl Default for FunctionContext {
}
}

impl ExtensionOptions for FunctionContext {
fn as_any(&self) -> &dyn Any {
self
}

fn as_any_mut(&mut self) -> &mut dyn Any {
self
}

fn cloned(&self) -> Box<dyn ExtensionOptions> {
Box::new(self.clone())
}

fn set(&mut self, _: &str, _: &str) -> datafusion_common::Result<()> {
Err(DataFusionError::NotImplemented(
"set options for `FunctionContext`".to_string(),
))
}

fn entries(&self) -> Vec<ConfigEntry> {
vec![]
}
}

impl Debug for FunctionContext {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
f.debug_struct("FunctionContext")
.field("query_ctx", &self.query_ctx)
.finish()
}
}

impl ConfigExtension for FunctionContext {
const PREFIX: &'static str = "FunctionContext";
}

/// Scalar function trait, modified from databend to adapt datafusion
/// TODO(dennis): optimize function by it's features such as monotonicity etc.
pub trait Function: fmt::Display + Sync + Send {
Expand Down Expand Up @@ -99,3 +140,26 @@ pub trait Function: fmt::Display + Sync + Send {
}

pub type FunctionRef = Arc<dyn Function>;

/// Find the [FunctionContext] in the [ScalarFunctionArgs]. The [FunctionContext] was set
/// previously in the DataFusion session context creation, and is passed all the way down to the
/// args by DataFusion.
pub(crate) fn find_function_context(
args: &ScalarFunctionArgs,
) -> datafusion_common::Result<&FunctionContext> {
let Some(x) = args.config_options.extensions.get::<FunctionContext>() else {
return Err(DataFusionError::Execution(
"function context is not set".to_string(),
));
};
Ok(x)
}

/// Extract UDF arguments (as Arrow's [ArrayRef]) from [ScalarFunctionArgs] directly.
pub(crate) fn extract_args<const N: usize>(
name: &str,
args: &ScalarFunctionArgs,
) -> datafusion_common::Result<[ArrayRef; N]> {
ColumnarValue::values_to_arrays(&args.args)
.and_then(|x| datafusion_common::utils::take_function_args(name, x))
}
6 changes: 2 additions & 4 deletions src/common/function/src/scalars/date/date_add.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,12 @@ use std::fmt;

use common_query::error::{ArrowComputeSnafu, Result};
use datafusion::logical_expr::ColumnarValue;
use datafusion_common::utils;
use datafusion_expr::{ScalarFunctionArgs, Signature};
use datatypes::arrow::compute::kernels::numeric;
use datatypes::arrow::datatypes::{DataType, IntervalUnit, TimeUnit};
use snafu::ResultExt;

use crate::function::Function;
use crate::function::{Function, extract_args};
use crate::helper;

/// A function adds an interval value to Timestamp, Date, and return the result.
Expand Down Expand Up @@ -63,8 +62,7 @@ impl Function for DateAddFunction {
&self,
args: ScalarFunctionArgs,
) -> datafusion_common::Result<ColumnarValue> {
let args = ColumnarValue::values_to_arrays(&args.args)?;
let [left, right] = utils::take_function_args(self.name(), args)?;
let [left, right] = extract_args(self.name(), &args)?;

let result = numeric::add(&left, &right).context(ArrowComputeSnafu)?;
Ok(ColumnarValue::Array(result))
Expand Down
6 changes: 2 additions & 4 deletions src/common/function/src/scalars/date/date_sub.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,12 @@ use std::fmt;

use common_query::error::{ArrowComputeSnafu, Result};
use datafusion::logical_expr::ColumnarValue;
use datafusion_common::utils;
use datafusion_expr::{ScalarFunctionArgs, Signature};
use datatypes::arrow::compute::kernels::numeric;
use datatypes::arrow::datatypes::{DataType, IntervalUnit, TimeUnit};
use snafu::ResultExt;

use crate::function::Function;
use crate::function::{Function, extract_args};
use crate::helper;

/// A function subtracts an interval value to Timestamp, Date, and return the result.
Expand Down Expand Up @@ -63,8 +62,7 @@ impl Function for DateSubFunction {
&self,
args: ScalarFunctionArgs,
) -> datafusion_common::Result<ColumnarValue> {
let args = ColumnarValue::values_to_arrays(&args.args)?;
let [left, right] = utils::take_function_args(self.name(), args)?;
let [left, right] = extract_args(self.name(), &args)?;

let result = numeric::sub(&left, &right).context(ArrowComputeSnafu)?;
Ok(ColumnarValue::Array(result))
Expand Down
55 changes: 25 additions & 30 deletions src/common/function/src/scalars/expression/is_null.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,12 @@ use std::fmt;
use std::fmt::Display;
use std::sync::Arc;

use common_query::error;
use common_query::error::{ArrowComputeSnafu, InvalidFuncArgsSnafu, Result};
use datafusion::arrow::array::ArrayRef;
use common_query::error::Result;
use datafusion::arrow::compute::is_null;
use datafusion::arrow::datatypes::DataType;
use datafusion_expr::{Signature, Volatility};
use datatypes::prelude::VectorRef;
use datatypes::vectors::Helper;
use snafu::{ResultExt, ensure};
use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, Signature, Volatility};

use crate::function::{Function, FunctionContext};
use crate::function::{Function, extract_args};

const NAME: &str = "isnull";

Expand All @@ -53,35 +48,24 @@ impl Function for IsNullFunction {
Signature::any(1, Volatility::Immutable)
}

fn eval(
fn invoke_with_args(
&self,
_func_ctx: &FunctionContext,
columns: &[VectorRef],
) -> common_query::error::Result<VectorRef> {
ensure!(
columns.len() == 1,
InvalidFuncArgsSnafu {
err_msg: format!(
"The length of the args is not correct, expect exactly one, have: {}",
columns.len()
),
}
);
let values = &columns[0];
let arrow_array = &values.to_arrow_array();
let result = is_null(arrow_array).context(ArrowComputeSnafu)?;
args: ScalarFunctionArgs,
) -> datafusion_common::Result<ColumnarValue> {
let [arg0] = extract_args(self.name(), &args)?;
let result = is_null(&arg0)?;

Helper::try_into_vector(Arc::new(result) as ArrayRef).context(error::FromArrowArraySnafu)
Ok(ColumnarValue::Array(Arc::new(result)))
}
}

#[cfg(test)]
mod tests {
use std::sync::Arc;

use arrow_schema::Field;
use datafusion_common::arrow::array::{AsArray, BooleanArray, Float32Array};
use datafusion_expr::TypeSignature;
use datatypes::scalars::ScalarVector;
use datatypes::vectors::{BooleanVector, Float32Vector};

use super::*;
#[test]
Expand All @@ -98,9 +82,20 @@ mod tests {
);
let values = vec![None, Some(3.0), None];

let args: Vec<VectorRef> = vec![Arc::new(Float32Vector::from(values))];
let vector = is_null.eval(&FunctionContext::default(), &args).unwrap();
let expect: VectorRef = Arc::new(BooleanVector::from_vec(vec![true, false, true]));
let result = is_null
.invoke_with_args(ScalarFunctionArgs {
args: vec![ColumnarValue::Array(Arc::new(Float32Array::from(values)))],
arg_fields: vec![],
number_rows: 3,
return_field: Arc::new(Field::new("", DataType::Boolean, false)),
config_options: Arc::new(Default::default()),
})
.unwrap();
let ColumnarValue::Array(result) = result else {
unreachable!()
};
let vector = result.as_boolean();
let expect = &BooleanArray::from(vec![true, false, true]);
assert_eq!(expect, vector);
}
}
10 changes: 4 additions & 6 deletions src/common/function/src/scalars/geo/geohash.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,13 @@ use common_query::error::{self, Result};
use datafusion::arrow::array::{Array, AsArray, ListBuilder, StringViewBuilder};
use datafusion::arrow::datatypes::{DataType, Field, Float64Type, UInt8Type};
use datafusion::logical_expr::ColumnarValue;
use datafusion_common::{DataFusionError, utils};
use datafusion_common::DataFusionError;
use datafusion_expr::type_coercion::aggregates::INTEGERS;
use datafusion_expr::{ScalarFunctionArgs, Signature, TypeSignature, Volatility};
use geohash::Coord;
use snafu::ResultExt;

use crate::function::Function;
use crate::function::{Function, extract_args};
use crate::scalars::geo::helpers;

fn ensure_resolution_usize(v: u8) -> datafusion_common::Result<usize> {
Expand Down Expand Up @@ -77,8 +77,7 @@ impl Function for GeohashFunction {
&self,
args: ScalarFunctionArgs,
) -> datafusion_common::Result<ColumnarValue> {
let args = ColumnarValue::values_to_arrays(&args.args)?;
let [lat_vec, lon_vec, resolutions] = utils::take_function_args(self.name(), args)?;
let [lat_vec, lon_vec, resolutions] = extract_args(self.name(), &args)?;

let lat_vec = helpers::cast::<Float64Type>(&lat_vec)?;
let lat_vec = lat_vec.as_primitive::<Float64Type>();
Expand Down Expand Up @@ -169,8 +168,7 @@ impl Function for GeohashNeighboursFunction {
&self,
args: ScalarFunctionArgs,
) -> datafusion_common::Result<ColumnarValue> {
let args = ColumnarValue::values_to_arrays(&args.args)?;
let [lat_vec, lon_vec, resolutions] = utils::take_function_args(self.name(), args)?;
let [lat_vec, lon_vec, resolutions] = extract_args(self.name(), &args)?;

let lat_vec = helpers::cast::<Float64Type>(&lat_vec)?;
let lat_vec = lat_vec.as_primitive::<Float64Type>();
Expand Down
Loading
Loading