Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MergeJoinBy also accept functions returning bool #704

Merged
merged 1 commit into from Jun 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
45 changes: 38 additions & 7 deletions src/lib.rs
Expand Up @@ -1012,7 +1012,10 @@ pub trait Itertools : Iterator {
/// Create an iterator that merges items from both this and the specified
/// iterator in ascending order.
///
/// It chooses whether to pair elements based on the `Ordering` returned by the
/// The function can either return an `Ordering` variant or a boolean.
///
/// If `cmp_fn` returns `Ordering`,
/// it chooses whether to pair elements based on the `Ordering` returned by the
/// specified compare function. At any point, inspecting the tip of the
/// iterators `I` and `J` as items `i` of type `I::Item` and `j` of type
/// `J::Item` respectively, the resulting iterator will:
Expand All @@ -1028,18 +1031,46 @@ pub trait Itertools : Iterator {
/// use itertools::Itertools;
/// use itertools::EitherOrBoth::{Left, Right, Both};
///
/// let multiples_of_2 = (0..10).step_by(2);
/// let multiples_of_3 = (0..10).step_by(3);
/// let a = vec![0, 2, 4, 6, 1].into_iter();
/// let b = (0..10).step_by(3);
///
/// itertools::assert_equal(
/// a.merge_join_by(b, |i, j| i.cmp(j)),
/// vec![Both(0, 0), Left(2), Right(3), Left(4), Both(6, 6), Left(1), Right(9)]
/// );
/// ```
///
/// If `cmp_fn` returns `bool`,
/// it chooses whether to pair elements based on the boolean returned by the
/// specified function. At any point, inspecting the tip of the
/// iterators `I` and `J` as items `i` of type `I::Item` and `j` of type
/// `J::Item` respectively, the resulting iterator will:
///
/// - Emit `Either::Left(i)` when `true`,
/// and remove `i` from its source iterator
/// - Emit `Either::Right(j)` when `false`,
/// and remove `j` from its source iterator
///
/// It is similar to the `Ordering` case if the first argument is considered
/// "less" than the second argument.
///
/// ```
/// use itertools::Itertools;
/// use itertools::Either::{Left, Right};
///
/// let a = vec![0, 2, 4, 6, 1].into_iter();
/// let b = (0..10).step_by(3);
///
/// itertools::assert_equal(
/// multiples_of_2.merge_join_by(multiples_of_3, |i, j| i.cmp(j)),
/// vec![Both(0, 0), Left(2), Right(3), Left(4), Both(6, 6), Left(8), Right(9)]
/// a.merge_join_by(b, |i, j| i <= j),
/// vec![Left(0), Right(0), Left(2), Right(3), Left(4), Left(6), Left(1), Right(6), Right(9)]
/// );
/// ```
#[inline]
fn merge_join_by<J, F>(self, other: J, cmp_fn: F) -> MergeJoinBy<Self, J::IntoIter, F>
fn merge_join_by<J, F, T>(self, other: J, cmp_fn: F) -> MergeJoinBy<Self, J::IntoIter, F>
Philippe-Cholet marked this conversation as resolved.
Show resolved Hide resolved
where J: IntoIterator,
F: FnMut(&Self::Item, &J::Item) -> std::cmp::Ordering,
F: FnMut(&Self::Item, &J::Item) -> T,
T: merge_join::OrderingOrBool<Self::Item, J::Item>,
Self: Sized
{
merge_join_by(self, other, cmp_fn)
Expand Down
163 changes: 107 additions & 56 deletions src/merge_join.rs
Expand Up @@ -2,19 +2,23 @@ use std::cmp::Ordering;
use std::iter::Fuse;
use std::fmt;

use either::Either;

use super::adaptors::{PutBack, put_back};
use crate::either_or_both::EitherOrBoth;
use crate::size_hint::{self, SizeHint};
#[cfg(doc)]
use crate::Itertools;

/// Return an iterator adaptor that merge-joins items from the two base iterators in ascending order.
///
/// [`IntoIterator`] enabled version of [`Itertools::merge_join_by`].
pub fn merge_join_by<I, J, F>(left: I, right: J, cmp_fn: F)
pub fn merge_join_by<I, J, F, T>(left: I, right: J, cmp_fn: F)
-> MergeJoinBy<I::IntoIter, J::IntoIter, F>
where I: IntoIterator,
J: IntoIterator,
F: FnMut(&I::Item, &J::Item) -> Ordering
F: FnMut(&I::Item, &J::Item) -> T,
T: OrderingOrBool<I::Item, J::Item>,
{
MergeJoinBy {
left: put_back(left.into_iter().fuse()),
Expand All @@ -30,7 +34,66 @@ pub fn merge_join_by<I, J, F>(left: I, right: J, cmp_fn: F)
pub struct MergeJoinBy<I: Iterator, J: Iterator, F> {
left: PutBack<Fuse<I>>,
right: PutBack<Fuse<J>>,
cmp_fn: F
cmp_fn: F,
}

pub trait OrderingOrBool<L, R> {
type MergeResult;
fn left(left: L) -> Self::MergeResult;
fn right(right: R) -> Self::MergeResult;
// "merge" never returns (Some(...), Some(...), ...) so Option<Either<I::Item, J::Item>>
// is appealing but it is always followed by two put_backs, so we think the compiler is
// smart enough to optimize it. Or we could move put_backs into "merge".
fn merge(self, left: L, right: R) -> (Option<L>, Option<R>, Self::MergeResult);
fn size_hint(left: SizeHint, right: SizeHint) -> SizeHint;
}

impl<L, R> OrderingOrBool<L, R> for Ordering {
type MergeResult = EitherOrBoth<L, R>;
fn left(left: L) -> Self::MergeResult {
EitherOrBoth::Left(left)
}
fn right(right: R) -> Self::MergeResult {
EitherOrBoth::Right(right)
}
fn merge(self, left: L, right: R) -> (Option<L>, Option<R>, Self::MergeResult) {
match self {
Ordering::Equal => (None, None, EitherOrBoth::Both(left, right)),
Ordering::Less => (None, Some(right), EitherOrBoth::Left(left)),
Ordering::Greater => (Some(left), None, EitherOrBoth::Right(right)),
}
}
fn size_hint(left: SizeHint, right: SizeHint) -> SizeHint {
let (a_lower, a_upper) = left;
let (b_lower, b_upper) = right;
let lower = ::std::cmp::max(a_lower, b_lower);
let upper = match (a_upper, b_upper) {
(Some(x), Some(y)) => x.checked_add(y),
_ => None,
};
(lower, upper)
}
}

impl<L, R> OrderingOrBool<L, R> for bool {
type MergeResult = Either<L, R>;
fn left(left: L) -> Self::MergeResult {
Either::Left(left)
}
fn right(right: R) -> Self::MergeResult {
Either::Right(right)
}
fn merge(self, left: L, right: R) -> (Option<L>, Option<R>, Self::MergeResult) {
if self {
(None, Some(right), Either::Left(left))
} else {
(Some(left), None, Either::Right(right))
}
}
fn size_hint(left: SizeHint, right: SizeHint) -> SizeHint {
// Not ExactSizeIterator because size may be larger than usize
size_hint::add(left, right)
}
}

impl<I, J, F> Clone for MergeJoinBy<I, J, F>
Expand All @@ -52,49 +115,34 @@ impl<I, J, F> fmt::Debug for MergeJoinBy<I, J, F>
debug_fmt_fields!(MergeJoinBy, left, right);
}

impl<I, J, F> Iterator for MergeJoinBy<I, J, F>
impl<I, J, F, T> Iterator for MergeJoinBy<I, J, F>
where I: Iterator,
J: Iterator,
F: FnMut(&I::Item, &J::Item) -> Ordering
F: FnMut(&I::Item, &J::Item) -> T,
T: OrderingOrBool<I::Item, J::Item>,
{
type Item = EitherOrBoth<I::Item, J::Item>;
type Item = T::MergeResult;

fn next(&mut self) -> Option<Self::Item> {
match (self.left.next(), self.right.next()) {
(None, None) => None,
(Some(left), None) =>
Some(EitherOrBoth::Left(left)),
(None, Some(right)) =>
Some(EitherOrBoth::Right(right)),
(Some(left), None) => Some(T::left(left)),
(None, Some(right)) => Some(T::right(right)),
(Some(left), Some(right)) => {
match (self.cmp_fn)(&left, &right) {
Ordering::Equal =>
Some(EitherOrBoth::Both(left, right)),
Ordering::Less => {
self.right.put_back(right);
Some(EitherOrBoth::Left(left))
},
Ordering::Greater => {
self.left.put_back(left);
Some(EitherOrBoth::Right(right))
}
let (left, right, next) = (self.cmp_fn)(&left, &right).merge(left, right);
if let Some(left) = left {
self.left.put_back(left);
}
if let Some(right) = right {
self.right.put_back(right);
}
Some(next)
}
}
}

fn size_hint(&self) -> (usize, Option<usize>) {
let (a_lower, a_upper) = self.left.size_hint();
let (b_lower, b_upper) = self.right.size_hint();

let lower = ::std::cmp::max(a_lower, b_lower);

let upper = match (a_upper, b_upper) {
(Some(x), Some(y)) => x.checked_add(y),
_ => None,
};

(lower, upper)
fn size_hint(&self) -> SizeHint {
T::size_hint(self.left.size_hint(), self.right.size_hint())
}

fn count(mut self) -> usize {
Expand All @@ -106,10 +154,12 @@ impl<I, J, F> Iterator for MergeJoinBy<I, J, F>
(None, Some(_right)) => break count + 1 + self.right.into_parts().1.count(),
(Some(left), Some(right)) => {
count += 1;
match (self.cmp_fn)(&left, &right) {
Ordering::Equal => {}
Ordering::Less => self.right.put_back(right),
Ordering::Greater => self.left.put_back(left),
let (left, right, _) = (self.cmp_fn)(&left, &right).merge(left, right);
if let Some(left) = left {
self.left.put_back(left);
}
if let Some(right) = right {
self.right.put_back(right);
}
}
}
Expand All @@ -122,27 +172,24 @@ impl<I, J, F> Iterator for MergeJoinBy<I, J, F>
match (self.left.next(), self.right.next()) {
(None, None) => break previous_element,
(Some(left), None) => {
break Some(EitherOrBoth::Left(
break Some(T::left(
self.left.into_parts().1.last().unwrap_or(left),
))
}
(None, Some(right)) => {
break Some(EitherOrBoth::Right(
break Some(T::right(
self.right.into_parts().1.last().unwrap_or(right),
))
}
(Some(left), Some(right)) => {
previous_element = match (self.cmp_fn)(&left, &right) {
Ordering::Equal => Some(EitherOrBoth::Both(left, right)),
Ordering::Less => {
self.right.put_back(right);
Some(EitherOrBoth::Left(left))
}
Ordering::Greater => {
self.left.put_back(left);
Some(EitherOrBoth::Right(right))
}
let (left, right, elem) = (self.cmp_fn)(&left, &right).merge(left, right);
if let Some(left) = left {
self.left.put_back(left);
}
if let Some(right) = right {
self.right.put_back(right);
}
previous_element = Some(elem);
}
}
}
Expand All @@ -156,13 +203,17 @@ impl<I, J, F> Iterator for MergeJoinBy<I, J, F>
n -= 1;
match (self.left.next(), self.right.next()) {
(None, None) => break None,
(Some(_left), None) => break self.left.nth(n).map(EitherOrBoth::Left),
(None, Some(_right)) => break self.right.nth(n).map(EitherOrBoth::Right),
(Some(left), Some(right)) => match (self.cmp_fn)(&left, &right) {
Ordering::Equal => {}
Ordering::Less => self.right.put_back(right),
Ordering::Greater => self.left.put_back(left),
},
(Some(_left), None) => break self.left.nth(n).map(T::left),
(None, Some(_right)) => break self.right.nth(n).map(T::right),
(Some(left), Some(right)) => {
let (left, right, _) = (self.cmp_fn)(&left, &right).merge(left, right);
if let Some(left) = left {
self.left.put_back(left);
}
if let Some(right) = right {
self.right.put_back(right);
}
}
}
}
}
Expand Down
25 changes: 25 additions & 0 deletions tests/quick.rs
Expand Up @@ -829,6 +829,31 @@ quickcheck! {
}
}

quickcheck! {
fn merge_join_by_ordering_vs_bool(a: Vec<u8>, b: Vec<u8>) -> bool {
use either::Either;
use itertools::free::merge_join_by;
let mut has_equal = false;
let it_ord = merge_join_by(a.clone(), b.clone(), Ord::cmp).flat_map(|v| match v {
EitherOrBoth::Both(l, r) => {
has_equal = true;
vec![Either::Left(l), Either::Right(r)]
}
EitherOrBoth::Left(l) => vec![Either::Left(l)],
EitherOrBoth::Right(r) => vec![Either::Right(r)],
});
let it_bool = merge_join_by(a, b, PartialOrd::le);
itertools::equal(it_ord, it_bool) || has_equal
}
fn merge_join_by_bool_unwrapped_is_merge_by(a: Vec<u8>, b: Vec<u8>) -> bool {
use either::Either;
use itertools::free::merge_join_by;
let it = a.clone().into_iter().merge_by(b.clone(), PartialOrd::ge);
let it_join = merge_join_by(a, b, PartialOrd::ge).map(Either::into_inner);
itertools::equal(it, it_join)
}
}

quickcheck! {
fn size_tee(a: Vec<u8>) -> bool {
let (mut t1, mut t2) = a.iter().tee();
Expand Down