Skip to content

Commit

Permalink
Support nested schema projection (#5148) (#5149)
Browse files Browse the repository at this point in the history
* Support nested schema projection

* Tweak doc

* Review feedback
  • Loading branch information
tustvold committed Nov 29, 2023
1 parent 8867a1f commit 6d4b8bb
Showing 1 changed file with 231 additions and 1 deletion.
232 changes: 231 additions & 1 deletion arrow-schema/src/fields.rs
Expand Up @@ -15,10 +15,11 @@
// specific language governing permissions and limitations
// under the License.

use crate::{ArrowError, Field, FieldRef, SchemaBuilder};
use std::ops::Deref;
use std::sync::Arc;

use crate::{ArrowError, DataType, Field, FieldRef, SchemaBuilder};

/// A cheaply cloneable, owned slice of [`FieldRef`]
///
/// Similar to `Arc<Vec<FieldRef>>` or `Arc<[FieldRef]>`
Expand Down Expand Up @@ -99,6 +100,108 @@ impl Fields {
.all(|(a, b)| Arc::ptr_eq(a, b) || a.contains(b))
}

/// Returns a copy of this [`Fields`] containing only those [`FieldRef`] passing a predicate
///
/// Performs a depth-first scan of [`Fields`] invoking `filter` for each [`FieldRef`]
/// containing no child [`FieldRef`], a leaf field, along with a count of the number
/// of such leaves encountered so far. Only [`FieldRef`] for which `filter`
/// returned `true` will be included in the result.
///
/// This can therefore be used to select a subset of fields from nested types
/// such as [`DataType::Struct`] or [`DataType::List`].
///
/// ```
/// # use arrow_schema::{DataType, Field, Fields};
/// let fields = Fields::from(vec![
/// Field::new("a", DataType::Int32, true), // Leaf 0
/// Field::new("b", DataType::Struct(Fields::from(vec![
/// Field::new("c", DataType::Float32, false), // Leaf 1
/// Field::new("d", DataType::Float64, false), // Leaf 2
/// Field::new("e", DataType::Struct(Fields::from(vec![
/// Field::new("f", DataType::Int32, false), // Leaf 3
/// Field::new("g", DataType::Float16, false), // Leaf 4
/// ])), true),
/// ])), false)
/// ]);
/// let filtered = fields.filter_leaves(|idx, _| [0, 2, 3, 4].contains(&idx));
/// let expected = Fields::from(vec![
/// Field::new("a", DataType::Int32, true),
/// Field::new("b", DataType::Struct(Fields::from(vec![
/// Field::new("d", DataType::Float64, false),
/// Field::new("e", DataType::Struct(Fields::from(vec![
/// Field::new("f", DataType::Int32, false),
/// Field::new("g", DataType::Float16, false),
/// ])), true),
/// ])), false)
/// ]);
/// assert_eq!(filtered, expected);
/// ```
pub fn filter_leaves<F: FnMut(usize, &FieldRef) -> bool>(&self, mut filter: F) -> Self {
fn filter_field<F: FnMut(&FieldRef) -> bool>(
f: &FieldRef,
filter: &mut F,
) -> Option<FieldRef> {
use DataType::*;

let v = match f.data_type() {
Dictionary(_, v) => v.as_ref(), // Key must be integer
RunEndEncoded(_, v) => v.data_type(), // Run-ends must be integer
d => d,
};
let d = match v {
List(child) => List(filter_field(child, filter)?),
LargeList(child) => LargeList(filter_field(child, filter)?),
Map(child, ordered) => Map(filter_field(child, filter)?, *ordered),
FixedSizeList(child, size) => FixedSizeList(filter_field(child, filter)?, *size),
Struct(fields) => {
let filtered: Fields = fields
.iter()
.filter_map(|f| filter_field(f, filter))
.collect();

if filtered.is_empty() {
return None;
}

Struct(filtered)
}
Union(fields, mode) => {
let filtered: UnionFields = fields
.iter()
.filter_map(|(id, f)| Some((id, filter_field(f, filter)?)))
.collect();

if filtered.is_empty() {
return None;
}

Union(filtered, *mode)
}
_ => return filter(f).then(|| f.clone()),
};
let d = match f.data_type() {
Dictionary(k, _) => Dictionary(k.clone(), Box::new(d)),
RunEndEncoded(v, f) => {
RunEndEncoded(v.clone(), Arc::new(f.as_ref().clone().with_data_type(d)))
}
_ => d,
};
Some(Arc::new(f.as_ref().clone().with_data_type(d)))
}

let mut leaf_idx = 0;
let mut filter = |f: &FieldRef| {
let t = filter(leaf_idx, f);
leaf_idx += 1;
t
};

self.0
.iter()
.filter_map(|f| filter_field(f, &mut filter))
.collect()
}

/// Remove a field by index and return it.
///
/// # Panic
Expand Down Expand Up @@ -307,3 +410,130 @@ impl FromIterator<(i8, FieldRef)> for UnionFields {
Self(iter.into_iter().collect())
}
}

#[cfg(test)]
mod tests {
use super::*;
use crate::UnionMode;

#[test]
fn test_filter() {
let floats = Fields::from(vec![
Field::new("a", DataType::Float32, false),
Field::new("b", DataType::Float32, false),
]);
let fields = Fields::from(vec![
Field::new("a", DataType::Int32, true),
Field::new("floats", DataType::Struct(floats.clone()), true),
Field::new("b", DataType::Int16, true),
Field::new(
"c",
DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)),
false,
),
Field::new(
"d",
DataType::Dictionary(
Box::new(DataType::Int32),
Box::new(DataType::Struct(floats.clone())),
),
false,
),
Field::new_list(
"e",
Field::new("floats", DataType::Struct(floats.clone()), true),
true,
),
Field::new(
"f",
DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Int32, false)), 3),
false,
),
Field::new_map(
"g",
"entries",
Field::new("keys", DataType::LargeUtf8, false),
Field::new("values", DataType::Int32, true),
false,
false,
),
Field::new(
"h",
DataType::Union(
UnionFields::new(
vec![1, 3],
vec![
Field::new("field1", DataType::UInt8, false),
Field::new("field3", DataType::Utf8, false),
],
),
UnionMode::Dense,
),
true,
),
Field::new(
"i",
DataType::RunEndEncoded(
Arc::new(Field::new("run_ends", DataType::Int32, false)),
Arc::new(Field::new("values", DataType::Struct(floats.clone()), true)),
),
false,
),
]);

let floats_a = DataType::Struct(vec![floats[0].clone()].into());

let r = fields.filter_leaves(|idx, _| idx == 0 || idx == 1);
assert_eq!(r.len(), 2);
assert_eq!(r[0], fields[0]);
assert_eq!(r[1].data_type(), &floats_a);

let r = fields.filter_leaves(|_, f| f.name() == "a");
assert_eq!(r.len(), 5);
assert_eq!(r[0], fields[0]);
assert_eq!(r[1].data_type(), &floats_a);
assert_eq!(
r[2].data_type(),
&DataType::Dictionary(Box::new(DataType::Int32), Box::new(floats_a.clone()))
);
assert_eq!(
r[3].as_ref(),
&Field::new_list("e", Field::new("floats", floats_a.clone(), true), true)
);
assert_eq!(
r[4].as_ref(),
&Field::new(
"i",
DataType::RunEndEncoded(
Arc::new(Field::new("run_ends", DataType::Int32, false)),
Arc::new(Field::new("values", floats_a.clone(), true)),
),
false,
)
);

let r = fields.filter_leaves(|_, f| f.name() == "floats");
assert_eq!(r.len(), 0);

let r = fields.filter_leaves(|idx, _| idx == 9);
assert_eq!(r.len(), 1);
assert_eq!(r[0], fields[6]);

let r = fields.filter_leaves(|idx, _| idx == 10 || idx == 11);
assert_eq!(r.len(), 1);
assert_eq!(r[0], fields[7]);

let union = DataType::Union(
UnionFields::new(vec![1], vec![Field::new("field1", DataType::UInt8, false)]),
UnionMode::Dense,
);

let r = fields.filter_leaves(|idx, _| idx == 12);
assert_eq!(r.len(), 1);
assert_eq!(r[0].data_type(), &union);

let r = fields.filter_leaves(|idx, _| idx == 14 || idx == 15);
assert_eq!(r.len(), 1);
assert_eq!(r[0], fields[9]);
}
}

0 comments on commit 6d4b8bb

Please sign in to comment.