From ae31559af533208e25e277aa01950d2ebf04eecb Mon Sep 17 00:00:00 2001 From: Philippe-Cholet Date: Sun, 11 Jun 2023 22:03:56 +0200 Subject: [PATCH] `MergeJoinBy` also accept functions returning `bool` Done with `trait OrderingOrBool`. Now, `merge_join_by` needs an extra type parameter `T` so this is a breaking change, because any [unlikely] invocations that explicitly provided these parameters are now one parameter short. Documentation updated, two quickcheck. --- src/lib.rs | 45 +++++++++++-- src/merge_join.rs | 163 ++++++++++++++++++++++++++++++---------------- tests/quick.rs | 25 +++++++ 3 files changed, 170 insertions(+), 63 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index daf935f73..340a04120 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1012,7 +1012,10 @@ pub trait Itertools : Iterator { /// Create an iterator that merges items from both this and the specified /// iterator in ascending order. /// - /// It chooses whether to pair elements based on the `Ordering` returned by the + /// The function can either return an `Ordering` variant or a boolean. + /// + /// If `cmp_fn` returns `Ordering`, + /// it chooses whether to pair elements based on the `Ordering` returned by the /// specified compare function. At any point, inspecting the tip of the /// iterators `I` and `J` as items `i` of type `I::Item` and `j` of type /// `J::Item` respectively, the resulting iterator will: @@ -1028,18 +1031,46 @@ pub trait Itertools : Iterator { /// use itertools::Itertools; /// use itertools::EitherOrBoth::{Left, Right, Both}; /// - /// let multiples_of_2 = (0..10).step_by(2); - /// let multiples_of_3 = (0..10).step_by(3); + /// let a = vec![0, 2, 4, 6, 1].into_iter(); + /// let b = (0..10).step_by(3); + /// + /// itertools::assert_equal( + /// a.merge_join_by(b, |i, j| i.cmp(j)), + /// vec![Both(0, 0), Left(2), Right(3), Left(4), Both(6, 6), Left(1), Right(9)] + /// ); + /// ``` + /// + /// If `cmp_fn` returns `bool`, + /// it chooses whether to pair elements based on the boolean returned by the + /// specified function. At any point, inspecting the tip of the + /// iterators `I` and `J` as items `i` of type `I::Item` and `j` of type + /// `J::Item` respectively, the resulting iterator will: + /// + /// - Emit `Either::Left(i)` when `true`, + /// and remove `i` from its source iterator + /// - Emit `Either::Right(j)` when `false`, + /// and remove `j` from its source iterator + /// + /// It is similar to the `Ordering` case if the first argument is considered + /// "less" than the second argument. + /// + /// ``` + /// use itertools::Itertools; + /// use itertools::Either::{Left, Right}; + /// + /// let a = vec![0, 2, 4, 6, 1].into_iter(); + /// let b = (0..10).step_by(3); /// /// itertools::assert_equal( - /// multiples_of_2.merge_join_by(multiples_of_3, |i, j| i.cmp(j)), - /// vec![Both(0, 0), Left(2), Right(3), Left(4), Both(6, 6), Left(8), Right(9)] + /// a.merge_join_by(b, |i, j| i <= j), + /// vec![Left(0), Right(0), Left(2), Right(3), Left(4), Left(6), Left(1), Right(6), Right(9)] /// ); /// ``` #[inline] - fn merge_join_by(self, other: J, cmp_fn: F) -> MergeJoinBy + fn merge_join_by(self, other: J, cmp_fn: F) -> MergeJoinBy where J: IntoIterator, - F: FnMut(&Self::Item, &J::Item) -> std::cmp::Ordering, + F: FnMut(&Self::Item, &J::Item) -> T, + T: merge_join::OrderingOrBool, Self: Sized { merge_join_by(self, other, cmp_fn) diff --git a/src/merge_join.rs b/src/merge_join.rs index f2fbdea2c..84f7d0333 100644 --- a/src/merge_join.rs +++ b/src/merge_join.rs @@ -2,19 +2,23 @@ use std::cmp::Ordering; use std::iter::Fuse; use std::fmt; +use either::Either; + use super::adaptors::{PutBack, put_back}; use crate::either_or_both::EitherOrBoth; +use crate::size_hint::{self, SizeHint}; #[cfg(doc)] use crate::Itertools; /// Return an iterator adaptor that merge-joins items from the two base iterators in ascending order. /// /// [`IntoIterator`] enabled version of [`Itertools::merge_join_by`]. -pub fn merge_join_by(left: I, right: J, cmp_fn: F) +pub fn merge_join_by(left: I, right: J, cmp_fn: F) -> MergeJoinBy where I: IntoIterator, J: IntoIterator, - F: FnMut(&I::Item, &J::Item) -> Ordering + F: FnMut(&I::Item, &J::Item) -> T, + T: OrderingOrBool, { MergeJoinBy { left: put_back(left.into_iter().fuse()), @@ -30,7 +34,66 @@ pub fn merge_join_by(left: I, right: J, cmp_fn: F) pub struct MergeJoinBy { left: PutBack>, right: PutBack>, - cmp_fn: F + cmp_fn: F, +} + +pub trait OrderingOrBool { + type MergeResult; + fn left(left: L) -> Self::MergeResult; + fn right(right: R) -> Self::MergeResult; + // "merge" never returns (Some(...), Some(...), ...) so Option> + // is appealing but it is always followed by two put_backs, so we think the compiler is + // smart enough to optimize it. Or we could move put_backs into "merge". + fn merge(self, left: L, right: R) -> (Option, Option, Self::MergeResult); + fn size_hint(left: SizeHint, right: SizeHint) -> SizeHint; +} + +impl OrderingOrBool for Ordering { + type MergeResult = EitherOrBoth; + fn left(left: L) -> Self::MergeResult { + EitherOrBoth::Left(left) + } + fn right(right: R) -> Self::MergeResult { + EitherOrBoth::Right(right) + } + fn merge(self, left: L, right: R) -> (Option, Option, Self::MergeResult) { + match self { + Ordering::Equal => (None, None, EitherOrBoth::Both(left, right)), + Ordering::Less => (None, Some(right), EitherOrBoth::Left(left)), + Ordering::Greater => (Some(left), None, EitherOrBoth::Right(right)), + } + } + fn size_hint(left: SizeHint, right: SizeHint) -> SizeHint { + let (a_lower, a_upper) = left; + let (b_lower, b_upper) = right; + let lower = ::std::cmp::max(a_lower, b_lower); + let upper = match (a_upper, b_upper) { + (Some(x), Some(y)) => x.checked_add(y), + _ => None, + }; + (lower, upper) + } +} + +impl OrderingOrBool for bool { + type MergeResult = Either; + fn left(left: L) -> Self::MergeResult { + Either::Left(left) + } + fn right(right: R) -> Self::MergeResult { + Either::Right(right) + } + fn merge(self, left: L, right: R) -> (Option, Option, Self::MergeResult) { + if self { + (None, Some(right), Either::Left(left)) + } else { + (Some(left), None, Either::Right(right)) + } + } + fn size_hint(left: SizeHint, right: SizeHint) -> SizeHint { + // Not ExactSizeIterator because size may be larger than usize + size_hint::add(left, right) + } } impl Clone for MergeJoinBy @@ -52,49 +115,34 @@ impl fmt::Debug for MergeJoinBy debug_fmt_fields!(MergeJoinBy, left, right); } -impl Iterator for MergeJoinBy +impl Iterator for MergeJoinBy where I: Iterator, J: Iterator, - F: FnMut(&I::Item, &J::Item) -> Ordering + F: FnMut(&I::Item, &J::Item) -> T, + T: OrderingOrBool, { - type Item = EitherOrBoth; + type Item = T::MergeResult; fn next(&mut self) -> Option { match (self.left.next(), self.right.next()) { (None, None) => None, - (Some(left), None) => - Some(EitherOrBoth::Left(left)), - (None, Some(right)) => - Some(EitherOrBoth::Right(right)), + (Some(left), None) => Some(T::left(left)), + (None, Some(right)) => Some(T::right(right)), (Some(left), Some(right)) => { - match (self.cmp_fn)(&left, &right) { - Ordering::Equal => - Some(EitherOrBoth::Both(left, right)), - Ordering::Less => { - self.right.put_back(right); - Some(EitherOrBoth::Left(left)) - }, - Ordering::Greater => { - self.left.put_back(left); - Some(EitherOrBoth::Right(right)) - } + let (left, right, next) = (self.cmp_fn)(&left, &right).merge(left, right); + if let Some(left) = left { + self.left.put_back(left); + } + if let Some(right) = right { + self.right.put_back(right); } + Some(next) } } } - fn size_hint(&self) -> (usize, Option) { - let (a_lower, a_upper) = self.left.size_hint(); - let (b_lower, b_upper) = self.right.size_hint(); - - let lower = ::std::cmp::max(a_lower, b_lower); - - let upper = match (a_upper, b_upper) { - (Some(x), Some(y)) => x.checked_add(y), - _ => None, - }; - - (lower, upper) + fn size_hint(&self) -> SizeHint { + T::size_hint(self.left.size_hint(), self.right.size_hint()) } fn count(mut self) -> usize { @@ -106,10 +154,12 @@ impl Iterator for MergeJoinBy (None, Some(_right)) => break count + 1 + self.right.into_parts().1.count(), (Some(left), Some(right)) => { count += 1; - match (self.cmp_fn)(&left, &right) { - Ordering::Equal => {} - Ordering::Less => self.right.put_back(right), - Ordering::Greater => self.left.put_back(left), + let (left, right, _) = (self.cmp_fn)(&left, &right).merge(left, right); + if let Some(left) = left { + self.left.put_back(left); + } + if let Some(right) = right { + self.right.put_back(right); } } } @@ -122,27 +172,24 @@ impl Iterator for MergeJoinBy match (self.left.next(), self.right.next()) { (None, None) => break previous_element, (Some(left), None) => { - break Some(EitherOrBoth::Left( + break Some(T::left( self.left.into_parts().1.last().unwrap_or(left), )) } (None, Some(right)) => { - break Some(EitherOrBoth::Right( + break Some(T::right( self.right.into_parts().1.last().unwrap_or(right), )) } (Some(left), Some(right)) => { - previous_element = match (self.cmp_fn)(&left, &right) { - Ordering::Equal => Some(EitherOrBoth::Both(left, right)), - Ordering::Less => { - self.right.put_back(right); - Some(EitherOrBoth::Left(left)) - } - Ordering::Greater => { - self.left.put_back(left); - Some(EitherOrBoth::Right(right)) - } + let (left, right, elem) = (self.cmp_fn)(&left, &right).merge(left, right); + if let Some(left) = left { + self.left.put_back(left); + } + if let Some(right) = right { + self.right.put_back(right); } + previous_element = Some(elem); } } } @@ -156,13 +203,17 @@ impl Iterator for MergeJoinBy n -= 1; match (self.left.next(), self.right.next()) { (None, None) => break None, - (Some(_left), None) => break self.left.nth(n).map(EitherOrBoth::Left), - (None, Some(_right)) => break self.right.nth(n).map(EitherOrBoth::Right), - (Some(left), Some(right)) => match (self.cmp_fn)(&left, &right) { - Ordering::Equal => {} - Ordering::Less => self.right.put_back(right), - Ordering::Greater => self.left.put_back(left), - }, + (Some(_left), None) => break self.left.nth(n).map(T::left), + (None, Some(_right)) => break self.right.nth(n).map(T::right), + (Some(left), Some(right)) => { + let (left, right, _) = (self.cmp_fn)(&left, &right).merge(left, right); + if let Some(left) = left { + self.left.put_back(left); + } + if let Some(right) = right { + self.right.put_back(right); + } + } } } } diff --git a/tests/quick.rs b/tests/quick.rs index 0adcf1ad7..914960a5f 100644 --- a/tests/quick.rs +++ b/tests/quick.rs @@ -829,6 +829,31 @@ quickcheck! { } } +quickcheck! { + fn merge_join_by_ordering_vs_bool(a: Vec, b: Vec) -> bool { + use either::Either; + use itertools::free::merge_join_by; + let mut has_equal = false; + let it_ord = merge_join_by(a.clone(), b.clone(), Ord::cmp).flat_map(|v| match v { + EitherOrBoth::Both(l, r) => { + has_equal = true; + vec![Either::Left(l), Either::Right(r)] + } + EitherOrBoth::Left(l) => vec![Either::Left(l)], + EitherOrBoth::Right(r) => vec![Either::Right(r)], + }); + let it_bool = merge_join_by(a, b, PartialOrd::le); + itertools::equal(it_ord, it_bool) || has_equal + } + fn merge_join_by_bool_unwrapped_is_merge_by(a: Vec, b: Vec) -> bool { + use either::Either; + use itertools::free::merge_join_by; + let it = a.clone().into_iter().merge_by(b.clone(), PartialOrd::ge); + let it_join = merge_join_by(a, b, PartialOrd::ge).map(Either::into_inner); + itertools::equal(it, it_join) + } +} + quickcheck! { fn size_tee(a: Vec) -> bool { let (mut t1, mut t2) = a.iter().tee();