Skip to content

Commit

Permalink
Reduce amount of computation in tine RWKV test
Browse files Browse the repository at this point in the history
  • Loading branch information
saharNooby committed Apr 19, 2023
1 parent 958f53d commit 43d9ced
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 11 deletions.
Binary file modified tests/expected_logits.bin
Binary file not shown.
35 changes: 24 additions & 11 deletions tests/test_tiny_rwkv.c
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
// ---

#define N_VOCAB 256
#define N_THREADS 4
#define N_THREADS 2

void test_model(const char * model_path, const float * expected_logits, const float max_diff) {
fprintf(stderr, "Testing %s\n", model_path);
Expand All @@ -34,7 +34,7 @@ void test_model(const char * model_path, const float * expected_logits, const fl
float * state = malloc(sizeof(float) * rwkv_get_state_buffer_element_count(model));
float * logits = malloc(sizeof(float) * n_vocab);

char * prompt = "Describe the structure of an atom.";
char * prompt = "\"in";

const size_t prompt_length = strlen(prompt);

Expand All @@ -50,7 +50,7 @@ void test_model(const char * model_path, const float * expected_logits, const fl

fprintf(stderr, "Difference sum: %f\n", diff_sum);

ASSERT(fabsf(diff_sum) <= fabsf(max_diff) + 0.000001F, "Too big difference %f, expected no more than %f", diff_sum, max_diff);
ASSERT(fabsf(diff_sum) <= fabsf(max_diff) + 0.00001F, "Too big difference %f, expected no more than %f", diff_sum, max_diff);

rwkv_free(model);

Expand All @@ -68,24 +68,37 @@ int main(int argc, const char ** argv) {
ASSERT(elements_read == N_VOCAB, "Failed to read expected_logits.bin, read %zd elements", elements_read);
fclose(file);

test_model("tiny-rwkv-660K-FP32.bin", expected_logits, -0.000002F);
test_model("tiny-rwkv-660K-FP16.bin", expected_logits, -0.002430F);
float expected_difference_sum[8] = {
0.000000F,
-0.005320F,

-0.501214F,
-1.092427F,
-0.268956F,

-0.501073F,
-1.103214F,
-0.244590F
};

test_model("tiny-rwkv-660K-FP32.bin", expected_logits, expected_difference_sum[0]);
test_model("tiny-rwkv-660K-FP16.bin", expected_logits, expected_difference_sum[1]);

rwkv_quantize_model_file("tiny-rwkv-660K-FP32.bin", "tiny-rwkv-660K-FP32-Q4_0.bin", 2);
rwkv_quantize_model_file("tiny-rwkv-660K-FP32.bin", "tiny-rwkv-660K-FP32-Q4_1.bin", 3);
rwkv_quantize_model_file("tiny-rwkv-660K-FP32.bin", "tiny-rwkv-660K-FP32-Q4_1_O.bin", 4);

test_model("tiny-rwkv-660K-FP32-Q4_0.bin", expected_logits, -0.038045F);
test_model("tiny-rwkv-660K-FP32-Q4_1.bin", expected_logits, -0.468718F);
test_model("tiny-rwkv-660K-FP32-Q4_1_O.bin", expected_logits, -0.085120F);
test_model("tiny-rwkv-660K-FP32-Q4_0.bin", expected_logits, expected_difference_sum[2]);
test_model("tiny-rwkv-660K-FP32-Q4_1.bin", expected_logits, expected_difference_sum[3]);
test_model("tiny-rwkv-660K-FP32-Q4_1_O.bin", expected_logits, expected_difference_sum[4]);

rwkv_quantize_model_file("tiny-rwkv-660K-FP16.bin", "tiny-rwkv-660K-FP16-Q4_0.bin", 2);
rwkv_quantize_model_file("tiny-rwkv-660K-FP16.bin", "tiny-rwkv-660K-FP16-Q4_1.bin", 3);
rwkv_quantize_model_file("tiny-rwkv-660K-FP16.bin", "tiny-rwkv-660K-FP16-Q4_1_O.bin", 4);

test_model("tiny-rwkv-660K-FP16-Q4_0.bin", expected_logits, -0.034945F);
test_model("tiny-rwkv-660K-FP16-Q4_1.bin", expected_logits, -0.483789F);
test_model("tiny-rwkv-660K-FP16-Q4_1_O.bin", expected_logits, -0.083739F);
test_model("tiny-rwkv-660K-FP16-Q4_0.bin", expected_logits, expected_difference_sum[5]);
test_model("tiny-rwkv-660K-FP16-Q4_1.bin", expected_logits, expected_difference_sum[6]);
test_model("tiny-rwkv-660K-FP16-Q4_1_O.bin", expected_logits, expected_difference_sum[7]);

free(expected_logits);

Expand Down

0 comments on commit 43d9ced

Please sign in to comment.