Skip to content

Commit

Permalink
Add test for on-the-fly quantization
Browse files Browse the repository at this point in the history
  • Loading branch information
saharNooby committed Jun 14, 2023
1 parent d3b6749 commit c49d3d8
Show file tree
Hide file tree
Showing 3 changed files with 97 additions and 3 deletions.
7 changes: 4 additions & 3 deletions tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ file(COPY tiny-rwkv-660K-FP32.bin DESTINATION ${CMAKE_CURRENT_BINARY_DIR})
file(COPY tiny-rwkv-660K-FP16.bin DESTINATION ${CMAKE_CURRENT_BINARY_DIR})
file(COPY expected_logits.bin DESTINATION ${CMAKE_CURRENT_BINARY_DIR})

rwkv_add_test(test_ggml_basics.c)
rwkv_add_test(test_tiny_rwkv.c)
rwkv_add_test(test_context_cloning.c)
file(GLOB tests *.c)
foreach (test ${tests})
rwkv_add_test(${test})
endforeach()
2 changes: 2 additions & 0 deletions tests/test_context_cloning.c
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
// Tests that after context cloning evaluation gives identical results.

#include "rwkv.h"

#include <stdlib.h>
Expand Down
91 changes: 91 additions & 0 deletions tests/test_quantization_on_the_fly.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
// Tests that results from on-the-fly quantized model are identical with results of pre-quantized model.

#include "ggml.h"
#include "rwkv.h"

#include <stdio.h>
#include <stdlib.h>
#include <string.h>

#define N_THREADS 2

int main(void) {
rwkv_quantize_model_file("tiny-rwkv-660K-FP32.bin", "tiny-rwkv-660K-FP32-Q5_1.bin", "Q5_1");

struct rwkv_context * prequantized_ctx = rwkv_init_from_file("tiny-rwkv-660K-FP32-Q5_1.bin", N_THREADS);

if (!prequantized_ctx) {
enum rwkv_error_flags error = rwkv_get_last_error(NULL);
fprintf(stderr, "Unexpected error 0x%.8X\n", error);
return EXIT_FAILURE;
}

// ---

struct rwkv_init_from_file_option option = {RWKV_INIT_FROM_FILE_OPTION_TARGET_FORMAT_NAME, "Q5_1"};

struct rwkv_context * on_the_fly_quantized_ctx = rwkv_init_from_file_ex("tiny-rwkv-660K-FP32.bin", N_THREADS, &option, 1);

if (!on_the_fly_quantized_ctx) {
enum rwkv_error_flags error = rwkv_get_last_error(NULL);
fprintf(stderr, "Unexpected error 0x%.8X\n", error);
return EXIT_FAILURE;
}

// ---

float * state = calloc(rwkv_get_state_len(prequantized_ctx), sizeof(float));

if (!state) {
fprintf(stderr, "Failed to allocate state\n");
return EXIT_FAILURE;
}

float * expected_logits = calloc(rwkv_get_logits_len(prequantized_ctx), sizeof(float));

if (!expected_logits) {
fprintf(stderr, "Failed to allocate logits\n");
return EXIT_FAILURE;
}

const unsigned char prompt[12] = "hello world";

rwkv_eval(prequantized_ctx, prompt[0], NULL, state, expected_logits);

for (int i = 1; prompt[i] != 0; i++) {
rwkv_eval(prequantized_ctx, prompt[i], state, state, expected_logits);
}

// ---

float * actual_logits = calloc(rwkv_get_logits_len(on_the_fly_quantized_ctx), sizeof(float));

if (!actual_logits) {
fprintf(stderr, "Failed to allocate logits\n");
return EXIT_FAILURE;
}

rwkv_eval(on_the_fly_quantized_ctx, prompt[0], NULL, state, actual_logits);

for (int i = 1; prompt[i] != 0; i++) {
rwkv_eval(on_the_fly_quantized_ctx, prompt[i], state, state, actual_logits);
}

// ---

if (memcmp(expected_logits, actual_logits, rwkv_get_logits_len(on_the_fly_quantized_ctx) * sizeof(float))) {
fprintf(stderr, "Results not identical :(\n");
return EXIT_FAILURE;
} else {
fprintf(stdout, "Results identical, success!\n");
}

rwkv_free(on_the_fly_quantized_ctx);
rwkv_free(prequantized_ctx);

free(expected_logits);
free(actual_logits);
free(state);

return 0;
}

0 comments on commit c49d3d8

Please sign in to comment.