diff --git a/datafusion/functions/benches/regx.rs b/datafusion/functions/benches/regx.rs index 3a1a6a71173e..b003fd22acbb 100644 --- a/datafusion/functions/benches/regx.rs +++ b/datafusion/functions/benches/regx.rs @@ -23,6 +23,7 @@ use arrow::compute::cast; use arrow::datatypes::DataType; use criterion::{black_box, criterion_group, criterion_main, Criterion}; use datafusion_functions::regex::regexpcount::regexp_count_func; +use datafusion_functions::regex::regexpinstr::regexp_instr_func; use datafusion_functions::regex::regexplike::regexp_like; use datafusion_functions::regex::regexpmatch::regexp_match; use datafusion_functions::regex::regexpreplace::regexp_replace; @@ -127,6 +128,46 @@ fn criterion_benchmark(c: &mut Criterion) { }) }); + c.bench_function("regexp_instr_1000 string", |b| { + let mut rng = rand::thread_rng(); + let data = Arc::new(data(&mut rng)) as ArrayRef; + let regex = Arc::new(regex(&mut rng)) as ArrayRef; + let start = Arc::new(start(&mut rng)) as ArrayRef; + let flags = Arc::new(flags(&mut rng)) as ArrayRef; + + b.iter(|| { + black_box( + regexp_instr_func(&[ + Arc::clone(&data), + Arc::clone(®ex), + Arc::clone(&start), + Arc::clone(&flags), + ]) + .expect("regexp_instr should work on utf8"), + ) + }) + }); + + c.bench_function("regexp_instr_1000 utf8view", |b| { + let mut rng = rand::thread_rng(); + let data = cast(&data(&mut rng), &DataType::Utf8View).unwrap(); + let regex = cast(®ex(&mut rng), &DataType::Utf8View).unwrap(); + let start = Arc::new(start(&mut rng)) as ArrayRef; + let flags = cast(&flags(&mut rng), &DataType::Utf8View).unwrap(); + + b.iter(|| { + black_box( + regexp_instr_func(&[ + Arc::clone(&data), + Arc::clone(®ex), + Arc::clone(&start), + Arc::clone(&flags), + ]) + .expect("regexp_instr should work on utf8view"), + ) + }) + }); + c.bench_function("regexp_like_1000", |b| { let mut rng = rand::thread_rng(); let data = Arc::new(data(&mut rng)) as ArrayRef; diff --git a/datafusion/functions/src/regex/mod.rs b/datafusion/functions/src/regex/mod.rs index 13fbc049af58..6b3919bd2b75 100644 --- a/datafusion/functions/src/regex/mod.rs +++ b/datafusion/functions/src/regex/mod.rs @@ -17,15 +17,20 @@ //! "regex" DataFusion functions +use arrow::error::ArrowError; +use regex::Regex; +use std::collections::hash_map::Entry; +use std::collections::HashMap; use std::sync::Arc; - pub mod regexpcount; +pub mod regexpinstr; pub mod regexplike; pub mod regexpmatch; pub mod regexpreplace; // create UDFs make_udf_function!(regexpcount::RegexpCountFunc, regexp_count); +make_udf_function!(regexpinstr::RegexpInstrFunc, regexp_instr); make_udf_function!(regexpmatch::RegexpMatchFunc, regexp_match); make_udf_function!(regexplike::RegexpLikeFunc, regexp_like); make_udf_function!(regexpreplace::RegexpReplaceFunc, regexp_replace); @@ -60,6 +65,34 @@ pub mod expr_fn { super::regexp_match().call(args) } + /// Returns index of regular expression matches in a string. + pub fn regexp_instr( + values: Expr, + regex: Expr, + start: Option, + n: Option, + endoption: Option, + flags: Option, + subexpr: Option, + ) -> Expr { + let mut args = vec![values, regex]; + if let Some(start) = start { + args.push(start); + }; + if let Some(n) = n { + args.push(n); + }; + if let Some(endoption) = endoption { + args.push(endoption); + }; + if let Some(flags) = flags { + args.push(flags); + }; + if let Some(subexpr) = subexpr { + args.push(subexpr); + }; + super::regexp_instr().call(args) + } /// Returns true if a has at least one match in a string, false otherwise. pub fn regexp_like(values: Expr, regex: Expr, flags: Option) -> Expr { let mut args = vec![values, regex]; @@ -89,7 +122,44 @@ pub fn functions() -> Vec> { vec![ regexp_count(), regexp_match(), + regexp_instr(), regexp_like(), regexp_replace(), ] } + +pub fn compile_and_cache_regex<'strings, 'cache>( + regex: &'strings str, + flags: Option<&'strings str>, + regex_cache: &'cache mut HashMap<(&'strings str, Option<&'strings str>), Regex>, +) -> Result<&'cache Regex, ArrowError> +where + 'strings: 'cache, +{ + let result = match regex_cache.entry((regex, flags)) { + Entry::Occupied(occupied_entry) => occupied_entry.into_mut(), + Entry::Vacant(vacant_entry) => { + let compiled = compile_regex(regex, flags)?; + vacant_entry.insert(compiled) + } + }; + Ok(result) +} + +pub fn compile_regex(regex: &str, flags: Option<&str>) -> Result { + let pattern = match flags { + None | Some("") => regex.to_string(), + Some(flags) => { + if flags.contains("g") { + return Err(ArrowError::ComputeError( + "regexp_count() does not support global flag".to_string(), + )); + } + format!("(?{flags}){regex}") + } + }; + + Regex::new(&pattern).map_err(|_| { + ArrowError::ComputeError(format!("Regular expression did not compile: {pattern}")) + }) +} diff --git a/datafusion/functions/src/regex/regexpcount.rs b/datafusion/functions/src/regex/regexpcount.rs index 8f53bf8eb158..9a59cad74b5b 100644 --- a/datafusion/functions/src/regex/regexpcount.rs +++ b/datafusion/functions/src/regex/regexpcount.rs @@ -15,6 +15,7 @@ // specific language governing permissions and limitations // under the License. +use crate::regex::{compile_and_cache_regex, compile_regex}; use arrow::array::{Array, ArrayRef, AsArray, Datum, Int64Array, StringArrayType}; use arrow::datatypes::{DataType, Int64Type}; use arrow::datatypes::{ @@ -29,10 +30,10 @@ use datafusion_expr::{ use datafusion_macros::user_doc; use itertools::izip; use regex::Regex; -use std::collections::hash_map::Entry; use std::collections::HashMap; use std::sync::Arc; +// Ensure the `compile_and_cache_regex` function is defined in the `regex` module or imported correctly. #[user_doc( doc_section(label = "Regular Expression Functions"), description = "Returns the number of matches that a [regular expression](https://docs.rs/regex/latest/regex/#syntax) has in a string.", @@ -550,42 +551,6 @@ where } } -fn compile_and_cache_regex<'strings, 'cache>( - regex: &'strings str, - flags: Option<&'strings str>, - regex_cache: &'cache mut HashMap<(&'strings str, Option<&'strings str>), Regex>, -) -> Result<&'cache Regex, ArrowError> -where - 'strings: 'cache, -{ - let result = match regex_cache.entry((regex, flags)) { - Entry::Occupied(occupied_entry) => occupied_entry.into_mut(), - Entry::Vacant(vacant_entry) => { - let compiled = compile_regex(regex, flags)?; - vacant_entry.insert(compiled) - } - }; - Ok(result) -} - -fn compile_regex(regex: &str, flags: Option<&str>) -> Result { - let pattern = match flags { - None | Some("") => regex.to_string(), - Some(flags) => { - if flags.contains("g") { - return Err(ArrowError::ComputeError( - "regexp_count() does not support global flag".to_string(), - )); - } - format!("(?{flags}){regex}") - } - }; - - Regex::new(&pattern).map_err(|_| { - ArrowError::ComputeError(format!("Regular expression did not compile: {pattern}")) - }) -} - fn count_matches( value: Option<&str>, pattern: &Regex, diff --git a/datafusion/functions/src/regex/regexpinstr.rs b/datafusion/functions/src/regex/regexpinstr.rs new file mode 100644 index 000000000000..dafd3cdf61d5 --- /dev/null +++ b/datafusion/functions/src/regex/regexpinstr.rs @@ -0,0 +1,804 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::{Array, ArrayRef, AsArray, Datum, Int64Array, StringArrayType}; +use arrow::datatypes::{DataType, Int64Type}; +use arrow::datatypes::{ + DataType::Int64, DataType::LargeUtf8, DataType::Utf8, DataType::Utf8View, +}; +use arrow::error::ArrowError; +use datafusion_common::{exec_err, internal_err, Result, ScalarValue}; +use datafusion_expr::{ + ColumnarValue, Documentation, ScalarUDFImpl, Signature, TypeSignature::Exact, + TypeSignature::Uniform, Volatility, +}; +use datafusion_macros::user_doc; +use itertools::izip; +use regex::Regex; +use std::collections::HashMap; +use std::sync::Arc; + +use crate::regex::compile_and_cache_regex; + +#[user_doc( + doc_section(label = "Regular Expression Functions"), + description = "Returns the position in a string where the specified occurrence of a POSIX regular expression is located.", + syntax_example = "regexp_instr(str, regexp[, start[, N[, flags]]])", + sql_example = r#"```sql +> SELECT regexp_instr('ABCDEF', 'C(.)(..)'); ++---------------------------------------------------------------+ +| regexp_instr(Utf8("ABCDEF"),Utf8("C(.)(..)")) | ++---------------------------------------------------------------+ +| 3 | ++---------------------------------------------------------------+ +```"#, + standard_argument(name = "str", prefix = "String"), + standard_argument(name = "regexp", prefix = "Regular"), + argument( + name = "start", + description = "- **start**: Optional start position (the first position is 1) to search for the regular expression. Can be a constant, column, or function. Defaults to 1" + ), + argument( + name = "N", + description = "- **N**: Optional The N-th occurrence of pattern to find. Defaults to 1 (first match). Can be a constant, column, or function." + ), + argument( + name = "flags", + description = r#"Optional regular expression flags that control the behavior of the regular expression. The following flags are supported: + - **i**: case-insensitive: letters match both upper and lower case + - **m**: multi-line mode: ^ and $ match begin/end of line + - **s**: allow . to match \n + - **R**: enables CRLF mode: when multi-line mode is enabled, \r\n is used + - **U**: swap the meaning of x* and x*?"# + ), + argument( + name = "subexpr", + description = "Optional Specifies which capture group (subexpression) to return the position for. Defaults to 0, which returns the position of the entire match." + ) +)] +#[derive(Debug)] +pub struct RegexpInstrFunc { + signature: Signature, +} + +impl Default for RegexpInstrFunc { + fn default() -> Self { + Self::new() + } +} + +impl RegexpInstrFunc { + pub fn new() -> Self { + Self { + signature: Signature::one_of( + vec![ + Uniform(2, vec![Utf8View, LargeUtf8, Utf8]), + Exact(vec![Utf8View, Utf8View, Int64]), + Exact(vec![LargeUtf8, LargeUtf8, Int64]), + Exact(vec![Utf8, Utf8, Int64]), + Exact(vec![Utf8View, Utf8View, Int64, Int64]), + Exact(vec![LargeUtf8, LargeUtf8, Int64, Int64]), + Exact(vec![Utf8, Utf8, Int64, Int64]), + Exact(vec![Utf8View, Utf8View, Int64, Int64, Utf8View]), + Exact(vec![LargeUtf8, LargeUtf8, Int64, Int64, LargeUtf8]), + Exact(vec![Utf8, Utf8, Int64, Int64, Utf8]), + ], + Volatility::Immutable, + ), + } + } +} + +impl ScalarUDFImpl for RegexpInstrFunc { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn name(&self) -> &str { + "regexp_instr" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(Int64) + } + + fn invoke_with_args( + &self, + args: datafusion_expr::ScalarFunctionArgs, + ) -> Result { + let args = &args.args; + + let len = args + .iter() + .fold(Option::::None, |acc, arg| match arg { + ColumnarValue::Scalar(_) => acc, + ColumnarValue::Array(a) => Some(a.len()), + }); + + let is_scalar = len.is_none(); + let inferred_length = len.unwrap_or(1); + let args = args + .iter() + .map(|arg| arg.to_array(inferred_length)) + .collect::>>()?; + + let result = regexp_instr_func(&args); + if is_scalar { + // If all inputs are scalar, keeps output as scalar + let result = result.and_then(|arr| ScalarValue::try_from_array(&arr, 0)); + result.map(ColumnarValue::Scalar) + } else { + result.map(ColumnarValue::Array) + } + } + + fn documentation(&self) -> Option<&Documentation> { + self.doc() + } +} + +pub fn regexp_instr_func(args: &[ArrayRef]) -> Result { + let args_len = args.len(); + if !(2..=6).contains(&args_len) { + return exec_err!("regexp_instr was called with {args_len} arguments. It requires at least 2 and at most 6."); + } + + let values = &args[0]; + match values.data_type() { + Utf8 | LargeUtf8 | Utf8View => (), + other => { + return internal_err!( + "Unsupported data type {other:?} for function regexp_instr" + ); + } + } + + regexp_instr( + values, + &args[1], + if args_len > 2 { Some(&args[2]) } else { None }, + if args_len > 3 { Some(&args[3]) } else { None }, + if args_len > 4 { Some(&args[4]) } else { None }, + if args_len > 5 { Some(&args[5]) } else { None }, + ) + .map_err(|e| e.into()) +} + +/// `arrow-rs` style implementation of `regexp_instr` function. +/// This function `regexp_instr` is responsible for returning the index of a regular expression pattern +/// within a string array. It supports optional start positions and flags for case insensitivity. +/// +/// The function accepts a variable number of arguments: +/// - `values`: The array of strings to search within. +/// - `regex_array`: The array of regular expression patterns to search for. +/// - `start_array` (optional): The array of start positions for the search. +/// - `nth_array` (optional): The array of start nth for the search. +/// - `endoption_array` (optional): The array of endoption positions for the search. +/// - `flags_array` (optional): The array of flags to modify the search behavior (e.g., case insensitivity). +/// - `subexpr_array` (optional): The array of subexpr positions for the search. +/// +/// The function handles different combinations of scalar and array inputs for the regex patterns, start positions, +/// and flags. It uses a cache to store compiled regular expressions for efficiency. +/// +/// # Errors +/// Returns an error if the input arrays have mismatched lengths or if the regular expression fails to compile. +pub fn regexp_instr( + values: &dyn Array, + regex_array: &dyn Datum, + start_array: Option<&dyn Datum>, + nth_array: Option<&dyn Datum>, + flags_array: Option<&dyn Datum>, + subexpr_array: Option<&dyn Datum>, +) -> Result { + let (regex_array, is_regex_scalar) = regex_array.get(); + let (start_array, is_start_scalar) = start_array.map_or((None, true), |start| { + let (start, is_start_scalar) = start.get(); + (Some(start), is_start_scalar) + }); + let (nth_array, is_nth_scalar) = nth_array.map_or((None, true), |nth| { + let (nth, is_nth_scalar) = nth.get(); + (Some(nth), is_nth_scalar) + }); + let (flags_array, is_flags_scalar) = flags_array.map_or((None, true), |flags| { + let (flags, is_flags_scalar) = flags.get(); + (Some(flags), is_flags_scalar) + }); + let (subexpr_array, is_subexpr_scalar) = + subexpr_array.map_or((None, true), |subexpr| { + let (subexpr, is_subexpr_scalar) = subexpr.get(); + (Some(subexpr), is_subexpr_scalar) + }); + + match (values.data_type(), regex_array.data_type(), flags_array) { + (Utf8, Utf8, None) => regexp_instr_inner( + values.as_string::(), + regex_array.as_string::(), + is_regex_scalar, + start_array.map(|start| start.as_primitive::()), + is_start_scalar, + nth_array.map(|nth| nth.as_primitive::()), + is_nth_scalar, + None, + is_flags_scalar, + subexpr_array.map(|subexpr| subexpr.as_primitive::()), + is_subexpr_scalar, + ), + (Utf8, Utf8, Some(flags_array)) if *flags_array.data_type() == Utf8 => regexp_instr_inner( + values.as_string::(), + regex_array.as_string::(), + is_regex_scalar, + start_array.map(|start| start.as_primitive::()), + is_start_scalar, + nth_array.map(|nth| nth.as_primitive::()), + is_nth_scalar, + Some(flags_array.as_string::()), + is_flags_scalar, + subexpr_array.map(|subexpr| subexpr.as_primitive::()), + is_subexpr_scalar, + ), + (LargeUtf8, LargeUtf8, None) => regexp_instr_inner( + values.as_string::(), + regex_array.as_string::(), + is_regex_scalar, + start_array.map(|start| start.as_primitive::()), + is_start_scalar, + nth_array.map(|nth| nth.as_primitive::()), + is_nth_scalar, + None, + is_flags_scalar, + subexpr_array.map(|subexpr| subexpr.as_primitive::()), + is_subexpr_scalar, + ), + (LargeUtf8, LargeUtf8, Some(flags_array)) if *flags_array.data_type() == LargeUtf8 => regexp_instr_inner( + values.as_string::(), + regex_array.as_string::(), + is_regex_scalar, + start_array.map(|start| start.as_primitive::()), + is_start_scalar, + nth_array.map(|nth| nth.as_primitive::()), + is_nth_scalar, + Some(flags_array.as_string::()), + is_flags_scalar, + subexpr_array.map(|subexpr| subexpr.as_primitive::()), + is_subexpr_scalar, + ), + (Utf8View, Utf8View, None) => regexp_instr_inner( + values.as_string_view(), + regex_array.as_string_view(), + is_regex_scalar, + start_array.map(|start| start.as_primitive::()), + is_start_scalar, + nth_array.map(|nth| nth.as_primitive::()), + is_nth_scalar, + None, + is_flags_scalar, + subexpr_array.map(|subexpr| subexpr.as_primitive::()), + is_subexpr_scalar, + ), + (Utf8View, Utf8View, Some(flags_array)) if *flags_array.data_type() == Utf8View => regexp_instr_inner( + values.as_string_view(), + regex_array.as_string_view(), + is_regex_scalar, + start_array.map(|start| start.as_primitive::()), + is_start_scalar, + nth_array.map(|nth| nth.as_primitive::()), + is_nth_scalar, + Some(flags_array.as_string_view()), + is_flags_scalar, + subexpr_array.map(|subexpr| subexpr.as_primitive::()), + is_subexpr_scalar, + ), + _ => Err(ArrowError::ComputeError( + "regexp_instr() expected the input arrays to be of type Utf8, LargeUtf8, or Utf8View and the data types of the values, regex_array, and flags_array to match".to_string(), + )), + } +} + +enum ScalarOrArray { + Scalar(T), + Array(Vec), +} + +impl ScalarOrArray { + fn iter(&self, len: usize) -> Box + '_> { + match self { + ScalarOrArray::Scalar(val) => Box::new(std::iter::repeat_n(val.clone(), len)), + ScalarOrArray::Array(arr) => Box::new(arr.iter().cloned()), + } + } +} + +#[allow(clippy::too_many_arguments)] +pub fn regexp_instr_inner<'a, S>( + values: S, + regex_array: S, + is_regex_scalar: bool, + start_array: Option<&Int64Array>, + is_start_scalar: bool, + nth_array: Option<&Int64Array>, + is_nth_scalar: bool, + flags_array: Option, + is_flags_scalar: bool, + subexp_array: Option<&Int64Array>, + is_subexp_scalar: bool, +) -> Result +where + S: StringArrayType<'a>, +{ + let len = values.len(); + + let regex_input = if is_regex_scalar || regex_array.len() == 1 { + ScalarOrArray::Scalar(Some(regex_array.value(0))) + } else { + let regex_vec: Vec> = regex_array.iter().collect(); + ScalarOrArray::Array(regex_vec) + }; + + let start_input = if let Some(start) = start_array { + if is_start_scalar || start.len() == 1 { + ScalarOrArray::Scalar(start.value(0)) + } else { + let start_vec: Vec = (0..start.len()) + .map(|i| if start.is_null(i) { 0 } else { start.value(i) }) // handle nulls as 0 + .collect(); + + ScalarOrArray::Array(start_vec) + } + } else if len == 1 { + ScalarOrArray::Scalar(1) + } else { + ScalarOrArray::Array(vec![1; len]) + }; + + let nth_input = if let Some(nth) = nth_array { + if is_nth_scalar || nth.len() == 1 { + ScalarOrArray::Scalar(nth.value(0)) + } else { + let nth_vec: Vec = (0..nth.len()) + .map(|i| if nth.is_null(i) { 0 } else { nth.value(i) }) // handle nulls as 0 + .collect(); + ScalarOrArray::Array(nth_vec) + } + } else if len == 1 { + ScalarOrArray::Scalar(1) + } + // Default nth = 0 + else { + ScalarOrArray::Array(vec![1; len]) + }; + + let flags_input = if let Some(ref flags) = flags_array { + if is_flags_scalar || flags.len() == 1 { + ScalarOrArray::Scalar(flags.value(0)) + } else { + let flags_vec: Vec<&str> = flags.iter().map(|v| v.unwrap_or("")).collect(); + ScalarOrArray::Array(flags_vec) + } + } else if len == 1 { + ScalarOrArray::Scalar("") + } + // Default flags = "" + else { + ScalarOrArray::Array(vec![""; len]) + }; + + let subexp_input = if let Some(subexp) = subexp_array { + if is_subexp_scalar || subexp.len() == 1 { + ScalarOrArray::Scalar(subexp.value(0)) + } else { + let subexp_vec: Vec = (0..subexp.len()) + .map(|i| { + if subexp.is_null(i) { + 0 + } else { + subexp.value(i) + } + }) // handle nulls as 0 + .collect(); + ScalarOrArray::Array(subexp_vec) + } + } else if len == 1 { + ScalarOrArray::Scalar(0) + } + // Default subexp = 0 + else { + ScalarOrArray::Array(vec![0; len]) + }; + + let mut regex_cache = HashMap::new(); + + let result: Result>, ArrowError> = izip!( + values.iter(), + regex_input.iter(len), + start_input.iter(len), + nth_input.iter(len), + flags_input.iter(len), + subexp_input.iter(len) + ) + .map(|(value, regex, start, nth, flags, subexp)| match regex { + None => Ok(None), + Some("") => Ok(None), + Some(regex) => get_index( + value, + regex, + start, + nth, + subexp, + Some(flags), + &mut regex_cache, + ), + }) + .collect(); + + Ok(Arc::new(Int64Array::from(result?))) +} + +fn get_index<'strings, 'cache>( + value: Option<&str>, + pattern: &'strings str, + start: i64, + n: i64, + subexpr: i64, + flags: Option<&'strings str>, + regex_cache: &'cache mut HashMap<(&'strings str, Option<&'strings str>), Regex>, +) -> Result, ArrowError> +where + 'strings: 'cache, +{ + let value = match value { + None => return Ok(None), + Some("") => return Ok(Some(0)), + Some(value) => value, + }; + + let pattern = compile_and_cache_regex(pattern, flags, regex_cache)?; + if start < 1 { + return Err(ArrowError::ComputeError( + "regexp_instr() requires start to be 1-based".to_string(), + )); + } + + if n < 1 { + return Err(ArrowError::ComputeError( + "N must be 1 or greater".to_string(), + )); + } + + // --- Simplified byte_start_offset calculation --- + let total_chars = value.chars().count() as i64; + let byte_start_offset = if start > total_chars { + // If start is beyond the total characters, it means we start searching + // after the string effectively. No matches possible. + return Ok(Some(0)); + } else { + // Get the byte offset for the (start - 1)-th character (0-based) + value + .char_indices() + .nth((start - 1) as usize) + .map(|(idx, _)| idx) + .unwrap_or(0) // Should not happen if start is valid and <= total_chars + }; + // --- End simplified calculation --- + + let search_slice = &value[byte_start_offset..]; + + // Handle subexpression capturing first, as it takes precedence + if subexpr > 0 { + if let Some(captures) = pattern.captures(search_slice) { + if let Some(matched) = captures.get(subexpr as usize) { + // Convert byte offset relative to search_slice back to 1-based character offset + // relative to the original `value` string. + let start_char_offset = + value[..byte_start_offset + matched.start()].chars().count() as i64 + + 1; + return Ok(Some(start_char_offset)); + } + } + return Ok(Some(0)); // Return 0 if the subexpression was not found + } + + // Use nth to get the N-th match (n is 1-based, nth is 0-based) + if let Some(mat) = pattern.find_iter(search_slice).nth((n - 1) as usize) { + // Convert byte offset relative to search_slice back to 1-based character offset + // relative to the original `value` string. + let match_start_byte_offset = byte_start_offset + mat.start(); + let match_start_char_offset = + value[..match_start_byte_offset].chars().count() as i64 + 1; + Ok(Some(match_start_char_offset)) + } else { + Ok(Some(0)) // Return 0 if the N-th match was not found + } +} + +#[cfg(test)] +mod tests { + use super::*; + use arrow::array::Int64Array; + use arrow::array::{GenericStringArray, StringViewArray}; + use arrow::datatypes::Field; + use datafusion_expr::ScalarFunctionArgs; + #[test] + fn test_regexp_instr() { + test_case_sensitive_regexp_instr_scalar(); + test_case_sensitive_regexp_instr_scalar_start(); + test_case_sensitive_regexp_instr_scalar_nth(); + + test_case_sensitive_regexp_instr_array::>(); + test_case_sensitive_regexp_instr_array::>(); + test_case_sensitive_regexp_instr_array::(); + + test_case_sensitive_regexp_instr_array_start::>(); + test_case_sensitive_regexp_instr_array_start::>(); + test_case_sensitive_regexp_instr_array_start::(); + + test_case_sensitive_regexp_instr_array_nth::>(); + test_case_sensitive_regexp_instr_array_nth::>(); + test_case_sensitive_regexp_instr_array_nth::(); + } + + fn regexp_instr_with_scalar_values(args: &[ScalarValue]) -> Result { + let args_values = args + .iter() + .map(|sv| ColumnarValue::Scalar(sv.clone())) + .collect(); + + let arg_fields_owned = args + .iter() + .enumerate() + .map(|(idx, a)| Field::new(format!("arg_{idx}"), a.data_type(), true)) + .collect::>(); + let arg_fields = arg_fields_owned.iter().collect::>(); + + RegexpInstrFunc::new().invoke_with_args(ScalarFunctionArgs { + args: args_values, + arg_fields, + number_rows: args.len(), + return_field: &Field::new("f", Int64, true), + }) + } + + fn test_case_sensitive_regexp_instr_scalar() { + let values = [ + "hello world", + "abcdefg", + "xyz123xyz", + "no match here", + "abc", + "ДатаФусион数据融合📊🔥", + ]; + let regex = ["o", "d", "123", "z", "gg", "📊"]; + + let expected: Vec = vec![5, 4, 4, 0, 0, 15]; + + // let values = [""]; + // let regex = [""]; + // let expected: Vec = vec![0]; + + izip!(values.iter(), regex.iter()) + .enumerate() + .for_each(|(pos, (&v, &r))| { + // utf8 + let v_sv = ScalarValue::Utf8(Some(v.to_string())); + let regex_sv = ScalarValue::Utf8(Some(r.to_string())); + let expected = expected.get(pos).cloned(); + let re = regexp_instr_with_scalar_values(&[v_sv, regex_sv]); + // let res_exp = re.unwrap(); + match re { + Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => { + assert_eq!(v, expected, "regexp_instr scalar test failed"); + } + _ => panic!("Unexpected result"), + } + + // largeutf8 + let v_sv = ScalarValue::LargeUtf8(Some(v.to_string())); + let regex_sv = ScalarValue::LargeUtf8(Some(r.to_string())); + let re = regexp_instr_with_scalar_values(&[v_sv, regex_sv]); + match re { + Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => { + assert_eq!(v, expected, "regexp_instr scalar test failed"); + } + _ => panic!("Unexpected result"), + } + + // utf8view + let v_sv = ScalarValue::Utf8View(Some(v.to_string())); + let regex_sv = ScalarValue::Utf8View(Some(r.to_string())); + let re = regexp_instr_with_scalar_values(&[v_sv, regex_sv]); + match re { + Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => { + assert_eq!(v, expected, "regexp_instr scalar test failed"); + } + _ => panic!("Unexpected result"), + } + }); + } + + fn test_case_sensitive_regexp_instr_scalar_start() { + let values = ["abcabcabc", "abcabcabc", ""]; + let regex = ["abc", "abc", "gg"]; + let start = [4, 5, 5]; + let expected: Vec = vec![4, 7, 0]; + + izip!(values.iter(), regex.iter(), start.iter()) + .enumerate() + .for_each(|(pos, (&v, &r, &s))| { + // utf8 + let v_sv = ScalarValue::Utf8(Some(v.to_string())); + let regex_sv = ScalarValue::Utf8(Some(r.to_string())); + let start_sv = ScalarValue::Int64(Some(s)); + let expected = expected.get(pos).cloned(); + let re = + regexp_instr_with_scalar_values(&[v_sv, regex_sv, start_sv.clone()]); + match re { + Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => { + assert_eq!(v, expected, "regexp_instr scalar test failed"); + } + _ => panic!("Unexpected result"), + } + + // largeutf8 + let v_sv = ScalarValue::LargeUtf8(Some(v.to_string())); + let regex_sv = ScalarValue::LargeUtf8(Some(r.to_string())); + let start_sv = ScalarValue::Int64(Some(s)); + let re = + regexp_instr_with_scalar_values(&[v_sv, regex_sv, start_sv.clone()]); + match re { + Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => { + assert_eq!(v, expected, "regexp_instr scalar test failed"); + } + _ => panic!("Unexpected result"), + } + + // utf8view + let v_sv = ScalarValue::Utf8View(Some(v.to_string())); + let regex_sv = ScalarValue::Utf8View(Some(r.to_string())); + let start_sv = ScalarValue::Int64(Some(s)); + let re = + regexp_instr_with_scalar_values(&[v_sv, regex_sv, start_sv.clone()]); + match re { + Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => { + assert_eq!(v, expected, "regexp_instr scalar test failed"); + } + _ => panic!("Unexpected result"), + } + }); + } + + fn test_case_sensitive_regexp_instr_scalar_nth() { + let values = ["abcabcabc", "abcabcabc", "abcabcabc", "abcabcabc"]; + let regex = ["abc", "abc", "abc", "abc"]; + let start = [1, 1, 1, 1]; + let nth = [1, 2, 3, 4]; + let expected: Vec = vec![1, 4, 7, 0]; + + izip!(values.iter(), regex.iter(), start.iter(), nth.iter()) + .enumerate() + .for_each(|(pos, (&v, &r, &s, &n))| { + // utf8 + let v_sv = ScalarValue::Utf8(Some(v.to_string())); + let regex_sv = ScalarValue::Utf8(Some(r.to_string())); + let start_sv = ScalarValue::Int64(Some(s)); + let nth_sv = ScalarValue::Int64(Some(n)); + let expected = expected.get(pos).cloned(); + let re = regexp_instr_with_scalar_values(&[ + v_sv, + regex_sv, + start_sv.clone(), + nth_sv.clone(), + ]); + match re { + Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => { + assert_eq!(v, expected, "regexp_instr scalar test failed"); + } + _ => panic!("Unexpected result"), + } + + // largeutf8 + let v_sv = ScalarValue::LargeUtf8(Some(v.to_string())); + let regex_sv = ScalarValue::LargeUtf8(Some(r.to_string())); + let start_sv = ScalarValue::Int64(Some(s)); + let nth_sv = ScalarValue::Int64(Some(n)); + let re = regexp_instr_with_scalar_values(&[ + v_sv, + regex_sv, + start_sv.clone(), + nth_sv.clone(), + ]); + match re { + Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => { + assert_eq!(v, expected, "regexp_instr scalar test failed"); + } + _ => panic!("Unexpected result"), + } + + // utf8view + let v_sv = ScalarValue::Utf8View(Some(v.to_string())); + let regex_sv = ScalarValue::Utf8View(Some(r.to_string())); + let start_sv = ScalarValue::Int64(Some(s)); + let nth_sv = ScalarValue::Int64(Some(n)); + let re = regexp_instr_with_scalar_values(&[ + v_sv, + regex_sv, + start_sv.clone(), + nth_sv.clone(), + ]); + match re { + Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => { + assert_eq!(v, expected, "regexp_instr scalar test failed"); + } + _ => panic!("Unexpected result"), + } + }); + } + + fn test_case_sensitive_regexp_instr_array() + where + A: From> + Array + 'static, + { + let values = A::from(vec![ + "hello world", + "abcdefg", + "xyz123xyz", + "no match here", + "", + ]); + let regex = A::from(vec!["o", "d", "123", "z", "gg"]); + + let expected = Int64Array::from(vec![5, 4, 4, 0, 0]); + let re = regexp_instr_func(&[Arc::new(values), Arc::new(regex)]).unwrap(); + assert_eq!(re.as_ref(), &expected); + } + + fn test_case_sensitive_regexp_instr_array_start() + where + A: From> + Array + 'static, + { + let values = A::from(vec!["abcabcabc", "abcabcabc", ""]); + let regex = A::from(vec!["abc", "abc", "gg"]); + let start = Int64Array::from(vec![4, 5, 5]); + let expected = Int64Array::from(vec![4, 7, 0]); + + let re = regexp_instr_func(&[Arc::new(values), Arc::new(regex), Arc::new(start)]) + .unwrap(); + assert_eq!(re.as_ref(), &expected); + } + + fn test_case_sensitive_regexp_instr_array_nth() + where + A: From> + Array + 'static, + { + let values = A::from(vec!["abcabcabc", "abcabcabc", "abcabcabc", "abcabcabc"]); + let regex = A::from(vec!["abc", "abc", "abc", "abc"]); + let start = Int64Array::from(vec![1, 1, 1, 1]); + let nth = Int64Array::from(vec![1, 2, 3, 4]); + let expected = Int64Array::from(vec![1, 4, 7, 0]); + + let re = regexp_instr_func(&[ + Arc::new(values), + Arc::new(regex), + Arc::new(start), + Arc::new(nth), + ]) + .unwrap(); + assert_eq!(re.as_ref(), &expected); + } +} diff --git a/datafusion/sqllogictest/test_files/regexp/regexp_instr.slt b/datafusion/sqllogictest/test_files/regexp/regexp_instr.slt new file mode 100644 index 000000000000..c651422142d4 --- /dev/null +++ b/datafusion/sqllogictest/test_files/regexp/regexp_instr.slt @@ -0,0 +1,183 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# Import common test data +include ./init_data.slt.part + +query I +SELECT regexp_instr('123123123123123', '(12)3'); +---- +1 + +query I +SELECT regexp_instr('123123123123', '123', 1); +---- +1 + +query I +SELECT regexp_instr('123123123123', '123', 3); +---- +4 + +query I +SELECT regexp_instr('123123123123', '123', 33); +---- +0 + +query I +SELECT regexp_instr('ABCABCABCABC', 'Abc', 1, 2, ''); +---- +0 + +query I +SELECT regexp_instr('ABCABCABCABC', 'Abc', 1, 2, 'i'); +---- +4 + +statement error +External error: query failed: DataFusion error: Arrow error: Compute error: regexp_instr() requires start to be 1 based +SELECT regexp_instr('123123123123', '123', 0); + +statement error +External error: query failed: DataFusion error: Arrow error: Compute error: regexp_instr() requires start to be 1 based +SELECT regexp_instr('123123123123', '123', -3); + +query I +SELECT regexp_instr(str, pattern) FROM regexp_test_data; +---- +NULL +1 +1 +0 +0 +0 +0 +1 +1 +1 +1 +1 + +query I +SELECT regexp_instr(str, pattern, start) FROM regexp_test_data; +---- +NULL +1 +1 +0 +0 +0 +0 +0 +3 +4 +1 +2 + + +statement ok +CREATE TABLE t_stringview AS +SELECT + arrow_cast(str, 'Utf8View') AS str, + arrow_cast(pattern, 'Utf8View') AS pattern, + arrow_cast(start, 'Int64') AS start +FROM regexp_test_data; + +query I +SELECT regexp_instr(str, pattern, start) FROM t_stringview; +---- +NULL +1 +1 +0 +0 +0 +0 +0 +3 +4 +1 +2 + +query I +SELECT regexp_instr( + arrow_cast(str, 'Utf8'), + arrow_cast(pattern, 'LargeUtf8'), + arrow_cast(start, 'Int32') +) FROM t_stringview; +---- +NULL +1 +1 +0 +0 +0 +0 +0 +3 +4 +1 +2 + +query I +SELECT regexp_instr(NULL, NULL); +---- +NULL + +query I +SELECT regexp_instr(NULL, 'a'); +---- +NULL + +query I +SELECT regexp_instr('a', NULL); +---- +NULL + +query I +SELECT regexp_instr('😀abcdef', 'abc'); +---- +2 + + +statement ok +CREATE TABLE empty_table (str varchar, pattern varchar, start int); + +query I +SELECT regexp_instr(str, pattern, start) FROM empty_table; +---- + +statement ok +INSERT INTO empty_table VALUES + ('a', NULL, 1), + (NULL, 'a', 1), + (NULL, NULL, 1), + (NULL, NULL, NULL); + +query I +SELECT regexp_instr(str, pattern, start) FROM empty_table; +---- +NULL +NULL +NULL +NULL + +statement ok +DROP TABLE t_stringview; + +statement ok +DROP TABLE empty_table; \ No newline at end of file diff --git a/docs/source/user-guide/sql/scalar_functions.md b/docs/source/user-guide/sql/scalar_functions.md index cbcec710e267..885c8bbd2b98 100644 --- a/docs/source/user-guide/sql/scalar_functions.md +++ b/docs/source/user-guide/sql/scalar_functions.md @@ -1793,6 +1793,7 @@ regular expression [syntax](https://docs.rs/regex/latest/regex/#syntax) The following regular expression functions are supported: - [regexp_count](#regexp_count) +- [regexp_instr](#regexp_instr) - [regexp_like](#regexp_like) - [regexp_match](#regexp_match) - [regexp_replace](#regexp_replace) @@ -1828,6 +1829,39 @@ regexp_count(str, regexp[, start, flags]) +---------------------------------------------------------------+ ``` +### `regexp_instr` + +Returns the position in a string where the specified occurrence of a POSIX regular expression is located. + +```sql +regexp_instr(str, regexp[, start[, N[, flags]]]) +``` + +#### Arguments + +- **str**: String expression to operate on. Can be a constant, column, or function, and any combination of operators. +- **regexp**: Regular expression to operate on. Can be a constant, column, or function, and any combination of operators. +- **start**: - **start**: Optional start position (the first position is 1) to search for the regular expression. Can be a constant, column, or function. Defaults to 1 +- **N**: - **N**: Optional The N-th occurrence of pattern to find. Defaults to 1 (first match). Can be a constant, column, or function. +- **flags**: Optional regular expression flags that control the behavior of the regular expression. The following flags are supported: + - **i**: case-insensitive: letters match both upper and lower case + - **m**: multi-line mode: ^ and $ match begin/end of line + - **s**: allow . to match \n + - **R**: enables CRLF mode: when multi-line mode is enabled, \r\n is used + - **U**: swap the meaning of x* and x*? +- **subexpr**: Optional Specifies which capture group (subexpression) to return the position for. Defaults to 0, which returns the position of the entire match. + +#### Example + +```sql +> SELECT regexp_instr('ABCDEF', 'C(.)(..)'); ++---------------------------------------------------------------+ +| regexp_instr(Utf8("ABCDEF"),Utf8("C(.)(..)")) | ++---------------------------------------------------------------+ +| 3 | ++---------------------------------------------------------------+ +``` + ### `regexp_like` Returns true if a [regular expression](https://docs.rs/regex/latest/regex/#syntax) has at least one match in a string, false otherwise.