From d0be44c030b35f369db6f9431903a283c077a1b9 Mon Sep 17 00:00:00 2001 From: Pratyush Mishra Date: Tue, 27 Jul 2021 13:55:23 -0700 Subject: [PATCH] Make `test_rng` randomized by default in `std` (#35) --- CHANGELOG.md | 2 + Cargo.toml | 2 +- src/rand_helper.rs | 93 ++++++++++++++++++++++++++++++++++++++++++++-- 3 files changed, 93 insertions(+), 4 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 3da3013..0d0a8d1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,8 @@ ### Breaking changes +- [\#35](https://github.com/arkworks-rs/utils/pull/35) Change `test_rng` to return `impl Rng`, and make the output randomized by default when the `std` feature is set. Introduces a `DETERMINISTIC_TEST_RNG` environment variable that forces the old deterministic behavior when `DETERMINISTIC_TEST_RNG=1` is set. + ### Features ### Improvements diff --git a/Cargo.toml b/Cargo.toml index c5e1952..3ff360c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -20,7 +20,7 @@ num-traits = { version = "0.2", default-features = false } [features] default = [ "std" ] -std = [] +std = [ "rand/std" ] parallel = [ "rayon", "std" ] print-trace = [ "std", "colored" ] diff --git a/src/rand_helper.rs b/src/rand_helper.rs index 59ee385..b04ca63 100644 --- a/src/rand_helper.rs +++ b/src/rand_helper.rs @@ -1,8 +1,10 @@ use rand::{ distributions::{Distribution, Standard}, - rngs::StdRng, + prelude::StdRng, Rng, }; +#[cfg(feature = "std")] +use rand::{prelude::ThreadRng, RngCore}; pub use rand; @@ -20,8 +22,7 @@ where } } -/// Should be used only for tests, not for any real world usage. -pub fn test_rng() -> StdRng { +fn test_rng_helper() -> StdRng { use rand::SeedableRng; // arbitrary seed let seed = [ @@ -30,3 +31,89 @@ pub fn test_rng() -> StdRng { ]; rand::rngs::StdRng::from_seed(seed) } + +/// Should be used only for tests, not for any real world usage. +#[cfg(not(feature = "std"))] +pub fn test_rng() -> impl rand::Rng { + test_rng_helper() +} + +/// Should be used only for tests, not for any real world usage. +#[cfg(feature = "std")] +pub fn test_rng() -> impl rand::Rng { + let is_deterministic = + std::env::vars().any(|(key, val)| key == "DETERMINISTIC_TEST_RNG" && val == "1"); + if is_deterministic { + RngWrapper::Deterministic(test_rng_helper()) + } else { + RngWrapper::Randomized(rand::thread_rng()) + } +} + +/// Helper wrapper to enable `test_rng` to return `impl::Rng`. +#[cfg(feature = "std")] +enum RngWrapper { + Deterministic(StdRng), + Randomized(ThreadRng), +} + +#[cfg(feature = "std")] +impl RngCore for RngWrapper { + #[inline(always)] + fn next_u32(&mut self) -> u32 { + match self { + Self::Deterministic(rng) => rng.next_u32(), + Self::Randomized(rng) => rng.next_u32(), + } + } + + #[inline(always)] + fn next_u64(&mut self) -> u64 { + match self { + Self::Deterministic(rng) => rng.next_u64(), + Self::Randomized(rng) => rng.next_u64(), + } + } + + #[inline(always)] + fn fill_bytes(&mut self, dest: &mut [u8]) { + match self { + Self::Deterministic(rng) => rng.fill_bytes(dest), + Self::Randomized(rng) => rng.fill_bytes(dest), + } + } + + #[inline(always)] + fn try_fill_bytes(&mut self, dest: &mut [u8]) -> Result<(), rand::Error> { + match self { + Self::Deterministic(rng) => rng.try_fill_bytes(dest), + Self::Randomized(rng) => rng.try_fill_bytes(dest), + } + } +} + +#[cfg(all(test, feature = "std"))] +mod test { + #[test] + fn test_deterministic_rng() { + use super::*; + + let mut rng = super::test_rng(); + let a = u128::rand(&mut rng); + + // Reset the rng by sampling a new one. + let mut rng = super::test_rng(); + let b = u128::rand(&mut rng); + assert_ne!(a, b); // should be unequal with high probability. + + // Let's make the rng deterministic. + std::env::set_var("DETERMINISTIC_TEST_RNG", "1"); + let mut rng = super::test_rng(); + let a = u128::rand(&mut rng); + + // Reset the rng by sampling a new one. + let mut rng = super::test_rng(); + let b = u128::rand(&mut rng); + assert_eq!(a, b); // should be unequal with high probability. + } +}