diff --git a/Cargo.lock b/Cargo.lock index 73b52889..ac0d2fa2 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1043,6 +1043,17 @@ version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4443176a9f2c162692bd3d352d745ef9413eec5782a80d8fd6f8a1ac692a07f7" +[[package]] +name = "fancy-regex" +version = "0.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "531e46835a22af56d1e3b66f04844bed63158bc094a628bec1d321d9b4c44bf2" +dependencies = [ + "bit-set", + "regex-automata 0.4.5", + "regex-syntax 0.8.2", +] + [[package]] name = "fastrand" version = "2.0.1" @@ -2604,6 +2615,7 @@ dependencies = [ "egg", "enum_dispatch", "erased-serde", + "fancy-regex", "futures", "futures-async-stream", "glob", diff --git a/Cargo.toml b/Cargo.toml index 39123272..fb7e16d7 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -58,6 +58,7 @@ prost = "0.12" pyo3 = { version = "0.20", features = ["extension-module"], optional = true } ref-cast = "1.0" regex = "1" +fancy-regex = "0.13" risinglight_proto = "0.2" rust_decimal = "1" rustyline = "13" diff --git a/src/binder/create_function.rs b/src/binder/create_function.rs index 71cfd2fa..0442c019 100644 --- a/src/binder/create_function.rs +++ b/src/binder/create_function.rs @@ -3,6 +3,7 @@ use std::fmt; use std::str::FromStr; +use fancy_regex::Regex; use pretty_xmlish::helper::delegate_fmt; use pretty_xmlish::Pretty; use serde::{Deserialize, Serialize}; @@ -18,6 +19,7 @@ pub struct CreateFunction { pub return_type: crate::types::DataType, pub language: String, pub body: String, + pub is_recursive: bool, } impl fmt::Display for CreateFunction { @@ -45,6 +47,35 @@ impl CreateFunction { } } +/// Find the pattern for recursive sql udf +/// return the exact index where the pattern first appears +/// Source: +fn find_target(input: &str, target: &str) -> Option { + // Regex pattern to find `target` not preceded or followed by an ASCII letter + // The pattern uses negative lookbehind (? bool { + if let Some(_) = find_target(body, func_name) { + true + } else { + false + } +} + impl Binder { pub(super) fn bind_create_function( &mut self, @@ -102,6 +133,8 @@ impl Binder { arg_names.push(arg.name.map_or("".to_string(), |n| n.to_string())); } + let is_recursive = is_recursive(&body, &name); + let f = self.egraph.add(Node::CreateFunction(CreateFunction { schema_name, name, @@ -110,6 +143,7 @@ impl Binder { return_type, language, body, + is_recursive, })); Ok(f) diff --git a/src/binder/expr.rs b/src/binder/expr.rs index 36c327cb..8f3a3a94 100644 --- a/src/binder/expr.rs +++ b/src/binder/expr.rs @@ -305,6 +305,7 @@ impl Binder { fn bind_function(&mut self, func: Function) -> Result { let mut args = vec![]; for arg in func.args.clone() { + println!("arg: {:#?}", arg); // ignore argument name let arg = match arg { FunctionArg::Named { arg, .. } => arg, @@ -332,6 +333,21 @@ impl Binder { // See if the input function is sql udf if let Some(ref function_catalog) = catalog.get_function_by_name(schema_name, function_name) { + // For recursive sql udf, we will postpone its execution + // until reaching backend. + // a.k.a. this will not be *inlined* during binding phase + if function_catalog.is_recursive { + return Ok(self.egraph.add(Node::Udf(Udf { + // TODO: presumably there could be multiple arguments + // but for simplicity reason, currently only + // a single argument is supported + id: args[0], + name: function_catalog.name.clone(), + body: function_catalog.body.clone(), + return_type: function_catalog.return_type.clone(), + }))); + } + // Create the brand new `udf_context` let Ok(context) = UdfContext::create_udf_context(func.args.as_slice(), function_catalog) diff --git a/src/binder/mod.rs b/src/binder/mod.rs index 1de55cbf..2dfd4d06 100644 --- a/src/binder/mod.rs +++ b/src/binder/mod.rs @@ -24,9 +24,11 @@ mod expr; mod insert; mod select; mod table; +mod udf; pub use self::create_function::*; pub use self::create_table::*; +pub use self::udf::*; pub type Result = std::result::Result; diff --git a/src/binder/udf.rs b/src/binder/udf.rs new file mode 100644 index 00000000..2bdc7341 --- /dev/null +++ b/src/binder/udf.rs @@ -0,0 +1,42 @@ +// Copyright 2024 RisingLight Project Authors. Licensed under Apache-2.0. + +use egg::Id; + +use crate::types::DataType; +use pretty_xmlish::helper::delegate_fmt; +use pretty_xmlish::Pretty; +use std::str::FromStr; +use std::fmt; + +/// currently represents recursive sql udf +#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone)] +pub struct Udf { + pub id: Id, + pub name: String, + pub body: String, + pub return_type: DataType, +} + +impl fmt::Display for Udf { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let explainer = Pretty::childless_record("Udf", self.pretty_function()); + delegate_fmt(&explainer, f, String::with_capacity(1000)) + } +} + +impl FromStr for Udf { + type Err = (); + + fn from_str(_s: &str) -> std::result::Result { + Err(()) + } +} + +impl Udf { + pub fn pretty_function<'a>(&self) -> Vec<(&'a str, Pretty<'a>)> { + vec![ + ("name", Pretty::display(&self.name)), + ("body", Pretty::display(&self.body)), + ] + } +} \ No newline at end of file diff --git a/src/catalog/function.rs b/src/catalog/function.rs index 4e4081d7..252b8583 100644 --- a/src/catalog/function.rs +++ b/src/catalog/function.rs @@ -10,6 +10,7 @@ pub struct FunctionCatalog { pub return_type: DataType, pub language: String, pub body: String, + pub is_recursive: bool, } impl FunctionCatalog { @@ -20,6 +21,7 @@ impl FunctionCatalog { return_type: DataType, language: String, body: String, + is_recursive: bool, ) -> Self { Self { name, @@ -28,17 +30,21 @@ impl FunctionCatalog { return_type, language, body, + is_recursive, } } + #[inline] pub fn body(&self) -> String { self.body.clone() } + #[inline] pub fn name(&self) -> String { self.name.clone() } + #[inline] pub fn language(&self) -> String { self.language.clone() } diff --git a/src/catalog/root.rs b/src/catalog/root.rs index cc9dc7d7..3c947c98 100644 --- a/src/catalog/root.rs +++ b/src/catalog/root.rs @@ -133,11 +133,20 @@ impl RootCatalog { return_type: DataType, language: String, body: String, + is_recursive: bool, ) { let schema_idx = self.get_schema_id_by_name(&schema_name).unwrap(); let mut inner = self.inner.lock().unwrap(); let schema = inner.schemas.get_mut(&schema_idx).unwrap(); - schema.create_function(name, arg_types, arg_names, return_type, language, body); + schema.create_function( + name, + arg_types, + arg_names, + return_type, + language, + body, + is_recursive, + ); } pub const DEFAULT_SCHEMA_NAME: &'static str = "postgres"; diff --git a/src/catalog/schema.rs b/src/catalog/schema.rs index b33547d3..5d020269 100644 --- a/src/catalog/schema.rs +++ b/src/catalog/schema.rs @@ -119,6 +119,7 @@ impl SchemaCatalog { return_type: DataType, language: String, body: String, + is_recursive: bool, ) { self.functions.insert( name.clone(), @@ -129,6 +130,7 @@ impl SchemaCatalog { return_type, language, body, + is_recursive, }), ); } diff --git a/src/executor/create_function.rs b/src/executor/create_function.rs index cde51bcf..4fd2783e 100644 --- a/src/executor/create_function.rs +++ b/src/executor/create_function.rs @@ -21,6 +21,7 @@ impl CreateFunctionExecutor { return_type, language, body, + is_recursive, } = self.f; self.catalog.create_function( @@ -31,6 +32,7 @@ impl CreateFunctionExecutor { return_type, language, body, + is_recursive, ); } } diff --git a/src/executor/evaluator.rs b/src/executor/evaluator.rs index 08ed55e4..ab1e111a 100644 --- a/src/executor/evaluator.rs +++ b/src/executor/evaluator.rs @@ -7,6 +7,7 @@ use std::fmt; use egg::{Id, Language}; use crate::array::*; +use crate::executor::udf::UdfExecutor; use crate::planner::{Expr, RecExpr}; use crate::types::{ConvertError, DataValue}; @@ -129,6 +130,11 @@ impl<'a> Evaluator<'a> { }; a.replace(from, to) } + // recursive sql udf's actual backend logic + Udf(udf) => UdfExecutor { + udf: udf.clone(), + } + .execute(chunk), e => { if let Some((op, a, b)) = e.binary_op() { let left = self.next(a).eval(chunk)?; diff --git a/src/executor/mod.rs b/src/executor/mod.rs index 889ae951..e22d39e8 100644 --- a/src/executor/mod.rs +++ b/src/executor/mod.rs @@ -83,6 +83,7 @@ mod table_scan; mod top_n; mod values; mod window; +mod udf; /// The maximum chunk length produced by executor at a time. const PROCESSING_WINDOW_SIZE: usize = 1024; @@ -352,7 +353,7 @@ impl Builder { CreateFunction(f) => CreateFunctionExecutor { f, - catalog: self.optimizer.catalog().clone(), + catalog: self.catalog().clone(), } .execute(), diff --git a/src/executor/projection.rs b/src/executor/projection.rs index 3f9f36ea..1f1f423b 100644 --- a/src/executor/projection.rs +++ b/src/executor/projection.rs @@ -16,6 +16,8 @@ impl ProjectionExecutor { pub async fn execute(self, child: BoxedExecutor) { #[for_await] for batch in child { + println!("[project]\n{}", batch.clone().unwrap()); + println!("projs: {:#?}", self.projs); yield Evaluator::new(&self.projs).eval_list(&batch?)?; } } diff --git a/src/executor/udf.rs b/src/executor/udf.rs new file mode 100644 index 00000000..546cb59a --- /dev/null +++ b/src/executor/udf.rs @@ -0,0 +1,16 @@ +// Copyright 2024 RisingLight Project Authors. Licensed under Apache-2.0. + +use super::*; +use crate::{array::ArrayImpl, binder::Udf, types::ConvertError}; + +/// The executor of (recursive) sql udf +pub struct UdfExecutor { + pub udf: Udf, +} + +impl UdfExecutor { + pub fn execute(&self, chunk: &DataChunk) -> std::result::Result { + println!("udf\n{}", chunk); + Ok(ArrayImpl::new_null((0..1).map(|_| ()).collect())) + } +} diff --git a/src/planner/explain.rs b/src/planner/explain.rs index 4d32d402..c8be183d 100644 --- a/src/planner/explain.rs +++ b/src/planner/explain.rs @@ -379,6 +379,10 @@ impl<'a> Explain<'a> { vec![].with(cost, rows), vec![self.child(child).pretty()], ), + Udf(udf) => { + let v = udf.pretty_function(); + Pretty::childless_record("Udf", v) + } Empty(_) => Pretty::childless_record("Empty", vec![].with(cost, rows)), } } diff --git a/src/planner/mod.rs b/src/planner/mod.rs index ea29762d..8fcae163 100644 --- a/src/planner/mod.rs +++ b/src/planner/mod.rs @@ -3,7 +3,7 @@ use egg::{define_language, Id, Symbol}; use crate::binder::copy::ExtSource; -use crate::binder::{CreateFunction, CreateTable}; +use crate::binder::{CreateFunction, CreateTable, Udf}; use crate::catalog::{ColumnRefId, TableRefId}; use crate::parser::{BinaryOperator, UnaryOperator}; use crate::types::{ColumnIndex, DataType, DataValue, DateTimeField}; @@ -131,6 +131,9 @@ define_language! { // with the same schema as `child` Symbol(Symbol), + + // currently only used by recursive sql udf + Udf(Udf), } } diff --git a/src/planner/rules/type_.rs b/src/planner/rules/type_.rs index 5d98a918..0badd633 100644 --- a/src/planner/rules/type_.rs +++ b/src/planner/rules/type_.rs @@ -175,6 +175,9 @@ pub fn analyze_type(enode: &Expr, x: impl Fn(&Id) -> Type, catalog: &RootCatalog Ok(DataType::Struct(types)) } + // currently for recursive sql udf's type inference + Udf(udf) => Ok(udf.return_type.clone()), + // other plan nodes _ => Err(TypeError::Unavailable(enode.to_string())), }