Skip to content

refactor(query): rewrite function call expr to cast expr #17669

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 27 commits into from
Apr 3, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
394 changes: 279 additions & 115 deletions src/query/expression/src/expression.rs

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions src/query/expression/src/filter/select_expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ impl SelectExprBuilder {
.can_reorder(can_reorder)
}
"not" => {
self.not_function = Some((id.clone(), function.clone()));
self.not_function = Some((*id.clone(), function.clone()));
let result = self.build_select_expr(&args[0], not ^ true);
if result.can_push_down_not {
result
Expand Down Expand Up @@ -255,7 +255,7 @@ impl SelectExprBuilder {
let (id, function) = self.not_function.as_ref().unwrap();
Expr::FunctionCall {
span: None,
id: id.clone(),
id: Box::new(id.clone()),
function: function.clone(),
generics: vec![],
args: vec![expr.clone()],
Expand Down
1 change: 1 addition & 0 deletions src/query/expression/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
#![feature(trait_upcasting)]
#![feature(alloc_layout_extra)]
#![feature(debug_closure_helpers)]
#![feature(never_type)]

#[allow(dead_code)]
mod block;
Expand Down
144 changes: 114 additions & 30 deletions src/query/expression/src/type_check.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.

use std::borrow::Cow;
use std::collections::HashMap;
use std::fmt::Write;

Expand All @@ -31,10 +32,11 @@ use crate::types::decimal::MAX_DECIMAL256_PRECISION;
use crate::types::DataType;
use crate::types::DecimalDataType;
use crate::types::Number;
use crate::types::NumberScalar;
use crate::visit_expr;
use crate::AutoCastRules;
use crate::ColumnIndex;
use crate::ConstantFolder;
use crate::ExprVisitor;
use crate::FunctionContext;
use crate::Scalar;

Expand Down Expand Up @@ -150,7 +152,7 @@ pub fn check_cast<Index: ColumnIndex>(
is_try: bool,
expr: Expr<Index>,
dest_type: &DataType,
fn_registry: &FunctionRegistry,
_: &FunctionRegistry,
) -> Result<Expr<Index>> {
let wrapped_dest_type = if is_try {
wrap_nullable_for_try_cast(span, dest_type)?
Expand All @@ -168,26 +170,6 @@ pub fn check_cast<Index: ColumnIndex>(
dest_type: wrapped_dest_type,
})
} else {
// fast path to eval function for cast
if let Some(cast_fn) = get_simple_cast_function(is_try, expr.data_type(), dest_type) {
let params = if let DataType::Decimal(ty) = dest_type {
vec![
Scalar::Number(NumberScalar::Int64(ty.precision() as _)),
Scalar::Number(NumberScalar::Int64(ty.scale() as _)),
]
} else {
vec![]
};

if let Ok(cast_expr) =
check_function(span, &cast_fn, &params, &[expr.clone()], fn_registry)
{
if cast_expr.data_type() == &wrapped_dest_type {
return Ok(cast_expr);
}
}
}

if !can_cast_to(expr.data_type(), dest_type) {
return Err(ErrorCode::BadArguments(format!(
"unable to cast type `{}` to type `{}`",
Expand Down Expand Up @@ -301,7 +283,7 @@ pub fn check_function<Index: ColumnIndex>(
let return_type = function.signature.return_type.clone();
return Ok(Expr::FunctionCall {
span,
id,
id: Box::new(id),
function,
generics: vec![],
args: args.to_vec(),
Expand All @@ -312,22 +294,52 @@ pub fn check_function<Index: ColumnIndex>(
let auto_cast_rules = fn_registry.get_auto_cast_rules(name);

let mut fail_reasons = Vec::with_capacity(candidates.len());
for (id, func) in &candidates {
let mut checked_candidates = vec![];
let args_not_const = args
.iter()
.map(Expr::contains_column_ref)
.collect::<Vec<_>>();
let need_sort = candidates.len() > 1 && args_not_const.iter().any(|contain| !*contain);
for (seq, (id, func)) in candidates.iter().enumerate() {
match try_check_function(args, &func.signature, auto_cast_rules, fn_registry) {
Ok((checked_args, return_type, generics)) => {
return Ok(Expr::FunctionCall {
Ok((args, return_type, generics)) => {
let score = if need_sort {
args.iter()
.zip(args_not_const.iter().copied())
.map(|(expr, not_const)| {
// smaller score win
if not_const && expr.is_cast() {
1
} else {
0
}
})
.sum::<usize>()
} else {
0
};
let expr = Expr::FunctionCall {
span,
id: id.clone(),
id: Box::new(id.clone()),
function: func.clone(),
generics,
args: checked_args,
args,
return_type,
});
};
if !need_sort {
return Ok(expr);
}
checked_candidates.push((expr, score, seq));
}
Err(err) => fail_reasons.push(err),
}
}

if !checked_candidates.is_empty() {
checked_candidates.sort_by_key(|(_, score, seq)| std::cmp::Reverse((*score, *seq)));
return Ok(checked_candidates.pop().unwrap().0);
}

let mut msg = if params.is_empty() {
format!(
"no function matches signature `{name}({})`, you might need to add explicit type casts.",
Expand Down Expand Up @@ -783,6 +795,78 @@ pub const ALL_SIMPLE_CAST_FUNCTIONS: &[&str] = &[
"parse_json",
];

pub fn is_simple_cast_function(name: &str) -> bool {
fn is_simple_cast_function(name: &str) -> bool {
ALL_SIMPLE_CAST_FUNCTIONS.contains(&name)
}

pub fn rewrite_function_to_cast<Index: ColumnIndex>(expr: Expr<Index>) -> Expr<Index> {
match visit_expr(&expr, &mut RewriteCast).unwrap() {
None => expr,
Some(expr) => expr,
}
}

struct RewriteCast;

impl<Index: ColumnIndex> ExprVisitor<Index> for RewriteCast {
type Error = !;

fn enter_function_call(
&mut self,
expr: &Expr<Index>,
) -> std::result::Result<Option<Expr<Index>>, Self::Error> {
let expr = match Self::visit_function_call(expr, self)? {
Some(expr) => Cow::Owned(expr),
None => Cow::Borrowed(expr),
};
let Expr::FunctionCall {
span,
function,
generics,
args,
return_type,
..
} = expr.as_ref()
else {
unreachable!();
};
if !generics.is_empty() || args.len() != 1 {
return match expr {
Cow::Borrowed(_) => Ok(None),
Cow::Owned(expr) => Ok(Some(expr)),
};
}
if function.signature.name == "parse_json" {
return Ok(Some(Expr::Cast {
span: *span,
is_try: false,
expr: Box::new(args.first().unwrap().clone()),
dest_type: return_type.clone(),
}));
}
let func_name = format!(
"to_{}",
return_type.remove_nullable().to_string().to_lowercase()
);
if function.signature.name == func_name {
return Ok(Some(Expr::Cast {
span: *span,
is_try: false,
expr: Box::new(args.first().unwrap().clone()),
dest_type: return_type.clone(),
}));
};
if function.signature.name == format!("try_{func_name}") {
return Ok(Some(Expr::Cast {
span: *span,
is_try: true,
expr: Box::new(args.first().unwrap().clone()),
dest_type: return_type.clone(),
}));
}
match expr {
Cow::Borrowed(_) => Ok(None),
Cow::Owned(expr) => Ok(Some(expr)),
}
}
}
Loading