Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make regexp_match take scalar pattern and flag #5245

Merged
merged 8 commits into from Jan 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
242 changes: 219 additions & 23 deletions arrow-string/src/regexp.rs
Expand Up @@ -19,10 +19,11 @@
//! expression of a \[Large\]StringArray

use arrow_array::builder::{BooleanBufferBuilder, GenericStringBuilder, ListBuilder};
use arrow_array::cast::AsArray;
use arrow_array::*;
use arrow_buffer::NullBuffer;
use arrow_data::{ArrayData, ArrayDataBuilder};
use arrow_schema::{ArrowError, DataType};
use arrow_schema::{ArrowError, DataType, Field};
use regex::Regex;
use std::collections::HashMap;
use std::sync::Arc;
Expand Down Expand Up @@ -152,28 +153,7 @@ pub fn regexp_is_match_utf8_scalar<OffsetSize: OffsetSizeTrait>(
Ok(BooleanArray::from(data))
}

/// Extract all groups matched by a regular expression for a given String array.
///
/// Modelled after the Postgres [regexp_match].
///
/// Returns a ListArray of [`GenericStringArray`] with each element containing the leftmost-first
/// match of the corresponding index in `regex_array` to string in `array`
///
/// If there is no match, the list element is NULL.
///
/// If a match is found, and the pattern contains no capturing parenthesized subexpressions,
/// then the list element is a single-element [`GenericStringArray`] containing the substring
/// matching the whole pattern.
///
/// If a match is found, and the pattern contains capturing parenthesized subexpressions, then the
/// list element is a [`GenericStringArray`] whose n'th element is the substring matching
/// the n'th capturing parenthesized subexpression of the pattern.
///
/// The flags parameter is an optional text string containing zero or more single-letter flags
/// that change the function's behavior.
///
/// [regexp_match]: https://www.postgresql.org/docs/current/functions-matching.html#FUNCTIONS-POSIX-REGEXP
pub fn regexp_match<OffsetSize: OffsetSizeTrait>(
fn regexp_array_match<OffsetSize: OffsetSizeTrait>(
array: &GenericStringArray<OffsetSize>,
regex_array: &GenericStringArray<OffsetSize>,
flags_array: Option<&GenericStringArray<OffsetSize>>,
Expand Down Expand Up @@ -248,6 +228,179 @@ pub fn regexp_match<OffsetSize: OffsetSizeTrait>(
Ok(Arc::new(list_builder.finish()))
}

fn get_scalar_pattern_flag<'a, OffsetSize: OffsetSizeTrait>(
regex_array: &'a dyn Array,
flag_array: Option<&'a dyn Array>,
) -> (Option<&'a str>, Option<&'a str>) {
let regex = regex_array.as_string::<OffsetSize>();
let regex = regex.is_valid(0).then(|| regex.value(0));

if let Some(flag_array) = flag_array {
let flag = flag_array.as_string::<OffsetSize>();
(regex, flag.is_valid(0).then(|| flag.value(0)))
} else {
(regex, None)
}
}

fn regexp_scalar_match<OffsetSize: OffsetSizeTrait>(
array: &GenericStringArray<OffsetSize>,
regex: &Regex,
) -> Result<ArrayRef, ArrowError> {
let builder: GenericStringBuilder<OffsetSize> = GenericStringBuilder::with_capacity(0, 0);
let mut list_builder = ListBuilder::new(builder);

array
.iter()
.map(|value| {
match value {
// Required for Postgres compatibility:
// SELECT regexp_match('foobarbequebaz', ''); = {""}
Some(_) if regex.as_str() == "" => {
list_builder.values().append_value("");
list_builder.append(true);
}
Some(value) => match regex.captures(value) {
Some(caps) => {
let mut iter = caps.iter();
if caps.len() > 1 {
iter.next();
}
for m in iter.flatten() {
list_builder.values().append_value(m.as_str());
}

list_builder.append(true);
}
None => list_builder.append(false),
},
_ => list_builder.append(false),
}
Ok(())
})
.collect::<Result<Vec<()>, ArrowError>>()?;

Ok(Arc::new(list_builder.finish()))
}

/// Extract all groups matched by a regular expression for a given String array.
///
/// Modelled after the Postgres [regexp_match].
///
/// Returns a ListArray of [`GenericStringArray`] with each element containing the leftmost-first
/// match of the corresponding index in `regex_array` to string in `array`
///
/// If there is no match, the list element is NULL.
///
/// If a match is found, and the pattern contains no capturing parenthesized subexpressions,
/// then the list element is a single-element [`GenericStringArray`] containing the substring
/// matching the whole pattern.
///
/// If a match is found, and the pattern contains capturing parenthesized subexpressions, then the
/// list element is a [`GenericStringArray`] whose n'th element is the substring matching
/// the n'th capturing parenthesized subexpression of the pattern.
///
/// The flags parameter is an optional text string containing zero or more single-letter flags
/// that change the function's behavior.
///
/// [regexp_match]: https://www.postgresql.org/docs/current/functions-matching.html#FUNCTIONS-POSIX-REGEXP
pub fn regexp_match(
array: &dyn Array,
regex_array: &dyn Datum,
flags_array: Option<&dyn Datum>,
) -> Result<ArrayRef, ArrowError> {
let (rhs, is_rhs_scalar) = regex_array.get();

if array.data_type() != rhs.data_type() {
return Err(ArrowError::ComputeError(
"regexp_match() requires both array and pattern to be either Utf8 or LargeUtf8"
.to_string(),
));
}

let (flags, is_flags_scalar) = match flags_array {
Some(flags) => {
let (flags, is_flags_scalar) = flags.get();
(Some(flags), Some(is_flags_scalar))
}
None => (None, None),
};

if is_flags_scalar.is_some() && is_rhs_scalar != is_flags_scalar.unwrap() {
return Err(ArrowError::ComputeError(
"regexp_match() requires both pattern and flags to be either scalar or array"
.to_string(),
));
}

if flags_array.is_some() && rhs.data_type() != flags.unwrap().data_type() {
return Err(ArrowError::ComputeError(
"regexp_match() requires both pattern and flags to be either string or largestring"
.to_string(),
));
}

if is_rhs_scalar {
// Regex and flag is scalars
let (regex, flag) = match rhs.data_type() {
DataType::Utf8 => get_scalar_pattern_flag::<i32>(rhs, flags),
DataType::LargeUtf8 => get_scalar_pattern_flag::<i64>(rhs, flags),
_ => {
return Err(ArrowError::ComputeError(
"regexp_match() requires pattern to be either Utf8 or LargeUtf8".to_string(),
));
}
};

if regex.is_none() {
return Ok(new_null_array(
&DataType::List(Arc::new(Field::new(
"item",
array.data_type().clone(),
true,
))),
array.len(),
));
}

let regex = regex.unwrap();

let pattern = if let Some(flag) = flag {
format!("(?{flag}){regex}")
} else {
regex.to_string()
};

let re = Regex::new(pattern.as_str()).map_err(|e| {
ArrowError::ComputeError(format!("Regular expression did not compile: {e:?}"))
})?;

match array.data_type() {
DataType::Utf8 => regexp_scalar_match(array.as_string::<i32>(), &re),
DataType::LargeUtf8 => regexp_scalar_match(array.as_string::<i64>(), &re),
_ => Err(ArrowError::ComputeError(
"regexp_match() requires array to be either Utf8 or LargeUtf8".to_string(),
)),
}
} else {
match array.data_type() {
DataType::Utf8 => {
let regex_array = rhs.as_string();
let flags_array = flags.map(|flags| flags.as_string());
regexp_array_match(array.as_string::<i32>(), regex_array, flags_array)
}
DataType::LargeUtf8 => {
let regex_array = rhs.as_string();
let flags_array = flags.map(|flags| flags.as_string());
regexp_array_match(array.as_string::<i64>(), regex_array, flags_array)
}
_ => Err(ArrowError::ComputeError(
"regexp_match() requires array to be either Utf8 or LargeUtf8".to_string(),
)),
}
}
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down Expand Up @@ -304,6 +457,49 @@ mod tests {
assert_eq!(&expected, result);
}

#[test]
fn match_scalar_pattern() {
let values = vec![Some("abc-005-def"), Some("X-7-5"), Some("X545"), None];
let array = StringArray::from(values);
let pattern = Scalar::new(StringArray::from(vec![r"x.*-(\d*)-.*"; 1]));
let flags = Scalar::new(StringArray::from(vec!["i"; 1]));
let actual = regexp_match(&array, &pattern, Some(&flags)).unwrap();
let elem_builder: GenericStringBuilder<i32> = GenericStringBuilder::with_capacity(0, 0);
let mut expected_builder = ListBuilder::new(elem_builder);
expected_builder.append(false);
expected_builder.values().append_value("7");
expected_builder.append(true);
expected_builder.append(false);
expected_builder.append(false);
let expected = expected_builder.finish();
let result = actual.as_any().downcast_ref::<ListArray>().unwrap();
assert_eq!(&expected, result);

// No flag
let values = vec![Some("abc-005-def"), Some("x-7-5"), Some("X545"), None];
let array = StringArray::from(values);
let actual = regexp_match(&array, &pattern, None).unwrap();
let result = actual.as_any().downcast_ref::<ListArray>().unwrap();
assert_eq!(&expected, result);
}

#[test]
fn match_scalar_no_pattern() {
let values = vec![Some("abc-005-def"), Some("X-7-5"), Some("X545"), None];
let array = StringArray::from(values);
let pattern = Scalar::new(new_null_array(&DataType::Utf8, 1));
let actual = regexp_match(&array, &pattern, None).unwrap();
let elem_builder: GenericStringBuilder<i32> = GenericStringBuilder::with_capacity(0, 0);
let mut expected_builder = ListBuilder::new(elem_builder);
expected_builder.append(false);
expected_builder.append(false);
expected_builder.append(false);
expected_builder.append(false);
let expected = expected_builder.finish();
let result = actual.as_any().downcast_ref::<ListArray>().unwrap();
assert_eq!(&expected, result);
}

#[test]
fn test_single_group_not_skip_match() {
let array = StringArray::from(vec![Some("foo"), Some("bar")]);
Expand Down
9 changes: 8 additions & 1 deletion arrow/benches/regexp_kernels.rs
Expand Up @@ -25,7 +25,7 @@ use arrow::array::*;
use arrow::compute::kernels::regexp::*;
use arrow::util::bench_util::*;

fn bench_regexp(arr: &GenericStringArray<i32>, regex_array: &GenericStringArray<i32>) {
fn bench_regexp(arr: &GenericStringArray<i32>, regex_array: &dyn Datum) {
regexp_match(criterion::black_box(arr), regex_array, None).unwrap();
}

Expand All @@ -38,6 +38,13 @@ fn add_benchmark(c: &mut Criterion) {
let pattern = GenericStringArray::<i32>::from(pattern_values);

c.bench_function("regexp", |b| b.iter(|| bench_regexp(&arr_string, &pattern)));

let pattern_values = vec![r".*-(\d*)-.*"];
let pattern = Scalar::new(GenericStringArray::<i32>::from(pattern_values));

c.bench_function("regexp scalar", |b| {
b.iter(|| bench_regexp(&arr_string, &pattern))
});
}

criterion_group!(benches, add_benchmark);
Expand Down