Skip to content

refactor: use TypeSignature::Coercible for math functions #14872

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions datafusion/common/src/types/native.rs
Original file line number Diff line number Diff line change
Expand Up @@ -436,6 +436,12 @@ impl NativeType {
)
}

#[inline]
pub fn is_float(&self) -> bool {
use NativeType::*;
matches!(self, Float16 | Float32 | Float64)
}

#[inline]
pub fn is_integer(&self) -> bool {
use NativeType::*;
Expand Down
26 changes: 20 additions & 6 deletions datafusion/expr-common/src/signature.rs
Original file line number Diff line number Diff line change
Expand Up @@ -215,8 +215,8 @@ pub enum TypeSignatureClass {
Interval,
Duration,
Native(LogicalTypeRef),
// TODO:
// Numeric
Numeric,
Float,
Integer,
}

Expand Down Expand Up @@ -252,6 +252,16 @@ impl TypeSignatureClass {
TypeSignatureClass::Duration => {
vec![DataType::Duration(TimeUnit::Nanosecond)]
}
TypeSignatureClass::Numeric => {
vec![
DataType::Int64,
DataType::Float64,
DataType::Decimal256(3, -2),
]
}
TypeSignatureClass::Float => {
vec![DataType::Float64]
}
TypeSignatureClass::Integer => {
vec![DataType::Int64]
}
Expand All @@ -263,16 +273,14 @@ impl TypeSignatureClass {
self: &TypeSignatureClass,
logical_type: &NativeType,
) -> bool {
if logical_type == &NativeType::Null {
return true;
}

match self {
TypeSignatureClass::Native(t) if t.native() == logical_type => true,
TypeSignatureClass::Timestamp if logical_type.is_timestamp() => true,
TypeSignatureClass::Time if logical_type.is_time() => true,
TypeSignatureClass::Interval if logical_type.is_interval() => true,
TypeSignatureClass::Duration if logical_type.is_duration() => true,
TypeSignatureClass::Numeric if logical_type.is_numeric() => true,
TypeSignatureClass::Float if logical_type.is_float() => true,
TypeSignatureClass::Integer if logical_type.is_integer() => true,
_ => false,
}
Expand Down Expand Up @@ -301,6 +309,12 @@ impl TypeSignatureClass {
TypeSignatureClass::Duration if native_type.is_duration() => {
Ok(origin_type.to_owned())
}
TypeSignatureClass::Numeric if native_type.is_numeric() => {
Ok(origin_type.to_owned())
}
TypeSignatureClass::Float if native_type.is_float() => {
Ok(origin_type.to_owned())
}
TypeSignatureClass::Integer if native_type.is_integer() => {
Ok(origin_type.to_owned())
}
Expand Down
14 changes: 11 additions & 3 deletions datafusion/functions-nested/src/string.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ use arrow::datatypes::DataType::{
};
use datafusion_common::cast::{as_large_list_array, as_list_array};
use datafusion_common::exec_err;
use datafusion_common::types::logical_string;
use datafusion_common::types::{logical_null, logical_string, NativeType};
use datafusion_expr::{
Coercion, ColumnarValue, Documentation, ScalarUDFImpl, Signature, TypeSignature,
TypeSignatureClass, Volatility,
Expand Down Expand Up @@ -255,11 +255,19 @@ impl StringToArray {
vec![
TypeSignature::Coercible(vec![
Coercion::new_exact(TypeSignatureClass::Native(logical_string())),
Coercion::new_exact(TypeSignatureClass::Native(logical_string())),
Coercion::new_implicit(
TypeSignatureClass::Native(logical_string()),
vec![TypeSignatureClass::Native(logical_null())],
NativeType::String,
),
]),
TypeSignature::Coercible(vec![
Coercion::new_exact(TypeSignatureClass::Native(logical_string())),
Coercion::new_exact(TypeSignatureClass::Native(logical_string())),
Coercion::new_implicit(
TypeSignatureClass::Native(logical_string()),
vec![TypeSignatureClass::Native(logical_null())],
NativeType::String,
),
Coercion::new_exact(TypeSignatureClass::Native(logical_string())),
]),
],
Expand Down
26 changes: 21 additions & 5 deletions datafusion/functions/src/crypto/digest.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
use super::basic::{digest, utf8_or_binary_to_binary_type};
use arrow::datatypes::DataType;
use datafusion_common::{
types::{logical_binary, logical_string},
types::{logical_binary, logical_null, logical_string, NativeType},
Result,
};
use datafusion_expr::{
Expand Down Expand Up @@ -72,12 +72,28 @@ impl DigestFunc {
signature: Signature::one_of(
vec![
TypeSignature::Coercible(vec![
Coercion::new_exact(TypeSignatureClass::Native(logical_string())),
Coercion::new_exact(TypeSignatureClass::Native(logical_string())),
Coercion::new_implicit(
TypeSignatureClass::Native(logical_string()),
vec![TypeSignatureClass::Native(logical_null())],
NativeType::String,
),
Coercion::new_implicit(
TypeSignatureClass::Native(logical_string()),
vec![TypeSignatureClass::Native(logical_null())],
NativeType::String,
),
]),
TypeSignature::Coercible(vec![
Coercion::new_exact(TypeSignatureClass::Native(logical_binary())),
Coercion::new_exact(TypeSignatureClass::Native(logical_string())),
Coercion::new_implicit(
TypeSignatureClass::Native(logical_binary()),
vec![TypeSignatureClass::Native(logical_null())],
NativeType::Binary,
),
Coercion::new_implicit(
TypeSignatureClass::Native(logical_string()),
vec![TypeSignatureClass::Native(logical_null())],
NativeType::String,
),
]),
],
Volatility::Immutable,
Expand Down
12 changes: 9 additions & 3 deletions datafusion/functions/src/crypto/md5.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ use crate::crypto::basic::md5;
use arrow::datatypes::DataType;
use datafusion_common::{
plan_err,
types::{logical_binary, logical_string, NativeType},
types::{logical_binary, logical_null, logical_string, NativeType},
Result,
};
use datafusion_expr::{
Expand Down Expand Up @@ -62,12 +62,18 @@ impl Md5Func {
vec![
TypeSignature::Coercible(vec![Coercion::new_implicit(
TypeSignatureClass::Native(logical_binary()),
vec![TypeSignatureClass::Native(logical_string())],
vec![
TypeSignatureClass::Native(logical_string()),
TypeSignatureClass::Native(logical_null()),
],
NativeType::String,
)]),
TypeSignature::Coercible(vec![Coercion::new_implicit(
TypeSignatureClass::Native(logical_binary()),
vec![TypeSignatureClass::Native(logical_binary())],
vec![
TypeSignatureClass::Native(logical_binary()),
TypeSignatureClass::Native(logical_null()),
],
NativeType::Binary,
)]),
],
Expand Down
12 changes: 9 additions & 3 deletions datafusion/functions/src/crypto/sha224.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
use super::basic::{sha224, utf8_or_binary_to_binary_type};
use arrow::datatypes::DataType;
use datafusion_common::{
types::{logical_binary, logical_string, NativeType},
types::{logical_binary, logical_null, logical_string, NativeType},
Result,
};
use datafusion_expr::{
Expand Down Expand Up @@ -62,12 +62,18 @@ impl SHA224Func {
vec![
TypeSignature::Coercible(vec![Coercion::new_implicit(
TypeSignatureClass::Native(logical_binary()),
vec![TypeSignatureClass::Native(logical_string())],
vec![
TypeSignatureClass::Native(logical_string()),
TypeSignatureClass::Native(logical_null()),
],
NativeType::String,
)]),
TypeSignature::Coercible(vec![Coercion::new_implicit(
TypeSignatureClass::Native(logical_binary()),
vec![TypeSignatureClass::Native(logical_binary())],
vec![
TypeSignatureClass::Native(logical_binary()),
TypeSignatureClass::Native(logical_null()),
],
NativeType::Binary,
)]),
],
Expand Down
12 changes: 9 additions & 3 deletions datafusion/functions/src/crypto/sha256.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
use super::basic::{sha256, utf8_or_binary_to_binary_type};
use arrow::datatypes::DataType;
use datafusion_common::{
types::{logical_binary, logical_string, NativeType},
types::{logical_binary, logical_null, logical_string, NativeType},
Result,
};
use datafusion_expr::{
Expand Down Expand Up @@ -61,12 +61,18 @@ impl SHA256Func {
vec![
TypeSignature::Coercible(vec![Coercion::new_implicit(
TypeSignatureClass::Native(logical_binary()),
vec![TypeSignatureClass::Native(logical_string())],
vec![
TypeSignatureClass::Native(logical_string()),
TypeSignatureClass::Native(logical_null()),
],
NativeType::String,
)]),
TypeSignature::Coercible(vec![Coercion::new_implicit(
TypeSignatureClass::Native(logical_binary()),
vec![TypeSignatureClass::Native(logical_binary())],
vec![
TypeSignatureClass::Native(logical_binary()),
TypeSignatureClass::Native(logical_null()),
],
NativeType::Binary,
)]),
],
Expand Down
12 changes: 9 additions & 3 deletions datafusion/functions/src/crypto/sha384.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
use super::basic::{sha384, utf8_or_binary_to_binary_type};
use arrow::datatypes::DataType;
use datafusion_common::{
types::{logical_binary, logical_string, NativeType},
types::{logical_binary, logical_null, logical_string, NativeType},
Result,
};
use datafusion_expr::{
Expand Down Expand Up @@ -61,12 +61,18 @@ impl SHA384Func {
vec![
TypeSignature::Coercible(vec![Coercion::new_implicit(
TypeSignatureClass::Native(logical_binary()),
vec![TypeSignatureClass::Native(logical_string())],
vec![
TypeSignatureClass::Native(logical_string()),
TypeSignatureClass::Native(logical_null()),
],
NativeType::String,
)]),
TypeSignature::Coercible(vec![Coercion::new_implicit(
TypeSignatureClass::Native(logical_binary()),
vec![TypeSignatureClass::Native(logical_binary())],
vec![
TypeSignatureClass::Native(logical_binary()),
TypeSignatureClass::Native(logical_null()),
],
NativeType::Binary,
)]),
],
Expand Down
12 changes: 9 additions & 3 deletions datafusion/functions/src/crypto/sha512.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
use super::basic::{sha512, utf8_or_binary_to_binary_type};
use arrow::datatypes::DataType;
use datafusion_common::{
types::{logical_binary, logical_string, NativeType},
types::{logical_binary, logical_null, logical_string, NativeType},
Result,
};
use datafusion_expr::{
Expand Down Expand Up @@ -61,12 +61,18 @@ impl SHA512Func {
vec![
TypeSignature::Coercible(vec![Coercion::new_implicit(
TypeSignatureClass::Native(logical_binary()),
vec![TypeSignatureClass::Native(logical_string())],
vec![
TypeSignatureClass::Native(logical_string()),
TypeSignatureClass::Native(logical_null()),
],
NativeType::String,
)]),
TypeSignature::Coercible(vec![Coercion::new_implicit(
TypeSignatureClass::Native(logical_binary()),
vec![TypeSignatureClass::Native(logical_binary())],
vec![
TypeSignatureClass::Native(logical_binary()),
TypeSignatureClass::Native(logical_null()),
],
NativeType::Binary,
)]),
],
Expand Down
18 changes: 13 additions & 5 deletions datafusion/functions/src/macros.rs
Original file line number Diff line number Diff line change
Expand Up @@ -160,13 +160,16 @@ macro_rules! make_math_unary_udf {

use arrow::array::{ArrayRef, AsArray};
use arrow::datatypes::{DataType, Float32Type, Float64Type};
use datafusion_common::types::logical_null;
use datafusion_common::types::NativeType;
use datafusion_common::{exec_err, Result};
use datafusion_expr::interval_arithmetic::Interval;
use datafusion_expr::sort_properties::{ExprProperties, SortProperties};
use datafusion_expr::{
ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl,
Signature, Volatility,
Signature, TypeSignatureClass, Volatility,
};
use datafusion_expr_common::signature::Coercion;

#[derive(Debug)]
pub struct $UDF {
Expand All @@ -175,11 +178,16 @@ macro_rules! make_math_unary_udf {

impl $UDF {
pub fn new() -> Self {
use DataType::*;
Self {
signature: Signature::uniform(
1,
vec![Float64, Float32],
signature: Signature::coercible(
vec![Coercion::new_implicit(
TypeSignatureClass::Float,
vec![
TypeSignatureClass::Integer,
TypeSignatureClass::Native(logical_null()),
],
NativeType::Float64,
)],
Volatility::Immutable,
),
}
Expand Down
13 changes: 11 additions & 2 deletions datafusion/functions/src/math/abs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,17 @@ use arrow::array::{
};
use arrow::datatypes::DataType;
use arrow::error::ArrowError;
use datafusion_common::types::{logical_null, NativeType};
use datafusion_common::{
internal_datafusion_err, not_impl_err, utils::take_function_args, Result,
};
use datafusion_expr::interval_arithmetic::Interval;
use datafusion_expr::sort_properties::{ExprProperties, SortProperties};
use datafusion_expr::{
ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature,
Volatility,
TypeSignature, TypeSignatureClass, Volatility,
};
use datafusion_expr_common::signature::Coercion;
use datafusion_macros::user_doc;

type MathArrayFunction = fn(&ArrayRef) -> Result<ArrayRef>;
Expand Down Expand Up @@ -126,7 +128,14 @@ impl Default for AbsFunc {
impl AbsFunc {
pub fn new() -> Self {
Self {
signature: Signature::numeric(1, Volatility::Immutable),
signature: Signature::new(
TypeSignature::Coercible(vec![Coercion::new_implicit(
TypeSignatureClass::Numeric,
vec![TypeSignatureClass::Native(logical_null())],
NativeType::Float64,
)]),
Volatility::Immutable,
),
}
}
}
Expand Down
Loading