Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,7 @@ datafusion-datasource = { version = "54", default-features = false }
datafusion-execution = { version = "54" }
datafusion-expr = { version = "54" }
datafusion-functions = { version = "54" }
datafusion-functions-nested = { version = "54" }
datafusion-physical-expr = { version = "54" }
datafusion-physical-expr-adapter = { version = "54" }
datafusion-physical-expr-common = { version = "54" }
Expand Down
1 change: 1 addition & 0 deletions vortex-datafusion/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ datafusion-datasource = { workspace = true, default-features = false }
datafusion-execution = { workspace = true }
datafusion-expr = { workspace = true }
datafusion-functions = { workspace = true }
datafusion-functions-nested = { workspace = true }
datafusion-physical-expr = { workspace = true }
datafusion-physical-expr-adapter = { workspace = true }
datafusion-physical-expr-common = { workspace = true }
Expand Down
171 changes: 163 additions & 8 deletions vortex-datafusion/src/convert/exprs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,14 @@ use arrow_schema::DataType;
use arrow_schema::Field;
use arrow_schema::Schema;
use datafusion_common::Result as DFResult;
use datafusion_common::ScalarValue;
use datafusion_common::exec_datafusion_err;
use datafusion_common::tree_node::TreeNode;
use datafusion_common::tree_node::TreeNodeRecursion;
use datafusion_expr::Operator as DFOperator;
use datafusion_functions::core::getfield::GetFieldFunc;
use datafusion_functions::string::octet_length::OctetLengthFunc;
use datafusion_functions_nested::length::ArrayLength;
use datafusion_physical_expr::PhysicalExpr;
use datafusion_physical_expr::ScalarFunctionExpr;
use datafusion_physical_expr::projection::ProjectionExpr;
Expand All @@ -32,6 +34,7 @@ use vortex::expr::get_item;
use vortex::expr::is_not_null;
use vortex::expr::is_null;
use vortex::expr::list_contains;
use vortex::expr::list_length;
use vortex::expr::lit;
use vortex::expr::nested_case_when;
use vortex::expr::not;
Expand Down Expand Up @@ -155,6 +158,32 @@ impl DefaultExpressionConvertor {
Ok(cast(byte_length(input), return_dtype))
}

/// Attempts to convert DataFusion's `array_length` function (aliased as `list_length`) to
/// Vortex `list_length`.
///
/// Supports the single-argument form `array_length(arr)` and the equivalent two-argument
/// form with an explicit first dimension `array_length(arr, 1)`.
fn try_convert_array_length(&self, scalar_fn: &ScalarFunctionExpr) -> DFResult<Expression> {
let Some(input) = array_length_input(scalar_fn) else {
return Err(exec_datafusion_err!(
"array_length pushdown supports only the one-argument form or an explicit first \
dimension"
));
};

let input = self.convert(input.as_ref())?;
let return_dtype = self
.session
.arrow()
.from_arrow_field(&Field::new(
"",
scalar_fn.return_type().clone(),
scalar_fn.nullable(),
))
.map_err(|e| exec_datafusion_err!("Failed to convert return type to dtype: {e}"))?;
Ok(cast(list_length(input), return_dtype))
}

/// Attempts to convert a DataFusion ScalarFunctionExpr to a Vortex expression.
fn try_convert_scalar_function(&self, scalar_fn: &ScalarFunctionExpr) -> DFResult<Expression> {
if let Some(octet_length_fn) =
Expand All @@ -163,6 +192,12 @@ impl DefaultExpressionConvertor {
return self.try_convert_octet_length(octet_length_fn);
}

if let Some(array_length_fn) =
ScalarFunctionExpr::try_downcast_func::<ArrayLength>(scalar_fn)
{
return self.try_convert_array_length(array_length_fn);
}

if let Some(get_field_fn) = ScalarFunctionExpr::try_downcast_func::<GetFieldFunc>(scalar_fn)
{
// DataFusion's GetFieldFunc flattens nested field access into a single call
Expand Down Expand Up @@ -511,6 +546,7 @@ fn is_convertible_expr(expr: &Arc<dyn PhysicalExpr>) -> bool {
|| expr.downcast_ref::<ScalarFunctionExpr>().is_some_and(|sf| {
ScalarFunctionExpr::try_downcast_func::<GetFieldFunc>(sf).is_some()
|| ScalarFunctionExpr::try_downcast_func::<OctetLengthFunc>(sf).is_some()
|| ScalarFunctionExpr::try_downcast_func::<ArrayLength>(sf).is_some()
})
}

Expand Down Expand Up @@ -572,14 +608,20 @@ fn supported_data_types(dt: &DataType) -> bool {
}

/// Checks if a scalar function can be pushed down.
/// Currently GetFieldFunc and OctetLengthFunc are supported.
/// Currently GetFieldFunc, OctetLengthFunc, and ArrayLength are supported.
fn can_scalar_fn_be_pushed_down(scalar_fn: &ScalarFunctionExpr, schema: &Schema) -> bool {
if ScalarFunctionExpr::try_downcast_func::<GetFieldFunc>(scalar_fn).is_some() {
return true;
}

ScalarFunctionExpr::try_downcast_func::<OctetLengthFunc>(scalar_fn)
if ScalarFunctionExpr::try_downcast_func::<OctetLengthFunc>(scalar_fn)
.is_some_and(|octet_length| can_octet_length_be_pushed_down(octet_length, schema))
{
return true;
}

ScalarFunctionExpr::try_downcast_func::<ArrayLength>(scalar_fn)
.is_some_and(|array_length| can_array_length_be_pushed_down(array_length, schema))
}

fn can_octet_length_be_pushed_down(scalar_fn: &ScalarFunctionExpr, schema: &Schema) -> bool {
Expand All @@ -598,6 +640,42 @@ fn can_octet_length_be_pushed_down(scalar_fn: &ScalarFunctionExpr, schema: &Sche
}) && can_be_pushed_down_impl(input, schema)
}

fn can_array_length_be_pushed_down(scalar_fn: &ScalarFunctionExpr, schema: &Schema) -> bool {
let Some(input) = array_length_input(scalar_fn) else {
return false;
};

// The argument must resolve to a list type. We gate on the resolved data type rather than
// `can_be_pushed_down_impl`, since list columns are intentionally rejected there. We still
// require the argument to be a convertible expression (e.g. a column or struct field access).
input.data_type(schema).as_ref().is_ok_and(|data_type| {
matches!(
data_type,
DataType::List(_) | DataType::LargeList(_) | DataType::FixedSizeList(_, _)
)
}) && is_convertible_expr(input)
}

/// Returns the list argument of an `array_length` call if the call is a form we can rewrite to
/// `list_length`: either the single-argument form `array_length(arr)`, or the two-argument form
/// with an explicit first dimension `array_length(arr, 1)`, which is equivalent. Higher
/// dimensions recurse into nested lists and are not supported.
fn array_length_input(scalar_fn: &ScalarFunctionExpr) -> Option<&Arc<dyn PhysicalExpr>> {
match scalar_fn.args() {
[input] => Some(input),
[input, dimension] if is_dimension_one(dimension) => Some(input),
_ => None,
}
}

/// Returns true if `expr` is an `Int64` literal equal to 1. DataFusion coerces the `array_length`
/// dimension argument to `Int64`, so that is the only form we need to recognize; any other literal
/// simply isn't pushed down.
fn is_dimension_one(expr: &Arc<dyn PhysicalExpr>) -> bool {
expr.downcast_ref::<df_expr::Literal>()
.is_some_and(|literal| matches!(literal.value(), ScalarValue::Int64(Some(1))))
}

#[cfg(test)]
mod tests {
use std::sync::Arc;
Expand Down Expand Up @@ -633,7 +711,7 @@ mod tests {
true,
),
Field::new(
"unsupported_list",
"tags",
DataType::List(Arc::new(Field::new("item", DataType::Int32, true))),
true,
),
Expand All @@ -652,6 +730,21 @@ mod tests {
)
}

fn array_length_expr(
args: Vec<Arc<dyn PhysicalExpr>>,
schema: &Schema,
) -> Arc<dyn PhysicalExpr> {
Arc::new(
ScalarFunctionExpr::try_new(
Arc::new(ScalarUDF::from(ArrayLength::new())),
args,
schema,
Arc::new(ConfigOptions::new()),
)
.unwrap(),
)
}

#[test]
fn test_make_vortex_predicate_empty() {
let expr_convertor = DefaultExpressionConvertor::default();
Expand Down Expand Up @@ -798,6 +891,23 @@ mod tests {
");
}

#[rstest]
fn test_expr_from_df_array_length(test_schema: Schema) {
let expr = Arc::new(df_expr::Column::new("tags", 5)) as Arc<dyn PhysicalExpr>;
let array_length = array_length_expr(vec![expr], &test_schema);

let result = DefaultExpressionConvertor::default()
.convert(array_length.as_ref())
.unwrap();

assert_snapshot!(result.display_tree().to_string(), @r"
vortex.cast(u64?)
└── input: vortex.list.length()
└── input: vortex.get_item(tags)
└── input: vortex.root()
");
}

#[rstest]
// Supported types
#[case::null(DataType::Null, true)]
Expand Down Expand Up @@ -861,8 +971,7 @@ mod tests {

#[rstest]
fn test_can_be_pushed_down_column_unsupported_type(test_schema: Schema) {
let col_expr =
Arc::new(df_expr::Column::new("unsupported_list", 5)) as Arc<dyn PhysicalExpr>;
let col_expr = Arc::new(df_expr::Column::new("tags", 5)) as Arc<dyn PhysicalExpr>;

assert!(!can_be_pushed_down_impl(&col_expr, &test_schema));
}
Expand Down Expand Up @@ -919,7 +1028,7 @@ mod tests {

#[rstest]
fn test_can_be_pushed_down_binary_unsupported_operand(test_schema: Schema) {
let left = Arc::new(df_expr::Column::new("unsupported_list", 5)) as Arc<dyn PhysicalExpr>;
let left = Arc::new(df_expr::Column::new("tags", 5)) as Arc<dyn PhysicalExpr>;
let right =
Arc::new(df_expr::Literal::new(ScalarValue::Int32(Some(42)))) as Arc<dyn PhysicalExpr>;
let binary_expr = Arc::new(df_expr::BinaryExpr::new(left, DFOperator::Eq, right))
Expand All @@ -942,7 +1051,7 @@ mod tests {

#[rstest]
fn test_can_be_pushed_down_like_unsupported_operand(test_schema: Schema) {
let expr = Arc::new(df_expr::Column::new("unsupported_list", 5)) as Arc<dyn PhysicalExpr>;
let expr = Arc::new(df_expr::Column::new("tags", 5)) as Arc<dyn PhysicalExpr>;
let pattern = Arc::new(df_expr::Literal::new(ScalarValue::Utf8(Some(
"test%".to_string(),
)))) as Arc<dyn PhysicalExpr>;
Expand All @@ -962,7 +1071,7 @@ mod tests {

#[rstest]
fn test_can_be_pushed_down_octet_length_unsupported_operand(test_schema: Schema) {
let expr = Arc::new(df_expr::Column::new("unsupported_list", 5)) as Arc<dyn PhysicalExpr>;
let expr = Arc::new(df_expr::Column::new("tags", 5)) as Arc<dyn PhysicalExpr>;
let octet_length = Arc::new(ScalarFunctionExpr::new(
"octet_length",
Arc::new(ScalarUDF::from(OctetLengthFunc::new())),
Expand All @@ -974,6 +1083,52 @@ mod tests {
assert!(!can_be_pushed_down_impl(&octet_length, &test_schema));
}

#[rstest]
fn test_can_be_pushed_down_array_length_supported(test_schema: Schema) {
let expr = Arc::new(df_expr::Column::new("tags", 5)) as Arc<dyn PhysicalExpr>;
let array_length = array_length_expr(vec![expr], &test_schema);

assert!(can_be_pushed_down_impl(&array_length, &test_schema));
}

#[rstest]
fn test_can_be_pushed_down_array_length_unsupported_operand(test_schema: Schema) {
// `array_length` over a non-list column cannot be pushed down.
let expr = Arc::new(df_expr::Column::new("name", 1)) as Arc<dyn PhysicalExpr>;
let array_length = Arc::new(ScalarFunctionExpr::new(
"array_length",
Arc::new(ScalarUDF::from(ArrayLength::new())),
vec![expr],
Arc::new(Field::new("array_length", DataType::UInt64, true)),
Arc::new(ConfigOptions::new()),
)) as Arc<dyn PhysicalExpr>;

assert!(!can_be_pushed_down_impl(&array_length, &test_schema));
}

#[rstest]
fn test_can_be_pushed_down_array_length_dimension_one_supported(test_schema: Schema) {
// `array_length(arr, 1)` is the first-dimension length, equivalent to `list_length`.
let list = Arc::new(df_expr::Column::new("tags", 5)) as Arc<dyn PhysicalExpr>;
let dimension =
Arc::new(df_expr::Literal::new(ScalarValue::Int64(Some(1)))) as Arc<dyn PhysicalExpr>;
let array_length = array_length_expr(vec![list, dimension], &test_schema);

assert!(can_be_pushed_down_impl(&array_length, &test_schema));
}

#[rstest]
fn test_can_be_pushed_down_array_length_higher_dimension_not_supported(test_schema: Schema) {
// Dimensions other than 1 recurse into nested lists, which `list_length` does not model,
// so they must not be pushed down.
let list = Arc::new(df_expr::Column::new("tags", 5)) as Arc<dyn PhysicalExpr>;
let dimension =
Arc::new(df_expr::Literal::new(ScalarValue::Int64(Some(2)))) as Arc<dyn PhysicalExpr>;
let array_length = array_length_expr(vec![list, dimension], &test_schema);

assert!(!can_be_pushed_down_impl(&array_length, &test_schema));
}

// https://github.com/vortex-data/vortex/issues/6211
#[tokio::test]
async fn test_cast_int_to_string() -> anyhow::Result<()> {
Expand Down
1 change: 1 addition & 0 deletions vortex-sqllogictest/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ anyhow = { workspace = true }
async-trait = { workspace = true }
bigdecimal = { workspace = true }
datafusion = { workspace = true }
datafusion-functions-nested = { workspace = true }
datafusion-sqllogictest = { workspace = true }
indicatif = { workspace = true }
regex = { workspace = true }
Expand Down
8 changes: 6 additions & 2 deletions vortex-sqllogictest/bin/sqllogictests-runner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,12 @@ fn drive_datafusion(path: &Path, work_dir: &Path, mode: Mode) -> anyhow::Result<
Arc::new(DefaultTableFactory::new()),
)
.with_file_formats(vec![factory]);
let session =
SessionContext::new_with_state(session_state_builder.build()).enable_url_table();
// The workspace builds `datafusion` without the `nested_expressions` feature, so array
// functions (e.g. `make_array`, `array_length`) are not registered by default. Register
// them explicitly so SLT files can construct and query list columns.
let mut session_state = session_state_builder.build();
datafusion_functions_nested::register_all(&mut session_state)?;
let session = SessionContext::new_with_state(session_state).enable_url_table();

let mut runner = Runner::new(|| async {
Ok(PathNormalizing::new(
Expand Down
Loading
Loading