Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make regexp_match take scalar pattern and flag #5245

Merged
merged 8 commits into from
Jan 1, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
248 changes: 225 additions & 23 deletions arrow-string/src/regexp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ use arrow_array::builder::{BooleanBufferBuilder, GenericStringBuilder, ListBuild
use arrow_array::*;
use arrow_buffer::NullBuffer;
use arrow_data::{ArrayData, ArrayDataBuilder};
use arrow_schema::{ArrowError, DataType};
use arrow_schema::{ArrowError, DataType, Field};
use regex::Regex;
use std::collections::HashMap;
use std::sync::Arc;
Expand Down Expand Up @@ -152,28 +152,7 @@ pub fn regexp_is_match_utf8_scalar<OffsetSize: OffsetSizeTrait>(
Ok(BooleanArray::from(data))
}

/// Extract all groups matched by a regular expression for a given String array.
///
/// Modelled after the Postgres [regexp_match].
///
/// Returns a ListArray of [`GenericStringArray`] with each element containing the leftmost-first
/// match of the corresponding index in `regex_array` to string in `array`
///
/// If there is no match, the list element is NULL.
///
/// If a match is found, and the pattern contains no capturing parenthesized subexpressions,
/// then the list element is a single-element [`GenericStringArray`] containing the substring
/// matching the whole pattern.
///
/// If a match is found, and the pattern contains capturing parenthesized subexpressions, then the
/// list element is a [`GenericStringArray`] whose n'th element is the substring matching
/// the n'th capturing parenthesized subexpression of the pattern.
///
/// The flags parameter is an optional text string containing zero or more single-letter flags
/// that change the function's behavior.
///
/// [regexp_match]: https://www.postgresql.org/docs/current/functions-matching.html#FUNCTIONS-POSIX-REGEXP
pub fn regexp_match<OffsetSize: OffsetSizeTrait>(
pub fn regexp_array_match<OffsetSize: OffsetSizeTrait>(
array: &GenericStringArray<OffsetSize>,
regex_array: &GenericStringArray<OffsetSize>,
flags_array: Option<&GenericStringArray<OffsetSize>>,
Expand Down Expand Up @@ -248,6 +227,186 @@ pub fn regexp_match<OffsetSize: OffsetSizeTrait>(
Ok(Arc::new(list_builder.finish()))
}

fn get_scalar_pattern_flag<'a, OffsetSize: OffsetSizeTrait>(
regex_array: &'a dyn Array,
flag_array: Option<&'a dyn Array>,
) -> (Option<&'a str>, Option<&'a str>) {
let regex = regex_array
viirya marked this conversation as resolved.
Show resolved Hide resolved
.as_any()
.downcast_ref::<GenericStringArray<OffsetSize>>()
.expect("Unable to downcast to StringArray/LargeStringArray");
let regex = if regex.is_valid(0) {
viirya marked this conversation as resolved.
Show resolved Hide resolved
Some(regex.value(0))
} else {
None
};

if flag_array.is_some() {
viirya marked this conversation as resolved.
Show resolved Hide resolved
let flag = flag_array
.unwrap()
.as_any()
.downcast_ref::<GenericStringArray<OffsetSize>>()
viirya marked this conversation as resolved.
Show resolved Hide resolved
.expect("Unable to downcast to StringArray/LargeStringArray");

if flag.is_valid(0) {
let flag = flag.value(0);
(regex, Some(flag))
} else {
(regex, None)
}
} else {
(regex, None)
}
}

fn regexp_scalar_match<OffsetSize: OffsetSizeTrait>(
array: &GenericStringArray<OffsetSize>,
regex: &Regex,
) -> std::result::Result<ArrayRef, ArrowError> {
viirya marked this conversation as resolved.
Show resolved Hide resolved
let builder: GenericStringBuilder<OffsetSize> = GenericStringBuilder::with_capacity(0, 0);
let mut list_builder = ListBuilder::new(builder);

array
.iter()
.map(|value| {
match value {
// Required for Postgres compatibility:
// SELECT regexp_match('foobarbequebaz', ''); = {""}
Some(_) if regex.as_str() == "" => {
list_builder.values().append_value("");
list_builder.append(true);
}
Some(value) => match regex.captures(value) {
Some(caps) => {
let mut iter = caps.iter();
if caps.len() > 1 {
iter.next();
}
for m in iter.flatten() {
list_builder.values().append_value(m.as_str());
}

list_builder.append(true);
}
None => list_builder.append(false),
},
_ => list_builder.append(false),
}
Ok(())
})
.collect::<Result<Vec<()>, ArrowError>>()?;

Ok(Arc::new(list_builder.finish()))
}

/// Extract all groups matched by a regular expression for a given String array.
///
/// Modelled after the Postgres [regexp_match].
///
/// Returns a ListArray of [`GenericStringArray`] with each element containing the leftmost-first
/// match of the corresponding index in `regex_array` to string in `array`
///
/// If there is no match, the list element is NULL.
///
/// If a match is found, and the pattern contains no capturing parenthesized subexpressions,
/// then the list element is a single-element [`GenericStringArray`] containing the substring
/// matching the whole pattern.
///
/// If a match is found, and the pattern contains capturing parenthesized subexpressions, then the
/// list element is a [`GenericStringArray`] whose n'th element is the substring matching
/// the n'th capturing parenthesized subexpression of the pattern.
///
/// The flags parameter is an optional text string containing zero or more single-letter flags
/// that change the function's behavior.
///
/// [regexp_match]: https://www.postgresql.org/docs/current/functions-matching.html#FUNCTIONS-POSIX-REGEXP
pub fn regexp_match<OffsetSize: OffsetSizeTrait>(
array: &GenericStringArray<OffsetSize>,
viirya marked this conversation as resolved.
Show resolved Hide resolved
regex_array: &dyn Datum,
flags_array: Option<&dyn Datum>,
) -> std::result::Result<ArrayRef, ArrowError> {
let (rhs, is_rhs_scalar) = regex_array.get();

if array.data_type() != rhs.data_type() {
return Err(ArrowError::ComputeError(
"regexp_match() requires both array and pattern to be either Utf8 or LargeUtf8"
.to_string(),
));
}

let (flags, is_flags_scalar) = match flags_array {
Some(flags) => {
let (flags, is_flags_scalar) = flags.get();
(Some(flags), Some(is_flags_scalar))
}
None => (None, None),
};

if is_flags_scalar.is_some() && is_rhs_scalar != is_flags_scalar.unwrap() {
return Err(ArrowError::ComputeError(
"regexp_match() requires both pattern and flags to be either scalar or array"
.to_string(),
));
}

if flags_array.is_some() && rhs.data_type() != flags.unwrap().data_type() {
return Err(ArrowError::ComputeError(
"regexp_match() requires both pattern and flags to be either string or largestring"
.to_string(),
));
}

if is_rhs_scalar {
// Regex and flag is scalars
let (regex, flag) = match rhs.data_type() {
DataType::Utf8 => get_scalar_pattern_flag::<i32>(rhs, flags),
DataType::LargeUtf8 => get_scalar_pattern_flag::<i64>(rhs, flags),
_ => {
return Err(ArrowError::ComputeError(
"regexp_match() requires pattern to be either Utf8 or LargeUtf8".to_string(),
));
}
};

if regex.is_none() {
return Ok(new_null_array(
&DataType::List(Arc::new(Field::new(
"item",
array.data_type().clone(),
true,
))),
array.len(),
));
}

let regex = regex.unwrap();

let pattern = if let Some(flag) = flag {
format!("(?{flag}){regex}")
} else {
regex.to_string()
};

let re = Regex::new(pattern.as_str()).map_err(|e| {
ArrowError::ComputeError(format!("Regular expression did not compile: {e:?}"))
})?;

regexp_scalar_match(array, &re)
} else {
let regex_array = rhs
.as_any()
.downcast_ref::<GenericStringArray<OffsetSize>>()
.expect("Unable to downcast to StringArray/LargeStringArray");
let flags_array = flags.map(|flags| {
flags
.as_any()
.downcast_ref::<GenericStringArray<OffsetSize>>()
.expect("Unable to downcast to StringArray/LargeStringArray")
});
regexp_array_match(array, regex_array, flags_array)
}
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down Expand Up @@ -304,6 +463,49 @@ mod tests {
assert_eq!(&expected, result);
}

#[test]
fn match_scalar_pattern() {
let values = vec![Some("abc-005-def"), Some("X-7-5"), Some("X545"), None];
let array = StringArray::from(values);
let pattern = Scalar::new(StringArray::from(vec![r"x.*-(\d*)-.*"; 1]));
let flags = Scalar::new(StringArray::from(vec!["i"; 1]));
let actual = regexp_match(&array, &pattern, Some(&flags)).unwrap();
let elem_builder: GenericStringBuilder<i32> = GenericStringBuilder::with_capacity(0, 0);
let mut expected_builder = ListBuilder::new(elem_builder);
expected_builder.append(false);
expected_builder.values().append_value("7");
expected_builder.append(true);
expected_builder.append(false);
expected_builder.append(false);
let expected = expected_builder.finish();
let result = actual.as_any().downcast_ref::<ListArray>().unwrap();
assert_eq!(&expected, result);

// No flag
let values = vec![Some("abc-005-def"), Some("x-7-5"), Some("X545"), None];
let array = StringArray::from(values);
let actual = regexp_match(&array, &pattern, None).unwrap();
let result = actual.as_any().downcast_ref::<ListArray>().unwrap();
assert_eq!(&expected, result);
}

#[test]
fn match_scalar_no_pattern() {
let values = vec![Some("abc-005-def"), Some("X-7-5"), Some("X545"), None];
let array = StringArray::from(values);
let pattern = Scalar::new(new_null_array(&DataType::Utf8, 1));
let actual = regexp_match(&array, &pattern, None).unwrap();
let elem_builder: GenericStringBuilder<i32> = GenericStringBuilder::with_capacity(0, 0);
let mut expected_builder = ListBuilder::new(elem_builder);
expected_builder.append(false);
expected_builder.append(false);
expected_builder.append(false);
expected_builder.append(false);
let expected = expected_builder.finish();
let result = actual.as_any().downcast_ref::<ListArray>().unwrap();
assert_eq!(&expected, result);
}

#[test]
fn test_single_group_not_skip_match() {
let array = StringArray::from(vec![Some("foo"), Some("bar")]);
Expand Down
9 changes: 8 additions & 1 deletion arrow/benches/regexp_kernels.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ use arrow::array::*;
use arrow::compute::kernels::regexp::*;
use arrow::util::bench_util::*;

fn bench_regexp(arr: &GenericStringArray<i32>, regex_array: &GenericStringArray<i32>) {
fn bench_regexp(arr: &GenericStringArray<i32>, regex_array: &dyn Datum) {
regexp_match(criterion::black_box(arr), regex_array, None).unwrap();
}

Expand All @@ -38,6 +38,13 @@ fn add_benchmark(c: &mut Criterion) {
let pattern = GenericStringArray::<i32>::from(pattern_values);

c.bench_function("regexp", |b| b.iter(|| bench_regexp(&arr_string, &pattern)));

let pattern_values = vec![r".*-(\d*)-.*"];
let pattern = Scalar::new(GenericStringArray::<i32>::from(pattern_values));

c.bench_function("regexp scalar", |b| {
b.iter(|| bench_regexp(&arr_string, &pattern))
});
}

criterion_group!(benches, add_benchmark);
Expand Down