Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

combinations: count and size_hint #729

Merged
merged 14 commits into from
Aug 18, 2023
Merged
41 changes: 41 additions & 0 deletions src/combinations.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,21 @@ impl<I: Iterator> Combinations<I> {
self.pool.prefill(k);
}
}

/// For a given size `n`, return the count of remaining elements or None if it would overflow.
fn remaining_for(&self, n: usize) -> Option<usize> {
Philippe-Cholet marked this conversation as resolved.
Show resolved Hide resolved
let k = self.k();
if self.first {
checked_binomial(n, k)
Philippe-Cholet marked this conversation as resolved.
Show resolved Hide resolved
} else {
self.indices
.iter()
.enumerate()
.fold(Some(0), |sum, (k0, n0)| {
sum.and_then(|s| s.checked_add(checked_binomial(n - 1 - *n0, k - k0)?))
})
Philippe-Cholet marked this conversation as resolved.
Show resolved Hide resolved
}
}
}

impl<I> Iterator for Combinations<I>
Expand Down Expand Up @@ -120,9 +135,35 @@ impl<I> Iterator for Combinations<I>
// Create result vector based on the indices
Some(self.indices.iter().map(|i| self.pool[*i].clone()).collect())
}

fn size_hint(&self) -> (usize, Option<usize>) {
let (mut low, mut upp) = self.pool.size_hint();
low = self.remaining_for(low).unwrap_or(usize::MAX);
upp = upp.and_then(|upp| self.remaining_for(upp));
(low, upp)
}

fn count(mut self) -> usize {
self.pool.fill();
Philippe-Cholet marked this conversation as resolved.
Show resolved Hide resolved
self.remaining_for(self.n()).expect("Iterator count greater than usize::MAX")
}
}

impl<I> FusedIterator for Combinations<I>
where I: Iterator,
I::Item: Clone
{}

pub(crate) fn checked_binomial(mut n: usize, mut k: usize) -> Option<usize> {
if n < k {
return Some(0);
}
// `factorial(n) / factorial(n - k) / factorial(k)` but trying to avoid it overflows:
k = (n - k).min(k); // symmetry
let mut c = 1;
for i in 1..=k {
c = (c / i).checked_mul(n)?.checked_add((c % i).checked_mul(n)? / i)?;
n -= 1;
}
Some(c)
Philippe-Cholet marked this conversation as resolved.
Show resolved Hide resolved
}
Philippe-Cholet marked this conversation as resolved.
Show resolved Hide resolved
10 changes: 10 additions & 0 deletions src/lazy_buffer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ use std::iter::Fuse;
use std::ops::Index;
use alloc::vec::Vec;

use crate::size_hint::{self, SizeHint};

#[derive(Debug, Clone)]
pub struct LazyBuffer<I: Iterator> {
pub it: Fuse<I>,
Expand All @@ -23,6 +25,10 @@ where
self.buffer.len()
}

pub fn size_hint(&self) -> SizeHint {
size_hint::add_scalar(self.it.size_hint(), self.len())
Philippe-Cholet marked this conversation as resolved.
Show resolved Hide resolved
}

pub fn get_next(&mut self) -> bool {
if let Some(x) = self.it.next() {
self.buffer.push(x);
Expand All @@ -39,6 +45,10 @@ where
self.buffer.extend(self.it.by_ref().take(delta));
}
}

pub fn fill(&mut self) {
self.buffer.extend(self.it.by_ref());
}
}

impl<I, J> Index<J> for LazyBuffer<I>
Expand Down
15 changes: 15 additions & 0 deletions tests/test_std.rs
Original file line number Diff line number Diff line change
Expand Up @@ -909,6 +909,21 @@ fn combinations_zero() {
it::assert_equal((0..0).combinations(0), vec![vec![]]);
}

#[test]
fn combinations_range_count() {
for n in 0..6 {
for k in 0..=n {
Philippe-Cholet marked this conversation as resolved.
Show resolved Hide resolved
let len = (n - k + 1..=n).product::<usize>() / (1..=k).product::<usize>();
Philippe-Cholet marked this conversation as resolved.
Show resolved Hide resolved
let mut it = (0..n).combinations(k);
for count in (0..=len).rev() {
assert_eq!(it.size_hint(), (count, Some(count)));
assert_eq!(it.clone().count(), count);
assert_eq!(it.next().is_none(), count == 0);
Philippe-Cholet marked this conversation as resolved.
Show resolved Hide resolved
}
}
Philippe-Cholet marked this conversation as resolved.
Show resolved Hide resolved
}
}

#[test]
fn permutations_zero() {
it::assert_equal((1..3).permutations(0), vec![vec![]]);
Expand Down