Skip to content

Commit

Permalink
Make regexp_match take scalar pattern and flag (#5245)
Browse files Browse the repository at this point in the history
* Make regexp_match take Datum pattern input

* Add more tests

* More

* Update benchmark

* Fix clippy

* For review

* Fix clippy

* Don't expose utility function
  • Loading branch information
viirya committed Jan 1, 2024
1 parent b00f4e0 commit e6395e2
Show file tree
Hide file tree
Showing 2 changed files with 227 additions and 24 deletions.
242 changes: 219 additions & 23 deletions arrow-string/src/regexp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,11 @@
//! expression of a \[Large\]StringArray

use arrow_array::builder::{BooleanBufferBuilder, GenericStringBuilder, ListBuilder};
use arrow_array::cast::AsArray;
use arrow_array::*;
use arrow_buffer::NullBuffer;
use arrow_data::{ArrayData, ArrayDataBuilder};
use arrow_schema::{ArrowError, DataType};
use arrow_schema::{ArrowError, DataType, Field};
use regex::Regex;
use std::collections::HashMap;
use std::sync::Arc;
Expand Down Expand Up @@ -152,28 +153,7 @@ pub fn regexp_is_match_utf8_scalar<OffsetSize: OffsetSizeTrait>(
Ok(BooleanArray::from(data))
}

/// Extract all groups matched by a regular expression for a given String array.
///
/// Modelled after the Postgres [regexp_match].
///
/// Returns a ListArray of [`GenericStringArray`] with each element containing the leftmost-first
/// match of the corresponding index in `regex_array` to string in `array`
///
/// If there is no match, the list element is NULL.
///
/// If a match is found, and the pattern contains no capturing parenthesized subexpressions,
/// then the list element is a single-element [`GenericStringArray`] containing the substring
/// matching the whole pattern.
///
/// If a match is found, and the pattern contains capturing parenthesized subexpressions, then the
/// list element is a [`GenericStringArray`] whose n'th element is the substring matching
/// the n'th capturing parenthesized subexpression of the pattern.
///
/// The flags parameter is an optional text string containing zero or more single-letter flags
/// that change the function's behavior.
///
/// [regexp_match]: https://www.postgresql.org/docs/current/functions-matching.html#FUNCTIONS-POSIX-REGEXP
pub fn regexp_match<OffsetSize: OffsetSizeTrait>(
fn regexp_array_match<OffsetSize: OffsetSizeTrait>(
array: &GenericStringArray<OffsetSize>,
regex_array: &GenericStringArray<OffsetSize>,
flags_array: Option<&GenericStringArray<OffsetSize>>,
Expand Down Expand Up @@ -248,6 +228,179 @@ pub fn regexp_match<OffsetSize: OffsetSizeTrait>(
Ok(Arc::new(list_builder.finish()))
}

fn get_scalar_pattern_flag<'a, OffsetSize: OffsetSizeTrait>(
regex_array: &'a dyn Array,
flag_array: Option<&'a dyn Array>,
) -> (Option<&'a str>, Option<&'a str>) {
let regex = regex_array.as_string::<OffsetSize>();
let regex = regex.is_valid(0).then(|| regex.value(0));

if let Some(flag_array) = flag_array {
let flag = flag_array.as_string::<OffsetSize>();
(regex, flag.is_valid(0).then(|| flag.value(0)))
} else {
(regex, None)
}
}

fn regexp_scalar_match<OffsetSize: OffsetSizeTrait>(
array: &GenericStringArray<OffsetSize>,
regex: &Regex,
) -> Result<ArrayRef, ArrowError> {
let builder: GenericStringBuilder<OffsetSize> = GenericStringBuilder::with_capacity(0, 0);
let mut list_builder = ListBuilder::new(builder);

array
.iter()
.map(|value| {
match value {
// Required for Postgres compatibility:
// SELECT regexp_match('foobarbequebaz', ''); = {""}
Some(_) if regex.as_str() == "" => {
list_builder.values().append_value("");
list_builder.append(true);
}
Some(value) => match regex.captures(value) {
Some(caps) => {
let mut iter = caps.iter();
if caps.len() > 1 {
iter.next();
}
for m in iter.flatten() {
list_builder.values().append_value(m.as_str());
}

list_builder.append(true);
}
None => list_builder.append(false),
},
_ => list_builder.append(false),
}
Ok(())
})
.collect::<Result<Vec<()>, ArrowError>>()?;

Ok(Arc::new(list_builder.finish()))
}

/// Extract all groups matched by a regular expression for a given String array.
///
/// Modelled after the Postgres [regexp_match].
///
/// Returns a ListArray of [`GenericStringArray`] with each element containing the leftmost-first
/// match of the corresponding index in `regex_array` to string in `array`
///
/// If there is no match, the list element is NULL.
///
/// If a match is found, and the pattern contains no capturing parenthesized subexpressions,
/// then the list element is a single-element [`GenericStringArray`] containing the substring
/// matching the whole pattern.
///
/// If a match is found, and the pattern contains capturing parenthesized subexpressions, then the
/// list element is a [`GenericStringArray`] whose n'th element is the substring matching
/// the n'th capturing parenthesized subexpression of the pattern.
///
/// The flags parameter is an optional text string containing zero or more single-letter flags
/// that change the function's behavior.
///
/// [regexp_match]: https://www.postgresql.org/docs/current/functions-matching.html#FUNCTIONS-POSIX-REGEXP
pub fn regexp_match(
array: &dyn Array,
regex_array: &dyn Datum,
flags_array: Option<&dyn Datum>,
) -> Result<ArrayRef, ArrowError> {
let (rhs, is_rhs_scalar) = regex_array.get();

if array.data_type() != rhs.data_type() {
return Err(ArrowError::ComputeError(
"regexp_match() requires both array and pattern to be either Utf8 or LargeUtf8"
.to_string(),
));
}

let (flags, is_flags_scalar) = match flags_array {
Some(flags) => {
let (flags, is_flags_scalar) = flags.get();
(Some(flags), Some(is_flags_scalar))
}
None => (None, None),
};

if is_flags_scalar.is_some() && is_rhs_scalar != is_flags_scalar.unwrap() {
return Err(ArrowError::ComputeError(
"regexp_match() requires both pattern and flags to be either scalar or array"
.to_string(),
));
}

if flags_array.is_some() && rhs.data_type() != flags.unwrap().data_type() {
return Err(ArrowError::ComputeError(
"regexp_match() requires both pattern and flags to be either string or largestring"
.to_string(),
));
}

if is_rhs_scalar {
// Regex and flag is scalars
let (regex, flag) = match rhs.data_type() {
DataType::Utf8 => get_scalar_pattern_flag::<i32>(rhs, flags),
DataType::LargeUtf8 => get_scalar_pattern_flag::<i64>(rhs, flags),
_ => {
return Err(ArrowError::ComputeError(
"regexp_match() requires pattern to be either Utf8 or LargeUtf8".to_string(),
));
}
};

if regex.is_none() {
return Ok(new_null_array(
&DataType::List(Arc::new(Field::new(
"item",
array.data_type().clone(),
true,
))),
array.len(),
));
}

let regex = regex.unwrap();

let pattern = if let Some(flag) = flag {
format!("(?{flag}){regex}")
} else {
regex.to_string()
};

let re = Regex::new(pattern.as_str()).map_err(|e| {
ArrowError::ComputeError(format!("Regular expression did not compile: {e:?}"))
})?;

match array.data_type() {
DataType::Utf8 => regexp_scalar_match(array.as_string::<i32>(), &re),
DataType::LargeUtf8 => regexp_scalar_match(array.as_string::<i64>(), &re),
_ => Err(ArrowError::ComputeError(
"regexp_match() requires array to be either Utf8 or LargeUtf8".to_string(),
)),
}
} else {
match array.data_type() {
DataType::Utf8 => {
let regex_array = rhs.as_string();
let flags_array = flags.map(|flags| flags.as_string());
regexp_array_match(array.as_string::<i32>(), regex_array, flags_array)
}
DataType::LargeUtf8 => {
let regex_array = rhs.as_string();
let flags_array = flags.map(|flags| flags.as_string());
regexp_array_match(array.as_string::<i64>(), regex_array, flags_array)
}
_ => Err(ArrowError::ComputeError(
"regexp_match() requires array to be either Utf8 or LargeUtf8".to_string(),
)),
}
}
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down Expand Up @@ -304,6 +457,49 @@ mod tests {
assert_eq!(&expected, result);
}

#[test]
fn match_scalar_pattern() {
let values = vec![Some("abc-005-def"), Some("X-7-5"), Some("X545"), None];
let array = StringArray::from(values);
let pattern = Scalar::new(StringArray::from(vec![r"x.*-(\d*)-.*"; 1]));
let flags = Scalar::new(StringArray::from(vec!["i"; 1]));
let actual = regexp_match(&array, &pattern, Some(&flags)).unwrap();
let elem_builder: GenericStringBuilder<i32> = GenericStringBuilder::with_capacity(0, 0);
let mut expected_builder = ListBuilder::new(elem_builder);
expected_builder.append(false);
expected_builder.values().append_value("7");
expected_builder.append(true);
expected_builder.append(false);
expected_builder.append(false);
let expected = expected_builder.finish();
let result = actual.as_any().downcast_ref::<ListArray>().unwrap();
assert_eq!(&expected, result);

// No flag
let values = vec![Some("abc-005-def"), Some("x-7-5"), Some("X545"), None];
let array = StringArray::from(values);
let actual = regexp_match(&array, &pattern, None).unwrap();
let result = actual.as_any().downcast_ref::<ListArray>().unwrap();
assert_eq!(&expected, result);
}

#[test]
fn match_scalar_no_pattern() {
let values = vec![Some("abc-005-def"), Some("X-7-5"), Some("X545"), None];
let array = StringArray::from(values);
let pattern = Scalar::new(new_null_array(&DataType::Utf8, 1));
let actual = regexp_match(&array, &pattern, None).unwrap();
let elem_builder: GenericStringBuilder<i32> = GenericStringBuilder::with_capacity(0, 0);
let mut expected_builder = ListBuilder::new(elem_builder);
expected_builder.append(false);
expected_builder.append(false);
expected_builder.append(false);
expected_builder.append(false);
let expected = expected_builder.finish();
let result = actual.as_any().downcast_ref::<ListArray>().unwrap();
assert_eq!(&expected, result);
}

#[test]
fn test_single_group_not_skip_match() {
let array = StringArray::from(vec![Some("foo"), Some("bar")]);
Expand Down
9 changes: 8 additions & 1 deletion arrow/benches/regexp_kernels.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ use arrow::array::*;
use arrow::compute::kernels::regexp::*;
use arrow::util::bench_util::*;

fn bench_regexp(arr: &GenericStringArray<i32>, regex_array: &GenericStringArray<i32>) {
fn bench_regexp(arr: &GenericStringArray<i32>, regex_array: &dyn Datum) {
regexp_match(criterion::black_box(arr), regex_array, None).unwrap();
}

Expand All @@ -38,6 +38,13 @@ fn add_benchmark(c: &mut Criterion) {
let pattern = GenericStringArray::<i32>::from(pattern_values);

c.bench_function("regexp", |b| b.iter(|| bench_regexp(&arr_string, &pattern)));

let pattern_values = vec![r".*-(\d*)-.*"];
let pattern = Scalar::new(GenericStringArray::<i32>::from(pattern_values));

c.bench_function("regexp scalar", |b| {
b.iter(|| bench_regexp(&arr_string, &pattern))
});
}

criterion_group!(benches, add_benchmark);
Expand Down

0 comments on commit e6395e2

Please sign in to comment.