Skip to content

Commit

Permalink
feat: cast (Large)List to FixedSizeList (#5081)
Browse files Browse the repository at this point in the history
* feat: cast (Large)List to FixedSizeList

* fix: support 'safe' casting of list to FSL

* fix: if target is non-null, use non-null sentinel value

* Use MutableArrayData

* Docs

---------

Co-authored-by: Raphael Taylor-Davies <r.taylordavies@googlemail.com>
  • Loading branch information
wjones127 and tustvold committed Nov 17, 2023
1 parent bfe396e commit dc75a28
Showing 1 changed file with 263 additions and 1 deletion.
264 changes: 263 additions & 1 deletion arrow-cast/src/cast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ use crate::parse::{
};
use arrow_array::{builder::*, cast::*, temporal_conversions::*, timezone::Tz, types::*, *};
use arrow_buffer::{i256, ArrowNativeType, Buffer, OffsetBuffer};
use arrow_data::transform::MutableArrayData;
use arrow_data::ArrayData;
use arrow_schema::*;
use arrow_select::take::take;
Expand Down Expand Up @@ -138,6 +139,9 @@ pub fn can_cast_types(from_type: &DataType, to_type: &DataType) -> bool {
(List(list_from) | LargeList(list_from), Utf8 | LargeUtf8) => {
can_cast_types(list_from.data_type(), to_type)
}
(List(list_from) | LargeList(list_from), FixedSizeList(list_to, _)) => {
can_cast_types(list_from.data_type(), list_to.data_type())
}
(List(_), _) => false,
(FixedSizeList(list_from,_), List(list_to)) => {
list_from.data_type() == list_to.data_type()
Expand Down Expand Up @@ -279,6 +283,8 @@ pub fn can_cast_types(from_type: &DataType, to_type: &DataType) -> bool {
/// in integer casts return null
/// * Numeric to boolean: 0 returns `false`, any other value returns `true`
/// * List to List: the underlying data type is cast
/// * List to FixedSizeList: the underlying data type is cast. If safe is true and a list element
/// has the wrong length it will be replaced with NULL, otherwise an error will be returned
/// * Primitive to List: a list array with 1 value per slot is created
/// * Date32 and Date64: precision lost when going to higher interval
/// * Time32 and Time64: precision lost when going to higher interval
Expand Down Expand Up @@ -799,6 +805,14 @@ pub fn cast_with_options(
cast_list_container::<i64, i32>(array, cast_options)
}
}
(List(_), FixedSizeList(field, size)) => {
let array = array.as_list::<i32>();
cast_list_to_fixed_size_list::<i32>(array, field, *size, cast_options)
}
(LargeList(_), FixedSizeList(field, size)) => {
let array = array.as_list::<i64>();
cast_list_to_fixed_size_list::<i64>(array, field, *size, cast_options)
}
(List(_) | LargeList(_), _) => match to_type {
Utf8 => value_to_string::<i32>(array, cast_options),
LargeUtf8 => value_to_string::<i64>(array, cast_options),
Expand All @@ -824,7 +838,6 @@ pub fn cast_with_options(
cast_fixed_size_list_to_list::<i64>(array)
}
}

(_, List(ref to)) => cast_values_to_list::<i32>(array, to, cast_options),
(_, LargeList(ref to)) => cast_values_to_list::<i64>(array, to, cast_options),
(Decimal128(_, s1), Decimal128(p2, s2)) => {
Expand Down Expand Up @@ -3206,6 +3219,76 @@ where
Ok(Arc::new(list))
}

fn cast_list_to_fixed_size_list<OffsetSize>(
array: &GenericListArray<OffsetSize>,
field: &Arc<Field>,
size: i32,
cast_options: &CastOptions,
) -> Result<ArrayRef, ArrowError>
where
OffsetSize: OffsetSizeTrait,
{
let cap = array.len() * size as usize;

let mut nulls = (cast_options.safe || array.null_count() != 0).then(|| {
let mut buffer = BooleanBufferBuilder::new(array.len());
match array.nulls() {
Some(n) => buffer.append_buffer(n.inner()),
None => buffer.append_n(array.len(), true),
}
buffer
});

// Nulls in FixedSizeListArray take up space and so we must pad the values
let values = array.values().to_data();
let mut mutable = MutableArrayData::new(vec![&values], cast_options.safe, cap);
// The end position in values of the last incorrectly-sized list slice
let mut last_pos = 0;
for (idx, w) in array.offsets().windows(2).enumerate() {
let start_pos = w[0].as_usize();
let end_pos = w[1].as_usize();
let len = end_pos - start_pos;

if len != size as usize {
if cast_options.safe || array.is_null(idx) {
if last_pos != start_pos {
// Extend with valid slices
mutable.extend(0, last_pos, start_pos);
}
// Pad this slice with nulls
mutable.extend_nulls(size as _);
nulls.as_mut().unwrap().set_bit(idx, false);
// Set last_pos to the end of this slice's values
last_pos = end_pos
} else {
return Err(ArrowError::CastError(format!(
"Cannot cast to FixedSizeList({size}): value at index {idx} has length {len}",
)));
}
}
}

let values = match last_pos {
0 => array.values().slice(0, cap), // All slices were the correct length
_ => {
if mutable.len() != cap {
// Remaining slices were all correct length
let remaining = cap - mutable.len();
mutable.extend(0, last_pos, last_pos + remaining)
}
make_array(mutable.freeze())
}
};

// Cast the inner values if necessary
let values = cast_with_options(values.as_ref(), field.data_type(), cast_options)?;

// Construct the FixedSizeListArray
let nulls = nulls.map(|mut x| x.finish().into());
let array = FixedSizeListArray::new(field.clone(), size, values, nulls);
Ok(Arc::new(array))
}

/// Cast the container type of List/Largelist array but not the inner types.
/// This function can leave the value data intact and only has to cast the offset dtypes.
fn cast_list_container<OffsetSizeFrom, OffsetSizeTo>(
Expand Down Expand Up @@ -3274,6 +3357,8 @@ where

#[cfg(test)]
mod tests {
use arrow_buffer::NullBuffer;

use super::*;

macro_rules! generate_cast_test_case {
Expand Down Expand Up @@ -7374,6 +7459,183 @@ mod tests {
assert_eq!(&expected.value(2), &actual.value(2));
}

#[test]
fn test_cast_list_to_fsl() {
// There four noteworthy cases we should handle:
// 1. No nulls
// 2. Nulls that are always empty
// 3. Nulls that have varying lengths
// 4. Nulls that are correctly sized (same as target list size)

// Non-null case
let field = Arc::new(Field::new("item", DataType::Int32, true));
let values = vec![
Some(vec![Some(1), Some(2), Some(3)]),
Some(vec![Some(4), Some(5), Some(6)]),
];
let array = Arc::new(ListArray::from_iter_primitive::<Int32Type, _, _>(
values.clone(),
)) as ArrayRef;
let expected = Arc::new(FixedSizeListArray::from_iter_primitive::<Int32Type, _, _>(
values, 3,
)) as ArrayRef;
let actual = cast(array.as_ref(), &DataType::FixedSizeList(field.clone(), 3)).unwrap();
assert_eq!(expected.as_ref(), actual.as_ref());

// Null cases
// Array is [[1, 2, 3], null, [4, 5, 6], null]
let cases = [
(
// Zero-length nulls
vec![1, 2, 3, 4, 5, 6],
vec![3, 0, 3, 0],
),
(
// Varying-length nulls
vec![1, 2, 3, 0, 0, 4, 5, 6, 0],
vec![3, 2, 3, 1],
),
(
// Correctly-sized nulls
vec![1, 2, 3, 0, 0, 0, 4, 5, 6, 0, 0, 0],
vec![3, 3, 3, 3],
),
(
// Mixed nulls
vec![1, 2, 3, 4, 5, 6, 0, 0, 0],
vec![3, 0, 3, 3],
),
];
let null_buffer = NullBuffer::from(vec![true, false, true, false]);

let expected = Arc::new(FixedSizeListArray::from_iter_primitive::<Int32Type, _, _>(
vec![
Some(vec![Some(1), Some(2), Some(3)]),
None,
Some(vec![Some(4), Some(5), Some(6)]),
None,
],
3,
)) as ArrayRef;

for (values, lengths) in cases.iter() {
let array = Arc::new(ListArray::new(
field.clone(),
OffsetBuffer::from_lengths(lengths.clone()),
Arc::new(Int32Array::from(values.clone())),
Some(null_buffer.clone()),
)) as ArrayRef;
let actual = cast(array.as_ref(), &DataType::FixedSizeList(field.clone(), 3)).unwrap();
assert_eq!(expected.as_ref(), actual.as_ref());
}
}

#[test]
fn test_cast_list_to_fsl_safety() {
let values = vec![
Some(vec![Some(1), Some(2), Some(3)]),
Some(vec![Some(4), Some(5)]),
Some(vec![Some(6), Some(7), Some(8), Some(9)]),
Some(vec![Some(3), Some(4), Some(5)]),
];
let array = Arc::new(ListArray::from_iter_primitive::<Int32Type, _, _>(
values.clone(),
)) as ArrayRef;

let res = cast_with_options(
array.as_ref(),
&DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Int32, true)), 3),
&CastOptions {
safe: false,
..Default::default()
},
);
assert!(res.is_err());
assert!(format!("{:?}", res)
.contains("Cannot cast to FixedSizeList(3): value at index 1 has length 2"));

// When safe=true (default), the cast will fill nulls for lists that are
// too short and truncate lists that are too long.
let res = cast(
array.as_ref(),
&DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Int32, true)), 3),
)
.unwrap();
let expected = Arc::new(FixedSizeListArray::from_iter_primitive::<Int32Type, _, _>(
vec![
Some(vec![Some(1), Some(2), Some(3)]),
None, // Too short -> replaced with null
None, // Too long -> replaced with null
Some(vec![Some(3), Some(4), Some(5)]),
],
3,
)) as ArrayRef;
assert_eq!(expected.as_ref(), res.as_ref());
}

#[test]
fn test_cast_large_list_to_fsl() {
let values = vec![Some(vec![Some(1), Some(2)]), Some(vec![Some(3), Some(4)])];
let array = Arc::new(LargeListArray::from_iter_primitive::<Int32Type, _, _>(
values.clone(),
)) as ArrayRef;
let expected = Arc::new(FixedSizeListArray::from_iter_primitive::<Int32Type, _, _>(
values, 2,
)) as ArrayRef;
let actual = cast(
array.as_ref(),
&DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Int32, true)), 2),
)
.unwrap();
assert_eq!(expected.as_ref(), actual.as_ref());
}

#[test]
fn test_cast_list_to_fsl_subcast() {
let array = Arc::new(LargeListArray::from_iter_primitive::<Int32Type, _, _>(
vec![
Some(vec![Some(1), Some(2)]),
Some(vec![Some(3), Some(i32::MAX)]),
],
)) as ArrayRef;
let expected = Arc::new(FixedSizeListArray::from_iter_primitive::<Int64Type, _, _>(
vec![
Some(vec![Some(1), Some(2)]),
Some(vec![Some(3), Some(i32::MAX as i64)]),
],
2,
)) as ArrayRef;
let actual = cast(
array.as_ref(),
&DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Int64, true)), 2),
)
.unwrap();
assert_eq!(expected.as_ref(), actual.as_ref());

let res = cast_with_options(
array.as_ref(),
&DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Int16, true)), 2),
&CastOptions {
safe: false,
..Default::default()
},
);
assert!(res.is_err());
assert!(format!("{:?}", res).contains("Can't cast value 2147483647 to type Int16"));
}

#[test]
fn test_cast_list_to_fsl_empty() {
let field = Arc::new(Field::new("item", DataType::Int32, true));
let array = new_empty_array(&DataType::List(field.clone()));

let target_type = DataType::FixedSizeList(field.clone(), 3);
let expected = new_empty_array(&target_type);

let actual = cast(array.as_ref(), &target_type).unwrap();
assert_eq!(expected.as_ref(), actual.as_ref());
}

fn make_list_array() -> ListArray {
// Construct a value array
let value_data = ArrayData::builder(DataType::Int32)
Expand Down

0 comments on commit dc75a28

Please sign in to comment.