From 8828869632a98fc70f1d83ba7bbe578521a32c11 Mon Sep 17 00:00:00 2001 From: saharNooby Date: Sun, 12 Nov 2023 14:42:21 +0400 Subject: [PATCH] Add late_abort option for tests --- tests/assertions.inc | 9 ++++++++- tests/logit_difference_validator.inc | 2 -- tests/test_tiny_rwkv.c | 8 +++++++- 3 files changed, 15 insertions(+), 4 deletions(-) diff --git a/tests/assertions.inc b/tests/assertions.inc index df5ba64..25127ae 100644 --- a/tests/assertions.inc +++ b/tests/assertions.inc @@ -3,12 +3,19 @@ #include +bool late_abort = false; +bool must_abort = false; + #define ASSERT(x, ...) {\ if (!(x)) {\ fprintf(stderr, "*** Assertion failed ***\n");\ fprintf(stderr, __VA_ARGS__);\ fprintf(stderr, "\n%s:%d\n", __FILE__, __LINE__);\ - abort();\ + if (late_abort) {\ + must_abort = true;\ + } else {\ + abort();\ + }\ }\ } diff --git a/tests/logit_difference_validator.inc b/tests/logit_difference_validator.inc index 8504261..5b83732 100644 --- a/tests/logit_difference_validator.inc +++ b/tests/logit_difference_validator.inc @@ -67,7 +67,6 @@ void test_model(const char * version, const char * format, const float * expecte fprintf(stderr, "Serial difference sum: %f, expected %f\n", diff_sum, max_diff); - // When something breaks, difference would be way more than 10 ASSERT(fabsf(diff_sum) <= fabsf(max_diff) + 0.01F, "Too big serial difference %f, expected no more than %f", (double) diff_sum, (double) max_diff); // --- @@ -83,7 +82,6 @@ void test_model(const char * version, const char * format, const float * expecte fprintf(stderr, "Sequence difference sum: %f, expected %f\n", diff_sum, max_diff); - // When something breaks, difference would be way more than 10 ASSERT(fabsf(diff_sum) <= fabsf(max_diff) + 0.01F, "Too big sequence difference %f, expected no more than %f", (double) diff_sum, (double) max_diff); // --- diff --git a/tests/test_tiny_rwkv.c b/tests/test_tiny_rwkv.c index 196a250..e631938 100644 --- a/tests/test_tiny_rwkv.c +++ b/tests/test_tiny_rwkv.c @@ -10,6 +10,8 @@ #define FORMAT_COUNT 7 int main(void) { + late_abort = true; + fprintf(stderr, "System info: %s\n", rwkv_get_system_info_string()); // Silences the overly verbose output during quantization. @@ -87,7 +89,7 @@ int main(void) { +000.065571F, // Q8_0 // 5v1 +119.471931F, // Q4_0 - -027.862976F, // Q4_1 + -008.245888F, // Q4_1 -159.870956F, // Q5_0 -039.117004F, // Q5_1 -000.962695F, // Q8_0 @@ -138,5 +140,9 @@ int main(void) { free(expected_logits); } + if (must_abort) { + abort(); + } + return 0; }