Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

combinations: count and size_hint #729

Merged
merged 14 commits into from
Aug 18, 2023
Merged
42 changes: 42 additions & 0 deletions src/combinations.rs
Original file line number Diff line number Diff line change
Expand Up @@ -120,9 +120,51 @@ impl<I> Iterator for Combinations<I>
// Create result vector based on the indices
Some(self.indices.iter().map(|i| self.pool[*i].clone()).collect())
}

fn size_hint(&self) -> (usize, Option<usize>) {
let (mut low, mut upp) = self.pool.size_hint();
low = remaining_for(low, self.first, &self.indices).unwrap_or(usize::MAX);
upp = upp.and_then(|upp| remaining_for(upp, self.first, &self.indices));
(low, upp)
}

fn count(self) -> usize {
let Self { indices, pool, first } = self;
let n = pool.len() + pool.it.count();
Philippe-Cholet marked this conversation as resolved.
Show resolved Hide resolved
remaining_for(n, first, &indices).expect("Iterator count greater than usize::MAX")
Philippe-Cholet marked this conversation as resolved.
Show resolved Hide resolved
}
}

impl<I> FusedIterator for Combinations<I>
where I: Iterator,
I::Item: Clone
{}

pub(crate) fn checked_binomial(mut n: usize, mut k: usize) -> Option<usize> {
if n < k {
return Some(0);
}
// `factorial(n) / factorial(n - k) / factorial(k)` but trying to avoid it overflows:
k = (n - k).min(k); // symmetry
let mut c = 1;
for i in 1..=k {
c = (c / i).checked_mul(n)?.checked_add((c % i).checked_mul(n)? / i)?;
n -= 1;
}
Some(c)
Philippe-Cholet marked this conversation as resolved.
Show resolved Hide resolved
}
Philippe-Cholet marked this conversation as resolved.
Show resolved Hide resolved

/// For a given size `n`, return the count of remaining combinations or None if it would overflow.
fn remaining_for(n: usize, first: bool, indices: &[usize]) -> Option<usize> {
let k = indices.len();
Philippe-Cholet marked this conversation as resolved.
Show resolved Hide resolved
if first {
checked_binomial(n, k)
} else {
indices
.iter()
.enumerate()
.fold(Some(0), |sum, (k0, n0)| {
sum.and_then(|s| s.checked_add(checked_binomial(n - 1 - *n0, k - k0)?))
})
Philippe-Cholet marked this conversation as resolved.
Show resolved Hide resolved
}
}
6 changes: 6 additions & 0 deletions src/lazy_buffer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ use std::iter::Fuse;
use std::ops::Index;
use alloc::vec::Vec;

use crate::size_hint::{self, SizeHint};

#[derive(Debug, Clone)]
pub struct LazyBuffer<I: Iterator> {
pub it: Fuse<I>,
Expand All @@ -23,6 +25,10 @@ where
self.buffer.len()
}

pub fn size_hint(&self) -> SizeHint {
size_hint::add_scalar(self.it.size_hint(), self.len())
Philippe-Cholet marked this conversation as resolved.
Show resolved Hide resolved
}

pub fn get_next(&mut self) -> bool {
if let Some(x) = self.it.next() {
self.buffer.push(x);
Expand Down
15 changes: 15 additions & 0 deletions tests/test_std.rs
Original file line number Diff line number Diff line change
Expand Up @@ -909,6 +909,21 @@ fn combinations_zero() {
it::assert_equal((0..0).combinations(0), vec![vec![]]);
}

#[test]
fn combinations_range_count() {
for n in 0..6 {
for k in 0..=n {
Philippe-Cholet marked this conversation as resolved.
Show resolved Hide resolved
let len = (n - k + 1..=n).product::<usize>() / (1..=k).product::<usize>();
Philippe-Cholet marked this conversation as resolved.
Show resolved Hide resolved
let mut it = (0..n).combinations(k);
for count in (0..=len).rev() {
assert_eq!(it.size_hint(), (count, Some(count)));
assert_eq!(it.clone().count(), count);
assert_eq!(it.next().is_none(), count == 0);
Philippe-Cholet marked this conversation as resolved.
Show resolved Hide resolved
}
}
Philippe-Cholet marked this conversation as resolved.
Show resolved Hide resolved
}
}

#[test]
fn permutations_zero() {
it::assert_equal((1..3).permutations(0), vec![vec![]]);
Expand Down